langchain_rust/chain/sequential/
chain.rs1use std::collections::{HashMap, HashSet};
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5
6use crate::{
7 chain::{Chain, ChainError, DEFAULT_OUTPUT_KEY, DEFAULT_RESULT_KEY},
8 language_models::{GenerateResult, TokenUsage},
9 prompt::PromptArgs,
10};
11
12pub struct SequentialChain {
14 pub(crate) chains: Vec<Box<dyn Chain>>,
15 pub(crate) input_keys: HashSet<String>,
16 pub(crate) outputs: HashSet<String>,
17}
18
19#[async_trait]
20impl Chain for SequentialChain {
21 async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
22 let output = self.execute(input_variables).await?;
23 let result = output
24 .get(DEFAULT_RESULT_KEY)
25 .ok_or_else(|| ChainError::MissingInputVariable(DEFAULT_RESULT_KEY.to_string()))?
26 .clone();
27 let result: GenerateResult = serde_json::from_value(result)?;
28 Ok(result)
29 }
30 async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
31 self.call(input_variables.clone())
32 .await
33 .map(|result| result.generation)
34 }
35 fn get_input_keys(&self) -> Vec<String> {
36 self.outputs.iter().cloned().collect()
37 }
38
39 async fn execute(
40 &self,
41 input_variables: PromptArgs,
42 ) -> Result<HashMap<String, Value>, ChainError> {
43 let mut input_variables = input_variables;
44 let mut final_token_usage: Option<TokenUsage> = None;
45 let mut output_result = HashMap::new();
46 let mut final_result = GenerateResult::default();
47 for chain in self.chains.iter() {
48 let output = chain.execute(input_variables.clone()).await?;
49 let output_key = chain
51 .get_output_keys()
52 .first()
53 .unwrap_or(&DEFAULT_OUTPUT_KEY.to_string())
54 .clone();
55 let result = output
57 .get(DEFAULT_RESULT_KEY)
58 .unwrap_or(&json!(GenerateResult::default()))
59 .clone();
60 let result: GenerateResult = serde_json::from_value(result)?;
61 log::debug!("{}", result.generation);
62 output_result.insert(output_key.clone(), json!(result.generation.clone()));
64 input_variables.insert(output_key, json!(result.generation.clone()));
65
66 final_result.generation = result.generation;
68 if let Some(token) = &result.tokens {
70 match final_token_usage {
71 Some(token_usage) => {
72 final_token_usage = Some(token_usage.sum(token));
73 }
74 None => {
75 final_token_usage = Some(token.clone());
76 }
77 }
78 }
79 }
80
81 final_result.tokens = final_token_usage;
83 output_result.insert(DEFAULT_RESULT_KEY.to_string(), json!(final_result));
84 Ok(output_result)
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use crate::{
91 chain::{Chain, LLMChainBuilder},
92 llm::openai::OpenAI,
93 prompt_args, sequential_chain, template_fstring,
94 };
95
96 #[tokio::test]
97 #[ignore]
98 async fn test_sequential() {
99 let llm = OpenAI::default();
100 let chain1 = LLMChainBuilder::new()
101 .prompt(template_fstring!(
102 "dame un nombre para una tienda de {input}",
103 "input"
104 ))
105 .llm(llm.clone())
106 .output_key("nombre")
107 .build()
108 .expect("Failed to build LLMChain");
109
110 let chain2 = LLMChainBuilder::new()
111 .prompt(template_fstring!(
112 "dame un slogan para una tienda llamada {nombre},tiene que incluir la palabra {palabra}",
113 "nombre",
114 "palabra"
115 ))
116 .llm(llm.clone())
117 .output_key("slogan")
118 .build()
119 .expect("Failed to build LLMChain");
120
121 let chain = sequential_chain!(chain1, chain2);
122 let result = chain
123 .execute(prompt_args! {"input"=>"medias","palabra"=>"arroz"})
124 .await;
125 assert!(
126 result.is_ok(),
127 "Expected `chain.call` to succeed, but it failed with error: {:?}",
128 result.err()
129 );
130
131 if let Ok(output) = result {
132 println!("{:?}", output);
133 }
134 }
135}