langchain_rust/prompt/
prompt.rs

1use crate::schemas::{messages::Message, prompt::PromptValue};
2
3use super::{FormatPrompter, PromptArgs, PromptError, PromptFromatter};
4
5#[derive(Clone)]
6pub enum TemplateFormat {
7    FString,
8    Jinja2,
9}
10
11#[derive(Clone)]
12pub struct PromptTemplate {
13    template: String,
14    variables: Vec<String>,
15    format: TemplateFormat,
16}
17
18impl PromptTemplate {
19    pub fn new(template: String, variables: Vec<String>, format: TemplateFormat) -> Self {
20        Self {
21            template,
22            variables,
23            format,
24        }
25    }
26}
27
28//PromptTemplate will be default transformed to an Human Input when used as FromatPrompter
29impl FormatPrompter for PromptTemplate {
30    fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
31        let messages = vec![Message::new_human_message(self.format(input_variables)?)];
32        Ok(PromptValue::from_messages(messages))
33    }
34    fn get_input_variables(&self) -> Vec<String> {
35        self.variables.clone()
36    }
37}
38
39impl PromptFromatter for PromptTemplate {
40    fn template(&self) -> String {
41        self.template.clone()
42    }
43
44    fn variables(&self) -> Vec<String> {
45        self.variables.clone()
46    }
47
48    fn format(&self, input_variables: PromptArgs) -> Result<String, PromptError> {
49        let mut prompt = self.template();
50
51        // check if all variables are in the input variables
52        for key in self.variables() {
53            if !input_variables.contains_key(key.as_str()) {
54                return Err(PromptError::MissingVariable(key));
55            }
56        }
57
58        for (key, value) in input_variables {
59            let key = match self.format {
60                TemplateFormat::FString => format!("{{{}}}", key),
61                TemplateFormat::Jinja2 => format!("{{{{{}}}}}", key),
62            };
63            let value_str = match &value {
64                serde_json::Value::String(s) => s.clone(),
65                _ => value.to_string(),
66            };
67            prompt = prompt.replace(&key, &value_str);
68        }
69
70        log::debug!("Formatted prompt: {}", prompt);
71        Ok(prompt)
72    }
73}
74
75/// `prompt_args!` is a utility macro used for creating a `std::collections::HashMap<String, serde_json::Value>`.
76/// This HashMap can then be passed as arguments to a function or method.
77///
78/// # Usage
79/// In this macro, the keys are `&str` and values are arbitrary types that get serialized into `serde_json::Value`:
80/// ```rust,ignore
81/// prompt_args! {
82///     "input" => "Who is the writer of 20,000 Leagues Under the Sea, and what is my name?",
83///     "history" => vec![
84///         Message::new_human_message("My name is: Luis"),
85///         Message::new_ai_message("Hi Luis"),
86///     ],
87/// }
88/// ```
89///
90/// # Arguments
91/// * `key` - A `&str` that will be used as the key in the resulting HashMap.<br>
92/// * `value` - An arbitrary type that will be serialized into `serde_json::Value` and associated with the corresponding key.
93///
94/// The precise keys and values are dependent on your specific use case. In this example, "input" and "history" are keys,
95/// and
96#[macro_export]
97macro_rules! prompt_args {
98    ( $($key:expr => $value:expr),* $(,)? ) => {
99        {
100            #[allow(unused_mut)]
101            let mut args = std::collections::HashMap::<String, serde_json::Value>::new();
102            $(
103                // Convert the value to serde_json::Value before inserting
104                args.insert($key.to_string(), serde_json::json!($value));
105            )*
106            args
107        }
108    };
109}
110
111/// `template_fstring` is a utility macro that creates a new `PromptTemplate` with FString as the template format.
112///
113/// # Usage
114/// The macro is called with a template string and a list of variables that exist in the template. For example:
115/// ```rust,ignore
116/// template_fstring!(
117///     "Hello {name}",
118///     "name"
119/// )
120/// ```
121/// This returns a `PromptTemplate` object that contains the string "Hello {name}" as the template and ["name"] as the variables, with TemplateFormat set to FString.
122#[macro_export]
123macro_rules! template_fstring {
124    ($template:expr, $($var:expr),* $(,)?) => {
125        $crate::prompt::PromptTemplate::new(
126            $template.to_string(),
127            vec![$($var.to_string()),*],
128            $crate::prompt::TemplateFormat::FString,
129        )
130    };
131}
132
133/// `template_jinja2` is a utility macro that creates a new `PromptTemplate` with Jinja2 as the template format.
134///
135/// # Usage
136/// The macro is called with a template string and a list of variables that exist in the template. For example:
137/// ```rust,ignore
138/// template_jinja2!(
139///     "Hello {{ name }}",
140///     "name"
141/// )
142/// ```
143/// This returns a `PromptTemplate` object that contains the string "Hello {{ name }}" as the template and ["name"] as the variables, with TemplateFormat set to Jinja2.
144#[macro_export]
145macro_rules! template_jinja2 {
146    ($template:expr, $($var:expr),* $(,)?) => {
147        $crate::prompt::PromptTemplate::new(
148            $template.to_string(),
149            vec![$($var.to_string()),*],
150            $crate::prompt::TemplateFormat::Jinja2,
151        )
152    };
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::prompt_args;
159
160    #[test]
161    fn should_format_jinja2_template() {
162        let template = PromptTemplate::new(
163            "Hello {{name}}!".to_string(),
164            vec!["name".to_string()],
165            TemplateFormat::Jinja2,
166        );
167
168        let input_variables = prompt_args! {};
169        let result = template.format(input_variables);
170        assert!(result.is_err());
171
172        let input_variables = prompt_args! {
173            "name" => "world",
174        };
175        let result = template.format(input_variables);
176        println!("{:?}", result);
177        assert!(result.is_ok());
178        assert_eq!(result.unwrap(), "Hello world!");
179    }
180
181    #[test]
182    fn should_format_fstring_template() {
183        let template = PromptTemplate::new(
184            "Hello {name}!".to_string(),
185            vec!["name".to_string()],
186            TemplateFormat::FString,
187        );
188
189        let input_variables = prompt_args! {};
190        let result = template.format(input_variables);
191        assert!(result.is_err());
192
193        let input_variables = prompt_args! {
194            "name" => "world",
195        };
196        let result = template.format(input_variables);
197        assert!(result.is_ok());
198        assert_eq!(result.unwrap(), "Hello world!");
199    }
200
201    #[test]
202    fn should_prompt_macro_work() {
203        let args = prompt_args! {};
204        assert!(args.is_empty());
205
206        let args = prompt_args! {
207            "name" => "world",
208        };
209        assert_eq!(args.len(), 1);
210        assert_eq!(args.get("name").unwrap(), &"world");
211
212        let args = prompt_args! {
213            "name" => "world",
214            "age" => "18",
215        };
216        assert_eq!(args.len(), 2);
217        assert_eq!(args.get("name").unwrap(), &"world");
218        assert_eq!(args.get("age").unwrap(), &"18");
219    }
220
221    #[test]
222    fn test_chat_template_macros() {
223        // Creating an FString chat template
224        let fstring_template = template_fstring!(
225            "FString Chat: {user} says {message} {test}",
226            "user",
227            "message",
228            "test"
229        );
230
231        // Creating a Jinja2 chat template
232        let jinja2_template =
233            template_jinja2!("Jinja2 Chat: {{user}} says {{message}}", "user", "message");
234
235        // Define input variables for the templates
236        let input_variables_fstring = prompt_args! {
237            "user" => "Alice",
238            "message" => "Hello, Bob!",
239            "test"=>"test2"
240        };
241
242        let input_variables_jinja2 = prompt_args! {
243            "user" => "Bob",
244            "message" => "Hi, Alice!",
245        };
246
247        // Format the FString chat template
248        let formatted_fstring = fstring_template.format(input_variables_fstring).unwrap();
249        assert_eq!(
250            formatted_fstring,
251            "FString Chat: Alice says Hello, Bob! test2"
252        );
253
254        // Format the Jinja2 chat template
255        let formatted_jinja2 = jinja2_template.format(input_variables_jinja2).unwrap();
256        assert_eq!(formatted_jinja2, "Jinja2 Chat: Bob says Hi, Alice!");
257    }
258}