langchain_rust/prompt/
chat.rs1use crate::schemas::{messages::Message, prompt::PromptValue};
2
3use super::{
4 FormatPrompter, MessageFormatter, PromptArgs, PromptError, PromptFromatter, PromptTemplate,
5};
6
7#[derive(Clone)]
18pub struct HumanMessagePromptTemplate {
19 prompt: PromptTemplate,
20}
21
22impl HumanMessagePromptTemplate {
23 pub fn new(prompt: PromptTemplate) -> Self {
24 Self { prompt }
25 }
26}
27impl MessageFormatter for HumanMessagePromptTemplate {
28 fn format_messages(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
29 let message = Message::new_human_message(self.prompt.format(input_variables)?);
30 log::debug!("message: {:?}", message);
31 Ok(vec![message])
32 }
33 fn input_variables(&self) -> Vec<String> {
34 self.prompt.variables().clone()
35 }
36}
37
38impl FormatPrompter for HumanMessagePromptTemplate {
39 fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
40 let messages = self.format_messages(input_variables)?;
41 Ok(PromptValue::from_messages(messages))
42 }
43 fn get_input_variables(&self) -> Vec<String> {
44 self.input_variables()
45 }
46}
47
48#[derive(Clone)]
60pub struct SystemMessagePromptTemplate {
61 prompt: PromptTemplate,
62}
63
64impl SystemMessagePromptTemplate {
65 pub fn new(prompt: PromptTemplate) -> Self {
66 Self { prompt }
67 }
68}
69
70impl FormatPrompter for SystemMessagePromptTemplate {
71 fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
72 let messages = self.format_messages(input_variables)?;
73 Ok(PromptValue::from_messages(messages))
74 }
75 fn get_input_variables(&self) -> Vec<String> {
76 self.input_variables()
77 }
78}
79
80impl MessageFormatter for SystemMessagePromptTemplate {
81 fn format_messages(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
82 let message = Message::new_system_message(self.prompt.format(input_variables)?);
83 log::debug!("message: {:?}", message);
84 Ok(vec![message])
85 }
86 fn input_variables(&self) -> Vec<String> {
87 self.prompt.variables().clone()
88 }
89}
90
91#[derive(Clone)]
102pub struct AIMessagePromptTemplate {
103 prompt: PromptTemplate,
104}
105
106impl FormatPrompter for AIMessagePromptTemplate {
107 fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
108 let messages = self.format_messages(input_variables)?;
109 Ok(PromptValue::from_messages(messages))
110 }
111 fn get_input_variables(&self) -> Vec<String> {
112 self.input_variables()
113 }
114}
115
116impl MessageFormatter for AIMessagePromptTemplate {
117 fn format_messages(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
118 let message = Message::new_ai_message(self.prompt.format(input_variables)?);
119 log::debug!("message: {:?}", message);
120 Ok(vec![message])
121 }
122 fn input_variables(&self) -> Vec<String> {
123 self.prompt.variables().clone()
124 }
125}
126
127impl AIMessagePromptTemplate {
128 pub fn new(prompt: PromptTemplate) -> Self {
129 Self { prompt }
130 }
131}
132
133pub enum MessageOrTemplate {
134 Message(Message),
135 Template(Box<dyn MessageFormatter>),
136 MessagesPlaceholder(String),
137}
138
139#[macro_export]
148macro_rules! fmt_message {
149 ($msg:expr) => {
150 $crate::prompt::MessageOrTemplate::Message($msg)
151 };
152}
153
154#[macro_export]
166macro_rules! fmt_template {
167 ($template:expr) => {
168 $crate::prompt::MessageOrTemplate::Template(Box::new($template))
169 };
170}
171
172#[macro_export]
181macro_rules! fmt_placeholder {
182 ($placeholder:expr) => {
183 $crate::prompt::MessageOrTemplate::MessagesPlaceholder($placeholder.into())
184 };
185}
186
187pub struct MessageFormatterStruct {
188 items: Vec<MessageOrTemplate>,
189}
190
191impl MessageFormatterStruct {
192 pub fn new() -> Self {
193 Self { items: Vec::new() }
194 }
195
196 pub fn add_message(&mut self, message: Message) {
197 self.items.push(MessageOrTemplate::Message(message));
198 }
199
200 pub fn add_template(&mut self, template: Box<dyn MessageFormatter>) {
201 self.items.push(MessageOrTemplate::Template(template));
202 }
203
204 pub fn add_messages_placeholder(&mut self, placeholder: &str) {
205 self.items.push(MessageOrTemplate::MessagesPlaceholder(
206 placeholder.to_string(),
207 ));
208 }
209
210 fn format(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
211 let mut result: Vec<Message> = Vec::new();
212 for item in &self.items {
213 match item {
214 MessageOrTemplate::Message(msg) => result.push(msg.clone()),
215 MessageOrTemplate::Template(tmpl) => {
216 result.extend(tmpl.format_messages(input_variables.clone())?)
217 }
218 MessageOrTemplate::MessagesPlaceholder(placeholder) => {
219 let messages = input_variables[placeholder].clone();
220 result.extend(Message::messages_from_value(&messages)?);
221 }
222 }
223 }
224 Ok(result)
225 }
226}
227
228impl MessageFormatter for MessageFormatterStruct {
229 fn format_messages(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
230 self.format(input_variables)
231 }
232 fn input_variables(&self) -> Vec<String> {
233 let mut variables = Vec::new();
234 for item in &self.items {
235 match item {
236 MessageOrTemplate::Message(_) => {}
237 MessageOrTemplate::Template(tmpl) => {
238 variables.extend(tmpl.input_variables());
239 }
240 MessageOrTemplate::MessagesPlaceholder(placeholder) => {
241 variables.extend(vec![placeholder.clone()]);
242 }
243 }
244 }
245 variables
246 }
247}
248
249impl FormatPrompter for MessageFormatterStruct {
250 fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
251 let messages = self.format(input_variables)?;
252 Ok(PromptValue::from_messages(messages))
253 }
254 fn get_input_variables(&self) -> Vec<String> {
255 self.input_variables()
256 }
257}
258
259#[macro_export]
260macro_rules! message_formatter {
283($($item:expr),* $(,)?) => {{
284 let mut formatter = $crate::prompt::MessageFormatterStruct::new();
285 $(
286 match $item {
287 $crate::prompt::MessageOrTemplate::Message(msg) => formatter.add_message(msg),
288 $crate::prompt::MessageOrTemplate::Template(tmpl) => formatter.add_template(tmpl),
289 $crate::prompt::MessageOrTemplate::MessagesPlaceholder(placeholder) => formatter.add_messages_placeholder(&placeholder.clone()),
290 }
291 )*
292 formatter
293}};
294}
295
296#[cfg(test)]
297mod tests {
298 use crate::{
299 message_formatter,
300 prompt::{chat::AIMessagePromptTemplate, FormatPrompter},
301 prompt_args,
302 schemas::messages::Message,
303 template_fstring,
304 };
305
306 #[test]
307 fn test_message_formatter_macro() {
308 let human_msg = Message::new_human_message("Hello from user");
310
311 let ai_message_prompt = AIMessagePromptTemplate::new(template_fstring!(
313 "AI response: {content} {test}",
314 "content",
315 "test"
316 ));
317
318 let formatter = message_formatter![
320 fmt_message!(human_msg),
321 fmt_template!(ai_message_prompt),
322 fmt_placeholder!("history")
323 ];
324
325 let input_variables = prompt_args! {
327 "content" => "This is a test",
328 "test" => "test2",
329 "history" => vec![
330 Message::new_human_message("Placeholder message 1"),
331 Message::new_ai_message("Placeholder message 2"),
332 ],
333
334
335 };
336
337 let formatted_messages = formatter
339 .format_prompt(input_variables)
340 .unwrap()
341 .to_chat_messages();
342
343 assert_eq!(formatted_messages.len(), 4);
345
346 assert_eq!(formatted_messages[0].content, "Hello from user");
348 assert_eq!(
349 formatted_messages[1].content,
350 "AI response: This is a test test2"
351 );
352 assert_eq!(formatted_messages[2].content, "Placeholder message 1");
353 assert_eq!(formatted_messages[3].content, "Placeholder message 2");
354 }
355}