langchain_rust/prompt/
prompt.rs1use crate::schemas::{messages::Message, prompt::PromptValue};
2
3use super::{FormatPrompter, PromptArgs, PromptError, PromptFromatter};
4
5#[derive(Clone)]
6pub enum TemplateFormat {
7 FString,
8 Jinja2,
9}
10
11#[derive(Clone)]
12pub struct PromptTemplate {
13 template: String,
14 variables: Vec<String>,
15 format: TemplateFormat,
16}
17
18impl PromptTemplate {
19 pub fn new(template: String, variables: Vec<String>, format: TemplateFormat) -> Self {
20 Self {
21 template,
22 variables,
23 format,
24 }
25 }
26}
27
28impl FormatPrompter for PromptTemplate {
30 fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
31 let messages = vec![Message::new_human_message(self.format(input_variables)?)];
32 Ok(PromptValue::from_messages(messages))
33 }
34 fn get_input_variables(&self) -> Vec<String> {
35 self.variables.clone()
36 }
37}
38
39impl PromptFromatter for PromptTemplate {
40 fn template(&self) -> String {
41 self.template.clone()
42 }
43
44 fn variables(&self) -> Vec<String> {
45 self.variables.clone()
46 }
47
48 fn format(&self, input_variables: PromptArgs) -> Result<String, PromptError> {
49 let mut prompt = self.template();
50
51 for key in self.variables() {
53 if !input_variables.contains_key(key.as_str()) {
54 return Err(PromptError::MissingVariable(key));
55 }
56 }
57
58 for (key, value) in input_variables {
59 let key = match self.format {
60 TemplateFormat::FString => format!("{{{}}}", key),
61 TemplateFormat::Jinja2 => format!("{{{{{}}}}}", key),
62 };
63 let value_str = match &value {
64 serde_json::Value::String(s) => s.clone(),
65 _ => value.to_string(),
66 };
67 prompt = prompt.replace(&key, &value_str);
68 }
69
70 log::debug!("Formatted prompt: {}", prompt);
71 Ok(prompt)
72 }
73}
74
75#[macro_export]
97macro_rules! prompt_args {
98 ( $($key:expr => $value:expr),* $(,)? ) => {
99 {
100 #[allow(unused_mut)]
101 let mut args = std::collections::HashMap::<String, serde_json::Value>::new();
102 $(
103 args.insert($key.to_string(), serde_json::json!($value));
105 )*
106 args
107 }
108 };
109}
110
111#[macro_export]
123macro_rules! template_fstring {
124 ($template:expr, $($var:expr),* $(,)?) => {
125 $crate::prompt::PromptTemplate::new(
126 $template.to_string(),
127 vec![$($var.to_string()),*],
128 $crate::prompt::TemplateFormat::FString,
129 )
130 };
131}
132
133#[macro_export]
145macro_rules! template_jinja2 {
146 ($template:expr, $($var:expr),* $(,)?) => {
147 $crate::prompt::PromptTemplate::new(
148 $template.to_string(),
149 vec![$($var.to_string()),*],
150 $crate::prompt::TemplateFormat::Jinja2,
151 )
152 };
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::prompt_args;
159
160 #[test]
161 fn should_format_jinja2_template() {
162 let template = PromptTemplate::new(
163 "Hello {{name}}!".to_string(),
164 vec!["name".to_string()],
165 TemplateFormat::Jinja2,
166 );
167
168 let input_variables = prompt_args! {};
169 let result = template.format(input_variables);
170 assert!(result.is_err());
171
172 let input_variables = prompt_args! {
173 "name" => "world",
174 };
175 let result = template.format(input_variables);
176 println!("{:?}", result);
177 assert!(result.is_ok());
178 assert_eq!(result.unwrap(), "Hello world!");
179 }
180
181 #[test]
182 fn should_format_fstring_template() {
183 let template = PromptTemplate::new(
184 "Hello {name}!".to_string(),
185 vec!["name".to_string()],
186 TemplateFormat::FString,
187 );
188
189 let input_variables = prompt_args! {};
190 let result = template.format(input_variables);
191 assert!(result.is_err());
192
193 let input_variables = prompt_args! {
194 "name" => "world",
195 };
196 let result = template.format(input_variables);
197 assert!(result.is_ok());
198 assert_eq!(result.unwrap(), "Hello world!");
199 }
200
201 #[test]
202 fn should_prompt_macro_work() {
203 let args = prompt_args! {};
204 assert!(args.is_empty());
205
206 let args = prompt_args! {
207 "name" => "world",
208 };
209 assert_eq!(args.len(), 1);
210 assert_eq!(args.get("name").unwrap(), &"world");
211
212 let args = prompt_args! {
213 "name" => "world",
214 "age" => "18",
215 };
216 assert_eq!(args.len(), 2);
217 assert_eq!(args.get("name").unwrap(), &"world");
218 assert_eq!(args.get("age").unwrap(), &"18");
219 }
220
221 #[test]
222 fn test_chat_template_macros() {
223 let fstring_template = template_fstring!(
225 "FString Chat: {user} says {message} {test}",
226 "user",
227 "message",
228 "test"
229 );
230
231 let jinja2_template =
233 template_jinja2!("Jinja2 Chat: {{user}} says {{message}}", "user", "message");
234
235 let input_variables_fstring = prompt_args! {
237 "user" => "Alice",
238 "message" => "Hello, Bob!",
239 "test"=>"test2"
240 };
241
242 let input_variables_jinja2 = prompt_args! {
243 "user" => "Bob",
244 "message" => "Hi, Alice!",
245 };
246
247 let formatted_fstring = fstring_template.format(input_variables_fstring).unwrap();
249 assert_eq!(
250 formatted_fstring,
251 "FString Chat: Alice says Hello, Bob! test2"
252 );
253
254 let formatted_jinja2 = jinja2_template.format(input_variables_jinja2).unwrap();
256 assert_eq!(formatted_jinja2, "Jinja2 Chat: Bob says Hi, Alice!");
257 }
258}