use async_trait::async_trait;
use std::collections::HashMap;
use serde_json::Value;
use super::base::{BaseChain, ChainResult, ChainError};
use crate::language_models::OpenAIChat;
use crate::schema::Message;
use crate::Runnable;
pub struct LLMChain {
llm: OpenAIChat,
prompt_template: String,
input_key: String,
output_key: String,
name: String,
}
impl LLMChain {
pub fn new(llm: OpenAIChat, prompt_template: impl Into<String>) -> Self {
Self {
llm,
prompt_template: prompt_template.into(),
input_key: "question".to_string(),
output_key: "text".to_string(),
name: "llm_chain".to_string(),
}
}
pub fn with_input_key(mut self, key: impl Into<String>) -> Self {
self.input_key = key.into();
self
}
pub fn with_output_key(mut self, key: impl Into<String>) -> Self {
self.output_key = key.into();
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
fn render_prompt(&self, inputs: &HashMap<String, Value>) -> Result<String, ChainError> {
let mut prompt = self.prompt_template.clone();
for (key, value) in inputs {
let placeholder = format!("{{{}}}", key);
let value_str = match value {
Value::String(s) => s.clone(),
_ => value.to_string(),
};
prompt = prompt.replace(&placeholder, &value_str);
}
Ok(prompt)
}
}
#[async_trait]
impl BaseChain for LLMChain {
fn input_keys(&self) -> Vec<&str> {
vec![&self.input_key]
}
fn output_keys(&self) -> Vec<&str> {
vec![&self.output_key]
}
async fn invoke(&self, inputs: HashMap<String, Value>) -> Result<ChainResult, ChainError> {
self.validate_inputs(&inputs)?;
let prompt = self.render_prompt(&inputs)?;
let messages = vec![Message::human(&prompt)];
let result = self.llm.invoke(messages, None).await
.map_err(|e| ChainError::ExecutionError(format!("LLM 调用失败: {}", e)))?;
let mut output = HashMap::new();
output.insert(self.output_key.clone(), Value::String(result.content));
Ok(output)
}
fn name(&self) -> &str {
&self.name
}
}
pub struct LLMChainBuilder {
llm: OpenAIChat,
prompt_template: String,
input_key: Option<String>,
output_key: Option<String>,
name: Option<String>,
}
impl LLMChainBuilder {
pub fn new(llm: OpenAIChat, prompt_template: impl Into<String>) -> Self {
Self {
llm,
prompt_template: prompt_template.into(),
input_key: None,
output_key: None,
name: None,
}
}
pub fn input_key(mut self, key: impl Into<String>) -> Self {
self.input_key = Some(key.into());
self
}
pub fn output_key(mut self, key: impl Into<String>) -> Self {
self.output_key = Some(key.into());
self
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn build(self) -> LLMChain {
let mut chain = LLMChain::new(self.llm, self.prompt_template);
if let Some(key) = self.input_key {
chain = chain.with_input_key(key);
}
if let Some(key) = self.output_key {
chain = chain.with_output_key(key);
}
if let Some(name) = self.name {
chain = chain.with_name(name);
}
chain
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OpenAIConfig;
fn create_test_config() -> OpenAIConfig {
OpenAIConfig {
api_key: "sk-l0YYMX65mCYRlTJYH0ptf4BFpqJwm8Xo9Z5IMqSZD0yOafl6".to_string(),
base_url: "https://api.openai-proxy.org/v1".to_string(),
model: "gpt-3.5-turbo".to_string(),
streaming: false,
organization: None,
frequency_penalty: None,
max_tokens: None,
presence_penalty: None,
temperature: None,
top_p: None,
tools: None,
tool_choice: None,
}
}
#[test]
fn test_render_prompt() {
let llm = OpenAIChat::new(create_test_config());
let chain = LLMChain::new(llm, "问题: {question}");
let inputs = HashMap::from([
("question".to_string(), Value::String("什么是Rust?".to_string()))
]);
let prompt = chain.render_prompt(&inputs).unwrap();
assert_eq!(prompt, "问题: 什么是Rust?");
}
#[test]
fn test_render_prompt_multiple_vars() {
let llm = OpenAIChat::new(create_test_config());
let chain = LLMChain::new(llm, "名字: {name}, 年龄: {age}");
let inputs = HashMap::from([
("name".to_string(), Value::String("张三".to_string())),
("age".to_string(), Value::Number(25.into())),
]);
let prompt = chain.render_prompt(&inputs).unwrap();
assert_eq!(prompt, "名字: 张三, 年龄: 25");
}
#[tokio::test]
#[ignore]
async fn test_llm_chain_simple() {
let llm = OpenAIChat::new(create_test_config());
let chain = LLMChain::new(llm, "请用一句话回答: {question}");
let inputs = HashMap::from([
("question".to_string(), Value::String("什么是 Rust?".to_string()))
]);
println!("\n=== 测试 LLMChain - 简单问题 ===");
let result = chain.invoke(inputs).await.unwrap();
println!("输出: {:?}", result);
assert!(result.contains_key("text"));
assert!(!result.get("text").unwrap().as_str().unwrap().is_empty());
}
#[tokio::test]
#[ignore]
async fn test_llm_chain_template() {
let llm = OpenAIChat::new(create_test_config());
let chain = LLMChain::new(llm,
"请用{style}的语气回答问题: {question}"
);
let inputs = HashMap::from([
("style".to_string(), Value::String("幽默".to_string())),
("question".to_string(), Value::String("什么是编程?".to_string()))
]);
println!("\n=== 测试 LLMChain - 多变量模板 ===");
let result = chain.invoke(inputs).await.unwrap();
println!("输出: {:?}", result);
assert!(result.contains_key("text"));
}
#[tokio::test]
#[ignore]
async fn test_llm_chain_builder() {
let llm = OpenAIChat::new(create_test_config());
let chain = LLMChainBuilder::new(llm, "翻译以下内容到{language}: {text}")
.input_key("text")
.output_key("translation")
.name("translator")
.build();
let inputs = HashMap::from([
("language".to_string(), Value::String("英文".to_string())),
("text".to_string(), Value::String("你好,世界".to_string()))
]);
println!("\n=== 测试 LLMChain - Builder ===");
let result = chain.invoke(inputs).await.unwrap();
println!("输出: {:?}", result);
assert!(result.contains_key("translation"));
}
}