agent_chain_core/prompts/
chat.rs

1//! Chat prompt template.
2//!
3//! This module provides chat prompt templates for chat-based models,
4//! mirroring `langchain_core.prompts.chat` in Python.
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use serde::{Deserialize, Serialize};
10
11use crate::error::{Error, Result};
12use crate::messages::{AIMessage, BaseMessage, ChatMessage, HumanMessage, SystemMessage};
13use crate::utils::input::get_colored_text;
14use crate::utils::interactive_env::is_interactive_env;
15
16use super::message::{BaseMessagePromptTemplate, get_msg_title_repr};
17use super::prompt::PromptTemplate;
18use super::string::{PromptTemplateFormat, StringPromptTemplate};
19
20/// Prompt template that assumes variable is already a list of messages.
21///
22/// A placeholder which can be used to pass in a list of messages.
23///
24/// # Example
25///
26/// ```ignore
27/// use agent_chain_core::prompts::MessagesPlaceholder;
28///
29/// let placeholder = MessagesPlaceholder::new("history");
30///
31/// // With optional=true, format_messages can be called with no arguments
32/// let placeholder = MessagesPlaceholder::new("history").optional(true);
33/// ```
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct MessagesPlaceholder {
36    /// Name of variable to use as messages.
37    pub variable_name: String,
38
39    /// If `true`, format_messages can be called with no arguments and will return
40    /// an empty list. If `false` then a named argument with name `variable_name`
41    /// must be passed in, even if the value is an empty list.
42    #[serde(default)]
43    pub optional: bool,
44
45    /// Maximum number of messages to include. If `None`, then will include all.
46    #[serde(default)]
47    pub n_messages: Option<usize>,
48}
49
50impl MessagesPlaceholder {
51    /// Create a new messages placeholder.
52    ///
53    /// # Arguments
54    ///
55    /// * `variable_name` - Name of variable to use as messages.
56    pub fn new(variable_name: impl Into<String>) -> Self {
57        Self {
58            variable_name: variable_name.into(),
59            optional: false,
60            n_messages: None,
61        }
62    }
63
64    /// Set whether this placeholder is optional.
65    pub fn optional(mut self, optional: bool) -> Self {
66        self.optional = optional;
67        self
68    }
69
70    /// Set the maximum number of messages to include.
71    pub fn n_messages(mut self, n: usize) -> Self {
72        self.n_messages = Some(n);
73        self
74    }
75
76    /// Format messages from kwargs.
77    ///
78    /// # Arguments
79    ///
80    /// * `messages` - The messages to format, or None if optional.
81    ///
82    /// # Returns
83    ///
84    /// A list of formatted messages.
85    pub fn format_with_messages(
86        &self,
87        messages: Option<Vec<BaseMessage>>,
88    ) -> Result<Vec<BaseMessage>> {
89        let value = if self.optional {
90            messages.unwrap_or_default()
91        } else {
92            messages.ok_or_else(|| {
93                Error::InvalidConfig(format!(
94                    "Variable '{}' is required but was not provided",
95                    self.variable_name
96                ))
97            })?
98        };
99
100        let result = if let Some(n) = self.n_messages {
101            let len = value.len();
102            if len > n {
103                value.into_iter().skip(len - n).collect()
104            } else {
105                value
106            }
107        } else {
108            value
109        };
110
111        Ok(result)
112    }
113}
114
115impl BaseMessagePromptTemplate for MessagesPlaceholder {
116    fn input_variables(&self) -> Vec<String> {
117        if self.optional {
118            Vec::new()
119        } else {
120            vec![self.variable_name.clone()]
121        }
122    }
123
124    fn format_messages(&self, _kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
125        // Note: In the actual implementation, we would extract messages from kwargs.
126        // Since we're using String values in kwargs, this requires special handling.
127        // The Python version expects the value to be a list of messages.
128        // For now, we return an empty list for optional placeholders.
129        if self.optional {
130            Ok(Vec::new())
131        } else {
132            Err(Error::InvalidConfig(format!(
133                "MessagesPlaceholder '{}' requires messages to be passed via format_with_messages",
134                self.variable_name
135            )))
136        }
137    }
138
139    fn pretty_repr(&self, html: bool) -> String {
140        let var = format!("{{{}}}", self.variable_name);
141        let title = get_msg_title_repr("Messages Placeholder", html);
142        let var_display = if html {
143            get_colored_text(&var, "yellow")
144        } else {
145            var
146        };
147        format!("{}\n\n{}", title, var_display)
148    }
149}
150
151/// Base class for message prompt templates that use a string prompt template.
152pub trait BaseStringMessagePromptTemplate: BaseMessagePromptTemplate {
153    /// Get the underlying string prompt template.
154    fn prompt(&self) -> &PromptTemplate;
155
156    /// Get additional kwargs to pass to the message.
157    fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
158        static EMPTY: std::sync::LazyLock<HashMap<String, serde_json::Value>> =
159            std::sync::LazyLock::new(HashMap::new);
160        &EMPTY
161    }
162
163    /// Format the prompt template into a message.
164    fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage>;
165
166    /// Async format the prompt template into a message.
167    fn aformat(
168        &self,
169        kwargs: &HashMap<String, String>,
170    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<BaseMessage>> + Send + '_>> {
171        let result = self.format(kwargs);
172        Box::pin(async move { result })
173    }
174}
175
176/// Chat message prompt template with a specific role.
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct ChatMessagePromptTemplate {
179    /// The underlying string prompt template.
180    pub prompt: PromptTemplate,
181
182    /// Role of the message.
183    pub role: String,
184
185    /// Additional keyword arguments to pass to the message.
186    #[serde(default)]
187    pub additional_kwargs: HashMap<String, serde_json::Value>,
188}
189
190impl ChatMessagePromptTemplate {
191    /// Create a new chat message prompt template.
192    pub fn new(prompt: PromptTemplate, role: impl Into<String>) -> Self {
193        Self {
194            prompt,
195            role: role.into(),
196            additional_kwargs: HashMap::new(),
197        }
198    }
199
200    /// Create from a template string.
201    pub fn from_template(
202        template: impl Into<String>,
203        role: impl Into<String>,
204        template_format: PromptTemplateFormat,
205    ) -> Result<Self> {
206        let prompt = PromptTemplate::from_template_with_format(template, template_format)?;
207        Ok(Self::new(prompt, role))
208    }
209}
210
211impl BaseMessagePromptTemplate for ChatMessagePromptTemplate {
212    fn input_variables(&self) -> Vec<String> {
213        self.prompt.input_variables.clone()
214    }
215
216    fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
217        let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
218        Ok(vec![BaseMessage::Chat(ChatMessage::new(&self.role, text))])
219    }
220
221    fn pretty_repr(&self, html: bool) -> String {
222        let title = format!("{} Message", self.role);
223        let title = get_msg_title_repr(&title, html);
224        format!("{}\n\n{}", title, self.prompt.pretty_repr(html))
225    }
226}
227
228impl BaseStringMessagePromptTemplate for ChatMessagePromptTemplate {
229    fn prompt(&self) -> &PromptTemplate {
230        &self.prompt
231    }
232
233    fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
234        &self.additional_kwargs
235    }
236
237    fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage> {
238        let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
239        Ok(BaseMessage::Chat(ChatMessage::new(&self.role, text)))
240    }
241}
242
243/// Human message prompt template.
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct HumanMessagePromptTemplate {
246    /// The underlying string prompt template.
247    pub prompt: PromptTemplate,
248
249    /// Additional keyword arguments to pass to the message.
250    #[serde(default)]
251    pub additional_kwargs: HashMap<String, serde_json::Value>,
252}
253
254impl HumanMessagePromptTemplate {
255    /// Create a new human message prompt template.
256    pub fn new(prompt: PromptTemplate) -> Self {
257        Self {
258            prompt,
259            additional_kwargs: HashMap::new(),
260        }
261    }
262
263    /// Create from a template string.
264    pub fn from_template(template: impl Into<String>) -> Result<Self> {
265        Self::from_template_with_format(template, PromptTemplateFormat::FString)
266    }
267
268    /// Create from a template string with a specific format.
269    pub fn from_template_with_format(
270        template: impl Into<String>,
271        template_format: PromptTemplateFormat,
272    ) -> Result<Self> {
273        let prompt = PromptTemplate::from_template_with_format(template, template_format)?;
274        Ok(Self::new(prompt))
275    }
276
277    /// Create from a template file.
278    pub fn from_template_file(template_file: impl AsRef<Path>) -> Result<Self> {
279        let prompt = PromptTemplate::from_file(template_file)?;
280        Ok(Self::new(prompt))
281    }
282}
283
284impl BaseMessagePromptTemplate for HumanMessagePromptTemplate {
285    fn input_variables(&self) -> Vec<String> {
286        self.prompt.input_variables.clone()
287    }
288
289    fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
290        let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
291        Ok(vec![BaseMessage::Human(HumanMessage::new(text))])
292    }
293
294    fn pretty_repr(&self, html: bool) -> String {
295        let title = get_msg_title_repr("Human Message", html);
296        format!("{}\n\n{}", title, self.prompt.pretty_repr(html))
297    }
298}
299
300impl BaseStringMessagePromptTemplate for HumanMessagePromptTemplate {
301    fn prompt(&self) -> &PromptTemplate {
302        &self.prompt
303    }
304
305    fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
306        &self.additional_kwargs
307    }
308
309    fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage> {
310        let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
311        Ok(BaseMessage::Human(HumanMessage::new(text)))
312    }
313}
314
315/// AI message prompt template.
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct AIMessagePromptTemplate {
318    /// The underlying string prompt template.
319    pub prompt: PromptTemplate,
320
321    /// Additional keyword arguments to pass to the message.
322    #[serde(default)]
323    pub additional_kwargs: HashMap<String, serde_json::Value>,
324}
325
326impl AIMessagePromptTemplate {
327    /// Create a new AI message prompt template.
328    pub fn new(prompt: PromptTemplate) -> Self {
329        Self {
330            prompt,
331            additional_kwargs: HashMap::new(),
332        }
333    }
334
335    /// Create from a template string.
336    pub fn from_template(template: impl Into<String>) -> Result<Self> {
337        Self::from_template_with_format(template, PromptTemplateFormat::FString)
338    }
339
340    /// Create from a template string with a specific format.
341    pub fn from_template_with_format(
342        template: impl Into<String>,
343        template_format: PromptTemplateFormat,
344    ) -> Result<Self> {
345        let prompt = PromptTemplate::from_template_with_format(template, template_format)?;
346        Ok(Self::new(prompt))
347    }
348
349    /// Create from a template file.
350    pub fn from_template_file(template_file: impl AsRef<Path>) -> Result<Self> {
351        let prompt = PromptTemplate::from_file(template_file)?;
352        Ok(Self::new(prompt))
353    }
354}
355
356impl BaseMessagePromptTemplate for AIMessagePromptTemplate {
357    fn input_variables(&self) -> Vec<String> {
358        self.prompt.input_variables.clone()
359    }
360
361    fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
362        let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
363        Ok(vec![BaseMessage::AI(AIMessage::new(text))])
364    }
365
366    fn pretty_repr(&self, html: bool) -> String {
367        let title = get_msg_title_repr("AI Message", html);
368        format!("{}\n\n{}", title, self.prompt.pretty_repr(html))
369    }
370}
371
372impl BaseStringMessagePromptTemplate for AIMessagePromptTemplate {
373    fn prompt(&self) -> &PromptTemplate {
374        &self.prompt
375    }
376
377    fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
378        &self.additional_kwargs
379    }
380
381    fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage> {
382        let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
383        Ok(BaseMessage::AI(AIMessage::new(text)))
384    }
385}
386
387/// System message prompt template.
388#[derive(Debug, Clone, Serialize, Deserialize)]
389pub struct SystemMessagePromptTemplate {
390    /// The underlying string prompt template.
391    pub prompt: PromptTemplate,
392
393    /// Additional keyword arguments to pass to the message.
394    #[serde(default)]
395    pub additional_kwargs: HashMap<String, serde_json::Value>,
396}
397
398impl SystemMessagePromptTemplate {
399    /// Create a new system message prompt template.
400    pub fn new(prompt: PromptTemplate) -> Self {
401        Self {
402            prompt,
403            additional_kwargs: HashMap::new(),
404        }
405    }
406
407    /// Create from a template string.
408    pub fn from_template(template: impl Into<String>) -> Result<Self> {
409        Self::from_template_with_format(template, PromptTemplateFormat::FString)
410    }
411
412    /// Create from a template string with a specific format.
413    pub fn from_template_with_format(
414        template: impl Into<String>,
415        template_format: PromptTemplateFormat,
416    ) -> Result<Self> {
417        let prompt = PromptTemplate::from_template_with_format(template, template_format)?;
418        Ok(Self::new(prompt))
419    }
420
421    /// Create from a template file.
422    pub fn from_template_file(template_file: impl AsRef<Path>) -> Result<Self> {
423        let prompt = PromptTemplate::from_file(template_file)?;
424        Ok(Self::new(prompt))
425    }
426}
427
428impl BaseMessagePromptTemplate for SystemMessagePromptTemplate {
429    fn input_variables(&self) -> Vec<String> {
430        self.prompt.input_variables.clone()
431    }
432
433    fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
434        let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
435        Ok(vec![BaseMessage::System(SystemMessage::new(text))])
436    }
437
438    fn pretty_repr(&self, html: bool) -> String {
439        let title = get_msg_title_repr("System Message", html);
440        format!("{}\n\n{}", title, self.prompt.pretty_repr(html))
441    }
442}
443
444impl BaseStringMessagePromptTemplate for SystemMessagePromptTemplate {
445    fn prompt(&self) -> &PromptTemplate {
446        &self.prompt
447    }
448
449    fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
450        &self.additional_kwargs
451    }
452
453    fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage> {
454        let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
455        Ok(BaseMessage::System(SystemMessage::new(text)))
456    }
457}
458
459/// A message-like type that can be part of a chat prompt.
460#[derive(Clone)]
461pub enum MessageLike {
462    /// A base message.
463    Message(Box<BaseMessage>),
464    /// A message prompt template.
465    Template(Box<dyn MessageLikeClone + Send + Sync>),
466    /// A messages placeholder.
467    Placeholder(MessagesPlaceholder),
468}
469
470impl std::fmt::Debug for MessageLike {
471    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472        match self {
473            MessageLike::Message(m) => f.debug_tuple("Message").field(m).finish(),
474            MessageLike::Template(_) => f.debug_tuple("Template").field(&"<template>").finish(),
475            MessageLike::Placeholder(p) => f.debug_tuple("Placeholder").field(p).finish(),
476        }
477    }
478}
479
480/// Helper trait for cloning message-like templates.
481pub trait MessageLikeClone: BaseMessagePromptTemplate {
482    fn clone_box(&self) -> Box<dyn MessageLikeClone + Send + Sync>;
483}
484
485impl<T> MessageLikeClone for T
486where
487    T: BaseMessagePromptTemplate + Clone + Send + Sync + 'static,
488{
489    fn clone_box(&self) -> Box<dyn MessageLikeClone + Send + Sync> {
490        Box::new(self.clone())
491    }
492}
493
494impl Clone for Box<dyn MessageLikeClone + Send + Sync> {
495    fn clone(&self) -> Self {
496        self.clone_box()
497    }
498}
499
500/// Representation of a message-like that can be converted to MessageLike.
501#[derive(Debug, Clone)]
502pub enum MessageLikeRepresentation {
503    /// A (role, content) tuple.
504    Tuple(String, String),
505    /// A string (shorthand for human message).
506    String(String),
507    /// A base message.
508    Message(Box<BaseMessage>),
509    /// A placeholder configuration.
510    Placeholder {
511        variable_name: String,
512        optional: bool,
513    },
514}
515
516impl MessageLikeRepresentation {
517    /// Create a tuple representation.
518    pub fn tuple(role: impl Into<String>, content: impl Into<String>) -> Self {
519        Self::Tuple(role.into(), content.into())
520    }
521
522    /// Create a string representation (human message).
523    pub fn string(content: impl Into<String>) -> Self {
524        Self::String(content.into())
525    }
526
527    /// Create a placeholder representation.
528    pub fn placeholder(variable_name: impl Into<String>, optional: bool) -> Self {
529        Self::Placeholder {
530            variable_name: variable_name.into(),
531            optional,
532        }
533    }
534}
535
536/// Base trait for chat prompt templates.
537pub trait BaseChatPromptTemplate: Send + Sync {
538    /// Get the input variables for this template.
539    fn input_variables(&self) -> &[String];
540
541    /// Get the optional variables for this template.
542    fn optional_variables(&self) -> &[String] {
543        &[]
544    }
545
546    /// Get partial variables for this template.
547    fn partial_variables(&self) -> &HashMap<String, String> {
548        static EMPTY: std::sync::LazyLock<HashMap<String, String>> =
549            std::sync::LazyLock::new(HashMap::new);
550        &EMPTY
551    }
552
553    /// Format kwargs into a list of messages.
554    fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>>;
555
556    /// Async format kwargs into a list of messages.
557    fn aformat_messages(
558        &self,
559        kwargs: &HashMap<String, String>,
560    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<BaseMessage>>> + Send + '_>>
561    {
562        let result = self.format_messages(kwargs);
563        Box::pin(async move { result })
564    }
565
566    /// Format the chat template into a string.
567    fn format(&self, kwargs: &HashMap<String, String>) -> Result<String> {
568        let messages = self.format_messages(kwargs)?;
569        Ok(messages
570            .iter()
571            .map(|m| format!("{}: {}", m.message_type(), m.content()))
572            .collect::<Vec<_>>()
573            .join("\n"))
574    }
575
576    /// Get a pretty representation of the template.
577    fn pretty_repr(&self, html: bool) -> String;
578
579    /// Print a human-readable representation.
580    fn pretty_print(&self) {
581        println!("{}", self.pretty_repr(is_interactive_env()));
582    }
583}
584
585/// Chat prompt template for chat models.
586///
587/// Use to create flexible templated prompts for chat models.
588///
589/// # Example
590///
591/// ```ignore
592/// use agent_chain_core::prompts::ChatPromptTemplate;
593///
594/// let template = ChatPromptTemplate::from_messages(&[
595///     ("system", "You are a helpful AI bot. Your name is {name}."),
596///     ("human", "Hello, how are you doing?"),
597///     ("ai", "I'm doing well, thanks!"),
598///     ("human", "{user_input}"),
599/// ]).unwrap();
600///
601/// let result = template.invoke(&[
602///     ("name", "Bob"),
603///     ("user_input", "What is your name?"),
604/// ].into_iter().collect());
605/// ```
606#[derive(Debug, Clone, Default)]
607pub struct ChatPromptTemplate {
608    /// List of messages or message templates.
609    messages: Vec<ChatPromptMessage>,
610
611    /// Input variables.
612    input_variables: Vec<String>,
613
614    /// Optional variables.
615    optional_variables: Vec<String>,
616
617    /// Partial variables.
618    partial_variables: HashMap<String, String>,
619
620    /// Whether to validate the template.
621    validate_template: bool,
622
623    /// The template format to use.
624    template_format: PromptTemplateFormat,
625}
626
627/// A message in a chat prompt template.
628#[derive(Debug, Clone)]
629pub enum ChatPromptMessage {
630    /// A static message.
631    Message(BaseMessage),
632    /// A human message template.
633    Human(HumanMessagePromptTemplate),
634    /// An AI message template.
635    AI(AIMessagePromptTemplate),
636    /// A system message template.
637    System(SystemMessagePromptTemplate),
638    /// A chat message template with role.
639    Chat(ChatMessagePromptTemplate),
640    /// A messages placeholder.
641    Placeholder(MessagesPlaceholder),
642}
643
644impl ChatPromptMessage {
645    /// Get the input variables for this message.
646    fn input_variables(&self) -> Vec<String> {
647        match self {
648            ChatPromptMessage::Message(_) => Vec::new(),
649            ChatPromptMessage::Human(t) => t.input_variables(),
650            ChatPromptMessage::AI(t) => t.input_variables(),
651            ChatPromptMessage::System(t) => t.input_variables(),
652            ChatPromptMessage::Chat(t) => t.input_variables(),
653            ChatPromptMessage::Placeholder(p) => p.input_variables(),
654        }
655    }
656
657    /// Format this message.
658    fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
659        match self {
660            ChatPromptMessage::Message(m) => Ok(vec![m.clone()]),
661            ChatPromptMessage::Human(t) => t.format_messages(kwargs),
662            ChatPromptMessage::AI(t) => t.format_messages(kwargs),
663            ChatPromptMessage::System(t) => t.format_messages(kwargs),
664            ChatPromptMessage::Chat(t) => t.format_messages(kwargs),
665            ChatPromptMessage::Placeholder(p) => p.format_messages(kwargs),
666        }
667    }
668
669    /// Get a pretty representation.
670    fn pretty_repr(&self, html: bool) -> String {
671        match self {
672            ChatPromptMessage::Message(m) => m.pretty_repr(html),
673            ChatPromptMessage::Human(t) => t.pretty_repr(html),
674            ChatPromptMessage::AI(t) => t.pretty_repr(html),
675            ChatPromptMessage::System(t) => t.pretty_repr(html),
676            ChatPromptMessage::Chat(t) => t.pretty_repr(html),
677            ChatPromptMessage::Placeholder(p) => p.pretty_repr(html),
678        }
679    }
680}
681
682impl ChatPromptTemplate {
683    /// Create a new empty chat prompt template.
684    pub fn new() -> Self {
685        Self::default()
686    }
687
688    /// Create a chat prompt template from a variety of message formats.
689    ///
690    /// # Arguments
691    ///
692    /// * `messages` - A slice of (role, template) tuples or strings.
693    ///
694    /// # Returns
695    ///
696    /// A new ChatPromptTemplate.
697    ///
698    /// # Example
699    ///
700    /// ```ignore
701    /// let template = ChatPromptTemplate::from_messages(&[
702    ///     ("system", "You are a helpful assistant."),
703    ///     ("human", "{question}"),
704    /// ]).unwrap();
705    /// ```
706    pub fn from_messages(messages: &[(&str, &str)]) -> Result<Self> {
707        Self::from_messages_with_format(messages, PromptTemplateFormat::FString)
708    }
709
710    /// Create a chat prompt template with a specific template format.
711    pub fn from_messages_with_format(
712        messages: &[(&str, &str)],
713        template_format: PromptTemplateFormat,
714    ) -> Result<Self> {
715        let mut template = Self::new();
716        template.template_format = template_format;
717
718        for (role, content) in messages {
719            let msg = create_template_from_message_type(role, content, template_format)?;
720            template.messages.push(msg);
721        }
722
723        // Infer input variables
724        let mut input_vars = std::collections::HashSet::new();
725        let mut optional_vars = std::collections::HashSet::new();
726
727        for msg in &template.messages {
728            match msg {
729                ChatPromptMessage::Placeholder(p) if p.optional => {
730                    optional_vars.insert(p.variable_name.clone());
731                }
732                _ => {
733                    for var in msg.input_variables() {
734                        input_vars.insert(var);
735                    }
736                }
737            }
738        }
739
740        template.input_variables = input_vars.into_iter().collect();
741        template.input_variables.sort();
742
743        template.optional_variables = optional_vars.into_iter().collect();
744        template.optional_variables.sort();
745
746        Ok(template)
747    }
748
749    /// Create a chat prompt template from a single template string.
750    ///
751    /// Creates a chat template consisting of a single message assumed to be from the human.
752    pub fn from_template(template: &str) -> Result<Self> {
753        let prompt_template = PromptTemplate::from_template(template)?;
754        let message = HumanMessagePromptTemplate::new(prompt_template);
755
756        Ok(Self {
757            messages: vec![ChatPromptMessage::Human(message.clone())],
758            input_variables: message.input_variables(),
759            optional_variables: Vec::new(),
760            partial_variables: HashMap::new(),
761            validate_template: false,
762            template_format: PromptTemplateFormat::FString,
763        })
764    }
765
766    /// Add a message to the template.
767    pub fn append(&mut self, message: ChatPromptMessage) {
768        for var in message.input_variables() {
769            if !self.input_variables.contains(&var) {
770                self.input_variables.push(var);
771            }
772        }
773        self.messages.push(message);
774    }
775
776    /// Add a human message template.
777    pub fn append_human(&mut self, template: &str) -> Result<()> {
778        let msg =
779            HumanMessagePromptTemplate::from_template_with_format(template, self.template_format)?;
780        self.append(ChatPromptMessage::Human(msg));
781        Ok(())
782    }
783
784    /// Add an AI message template.
785    pub fn append_ai(&mut self, template: &str) -> Result<()> {
786        let msg =
787            AIMessagePromptTemplate::from_template_with_format(template, self.template_format)?;
788        self.append(ChatPromptMessage::AI(msg));
789        Ok(())
790    }
791
792    /// Add a system message template.
793    pub fn append_system(&mut self, template: &str) -> Result<()> {
794        let msg =
795            SystemMessagePromptTemplate::from_template_with_format(template, self.template_format)?;
796        self.append(ChatPromptMessage::System(msg));
797        Ok(())
798    }
799
800    /// Add a messages placeholder.
801    pub fn append_placeholder(&mut self, variable_name: &str, optional: bool) {
802        let placeholder = MessagesPlaceholder::new(variable_name).optional(optional);
803        if !optional && !self.input_variables.contains(&variable_name.to_string()) {
804            self.input_variables.push(variable_name.to_string());
805        }
806        if optional {
807            self.optional_variables.push(variable_name.to_string());
808        }
809        self.messages
810            .push(ChatPromptMessage::Placeholder(placeholder));
811    }
812
813    /// Get a partial of the template with some variables filled in.
814    pub fn partial(&self, kwargs: HashMap<String, String>) -> Self {
815        let new_vars: Vec<_> = self
816            .input_variables
817            .iter()
818            .filter(|v| !kwargs.contains_key(*v))
819            .cloned()
820            .collect();
821
822        let mut new_partials = self.partial_variables.clone();
823        new_partials.extend(kwargs);
824
825        Self {
826            messages: self.messages.clone(),
827            input_variables: new_vars,
828            optional_variables: self.optional_variables.clone(),
829            partial_variables: new_partials,
830            validate_template: self.validate_template,
831            template_format: self.template_format,
832        }
833    }
834
835    /// Get the number of messages in the template.
836    pub fn len(&self) -> usize {
837        self.messages.len()
838    }
839
840    /// Check if the template is empty.
841    pub fn is_empty(&self) -> bool {
842        self.messages.is_empty()
843    }
844
845    /// Get a message by index.
846    pub fn get(&self, index: usize) -> Option<&ChatPromptMessage> {
847        self.messages.get(index)
848    }
849
850    /// Merge partial and user variables.
851    fn merge_partial_and_user_variables(
852        &self,
853        kwargs: &HashMap<String, String>,
854    ) -> HashMap<String, String> {
855        let mut merged = self.partial_variables.clone();
856        merged.extend(kwargs.clone());
857        merged
858    }
859}
860
861impl BaseChatPromptTemplate for ChatPromptTemplate {
862    fn input_variables(&self) -> &[String] {
863        &self.input_variables
864    }
865
866    fn optional_variables(&self) -> &[String] {
867        &self.optional_variables
868    }
869
870    fn partial_variables(&self) -> &HashMap<String, String> {
871        &self.partial_variables
872    }
873
874    fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
875        let merged = self.merge_partial_and_user_variables(kwargs);
876        let mut result = Vec::new();
877
878        for message in &self.messages {
879            let formatted = message.format_messages(&merged)?;
880            result.extend(formatted);
881        }
882
883        Ok(result)
884    }
885
886    fn pretty_repr(&self, html: bool) -> String {
887        self.messages
888            .iter()
889            .map(|m| m.pretty_repr(html))
890            .collect::<Vec<_>>()
891            .join("\n\n")
892    }
893}
894
895/// Create a message prompt template from a message type string.
896fn create_template_from_message_type(
897    message_type: &str,
898    template: &str,
899    template_format: PromptTemplateFormat,
900) -> Result<ChatPromptMessage> {
901    match message_type {
902        "human" | "user" => {
903            let t =
904                HumanMessagePromptTemplate::from_template_with_format(template, template_format)?;
905            Ok(ChatPromptMessage::Human(t))
906        }
907        "ai" | "assistant" => {
908            let t = AIMessagePromptTemplate::from_template_with_format(template, template_format)?;
909            Ok(ChatPromptMessage::AI(t))
910        }
911        "system" => {
912            let t =
913                SystemMessagePromptTemplate::from_template_with_format(template, template_format)?;
914            Ok(ChatPromptMessage::System(t))
915        }
916        "placeholder" => {
917            // Parse placeholder: "{variable_name}"
918            if !template.starts_with('{') || !template.ends_with('}') {
919                return Err(Error::InvalidConfig(format!(
920                    "Invalid placeholder template: {}. Expected a variable name surrounded by curly braces.",
921                    template
922                )));
923            }
924            let var_name = &template[1..template.len() - 1];
925            let placeholder = MessagesPlaceholder::new(var_name).optional(true);
926            Ok(ChatPromptMessage::Placeholder(placeholder))
927        }
928        _ => Err(Error::InvalidConfig(format!(
929            "Unexpected message type: {}. Use one of 'human', 'user', 'ai', 'assistant', 'system', or 'placeholder'.",
930            message_type
931        ))),
932    }
933}
934
935impl std::ops::Add for ChatPromptTemplate {
936    type Output = ChatPromptTemplate;
937
938    fn add(self, other: Self) -> Self::Output {
939        let mut messages = self.messages;
940        messages.extend(other.messages);
941
942        let mut input_vars: std::collections::HashSet<_> =
943            self.input_variables.into_iter().collect();
944        input_vars.extend(other.input_variables);
945
946        let mut partial_vars = self.partial_variables;
947        partial_vars.extend(other.partial_variables);
948
949        ChatPromptTemplate {
950            messages,
951            input_variables: input_vars.into_iter().collect(),
952            optional_variables: Vec::new(),
953            partial_variables: partial_vars,
954            validate_template: self.validate_template && other.validate_template,
955            template_format: self.template_format,
956        }
957    }
958}
959
960#[cfg(test)]
961mod tests {
962    use super::*;
963
964    #[test]
965    fn test_messages_placeholder() {
966        let placeholder = MessagesPlaceholder::new("history");
967        assert_eq!(placeholder.input_variables(), vec!["history"]);
968
969        let optional_placeholder = MessagesPlaceholder::new("history").optional(true);
970        assert!(optional_placeholder.input_variables().is_empty());
971    }
972
973    #[test]
974    fn test_human_message_template() {
975        let template = HumanMessagePromptTemplate::from_template("Hello, {name}!").unwrap();
976
977        let mut kwargs = HashMap::new();
978        kwargs.insert("name".to_string(), "World".to_string());
979
980        let messages = template.format_messages(&kwargs).unwrap();
981        assert_eq!(messages.len(), 1);
982        assert_eq!(messages[0].content(), "Hello, World!");
983    }
984
985    #[test]
986    fn test_system_message_template() {
987        let template = SystemMessagePromptTemplate::from_template("You are {role}").unwrap();
988
989        let mut kwargs = HashMap::new();
990        kwargs.insert("role".to_string(), "an assistant".to_string());
991
992        let messages = template.format_messages(&kwargs).unwrap();
993        assert_eq!(messages.len(), 1);
994        assert_eq!(messages[0].content(), "You are an assistant");
995    }
996
997    #[test]
998    fn test_chat_prompt_template() {
999        let template = ChatPromptTemplate::from_messages(&[
1000            ("system", "You are a helpful assistant."),
1001            ("human", "{question}"),
1002        ])
1003        .unwrap();
1004
1005        assert_eq!(template.input_variables(), &["question"]);
1006
1007        let mut kwargs = HashMap::new();
1008        kwargs.insert("question".to_string(), "Hello!".to_string());
1009
1010        let messages = template.format_messages(&kwargs).unwrap();
1011        assert_eq!(messages.len(), 2);
1012        assert_eq!(messages[0].content(), "You are a helpful assistant.");
1013        assert_eq!(messages[1].content(), "Hello!");
1014    }
1015
1016    #[test]
1017    fn test_chat_prompt_template_from_template() {
1018        let template = ChatPromptTemplate::from_template("Hello, {name}!").unwrap();
1019
1020        let mut kwargs = HashMap::new();
1021        kwargs.insert("name".to_string(), "World".to_string());
1022
1023        let messages = template.format_messages(&kwargs).unwrap();
1024        assert_eq!(messages.len(), 1);
1025        assert_eq!(messages[0].content(), "Hello, World!");
1026    }
1027
1028    #[test]
1029    fn test_chat_prompt_add() {
1030        let template1 =
1031            ChatPromptTemplate::from_messages(&[("system", "You are a helpful assistant.")])
1032                .unwrap();
1033
1034        let template2 = ChatPromptTemplate::from_messages(&[("human", "{question}")]).unwrap();
1035
1036        let combined = template1 + template2;
1037
1038        let mut kwargs = HashMap::new();
1039        kwargs.insert("question".to_string(), "Hello!".to_string());
1040
1041        let messages = combined.format_messages(&kwargs).unwrap();
1042        assert_eq!(messages.len(), 2);
1043    }
1044
1045    #[test]
1046    fn test_partial() {
1047        let template = ChatPromptTemplate::from_messages(&[
1048            ("system", "You are {role}."),
1049            ("human", "{question}"),
1050        ])
1051        .unwrap();
1052
1053        let mut partial_vars = HashMap::new();
1054        partial_vars.insert("role".to_string(), "an assistant".to_string());
1055
1056        let partial = template.partial(partial_vars);
1057        assert_eq!(partial.input_variables(), &["question"]);
1058
1059        let mut kwargs = HashMap::new();
1060        kwargs.insert("question".to_string(), "Hello!".to_string());
1061
1062        let messages = partial.format_messages(&kwargs).unwrap();
1063        assert_eq!(messages.len(), 2);
1064        assert_eq!(messages[0].content(), "You are an assistant.");
1065    }
1066}