use async_trait::async_trait;
use std::collections::HashMap;
use serde_json::Value;
use std::sync::Arc;
use super::base::{BaseChain, ChainResult, ChainError};
pub struct SequentialChain {
chains: Vec<ChainStep>,
name: String,
}
struct ChainStep {
chain: Arc<dyn BaseChain>,
input_mapping: HashMap<String, String>,
output_mapping: HashMap<String, String>,
}
impl SequentialChain {
pub fn new() -> Self {
Self {
chains: Vec::new(),
name: "sequential_chain".to_string(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn add_chain(
mut self,
chain: Arc<dyn BaseChain>,
input_keys: Vec<&str>,
output_keys: Vec<&str>,
) -> Self {
let input_mapping = input_keys
.into_iter()
.map(|k| (k.to_string(), k.to_string()))
.collect();
let output_mapping = output_keys
.into_iter()
.map(|k| (k.to_string(), k.to_string()))
.collect();
self.chains.push(ChainStep {
chain,
input_mapping,
output_mapping,
});
self
}
pub fn add_chain_with_mapping(
mut self,
chain: Arc<dyn BaseChain>,
input_mapping: HashMap<String, String>,
output_mapping: HashMap<String, String>,
) -> Self {
self.chains.push(ChainStep {
chain,
input_mapping,
output_mapping,
});
self
}
}
impl Default for SequentialChain {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BaseChain for SequentialChain {
fn input_keys(&self) -> Vec<&str> {
if let Some(first) = self.chains.first() {
first.input_mapping.values().map(|s| s.as_str()).collect()
} else {
vec![]
}
}
fn output_keys(&self) -> Vec<&str> {
if let Some(last) = self.chains.last() {
last.output_mapping.values().map(|s| s.as_str()).collect()
} else {
vec![]
}
}
async fn invoke(&self, inputs: HashMap<String, Value>) -> Result<ChainResult, ChainError> {
let mut current_state = inputs.clone();
let mut final_output = HashMap::new();
for (step_index, step) in self.chains.iter().enumerate() {
let mut chain_inputs = HashMap::new();
for (chain_key, global_key) in &step.input_mapping {
if let Some(value) = current_state.get(global_key) {
chain_inputs.insert(chain_key.clone(), value.clone());
} else {
return Err(ChainError::MissingInput(format!(
"Step {}: 缺少输入 '{}' (映射自 '{}')",
step_index, chain_key, global_key
)));
}
}
let chain_output = step.chain.invoke(chain_inputs).await.map_err(|e| {
ChainError::ExecutionError(format!("Step {} ({}) 执行失败: {}", step_index, step.chain.name(), e))
})?;
for (chain_key, global_key) in &step.output_mapping {
if let Some(value) = chain_output.get(chain_key) {
current_state.insert(global_key.clone(), value.clone());
final_output.insert(global_key.clone(), value.clone());
}
}
}
Ok(final_output)
}
fn name(&self) -> &str {
&self.name
}
}
impl std::fmt::Debug for SequentialChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SequentialChain")
.field("steps", &self.chains.len())
.field("name", &self.name)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{LLMChain, OpenAIConfig, OpenAIChat};
fn create_test_config() -> OpenAIConfig {
OpenAIConfig {
api_key: "sk-l0YYMX65mCYRlTJYH0ptf4BFpqJwm8Xo9Z5IMqSZD0yOafl6".to_string(),
base_url: "https://api.openai-proxy.org/v1".to_string(),
model: "gpt-3.5-turbo".to_string(),
streaming: false,
organization: None,
frequency_penalty: None,
max_tokens: None,
presence_penalty: None,
temperature: None,
top_p: None,
tools: None,
tool_choice: None,
}
}
#[tokio::test]
async fn test_sequential_chain_mock() {
struct MockChain {
name: String,
input_key: String,
output_key: String,
transform: fn(&str) -> String,
}
#[async_trait]
impl BaseChain for MockChain {
fn input_keys(&self) -> Vec<&str> {
vec![&self.input_key]
}
fn output_keys(&self) -> Vec<&str> {
vec![&self.output_key]
}
async fn invoke(&self, inputs: HashMap<String, Value>) -> Result<ChainResult, ChainError> {
let input = inputs.get(&self.input_key)
.and_then(|v| v.as_str())
.ok_or_else(|| ChainError::MissingInput(self.input_key.clone()))?;
let output = (self.transform)(input);
let mut result = HashMap::new();
result.insert(self.output_key.clone(), Value::String(output));
Ok(result)
}
fn name(&self) -> &str {
&self.name
}
}
let chain1 = Arc::new(MockChain {
name: "uppercase".to_string(),
input_key: "text".to_string(),
output_key: "upper".to_string(),
transform: |s| s.to_uppercase(),
});
let chain2 = Arc::new(MockChain {
name: "reverse".to_string(),
input_key: "upper".to_string(),
output_key: "result".to_string(),
transform: |s| s.chars().rev().collect(),
});
let seq_chain = SequentialChain::new()
.add_chain(chain1, vec!["text"], vec!["upper"])
.add_chain(chain2, vec!["upper"], vec!["result"]);
let inputs = HashMap::from([
("text".to_string(), Value::String("hello".to_string()))
]);
let result = seq_chain.invoke(inputs).await.unwrap();
assert_eq!(result.get("result").unwrap().as_str().unwrap(), "OLLEH");
}
#[tokio::test]
#[ignore]
async fn test_sequential_chain_real() {
let llm1 = OpenAIChat::new(create_test_config());
let llm2 = OpenAIChat::new(create_test_config());
let chain1 = Arc::new(
LLMChain::new(llm1, "只回复一个与'{topic}'相关的词语,不要其他内容")
.with_input_key("topic")
.with_output_key("word")
);
let chain2 = Arc::new(
LLMChain::new(llm2, "用词语'{word}'造一个简单的句子")
.with_input_key("word")
.with_output_key("sentence")
);
let seq_chain = SequentialChain::new()
.add_chain(chain1, vec!["topic"], vec!["word"])
.add_chain(chain2, vec!["word"], vec!["sentence"]);
let inputs = HashMap::from([
("topic".to_string(), Value::String("编程".to_string()))
]);
println!("\n=== 测试 SequentialChain - 两步 LLM ===");
let result = seq_chain.invoke(inputs).await.unwrap();
println!("生成的词: {:?}", result.get("word"));
println!("造句: {:?}", result.get("sentence"));
assert!(result.contains_key("word"));
assert!(result.contains_key("sentence"));
}
#[tokio::test]
#[ignore]
async fn test_sequential_chain_three_steps() {
let llm1 = OpenAIChat::new(create_test_config());
let llm2 = OpenAIChat::new(create_test_config());
let llm3 = OpenAIChat::new(create_test_config());
let chain1 = Arc::new(
LLMChain::new(llm1, "从以下文本提取一个关键词,只回复关键词: {text}")
.with_input_key("text")
.with_output_key("keyword")
);
let chain2 = Arc::new(
LLMChain::new(llm2, "用一句话解释'{keyword}'是什么")
.with_input_key("keyword")
.with_output_key("explanation")
);
let chain3 = Arc::new(
LLMChain::new(llm3, "为'{keyword}'生成一个简单示例")
.with_input_key("keyword")
.with_output_key("example")
);
let seq_chain = SequentialChain::new()
.add_chain(chain1, vec!["text"], vec!["keyword"])
.add_chain(chain2, vec!["keyword"], vec!["explanation"])
.add_chain(chain3, vec!["keyword"], vec!["example"]);
let inputs = HashMap::from([
("text".to_string(), Value::String("Rust是一门系统编程语言,注重安全和性能".to_string()))
]);
println!("\n=== 测试 SequentialChain - 三步管道 ===");
let result = seq_chain.invoke(inputs).await.unwrap();
println!("关键词: {:?}", result.get("keyword"));
println!("解释: {:?}", result.get("explanation"));
println!("示例: {:?}", result.get("example"));
assert!(result.contains_key("keyword"));
assert!(result.contains_key("explanation"));
assert!(result.contains_key("example"));
}
}