llm/chain/
mod.rs

1mod multi;
2
3use crate::{error::LLMError, LLMProvider};
4use std::collections::HashMap;
5
6pub use multi::{
7    LLMRegistry, LLMRegistryBuilder, MultiChainStepBuilder, MultiChainStepMode, MultiPromptChain,
8};
9
10/// Execution mode for a chain step
11#[derive(Debug, Clone)]
12pub enum ChainStepMode {
13    /// Execute step using chat completion
14    Chat,
15    /// Execute step using text completion
16    Completion,
17}
18
19/// Represents a single step in a prompt chain
20#[derive(Debug, Clone)]
21pub struct ChainStep {
22    /// Unique identifier for this step
23    pub id: String,
24    /// Prompt template with {{variable}} placeholders
25    pub template: String,
26    /// Execution mode (chat or completion)
27    pub mode: ChainStepMode,
28    /// Optional temperature parameter (0.0-1.0) controlling randomness
29    pub temperature: Option<f32>,
30    /// Optional maximum tokens to generate in response
31    pub max_tokens: Option<u32>,
32    /// Optional top_p parameter for nucleus sampling
33    pub top_p: Option<f32>,
34}
35
36/// Builder pattern for constructing ChainStep instances
37pub struct ChainStepBuilder {
38    id: String,
39    template: String,
40    mode: ChainStepMode,
41    temperature: Option<f32>,
42    max_tokens: Option<u32>,
43    top_p: Option<f32>,
44    top_k: Option<u32>,
45}
46
47impl ChainStepBuilder {
48    /// Creates a new ChainStepBuilder
49    ///
50    /// # Arguments
51    /// * `id` - Unique identifier for the step
52    /// * `template` - Prompt template with {{variable}} placeholders
53    /// * `mode` - Execution mode (chat or completion)
54    pub fn new(id: impl Into<String>, template: impl Into<String>, mode: ChainStepMode) -> Self {
55        Self {
56            id: id.into(),
57            template: template.into(),
58            mode,
59            temperature: None,
60            max_tokens: None,
61            top_p: None,
62            top_k: None,
63        }
64    }
65
66    /// Sets the temperature parameter
67    pub fn temperature(mut self, temp: f32) -> Self {
68        self.temperature = Some(temp);
69        self
70    }
71
72    /// Sets the maximum tokens parameter
73    pub fn max_tokens(mut self, mt: u32) -> Self {
74        self.max_tokens = Some(mt);
75        self
76    }
77
78    /// Sets the top_p parameter
79    pub fn top_p(mut self, val: f32) -> Self {
80        self.top_p = Some(val);
81        self
82    }
83
84    /// Sets the top_k parameter
85    pub fn top_k(mut self, val: u32) -> Self {
86        self.top_k = Some(val);
87        self
88    }
89
90    /// Builds and returns a ChainStep instance
91    pub fn build(self) -> ChainStep {
92        ChainStep {
93            id: self.id,
94            template: self.template,
95            mode: self.mode,
96            temperature: self.temperature,
97            max_tokens: self.max_tokens,
98            top_p: self.top_p,
99        }
100    }
101}
102
103/// Manages a sequence of prompt steps with variable substitution
104pub struct PromptChain<'a> {
105    llm: &'a dyn LLMProvider,
106    steps: Vec<ChainStep>,
107    memory: HashMap<String, String>,
108}
109
110impl<'a> PromptChain<'a> {
111    /// Creates a new PromptChain with the given LLM provider
112    pub fn new(llm: &'a dyn LLMProvider) -> Self {
113        Self {
114            llm,
115            steps: Vec::new(),
116            memory: HashMap::new(),
117        }
118    }
119
120    /// Adds a step to the chain
121    pub fn step(mut self, step: ChainStep) -> Self {
122        self.steps.push(step);
123        self
124    }
125
126    /// Executes all steps in the chain and returns the results
127    pub async fn run(mut self) -> Result<HashMap<String, String>, LLMError> {
128        for step in &self.steps {
129            let prompt = self.apply_template(&step.template);
130
131            let response_text = match step.mode {
132                ChainStepMode::Chat => {
133                    let messages = vec![crate::chat::ChatMessage {
134                        role: crate::chat::ChatRole::User,
135                        message_type: crate::chat::MessageType::Text,
136                        content: prompt,
137                    }];
138                    self.llm.chat(&messages).await?
139                }
140                ChainStepMode::Completion => {
141                    let mut req = crate::completion::CompletionRequest::new(prompt);
142                    req.max_tokens = step.max_tokens;
143                    req.temperature = step.temperature;
144                    let resp = self.llm.complete(&req).await?;
145                    Box::new(resp)
146                }
147            };
148
149            self.memory
150                .insert(step.id.clone(), response_text.text().unwrap_or_default());
151        }
152
153        Ok(self.memory)
154    }
155
156    /// Replaces {{variable}} placeholders in template with values from memory
157    fn apply_template(&self, input: &str) -> String {
158        let mut result = input.to_string();
159        for (k, v) in &self.memory {
160            let pattern = format!("{{{{{}}}}}", k);
161            result = result.replace(&pattern, v);
162        }
163        result
164    }
165}