1use std::collections::HashMap;
7use std::path::Path;
8
9use serde::{Deserialize, Serialize};
10
11use crate::error::{Error, Result};
12use crate::messages::{AIMessage, BaseMessage, ChatMessage, HumanMessage, SystemMessage};
13use crate::utils::input::get_colored_text;
14use crate::utils::interactive_env::is_interactive_env;
15
16use super::message::{BaseMessagePromptTemplate, get_msg_title_repr};
17use super::prompt::PromptTemplate;
18use super::string::{PromptTemplateFormat, StringPromptTemplate};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct MessagesPlaceholder {
36 pub variable_name: String,
38
39 #[serde(default)]
43 pub optional: bool,
44
45 #[serde(default)]
47 pub n_messages: Option<usize>,
48}
49
50impl MessagesPlaceholder {
51 pub fn new(variable_name: impl Into<String>) -> Self {
57 Self {
58 variable_name: variable_name.into(),
59 optional: false,
60 n_messages: None,
61 }
62 }
63
64 pub fn optional(mut self, optional: bool) -> Self {
66 self.optional = optional;
67 self
68 }
69
70 pub fn n_messages(mut self, n: usize) -> Self {
72 self.n_messages = Some(n);
73 self
74 }
75
76 pub fn format_with_messages(
86 &self,
87 messages: Option<Vec<BaseMessage>>,
88 ) -> Result<Vec<BaseMessage>> {
89 let value = if self.optional {
90 messages.unwrap_or_default()
91 } else {
92 messages.ok_or_else(|| {
93 Error::InvalidConfig(format!(
94 "Variable '{}' is required but was not provided",
95 self.variable_name
96 ))
97 })?
98 };
99
100 let result = if let Some(n) = self.n_messages {
101 let len = value.len();
102 if len > n {
103 value.into_iter().skip(len - n).collect()
104 } else {
105 value
106 }
107 } else {
108 value
109 };
110
111 Ok(result)
112 }
113}
114
115impl BaseMessagePromptTemplate for MessagesPlaceholder {
116 fn input_variables(&self) -> Vec<String> {
117 if self.optional {
118 Vec::new()
119 } else {
120 vec![self.variable_name.clone()]
121 }
122 }
123
124 fn format_messages(&self, _kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
125 if self.optional {
130 Ok(Vec::new())
131 } else {
132 Err(Error::InvalidConfig(format!(
133 "MessagesPlaceholder '{}' requires messages to be passed via format_with_messages",
134 self.variable_name
135 )))
136 }
137 }
138
139 fn pretty_repr(&self, html: bool) -> String {
140 let var = format!("{{{}}}", self.variable_name);
141 let title = get_msg_title_repr("Messages Placeholder", html);
142 let var_display = if html {
143 get_colored_text(&var, "yellow")
144 } else {
145 var
146 };
147 format!("{}\n\n{}", title, var_display)
148 }
149}
150
151pub trait BaseStringMessagePromptTemplate: BaseMessagePromptTemplate {
153 fn prompt(&self) -> &PromptTemplate;
155
156 fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
158 static EMPTY: std::sync::LazyLock<HashMap<String, serde_json::Value>> =
159 std::sync::LazyLock::new(HashMap::new);
160 &EMPTY
161 }
162
163 fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage>;
165
166 fn aformat(
168 &self,
169 kwargs: &HashMap<String, String>,
170 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<BaseMessage>> + Send + '_>> {
171 let result = self.format(kwargs);
172 Box::pin(async move { result })
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct ChatMessagePromptTemplate {
179 pub prompt: PromptTemplate,
181
182 pub role: String,
184
185 #[serde(default)]
187 pub additional_kwargs: HashMap<String, serde_json::Value>,
188}
189
190impl ChatMessagePromptTemplate {
191 pub fn new(prompt: PromptTemplate, role: impl Into<String>) -> Self {
193 Self {
194 prompt,
195 role: role.into(),
196 additional_kwargs: HashMap::new(),
197 }
198 }
199
200 pub fn from_template(
202 template: impl Into<String>,
203 role: impl Into<String>,
204 template_format: PromptTemplateFormat,
205 ) -> Result<Self> {
206 let prompt = PromptTemplate::from_template_with_format(template, template_format)?;
207 Ok(Self::new(prompt, role))
208 }
209}
210
211impl BaseMessagePromptTemplate for ChatMessagePromptTemplate {
212 fn input_variables(&self) -> Vec<String> {
213 self.prompt.input_variables.clone()
214 }
215
216 fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
217 let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
218 Ok(vec![BaseMessage::Chat(ChatMessage::new(&self.role, text))])
219 }
220
221 fn pretty_repr(&self, html: bool) -> String {
222 let title = format!("{} Message", self.role);
223 let title = get_msg_title_repr(&title, html);
224 format!("{}\n\n{}", title, self.prompt.pretty_repr(html))
225 }
226}
227
228impl BaseStringMessagePromptTemplate for ChatMessagePromptTemplate {
229 fn prompt(&self) -> &PromptTemplate {
230 &self.prompt
231 }
232
233 fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
234 &self.additional_kwargs
235 }
236
237 fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage> {
238 let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
239 Ok(BaseMessage::Chat(ChatMessage::new(&self.role, text)))
240 }
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct HumanMessagePromptTemplate {
246 pub prompt: PromptTemplate,
248
249 #[serde(default)]
251 pub additional_kwargs: HashMap<String, serde_json::Value>,
252}
253
254impl HumanMessagePromptTemplate {
255 pub fn new(prompt: PromptTemplate) -> Self {
257 Self {
258 prompt,
259 additional_kwargs: HashMap::new(),
260 }
261 }
262
263 pub fn from_template(template: impl Into<String>) -> Result<Self> {
265 Self::from_template_with_format(template, PromptTemplateFormat::FString)
266 }
267
268 pub fn from_template_with_format(
270 template: impl Into<String>,
271 template_format: PromptTemplateFormat,
272 ) -> Result<Self> {
273 let prompt = PromptTemplate::from_template_with_format(template, template_format)?;
274 Ok(Self::new(prompt))
275 }
276
277 pub fn from_template_file(template_file: impl AsRef<Path>) -> Result<Self> {
279 let prompt = PromptTemplate::from_file(template_file)?;
280 Ok(Self::new(prompt))
281 }
282}
283
284impl BaseMessagePromptTemplate for HumanMessagePromptTemplate {
285 fn input_variables(&self) -> Vec<String> {
286 self.prompt.input_variables.clone()
287 }
288
289 fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
290 let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
291 Ok(vec![BaseMessage::Human(HumanMessage::new(text))])
292 }
293
294 fn pretty_repr(&self, html: bool) -> String {
295 let title = get_msg_title_repr("Human Message", html);
296 format!("{}\n\n{}", title, self.prompt.pretty_repr(html))
297 }
298}
299
300impl BaseStringMessagePromptTemplate for HumanMessagePromptTemplate {
301 fn prompt(&self) -> &PromptTemplate {
302 &self.prompt
303 }
304
305 fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
306 &self.additional_kwargs
307 }
308
309 fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage> {
310 let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
311 Ok(BaseMessage::Human(HumanMessage::new(text)))
312 }
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct AIMessagePromptTemplate {
318 pub prompt: PromptTemplate,
320
321 #[serde(default)]
323 pub additional_kwargs: HashMap<String, serde_json::Value>,
324}
325
326impl AIMessagePromptTemplate {
327 pub fn new(prompt: PromptTemplate) -> Self {
329 Self {
330 prompt,
331 additional_kwargs: HashMap::new(),
332 }
333 }
334
335 pub fn from_template(template: impl Into<String>) -> Result<Self> {
337 Self::from_template_with_format(template, PromptTemplateFormat::FString)
338 }
339
340 pub fn from_template_with_format(
342 template: impl Into<String>,
343 template_format: PromptTemplateFormat,
344 ) -> Result<Self> {
345 let prompt = PromptTemplate::from_template_with_format(template, template_format)?;
346 Ok(Self::new(prompt))
347 }
348
349 pub fn from_template_file(template_file: impl AsRef<Path>) -> Result<Self> {
351 let prompt = PromptTemplate::from_file(template_file)?;
352 Ok(Self::new(prompt))
353 }
354}
355
356impl BaseMessagePromptTemplate for AIMessagePromptTemplate {
357 fn input_variables(&self) -> Vec<String> {
358 self.prompt.input_variables.clone()
359 }
360
361 fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
362 let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
363 Ok(vec![BaseMessage::AI(AIMessage::new(text))])
364 }
365
366 fn pretty_repr(&self, html: bool) -> String {
367 let title = get_msg_title_repr("AI Message", html);
368 format!("{}\n\n{}", title, self.prompt.pretty_repr(html))
369 }
370}
371
372impl BaseStringMessagePromptTemplate for AIMessagePromptTemplate {
373 fn prompt(&self) -> &PromptTemplate {
374 &self.prompt
375 }
376
377 fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
378 &self.additional_kwargs
379 }
380
381 fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage> {
382 let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
383 Ok(BaseMessage::AI(AIMessage::new(text)))
384 }
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
389pub struct SystemMessagePromptTemplate {
390 pub prompt: PromptTemplate,
392
393 #[serde(default)]
395 pub additional_kwargs: HashMap<String, serde_json::Value>,
396}
397
398impl SystemMessagePromptTemplate {
399 pub fn new(prompt: PromptTemplate) -> Self {
401 Self {
402 prompt,
403 additional_kwargs: HashMap::new(),
404 }
405 }
406
407 pub fn from_template(template: impl Into<String>) -> Result<Self> {
409 Self::from_template_with_format(template, PromptTemplateFormat::FString)
410 }
411
412 pub fn from_template_with_format(
414 template: impl Into<String>,
415 template_format: PromptTemplateFormat,
416 ) -> Result<Self> {
417 let prompt = PromptTemplate::from_template_with_format(template, template_format)?;
418 Ok(Self::new(prompt))
419 }
420
421 pub fn from_template_file(template_file: impl AsRef<Path>) -> Result<Self> {
423 let prompt = PromptTemplate::from_file(template_file)?;
424 Ok(Self::new(prompt))
425 }
426}
427
428impl BaseMessagePromptTemplate for SystemMessagePromptTemplate {
429 fn input_variables(&self) -> Vec<String> {
430 self.prompt.input_variables.clone()
431 }
432
433 fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
434 let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
435 Ok(vec![BaseMessage::System(SystemMessage::new(text))])
436 }
437
438 fn pretty_repr(&self, html: bool) -> String {
439 let title = get_msg_title_repr("System Message", html);
440 format!("{}\n\n{}", title, self.prompt.pretty_repr(html))
441 }
442}
443
444impl BaseStringMessagePromptTemplate for SystemMessagePromptTemplate {
445 fn prompt(&self) -> &PromptTemplate {
446 &self.prompt
447 }
448
449 fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
450 &self.additional_kwargs
451 }
452
453 fn format(&self, kwargs: &HashMap<String, String>) -> Result<BaseMessage> {
454 let text = StringPromptTemplate::format(&self.prompt, kwargs)?;
455 Ok(BaseMessage::System(SystemMessage::new(text)))
456 }
457}
458
459#[derive(Clone)]
461pub enum MessageLike {
462 Message(Box<BaseMessage>),
464 Template(Box<dyn MessageLikeClone + Send + Sync>),
466 Placeholder(MessagesPlaceholder),
468}
469
470impl std::fmt::Debug for MessageLike {
471 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 match self {
473 MessageLike::Message(m) => f.debug_tuple("Message").field(m).finish(),
474 MessageLike::Template(_) => f.debug_tuple("Template").field(&"<template>").finish(),
475 MessageLike::Placeholder(p) => f.debug_tuple("Placeholder").field(p).finish(),
476 }
477 }
478}
479
480pub trait MessageLikeClone: BaseMessagePromptTemplate {
482 fn clone_box(&self) -> Box<dyn MessageLikeClone + Send + Sync>;
483}
484
485impl<T> MessageLikeClone for T
486where
487 T: BaseMessagePromptTemplate + Clone + Send + Sync + 'static,
488{
489 fn clone_box(&self) -> Box<dyn MessageLikeClone + Send + Sync> {
490 Box::new(self.clone())
491 }
492}
493
494impl Clone for Box<dyn MessageLikeClone + Send + Sync> {
495 fn clone(&self) -> Self {
496 self.clone_box()
497 }
498}
499
500#[derive(Debug, Clone)]
502pub enum MessageLikeRepresentation {
503 Tuple(String, String),
505 String(String),
507 Message(Box<BaseMessage>),
509 Placeholder {
511 variable_name: String,
512 optional: bool,
513 },
514}
515
516impl MessageLikeRepresentation {
517 pub fn tuple(role: impl Into<String>, content: impl Into<String>) -> Self {
519 Self::Tuple(role.into(), content.into())
520 }
521
522 pub fn string(content: impl Into<String>) -> Self {
524 Self::String(content.into())
525 }
526
527 pub fn placeholder(variable_name: impl Into<String>, optional: bool) -> Self {
529 Self::Placeholder {
530 variable_name: variable_name.into(),
531 optional,
532 }
533 }
534}
535
536pub trait BaseChatPromptTemplate: Send + Sync {
538 fn input_variables(&self) -> &[String];
540
541 fn optional_variables(&self) -> &[String] {
543 &[]
544 }
545
546 fn partial_variables(&self) -> &HashMap<String, String> {
548 static EMPTY: std::sync::LazyLock<HashMap<String, String>> =
549 std::sync::LazyLock::new(HashMap::new);
550 &EMPTY
551 }
552
553 fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>>;
555
556 fn aformat_messages(
558 &self,
559 kwargs: &HashMap<String, String>,
560 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<BaseMessage>>> + Send + '_>>
561 {
562 let result = self.format_messages(kwargs);
563 Box::pin(async move { result })
564 }
565
566 fn format(&self, kwargs: &HashMap<String, String>) -> Result<String> {
568 let messages = self.format_messages(kwargs)?;
569 Ok(messages
570 .iter()
571 .map(|m| format!("{}: {}", m.message_type(), m.content()))
572 .collect::<Vec<_>>()
573 .join("\n"))
574 }
575
576 fn pretty_repr(&self, html: bool) -> String;
578
579 fn pretty_print(&self) {
581 println!("{}", self.pretty_repr(is_interactive_env()));
582 }
583}
584
585#[derive(Debug, Clone, Default)]
607pub struct ChatPromptTemplate {
608 messages: Vec<ChatPromptMessage>,
610
611 input_variables: Vec<String>,
613
614 optional_variables: Vec<String>,
616
617 partial_variables: HashMap<String, String>,
619
620 validate_template: bool,
622
623 template_format: PromptTemplateFormat,
625}
626
627#[derive(Debug, Clone)]
629pub enum ChatPromptMessage {
630 Message(BaseMessage),
632 Human(HumanMessagePromptTemplate),
634 AI(AIMessagePromptTemplate),
636 System(SystemMessagePromptTemplate),
638 Chat(ChatMessagePromptTemplate),
640 Placeholder(MessagesPlaceholder),
642}
643
644impl ChatPromptMessage {
645 fn input_variables(&self) -> Vec<String> {
647 match self {
648 ChatPromptMessage::Message(_) => Vec::new(),
649 ChatPromptMessage::Human(t) => t.input_variables(),
650 ChatPromptMessage::AI(t) => t.input_variables(),
651 ChatPromptMessage::System(t) => t.input_variables(),
652 ChatPromptMessage::Chat(t) => t.input_variables(),
653 ChatPromptMessage::Placeholder(p) => p.input_variables(),
654 }
655 }
656
657 fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
659 match self {
660 ChatPromptMessage::Message(m) => Ok(vec![m.clone()]),
661 ChatPromptMessage::Human(t) => t.format_messages(kwargs),
662 ChatPromptMessage::AI(t) => t.format_messages(kwargs),
663 ChatPromptMessage::System(t) => t.format_messages(kwargs),
664 ChatPromptMessage::Chat(t) => t.format_messages(kwargs),
665 ChatPromptMessage::Placeholder(p) => p.format_messages(kwargs),
666 }
667 }
668
669 fn pretty_repr(&self, html: bool) -> String {
671 match self {
672 ChatPromptMessage::Message(m) => m.pretty_repr(html),
673 ChatPromptMessage::Human(t) => t.pretty_repr(html),
674 ChatPromptMessage::AI(t) => t.pretty_repr(html),
675 ChatPromptMessage::System(t) => t.pretty_repr(html),
676 ChatPromptMessage::Chat(t) => t.pretty_repr(html),
677 ChatPromptMessage::Placeholder(p) => p.pretty_repr(html),
678 }
679 }
680}
681
682impl ChatPromptTemplate {
683 pub fn new() -> Self {
685 Self::default()
686 }
687
688 pub fn from_messages(messages: &[(&str, &str)]) -> Result<Self> {
707 Self::from_messages_with_format(messages, PromptTemplateFormat::FString)
708 }
709
710 pub fn from_messages_with_format(
712 messages: &[(&str, &str)],
713 template_format: PromptTemplateFormat,
714 ) -> Result<Self> {
715 let mut template = Self::new();
716 template.template_format = template_format;
717
718 for (role, content) in messages {
719 let msg = create_template_from_message_type(role, content, template_format)?;
720 template.messages.push(msg);
721 }
722
723 let mut input_vars = std::collections::HashSet::new();
725 let mut optional_vars = std::collections::HashSet::new();
726
727 for msg in &template.messages {
728 match msg {
729 ChatPromptMessage::Placeholder(p) if p.optional => {
730 optional_vars.insert(p.variable_name.clone());
731 }
732 _ => {
733 for var in msg.input_variables() {
734 input_vars.insert(var);
735 }
736 }
737 }
738 }
739
740 template.input_variables = input_vars.into_iter().collect();
741 template.input_variables.sort();
742
743 template.optional_variables = optional_vars.into_iter().collect();
744 template.optional_variables.sort();
745
746 Ok(template)
747 }
748
749 pub fn from_template(template: &str) -> Result<Self> {
753 let prompt_template = PromptTemplate::from_template(template)?;
754 let message = HumanMessagePromptTemplate::new(prompt_template);
755
756 Ok(Self {
757 messages: vec![ChatPromptMessage::Human(message.clone())],
758 input_variables: message.input_variables(),
759 optional_variables: Vec::new(),
760 partial_variables: HashMap::new(),
761 validate_template: false,
762 template_format: PromptTemplateFormat::FString,
763 })
764 }
765
766 pub fn append(&mut self, message: ChatPromptMessage) {
768 for var in message.input_variables() {
769 if !self.input_variables.contains(&var) {
770 self.input_variables.push(var);
771 }
772 }
773 self.messages.push(message);
774 }
775
776 pub fn append_human(&mut self, template: &str) -> Result<()> {
778 let msg =
779 HumanMessagePromptTemplate::from_template_with_format(template, self.template_format)?;
780 self.append(ChatPromptMessage::Human(msg));
781 Ok(())
782 }
783
784 pub fn append_ai(&mut self, template: &str) -> Result<()> {
786 let msg =
787 AIMessagePromptTemplate::from_template_with_format(template, self.template_format)?;
788 self.append(ChatPromptMessage::AI(msg));
789 Ok(())
790 }
791
792 pub fn append_system(&mut self, template: &str) -> Result<()> {
794 let msg =
795 SystemMessagePromptTemplate::from_template_with_format(template, self.template_format)?;
796 self.append(ChatPromptMessage::System(msg));
797 Ok(())
798 }
799
800 pub fn append_placeholder(&mut self, variable_name: &str, optional: bool) {
802 let placeholder = MessagesPlaceholder::new(variable_name).optional(optional);
803 if !optional && !self.input_variables.contains(&variable_name.to_string()) {
804 self.input_variables.push(variable_name.to_string());
805 }
806 if optional {
807 self.optional_variables.push(variable_name.to_string());
808 }
809 self.messages
810 .push(ChatPromptMessage::Placeholder(placeholder));
811 }
812
813 pub fn partial(&self, kwargs: HashMap<String, String>) -> Self {
815 let new_vars: Vec<_> = self
816 .input_variables
817 .iter()
818 .filter(|v| !kwargs.contains_key(*v))
819 .cloned()
820 .collect();
821
822 let mut new_partials = self.partial_variables.clone();
823 new_partials.extend(kwargs);
824
825 Self {
826 messages: self.messages.clone(),
827 input_variables: new_vars,
828 optional_variables: self.optional_variables.clone(),
829 partial_variables: new_partials,
830 validate_template: self.validate_template,
831 template_format: self.template_format,
832 }
833 }
834
835 pub fn len(&self) -> usize {
837 self.messages.len()
838 }
839
840 pub fn is_empty(&self) -> bool {
842 self.messages.is_empty()
843 }
844
845 pub fn get(&self, index: usize) -> Option<&ChatPromptMessage> {
847 self.messages.get(index)
848 }
849
850 fn merge_partial_and_user_variables(
852 &self,
853 kwargs: &HashMap<String, String>,
854 ) -> HashMap<String, String> {
855 let mut merged = self.partial_variables.clone();
856 merged.extend(kwargs.clone());
857 merged
858 }
859}
860
861impl BaseChatPromptTemplate for ChatPromptTemplate {
862 fn input_variables(&self) -> &[String] {
863 &self.input_variables
864 }
865
866 fn optional_variables(&self) -> &[String] {
867 &self.optional_variables
868 }
869
870 fn partial_variables(&self) -> &HashMap<String, String> {
871 &self.partial_variables
872 }
873
874 fn format_messages(&self, kwargs: &HashMap<String, String>) -> Result<Vec<BaseMessage>> {
875 let merged = self.merge_partial_and_user_variables(kwargs);
876 let mut result = Vec::new();
877
878 for message in &self.messages {
879 let formatted = message.format_messages(&merged)?;
880 result.extend(formatted);
881 }
882
883 Ok(result)
884 }
885
886 fn pretty_repr(&self, html: bool) -> String {
887 self.messages
888 .iter()
889 .map(|m| m.pretty_repr(html))
890 .collect::<Vec<_>>()
891 .join("\n\n")
892 }
893}
894
895fn create_template_from_message_type(
897 message_type: &str,
898 template: &str,
899 template_format: PromptTemplateFormat,
900) -> Result<ChatPromptMessage> {
901 match message_type {
902 "human" | "user" => {
903 let t =
904 HumanMessagePromptTemplate::from_template_with_format(template, template_format)?;
905 Ok(ChatPromptMessage::Human(t))
906 }
907 "ai" | "assistant" => {
908 let t = AIMessagePromptTemplate::from_template_with_format(template, template_format)?;
909 Ok(ChatPromptMessage::AI(t))
910 }
911 "system" => {
912 let t =
913 SystemMessagePromptTemplate::from_template_with_format(template, template_format)?;
914 Ok(ChatPromptMessage::System(t))
915 }
916 "placeholder" => {
917 if !template.starts_with('{') || !template.ends_with('}') {
919 return Err(Error::InvalidConfig(format!(
920 "Invalid placeholder template: {}. Expected a variable name surrounded by curly braces.",
921 template
922 )));
923 }
924 let var_name = &template[1..template.len() - 1];
925 let placeholder = MessagesPlaceholder::new(var_name).optional(true);
926 Ok(ChatPromptMessage::Placeholder(placeholder))
927 }
928 _ => Err(Error::InvalidConfig(format!(
929 "Unexpected message type: {}. Use one of 'human', 'user', 'ai', 'assistant', 'system', or 'placeholder'.",
930 message_type
931 ))),
932 }
933}
934
935impl std::ops::Add for ChatPromptTemplate {
936 type Output = ChatPromptTemplate;
937
938 fn add(self, other: Self) -> Self::Output {
939 let mut messages = self.messages;
940 messages.extend(other.messages);
941
942 let mut input_vars: std::collections::HashSet<_> =
943 self.input_variables.into_iter().collect();
944 input_vars.extend(other.input_variables);
945
946 let mut partial_vars = self.partial_variables;
947 partial_vars.extend(other.partial_variables);
948
949 ChatPromptTemplate {
950 messages,
951 input_variables: input_vars.into_iter().collect(),
952 optional_variables: Vec::new(),
953 partial_variables: partial_vars,
954 validate_template: self.validate_template && other.validate_template,
955 template_format: self.template_format,
956 }
957 }
958}
959
960#[cfg(test)]
961mod tests {
962 use super::*;
963
964 #[test]
965 fn test_messages_placeholder() {
966 let placeholder = MessagesPlaceholder::new("history");
967 assert_eq!(placeholder.input_variables(), vec!["history"]);
968
969 let optional_placeholder = MessagesPlaceholder::new("history").optional(true);
970 assert!(optional_placeholder.input_variables().is_empty());
971 }
972
973 #[test]
974 fn test_human_message_template() {
975 let template = HumanMessagePromptTemplate::from_template("Hello, {name}!").unwrap();
976
977 let mut kwargs = HashMap::new();
978 kwargs.insert("name".to_string(), "World".to_string());
979
980 let messages = template.format_messages(&kwargs).unwrap();
981 assert_eq!(messages.len(), 1);
982 assert_eq!(messages[0].content(), "Hello, World!");
983 }
984
985 #[test]
986 fn test_system_message_template() {
987 let template = SystemMessagePromptTemplate::from_template("You are {role}").unwrap();
988
989 let mut kwargs = HashMap::new();
990 kwargs.insert("role".to_string(), "an assistant".to_string());
991
992 let messages = template.format_messages(&kwargs).unwrap();
993 assert_eq!(messages.len(), 1);
994 assert_eq!(messages[0].content(), "You are an assistant");
995 }
996
997 #[test]
998 fn test_chat_prompt_template() {
999 let template = ChatPromptTemplate::from_messages(&[
1000 ("system", "You are a helpful assistant."),
1001 ("human", "{question}"),
1002 ])
1003 .unwrap();
1004
1005 assert_eq!(template.input_variables(), &["question"]);
1006
1007 let mut kwargs = HashMap::new();
1008 kwargs.insert("question".to_string(), "Hello!".to_string());
1009
1010 let messages = template.format_messages(&kwargs).unwrap();
1011 assert_eq!(messages.len(), 2);
1012 assert_eq!(messages[0].content(), "You are a helpful assistant.");
1013 assert_eq!(messages[1].content(), "Hello!");
1014 }
1015
1016 #[test]
1017 fn test_chat_prompt_template_from_template() {
1018 let template = ChatPromptTemplate::from_template("Hello, {name}!").unwrap();
1019
1020 let mut kwargs = HashMap::new();
1021 kwargs.insert("name".to_string(), "World".to_string());
1022
1023 let messages = template.format_messages(&kwargs).unwrap();
1024 assert_eq!(messages.len(), 1);
1025 assert_eq!(messages[0].content(), "Hello, World!");
1026 }
1027
1028 #[test]
1029 fn test_chat_prompt_add() {
1030 let template1 =
1031 ChatPromptTemplate::from_messages(&[("system", "You are a helpful assistant.")])
1032 .unwrap();
1033
1034 let template2 = ChatPromptTemplate::from_messages(&[("human", "{question}")]).unwrap();
1035
1036 let combined = template1 + template2;
1037
1038 let mut kwargs = HashMap::new();
1039 kwargs.insert("question".to_string(), "Hello!".to_string());
1040
1041 let messages = combined.format_messages(&kwargs).unwrap();
1042 assert_eq!(messages.len(), 2);
1043 }
1044
1045 #[test]
1046 fn test_partial() {
1047 let template = ChatPromptTemplate::from_messages(&[
1048 ("system", "You are {role}."),
1049 ("human", "{question}"),
1050 ])
1051 .unwrap();
1052
1053 let mut partial_vars = HashMap::new();
1054 partial_vars.insert("role".to_string(), "an assistant".to_string());
1055
1056 let partial = template.partial(partial_vars);
1057 assert_eq!(partial.input_variables(), &["question"]);
1058
1059 let mut kwargs = HashMap::new();
1060 kwargs.insert("question".to_string(), "Hello!".to_string());
1061
1062 let messages = partial.format_messages(&kwargs).unwrap();
1063 assert_eq!(messages.len(), 2);
1064 assert_eq!(messages[0].content(), "You are an assistant.");
1065 }
1066}