1mod multi;
2
3use crate::{error::LLMError, LLMProvider};
4use std::collections::HashMap;
5
6pub use multi::{
7 LLMRegistry, LLMRegistryBuilder, MultiChainStepBuilder, MultiChainStepMode, MultiPromptChain,
8};
9
10#[derive(Debug, Clone)]
12pub enum ChainStepMode {
13 Chat,
15 Completion,
17}
18
19#[derive(Debug, Clone)]
21pub struct ChainStep {
22 pub id: String,
24 pub template: String,
26 pub mode: ChainStepMode,
28 pub temperature: Option<f32>,
30 pub max_tokens: Option<u32>,
32 pub top_p: Option<f32>,
34}
35
36pub struct ChainStepBuilder {
38 id: String,
39 template: String,
40 mode: ChainStepMode,
41 temperature: Option<f32>,
42 max_tokens: Option<u32>,
43 top_p: Option<f32>,
44 top_k: Option<u32>,
45}
46
47impl ChainStepBuilder {
48 pub fn new(id: impl Into<String>, template: impl Into<String>, mode: ChainStepMode) -> Self {
55 Self {
56 id: id.into(),
57 template: template.into(),
58 mode,
59 temperature: None,
60 max_tokens: None,
61 top_p: None,
62 top_k: None,
63 }
64 }
65
66 pub fn temperature(mut self, temp: f32) -> Self {
68 self.temperature = Some(temp);
69 self
70 }
71
72 pub fn max_tokens(mut self, mt: u32) -> Self {
74 self.max_tokens = Some(mt);
75 self
76 }
77
78 pub fn top_p(mut self, val: f32) -> Self {
80 self.top_p = Some(val);
81 self
82 }
83
84 pub fn top_k(mut self, val: u32) -> Self {
86 self.top_k = Some(val);
87 self
88 }
89
90 pub fn build(self) -> ChainStep {
92 ChainStep {
93 id: self.id,
94 template: self.template,
95 mode: self.mode,
96 temperature: self.temperature,
97 max_tokens: self.max_tokens,
98 top_p: self.top_p,
99 }
100 }
101}
102
103pub struct PromptChain<'a> {
105 llm: &'a dyn LLMProvider,
106 steps: Vec<ChainStep>,
107 memory: HashMap<String, String>,
108}
109
110impl<'a> PromptChain<'a> {
111 pub fn new(llm: &'a dyn LLMProvider) -> Self {
113 Self {
114 llm,
115 steps: Vec::new(),
116 memory: HashMap::new(),
117 }
118 }
119
120 pub fn step(mut self, step: ChainStep) -> Self {
122 self.steps.push(step);
123 self
124 }
125
126 pub async fn run(mut self) -> Result<HashMap<String, String>, LLMError> {
128 for step in &self.steps {
129 let prompt = self.apply_template(&step.template);
130
131 let response_text = match step.mode {
132 ChainStepMode::Chat => {
133 let messages = vec![crate::chat::ChatMessage {
134 role: crate::chat::ChatRole::User,
135 message_type: crate::chat::MessageType::Text,
136 content: prompt,
137 }];
138 self.llm.chat(&messages).await?
139 }
140 ChainStepMode::Completion => {
141 let mut req = crate::completion::CompletionRequest::new(prompt);
142 req.max_tokens = step.max_tokens;
143 req.temperature = step.temperature;
144 let resp = self.llm.complete(&req).await?;
145 Box::new(resp)
146 }
147 };
148
149 self.memory
150 .insert(step.id.clone(), response_text.text().unwrap_or_default());
151 }
152
153 Ok(self.memory)
154 }
155
156 fn apply_template(&self, input: &str) -> String {
158 let mut result = input.to_string();
159 for (k, v) in &self.memory {
160 let pattern = format!("{{{{{}}}}}", k);
161 result = result.replace(&pattern, v);
162 }
163 result
164 }
165}