mod multi;
use crate::{error::LLMError, LLMProvider};
use std::collections::HashMap;
pub use multi::{
LLMRegistry, LLMRegistryBuilder, MultiChainStep, MultiChainStepBuilder, MultiChainStepMode,
MultiPromptChain,
};
#[derive(Debug, Clone)]
pub enum ChainStepMode {
Chat,
Completion,
}
#[derive(Debug, Clone)]
pub struct ChainStep {
pub id: String,
pub template: String,
pub mode: ChainStepMode,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub top_p: Option<f32>,
}
pub struct ChainStepBuilder {
id: String,
template: String,
mode: ChainStepMode,
temperature: Option<f32>,
max_tokens: Option<u32>,
top_p: Option<f32>,
top_k: Option<u32>,
}
impl ChainStepBuilder {
pub fn new(id: impl Into<String>, template: impl Into<String>, mode: ChainStepMode) -> Self {
Self {
id: id.into(),
template: template.into(),
mode,
temperature: None,
max_tokens: None,
top_p: None,
top_k: None,
}
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn max_tokens(mut self, mt: u32) -> Self {
self.max_tokens = Some(mt);
self
}
pub fn top_p(mut self, val: f32) -> Self {
self.top_p = Some(val);
self
}
pub fn top_k(mut self, val: u32) -> Self {
self.top_k = Some(val);
self
}
pub fn build(self) -> ChainStep {
ChainStep {
id: self.id,
template: self.template,
mode: self.mode,
temperature: self.temperature,
max_tokens: self.max_tokens,
top_p: self.top_p,
}
}
}
pub struct PromptChain<'a> {
llm: &'a dyn LLMProvider,
steps: Vec<ChainStep>,
memory: HashMap<String, String>,
}
impl<'a> PromptChain<'a> {
pub fn new(llm: &'a dyn LLMProvider) -> Self {
Self {
llm,
steps: Vec::new(),
memory: HashMap::new(),
}
}
pub fn step(mut self, step: ChainStep) -> Self {
self.steps.push(step);
self
}
pub async fn run(mut self) -> Result<HashMap<String, String>, LLMError> {
for step in &self.steps {
let prompt = self.apply_template(&step.template);
let response_text = match step.mode {
ChainStepMode::Chat => {
let messages = vec![crate::chat::ChatMessage {
role: crate::chat::ChatRole::User,
message_type: crate::chat::MessageType::Text,
content: prompt,
}];
self.llm.chat(&messages).await?
}
ChainStepMode::Completion => {
let mut req = crate::completion::CompletionRequest::new(prompt);
req.max_tokens = step.max_tokens;
req.temperature = step.temperature;
let resp = self.llm.complete(&req).await?;
Box::new(resp)
}
};
self.memory
.insert(step.id.clone(), response_text.text().unwrap_or_default());
}
Ok(self.memory)
}
fn apply_template(&self, input: &str) -> String {
let mut result = input.to_string();
for (k, v) in &self.memory {
let pattern = format!("{{{{{k}}}}}");
result = result.replace(&pattern, v);
}
result
}
}