langchain_rust/prompt/
chat.rs

1use crate::schemas::{messages::Message, prompt::PromptValue};
2
3use super::{
4    FormatPrompter, MessageFormatter, PromptArgs, PromptError, PromptFromatter, PromptTemplate,
5};
6
7/// Struct `HumanMessagePromptTemplate` defines a template for creating human (user) messages.
8/// `PromptTemplate` is used to generate the message template.
9///
10/// # Usage
11/// ```rust,ignore
12/// let human_message_prompt = HumanMessagePromptTemplate::new(template_fstring!(
13///    "User says: {content}",
14///    "content",
15/// ));
16/// ```
17#[derive(Clone)]
18pub struct HumanMessagePromptTemplate {
19    prompt: PromptTemplate,
20}
21
22impl HumanMessagePromptTemplate {
23    pub fn new(prompt: PromptTemplate) -> Self {
24        Self { prompt }
25    }
26}
27impl MessageFormatter for HumanMessagePromptTemplate {
28    fn format_messages(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
29        let message = Message::new_human_message(self.prompt.format(input_variables)?);
30        log::debug!("message: {:?}", message);
31        Ok(vec![message])
32    }
33    fn input_variables(&self) -> Vec<String> {
34        self.prompt.variables().clone()
35    }
36}
37
38impl FormatPrompter for HumanMessagePromptTemplate {
39    fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
40        let messages = self.format_messages(input_variables)?;
41        Ok(PromptValue::from_messages(messages))
42    }
43    fn get_input_variables(&self) -> Vec<String> {
44        self.input_variables()
45    }
46}
47
48/// Struct `SystemMessagePromptTemplate` defines a template for creating system-level messages.
49/// `PromptTemplate` is used to generate the message template.
50///
51/// # Usage
52/// ```rust,ignore
53/// let system_message_prompt = SystemMessagePromptTemplate::new(template_fstring!(
54///    "System alert: {alert_type} {alert_detail}",
55///    "alert_type",
56///    "alert_detail"
57/// ));
58/// ```
59#[derive(Clone)]
60pub struct SystemMessagePromptTemplate {
61    prompt: PromptTemplate,
62}
63
64impl SystemMessagePromptTemplate {
65    pub fn new(prompt: PromptTemplate) -> Self {
66        Self { prompt }
67    }
68}
69
70impl FormatPrompter for SystemMessagePromptTemplate {
71    fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
72        let messages = self.format_messages(input_variables)?;
73        Ok(PromptValue::from_messages(messages))
74    }
75    fn get_input_variables(&self) -> Vec<String> {
76        self.input_variables()
77    }
78}
79
80impl MessageFormatter for SystemMessagePromptTemplate {
81    fn format_messages(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
82        let message = Message::new_system_message(self.prompt.format(input_variables)?);
83        log::debug!("message: {:?}", message);
84        Ok(vec![message])
85    }
86    fn input_variables(&self) -> Vec<String> {
87        self.prompt.variables().clone()
88    }
89}
90
91/// Struct `AIMessagePromptTemplate` defines a template for creating AI (assistant) messages.
92/// `PromptTemplate` is used to generate the message template.
93///
94/// # Usage
95/// ```rust,ignore
96/// let ai_message_prompt = AIMessagePromptTemplate::new(template_fstring!(
97///    "AI response: {content} {additional_info}",
98///    "content",
99///    "additional_info"
100/// ));
101#[derive(Clone)]
102pub struct AIMessagePromptTemplate {
103    prompt: PromptTemplate,
104}
105
106impl FormatPrompter for AIMessagePromptTemplate {
107    fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
108        let messages = self.format_messages(input_variables)?;
109        Ok(PromptValue::from_messages(messages))
110    }
111    fn get_input_variables(&self) -> Vec<String> {
112        self.input_variables()
113    }
114}
115
116impl MessageFormatter for AIMessagePromptTemplate {
117    fn format_messages(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
118        let message = Message::new_ai_message(self.prompt.format(input_variables)?);
119        log::debug!("message: {:?}", message);
120        Ok(vec![message])
121    }
122    fn input_variables(&self) -> Vec<String> {
123        self.prompt.variables().clone()
124    }
125}
126
127impl AIMessagePromptTemplate {
128    pub fn new(prompt: PromptTemplate) -> Self {
129        Self { prompt }
130    }
131}
132
133pub enum MessageOrTemplate {
134    Message(Message),
135    Template(Box<dyn MessageFormatter>),
136    MessagesPlaceholder(String),
137}
138
139/// `fmt_message` is a utility macro used to create a `MessageOrTemplate::Message` variant.
140///
141/// # Usage
142/// The macro is called with a `Message` object. For example:
143/// ```rust,ignore
144/// let message = Message::new_human_message("Hello World");
145/// fmt_message!(message) // Returns a `MessageOrTemplate::Message` variant that wraps the `Message` object
146/// ```
147#[macro_export]
148macro_rules! fmt_message {
149    ($msg:expr) => {
150        $crate::prompt::MessageOrTemplate::Message($msg)
151    };
152}
153
154/// `fmt_template` is a utility macro used to create a `MessageOrTemplate::Template` variant.
155///
156/// # Usage
157/// The macro is called with a `MessageFormatter` object, for instance `HumanMessagePromptTemplate`,
158/// `SystemMessagePromptTemplate`, `AIMessagePromptTemplate` or any other implementation of `MessageFormatter`.
159///
160/// ```rust,ignore
161/// let prompt_template = HumanMessagePromptTemplate::new(template);
162/// fmt_template!(prompt_template)
163/// ```
164/// This returns a `MessageOrTemplate::Template` variant that wraps the `MessageFormatter` object within a Box.
165#[macro_export]
166macro_rules! fmt_template {
167    ($template:expr) => {
168        $crate::prompt::MessageOrTemplate::Template(Box::new($template))
169    };
170}
171
172/// `fmt_placeholder` is a utility macro used to create a `MessageOrTemplate::MessagesPlaceholder` variant.
173///
174/// # Usage
175/// The macro is called with a string literal or a String object:
176/// ```rust,ignore
177/// fmt_placeholder!("Placeholder message")
178/// ```
179/// This returns a `MessageOrTemplate::MessagesPlaceholder` variant that wraps the given string.
180#[macro_export]
181macro_rules! fmt_placeholder {
182    ($placeholder:expr) => {
183        $crate::prompt::MessageOrTemplate::MessagesPlaceholder($placeholder.into())
184    };
185}
186
187pub struct MessageFormatterStruct {
188    items: Vec<MessageOrTemplate>,
189}
190
191impl MessageFormatterStruct {
192    pub fn new() -> Self {
193        Self { items: Vec::new() }
194    }
195
196    pub fn add_message(&mut self, message: Message) {
197        self.items.push(MessageOrTemplate::Message(message));
198    }
199
200    pub fn add_template(&mut self, template: Box<dyn MessageFormatter>) {
201        self.items.push(MessageOrTemplate::Template(template));
202    }
203
204    pub fn add_messages_placeholder(&mut self, placeholder: &str) {
205        self.items.push(MessageOrTemplate::MessagesPlaceholder(
206            placeholder.to_string(),
207        ));
208    }
209
210    fn format(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
211        let mut result: Vec<Message> = Vec::new();
212        for item in &self.items {
213            match item {
214                MessageOrTemplate::Message(msg) => result.push(msg.clone()),
215                MessageOrTemplate::Template(tmpl) => {
216                    result.extend(tmpl.format_messages(input_variables.clone())?)
217                }
218                MessageOrTemplate::MessagesPlaceholder(placeholder) => {
219                    let messages = input_variables[placeholder].clone();
220                    result.extend(Message::messages_from_value(&messages)?);
221                }
222            }
223        }
224        Ok(result)
225    }
226}
227
228impl MessageFormatter for MessageFormatterStruct {
229    fn format_messages(&self, input_variables: PromptArgs) -> Result<Vec<Message>, PromptError> {
230        self.format(input_variables)
231    }
232    fn input_variables(&self) -> Vec<String> {
233        let mut variables = Vec::new();
234        for item in &self.items {
235            match item {
236                MessageOrTemplate::Message(_) => {}
237                MessageOrTemplate::Template(tmpl) => {
238                    variables.extend(tmpl.input_variables());
239                }
240                MessageOrTemplate::MessagesPlaceholder(placeholder) => {
241                    variables.extend(vec![placeholder.clone()]);
242                }
243            }
244        }
245        variables
246    }
247}
248
249impl FormatPrompter for MessageFormatterStruct {
250    fn format_prompt(&self, input_variables: PromptArgs) -> Result<PromptValue, PromptError> {
251        let messages = self.format(input_variables)?;
252        Ok(PromptValue::from_messages(messages))
253    }
254    fn get_input_variables(&self) -> Vec<String> {
255        self.input_variables()
256    }
257}
258
259#[macro_export]
260// A macro for creating a new MessageFormatterStruct with various types of messages.
261///
262///# Example
263/// ```rust,ignore
264/// // Create an AI message prompt template
265/// let ai_message_prompt = AIMessagePromptTemplate::new(
266/// template_fstring!(
267///     "AI response: {content} {test}",
268///     "content",
269///     "test"
270/// ));
271///
272///
273/// let human_msg = Message::new_human_message("Hello from user");
274///
275/// // Use the `message_formatter` macro to construct the formatter.
276/// let formatter = message_formatter![
277///     fmt_message!(human_msg),
278///     fmt_template!(ai_message_prompt),
279///     fmt_placeholder!("history")
280/// ];
281/// ```
282macro_rules! message_formatter {
283($($item:expr),* $(,)?) => {{
284    let mut formatter = $crate::prompt::MessageFormatterStruct::new();
285    $(
286        match $item {
287            $crate::prompt::MessageOrTemplate::Message(msg) => formatter.add_message(msg),
288            $crate::prompt::MessageOrTemplate::Template(tmpl) => formatter.add_template(tmpl),
289            $crate::prompt::MessageOrTemplate::MessagesPlaceholder(placeholder) => formatter.add_messages_placeholder(&placeholder.clone()),
290        }
291    )*
292    formatter
293}};
294}
295
296#[cfg(test)]
297mod tests {
298    use crate::{
299        message_formatter,
300        prompt::{chat::AIMessagePromptTemplate, FormatPrompter},
301        prompt_args,
302        schemas::messages::Message,
303        template_fstring,
304    };
305
306    #[test]
307    fn test_message_formatter_macro() {
308        // Create a human message and system message
309        let human_msg = Message::new_human_message("Hello from user");
310
311        // Create an AI message prompt template
312        let ai_message_prompt = AIMessagePromptTemplate::new(template_fstring!(
313            "AI response: {content} {test}",
314            "content",
315            "test"
316        ));
317
318        // Use the `message_formatter` macro to construct the formatter
319        let formatter = message_formatter![
320            fmt_message!(human_msg),
321            fmt_template!(ai_message_prompt),
322            fmt_placeholder!("history")
323        ];
324
325        // Define input variables for the AI message template
326        let input_variables = prompt_args! {
327            "content" => "This is a test",
328            "test" => "test2",
329            "history" => vec![
330                Message::new_human_message("Placeholder message 1"),
331                Message::new_ai_message("Placeholder message 2"),
332            ],
333
334
335        };
336
337        // Format messages
338        let formatted_messages = formatter
339            .format_prompt(input_variables)
340            .unwrap()
341            .to_chat_messages();
342
343        // Verify the number of messages
344        assert_eq!(formatted_messages.len(), 4);
345
346        // Verify the content of each message
347        assert_eq!(formatted_messages[0].content, "Hello from user");
348        assert_eq!(
349            formatted_messages[1].content,
350            "AI response: This is a test test2"
351        );
352        assert_eq!(formatted_messages[2].content, "Placeholder message 1");
353        assert_eq!(formatted_messages[3].content, "Placeholder message 2");
354    }
355}