langchain_rust/chain/
chain_trait.rs

1use std::{collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Stream;
5use serde_json::{json, Value};
6
7use crate::{language_models::GenerateResult, prompt::PromptArgs, schemas::StreamData};
8
9use super::ChainError;
10
11pub(crate) const DEFAULT_OUTPUT_KEY: &str = "output";
12pub(crate) const DEFAULT_RESULT_KEY: &str = "generate_result";
13
14#[async_trait]
15pub trait Chain: Sync + Send {
16    /// Call the `Chain` and receive as output the result of the generation process along with
17    /// additional information like token consumption. The input is a set of variables passed
18    /// as a `PromptArgs` hashmap.
19    ///
20    /// # Example
21    ///
22    /// ```rust,ignore
23    /// # use crate::my_crate::{Chain, ConversationalChainBuilder, OpenAI, OpenAIModel, SimpleMemory, PromptArgs, prompt_args};
24    /// # async {
25    /// let llm = OpenAI::default().with_model(OpenAIModel::Gpt35);
26    /// let memory = SimpleMemory::new();
27    ///
28    /// let chain = ConversationalChainBuilder::new()
29    ///     .llm(llm)
30    ///     .memory(memory.into())
31    ///     .build().expect("Error building ConversationalChain");
32    ///
33    /// let input_variables = prompt_args! {
34    ///     "input" => "Im from Peru",
35    /// };
36    ///
37    /// match chain.call(input_variables).await {
38    ///     Ok(result) => {
39    ///         println!("Result: {:?}", result);
40    ///     },
41    ///     Err(e) => panic!("Error calling Chain: {:?}", e),
42    /// };
43    /// # };
44    /// ```
45    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError>;
46
47    /// Invoke the `Chain` and receive just the generation result as a String.
48    /// The input is a set of variables passed as a `PromptArgs` hashmap.
49    ///
50    /// # Example
51    ///
52    /// ```rust,ignore
53    /// # use crate::my_crate::{Chain, ConversationalChainBuilder, OpenAI, OpenAIModel, SimpleMemory, PromptArgs, prompt_args};
54    /// # async {
55    /// let llm = OpenAI::default().with_model(OpenAIModel::Gpt35);
56    /// let memory = SimpleMemory::new();
57    ///
58    /// let chain = ConversationalChainBuilder::new()
59    ///     .llm(llm)
60    ///     .memory(memory.into())
61    ///     .build().expect("Error building ConversationalChain");
62    ///
63    /// let input_variables = prompt_args! {
64    ///     "input" => "Im from Peru",
65    /// };
66    ///
67    /// match chain.invoke(input_variables).await {
68    ///     Ok(result) => {
69    ///         println!("Result: {:?}", result);
70    ///     },
71    ///     Err(e) => panic!("Error invoking Chain: {:?}", e),
72    /// };
73    /// # };
74    /// ```
75    async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
76        self.call(input_variables)
77            .await
78            .map(|result| result.generation)
79    }
80
81    /// Execute the `Chain` and return the result of the generation process
82    /// along with additional information like token consumption formatted as a `HashMap`.
83    /// The input is a set of variables passed as a `PromptArgs` hashmap.
84    /// The key for the generated output is specified by the `get_output_keys`
85    /// method (default key is `output`).
86    ///
87    /// # Example
88    ///
89    /// ```rust,ignore
90    /// # use crate::my_crate::{Chain, ConversationalChainBuilder, OpenAI, OpenAIModel, SimpleMemory, PromptArgs, prompt_args};
91    /// # async {
92    /// let llm = OpenAI::default().with_model(OpenAIModel::Gpt35);
93    /// let memory = SimpleMemory::new();
94    ///
95    /// let chain = ConversationalChainBuilder::new()
96    ///     .llm(llm)
97    ///     .memory(memory.into())
98    ///     .output_key("name")
99    ///     .build().expect("Error building ConversationalChain");
100    ///
101    /// let input_variables = prompt_args! {
102    ///     "input" => "Im from Peru",
103    /// };
104    ///
105    /// match chain.execute(input_variables).await {
106    ///     Ok(result) => {
107    ///         println!("Result: {:?}", result);
108    ///     },
109    ///     Err(e) => panic!("Error executing Chain: {:?}", e),
110    /// };
111    /// # };
112    /// ```
113    async fn execute(
114        &self,
115        input_variables: PromptArgs,
116    ) -> Result<HashMap<String, Value>, ChainError> {
117        log::info!("Using default implementation");
118        let result = self.call(input_variables.clone()).await?;
119        let mut output = HashMap::new();
120        let output_key = self
121            .get_output_keys()
122            .first()
123            .unwrap_or(&DEFAULT_OUTPUT_KEY.to_string())
124            .clone();
125        output.insert(output_key, json!(result.generation));
126        output.insert(DEFAULT_RESULT_KEY.to_string(), json!(result));
127        Ok(output)
128    }
129    /// Stream the `Chain` and get an asynchronous stream of chain generations.
130    /// The input is a set of variables passed as a `PromptArgs` hashmap.
131    /// If the chain have memroy, the tream method will not be able to automaticaly
132    /// set the memroy, bocause it will not know if the how to extract the output message
133    /// out of the stram
134    /// # Example
135    ///
136    /// ```rust,ignore
137    /// # use futures::StreamExt;
138    /// # use crate::my_crate::{Chain, LLMChainBuilder, OpenAI, fmt_message, fmt_template,
139    /// #                      HumanMessagePromptTemplate, prompt_args, Message, template_fstring};
140    /// # async {
141    /// let open_ai = OpenAI::default();
142    ///
143    ///let prompt = message_formatter![
144    ///fmt_message!(Message::new_system_message(
145    ///"You are world class technical documentation writer."
146    ///)),
147    ///fmt_template!(HumanMessagePromptTemplate::new(template_fstring!(
148    ///      "{input}", "input"
149    ///)))
150    ///];
151    ///
152    /// let chain = LLMChainBuilder::new()
153    ///     .prompt(prompt)
154    ///     .llm(open_ai.clone())
155    ///     .build()
156    ///     .unwrap();
157    ///
158    /// let mut stream = chain.stream(
159    /// prompt_args! {
160    /// "input" => "Who is the writer of 20,000 Leagues Under the Sea?"
161    /// }).await.unwrap();
162    ///
163    /// while let Some(result) = stream.next().await {
164    ///     match result {
165    ///         Ok(value) => {
166    ///                 println!("Content: {}", value.content);
167    ///         },
168    ///         Err(e) => panic!("Error invoking LLMChain: {:?}", e),
169    ///     }
170    /// };
171    /// # };
172    /// ```
173    ///
174    async fn stream(
175        &self,
176        _input_variables: PromptArgs,
177    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
178    {
179        log::warn!("stream not implemented for this chain");
180        unimplemented!()
181    }
182
183    // Get the input keys of the prompt
184    fn get_input_keys(&self) -> Vec<String> {
185        log::info!("Using default implementation");
186        vec![]
187    }
188
189    fn get_output_keys(&self) -> Vec<String> {
190        log::info!("Using default implementation");
191        vec![
192            String::from(DEFAULT_OUTPUT_KEY),
193            String::from(DEFAULT_RESULT_KEY),
194        ]
195    }
196}
197
198impl<C> From<C> for Box<dyn Chain>
199where
200    C: Chain + 'static,
201{
202    fn from(chain: C) -> Self {
203        Box::new(chain)
204    }
205}