langchain_rust/chain/sequential/
chain.rs

1use std::collections::{HashMap, HashSet};
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5
6use crate::{
7    chain::{Chain, ChainError, DEFAULT_OUTPUT_KEY, DEFAULT_RESULT_KEY},
8    language_models::{GenerateResult, TokenUsage},
9    prompt::PromptArgs,
10};
11
12//THIS IS EXPERIMENTAL
13pub struct SequentialChain {
14    pub(crate) chains: Vec<Box<dyn Chain>>,
15    pub(crate) input_keys: HashSet<String>,
16    pub(crate) outputs: HashSet<String>,
17}
18
19#[async_trait]
20impl Chain for SequentialChain {
21    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
22        let output = self.execute(input_variables).await?;
23        let result = output
24            .get(DEFAULT_RESULT_KEY)
25            .ok_or_else(|| ChainError::MissingInputVariable(DEFAULT_RESULT_KEY.to_string()))?
26            .clone();
27        let result: GenerateResult = serde_json::from_value(result)?;
28        Ok(result)
29    }
30    async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
31        self.call(input_variables.clone())
32            .await
33            .map(|result| result.generation)
34    }
35    fn get_input_keys(&self) -> Vec<String> {
36        self.outputs.iter().cloned().collect()
37    }
38
39    async fn execute(
40        &self,
41        input_variables: PromptArgs,
42    ) -> Result<HashMap<String, Value>, ChainError> {
43        let mut input_variables = input_variables;
44        let mut final_token_usage: Option<TokenUsage> = None;
45        let mut output_result = HashMap::new();
46        let mut final_result = GenerateResult::default();
47        for chain in self.chains.iter() {
48            let output = chain.execute(input_variables.clone()).await?;
49            //Get the oput key for the chain result
50            let output_key = chain
51                .get_output_keys()
52                .first()
53                .unwrap_or(&DEFAULT_OUTPUT_KEY.to_string())
54                .clone();
55            //Get the ouput complete result
56            let result = output
57                .get(DEFAULT_RESULT_KEY)
58                .unwrap_or(&json!(GenerateResult::default()))
59                .clone();
60            let result: GenerateResult = serde_json::from_value(result)?;
61            log::debug!("{}", result.generation);
62            //Insert the output chain to the final output
63            output_result.insert(output_key.clone(), json!(result.generation.clone()));
64            input_variables.insert(output_key, json!(result.generation.clone()));
65
66            //add the generation to keep track of the final generation
67            final_result.generation = result.generation;
68            //Add to the token if it exist
69            if let Some(token) = &result.tokens {
70                match final_token_usage {
71                    Some(token_usage) => {
72                        final_token_usage = Some(token_usage.sum(token));
73                    }
74                    None => {
75                        final_token_usage = Some(token.clone());
76                    }
77                }
78            }
79        }
80
81        //add the filan token count to the result
82        final_result.tokens = final_token_usage;
83        output_result.insert(DEFAULT_RESULT_KEY.to_string(), json!(final_result));
84        Ok(output_result)
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use crate::{
91        chain::{Chain, LLMChainBuilder},
92        llm::openai::OpenAI,
93        prompt_args, sequential_chain, template_fstring,
94    };
95
96    #[tokio::test]
97    #[ignore]
98    async fn test_sequential() {
99        let llm = OpenAI::default();
100        let chain1 = LLMChainBuilder::new()
101            .prompt(template_fstring!(
102                "dame un nombre para una tienda de {input}",
103                "input"
104            ))
105            .llm(llm.clone())
106            .output_key("nombre")
107            .build()
108            .expect("Failed to build LLMChain");
109
110        let chain2 = LLMChainBuilder::new()
111            .prompt(template_fstring!(
112                "dame un slogan para una tienda llamada {nombre},tiene que incluir la palabra {palabra}",
113                "nombre",
114            "palabra"
115            ))
116            .llm(llm.clone())
117            .output_key("slogan")
118            .build()
119            .expect("Failed to build LLMChain");
120
121        let chain = sequential_chain!(chain1, chain2);
122        let result = chain
123            .execute(prompt_args! {"input"=>"medias","palabra"=>"arroz"})
124            .await;
125        assert!(
126            result.is_ok(),
127            "Expected `chain.call` to succeed, but it failed with error: {:?}",
128            result.err()
129        );
130
131        if let Ok(output) = result {
132            println!("{:?}", output);
133        }
134    }
135}