Skip to main content

synaptic_prompts/
chat_template.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use synaptic_core::{Message, RunnableConfig, SynapticError};
6use synaptic_runnables::Runnable;
7
8use crate::PromptTemplate;
9
10/// A template component that produces one or more Messages.
11pub enum MessageTemplate {
12    /// Renders a system message from a template string.
13    System(PromptTemplate),
14    /// Renders a human message from a template string.
15    Human(PromptTemplate),
16    /// Renders an AI message from a template string.
17    AI(PromptTemplate),
18    /// Injects messages from the input map under the given key.
19    /// The value at that key must be a JSON array of Message objects.
20    Placeholder(String),
21}
22
23/// A chat prompt template that renders a sequence of messages.
24///
25/// ```ignore
26/// let prompt = ChatPromptTemplate::from_messages(vec![
27///     MessageTemplate::System(PromptTemplate::new("You are a helpful assistant.")),
28///     MessageTemplate::Placeholder("history".to_string()),
29///     MessageTemplate::Human(PromptTemplate::new("{{ input }}")),
30/// ]);
31/// ```
32pub struct ChatPromptTemplate {
33    templates: Vec<MessageTemplate>,
34}
35
36impl ChatPromptTemplate {
37    pub fn new(templates: Vec<MessageTemplate>) -> Self {
38        Self { templates }
39    }
40
41    /// Alias for `new`, matching LangChain's factory method name.
42    pub fn from_messages(templates: Vec<MessageTemplate>) -> Self {
43        Self::new(templates)
44    }
45
46    /// Render the templates against the given variables, producing a list of messages.
47    pub fn format(&self, values: &HashMap<String, Value>) -> Result<Vec<Message>, SynapticError> {
48        let mut messages = Vec::new();
49
50        // Build a string map for PromptTemplate rendering
51        let string_values: HashMap<String, String> = values
52            .iter()
53            .filter_map(|(k, v)| {
54                if let Value::String(s) = v {
55                    Some((k.clone(), s.clone()))
56                } else {
57                    None
58                }
59            })
60            .collect();
61
62        for template in &self.templates {
63            match template {
64                MessageTemplate::System(pt) => {
65                    let content = pt
66                        .render(&string_values)
67                        .map_err(|e| SynapticError::Prompt(e.to_string()))?;
68                    messages.push(Message::system(content));
69                }
70                MessageTemplate::Human(pt) => {
71                    let content = pt
72                        .render(&string_values)
73                        .map_err(|e| SynapticError::Prompt(e.to_string()))?;
74                    messages.push(Message::human(content));
75                }
76                MessageTemplate::AI(pt) => {
77                    let content = pt
78                        .render(&string_values)
79                        .map_err(|e| SynapticError::Prompt(e.to_string()))?;
80                    messages.push(Message::ai(content));
81                }
82                MessageTemplate::Placeholder(key) => {
83                    let value = values.get(key).ok_or_else(|| {
84                        SynapticError::Prompt(format!("missing placeholder: {key}"))
85                    })?;
86                    let msgs: Vec<Message> =
87                        serde_json::from_value(value.clone()).map_err(|e| {
88                            SynapticError::Prompt(format!(
89                                "invalid messages for placeholder '{key}': {e}"
90                            ))
91                        })?;
92                    messages.extend(msgs);
93                }
94            }
95        }
96
97        Ok(messages)
98    }
99}
100
101#[async_trait]
102impl Runnable<HashMap<String, Value>, Vec<Message>> for ChatPromptTemplate {
103    async fn invoke(
104        &self,
105        input: HashMap<String, Value>,
106        _config: &RunnableConfig,
107    ) -> Result<Vec<Message>, SynapticError> {
108        self.format(&input)
109    }
110}