langchain_rust/chain/chain_trait.rs
1use std::{collections::HashMap, pin::Pin};
2
3use async_trait::async_trait;
4use futures::Stream;
5use serde_json::{json, Value};
6
7use crate::{language_models::GenerateResult, prompt::PromptArgs, schemas::StreamData};
8
9use super::ChainError;
10
11pub(crate) const DEFAULT_OUTPUT_KEY: &str = "output";
12pub(crate) const DEFAULT_RESULT_KEY: &str = "generate_result";
13
14#[async_trait]
15pub trait Chain: Sync + Send {
16 /// Call the `Chain` and receive as output the result of the generation process along with
17 /// additional information like token consumption. The input is a set of variables passed
18 /// as a `PromptArgs` hashmap.
19 ///
20 /// # Example
21 ///
22 /// ```rust,ignore
23 /// # use crate::my_crate::{Chain, ConversationalChainBuilder, OpenAI, OpenAIModel, SimpleMemory, PromptArgs, prompt_args};
24 /// # async {
25 /// let llm = OpenAI::default().with_model(OpenAIModel::Gpt35);
26 /// let memory = SimpleMemory::new();
27 ///
28 /// let chain = ConversationalChainBuilder::new()
29 /// .llm(llm)
30 /// .memory(memory.into())
31 /// .build().expect("Error building ConversationalChain");
32 ///
33 /// let input_variables = prompt_args! {
34 /// "input" => "Im from Peru",
35 /// };
36 ///
37 /// match chain.call(input_variables).await {
38 /// Ok(result) => {
39 /// println!("Result: {:?}", result);
40 /// },
41 /// Err(e) => panic!("Error calling Chain: {:?}", e),
42 /// };
43 /// # };
44 /// ```
45 async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError>;
46
47 /// Invoke the `Chain` and receive just the generation result as a String.
48 /// The input is a set of variables passed as a `PromptArgs` hashmap.
49 ///
50 /// # Example
51 ///
52 /// ```rust,ignore
53 /// # use crate::my_crate::{Chain, ConversationalChainBuilder, OpenAI, OpenAIModel, SimpleMemory, PromptArgs, prompt_args};
54 /// # async {
55 /// let llm = OpenAI::default().with_model(OpenAIModel::Gpt35);
56 /// let memory = SimpleMemory::new();
57 ///
58 /// let chain = ConversationalChainBuilder::new()
59 /// .llm(llm)
60 /// .memory(memory.into())
61 /// .build().expect("Error building ConversationalChain");
62 ///
63 /// let input_variables = prompt_args! {
64 /// "input" => "Im from Peru",
65 /// };
66 ///
67 /// match chain.invoke(input_variables).await {
68 /// Ok(result) => {
69 /// println!("Result: {:?}", result);
70 /// },
71 /// Err(e) => panic!("Error invoking Chain: {:?}", e),
72 /// };
73 /// # };
74 /// ```
75 async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
76 self.call(input_variables)
77 .await
78 .map(|result| result.generation)
79 }
80
81 /// Execute the `Chain` and return the result of the generation process
82 /// along with additional information like token consumption formatted as a `HashMap`.
83 /// The input is a set of variables passed as a `PromptArgs` hashmap.
84 /// The key for the generated output is specified by the `get_output_keys`
85 /// method (default key is `output`).
86 ///
87 /// # Example
88 ///
89 /// ```rust,ignore
90 /// # use crate::my_crate::{Chain, ConversationalChainBuilder, OpenAI, OpenAIModel, SimpleMemory, PromptArgs, prompt_args};
91 /// # async {
92 /// let llm = OpenAI::default().with_model(OpenAIModel::Gpt35);
93 /// let memory = SimpleMemory::new();
94 ///
95 /// let chain = ConversationalChainBuilder::new()
96 /// .llm(llm)
97 /// .memory(memory.into())
98 /// .output_key("name")
99 /// .build().expect("Error building ConversationalChain");
100 ///
101 /// let input_variables = prompt_args! {
102 /// "input" => "Im from Peru",
103 /// };
104 ///
105 /// match chain.execute(input_variables).await {
106 /// Ok(result) => {
107 /// println!("Result: {:?}", result);
108 /// },
109 /// Err(e) => panic!("Error executing Chain: {:?}", e),
110 /// };
111 /// # };
112 /// ```
113 async fn execute(
114 &self,
115 input_variables: PromptArgs,
116 ) -> Result<HashMap<String, Value>, ChainError> {
117 log::info!("Using default implementation");
118 let result = self.call(input_variables.clone()).await?;
119 let mut output = HashMap::new();
120 let output_key = self
121 .get_output_keys()
122 .first()
123 .unwrap_or(&DEFAULT_OUTPUT_KEY.to_string())
124 .clone();
125 output.insert(output_key, json!(result.generation));
126 output.insert(DEFAULT_RESULT_KEY.to_string(), json!(result));
127 Ok(output)
128 }
129 /// Stream the `Chain` and get an asynchronous stream of chain generations.
130 /// The input is a set of variables passed as a `PromptArgs` hashmap.
131 /// If the chain have memroy, the tream method will not be able to automaticaly
132 /// set the memroy, bocause it will not know if the how to extract the output message
133 /// out of the stram
134 /// # Example
135 ///
136 /// ```rust,ignore
137 /// # use futures::StreamExt;
138 /// # use crate::my_crate::{Chain, LLMChainBuilder, OpenAI, fmt_message, fmt_template,
139 /// # HumanMessagePromptTemplate, prompt_args, Message, template_fstring};
140 /// # async {
141 /// let open_ai = OpenAI::default();
142 ///
143 ///let prompt = message_formatter![
144 ///fmt_message!(Message::new_system_message(
145 ///"You are world class technical documentation writer."
146 ///)),
147 ///fmt_template!(HumanMessagePromptTemplate::new(template_fstring!(
148 /// "{input}", "input"
149 ///)))
150 ///];
151 ///
152 /// let chain = LLMChainBuilder::new()
153 /// .prompt(prompt)
154 /// .llm(open_ai.clone())
155 /// .build()
156 /// .unwrap();
157 ///
158 /// let mut stream = chain.stream(
159 /// prompt_args! {
160 /// "input" => "Who is the writer of 20,000 Leagues Under the Sea?"
161 /// }).await.unwrap();
162 ///
163 /// while let Some(result) = stream.next().await {
164 /// match result {
165 /// Ok(value) => {
166 /// println!("Content: {}", value.content);
167 /// },
168 /// Err(e) => panic!("Error invoking LLMChain: {:?}", e),
169 /// }
170 /// };
171 /// # };
172 /// ```
173 ///
174 async fn stream(
175 &self,
176 _input_variables: PromptArgs,
177 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
178 {
179 log::warn!("stream not implemented for this chain");
180 unimplemented!()
181 }
182
183 // Get the input keys of the prompt
184 fn get_input_keys(&self) -> Vec<String> {
185 log::info!("Using default implementation");
186 vec![]
187 }
188
189 fn get_output_keys(&self) -> Vec<String> {
190 log::info!("Using default implementation");
191 vec![
192 String::from(DEFAULT_OUTPUT_KEY),
193 String::from(DEFAULT_RESULT_KEY),
194 ]
195 }
196}
197
198impl<C> From<C> for Box<dyn Chain>
199where
200 C: Chain + 'static,
201{
202 fn from(chain: C) -> Self {
203 Box::new(chain)
204 }
205}