Skip to main content

cognis_core/prompts/
chat.rs

1//! `ChatPromptTemplate` — typed templating that produces `Vec<Message>`.
2
3use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use serde::Serialize;
7use serde_json::Value;
8
9use crate::content::ContentPart;
10use crate::message::Message;
11use crate::prompts::template::{render, scan_variables};
12use crate::runnable::{Runnable, RunnableConfig};
13use crate::{CognisError, Result};
14
15/// One element in a `ChatPromptTemplate`.
16#[derive(Debug, Clone)]
17enum Part {
18    /// A templated message at a fixed role.
19    Templated { role: Role, template: String },
20    /// A templated message that also carries multimodal parts. The text
21    /// template is rendered against the input; the parts are passed
22    /// through as-is.
23    Multimodal {
24        role: Role,
25        template: String,
26        parts: Vec<ContentPart>,
27    },
28    /// Drop in a `Vec<Message>` from a named field of the input.
29    Placeholder { key: String, optional: bool },
30}
31
32/// The role assigned to a templated message part.
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum Role {
35    /// Renders as `Message::System`.
36    System,
37    /// Renders as `Message::Human`.
38    Human,
39    /// Renders as `Message::Ai`.
40    Ai,
41}
42
43/// A typed chat prompt — renders to `Vec<Message>` when invoked.
44///
45/// Build via the fluent API:
46///
47/// ```no_run
48/// use cognis_core::prompts::ChatPromptTemplate;
49/// use serde::Serialize;
50///
51/// #[derive(Serialize)]
52/// struct In { name: String }
53///
54/// let prompt: ChatPromptTemplate<In> = ChatPromptTemplate::new()
55///     .system("You are a helpful assistant.")
56///     .placeholder("history")
57///     .human("Hello, my name is {name}.");
58/// ```
59///
60/// Placeholders pull a `Vec<Message>` from a named field of the input
61/// (the field must serialize to a JSON array of `Message` objects).
62#[derive(Debug, Clone)]
63pub struct ChatPromptTemplate<I = Value> {
64    parts: Vec<Part>,
65    _input: PhantomData<fn() -> I>,
66}
67
68impl<I> Default for ChatPromptTemplate<I> {
69    fn default() -> Self {
70        Self {
71            parts: Vec::new(),
72            _input: PhantomData,
73        }
74    }
75}
76
77impl<I> ChatPromptTemplate<I>
78where
79    I: Serialize + Send + Sync + 'static,
80{
81    /// Empty builder.
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    /// Append a system-role templated message.
87    pub fn system(mut self, template: impl Into<String>) -> Self {
88        self.parts.push(Part::Templated {
89            role: Role::System,
90            template: template.into(),
91        });
92        self
93    }
94
95    /// Append a human-role templated message.
96    pub fn human(mut self, template: impl Into<String>) -> Self {
97        self.parts.push(Part::Templated {
98            role: Role::Human,
99            template: template.into(),
100        });
101        self
102    }
103
104    /// Append an AI-role templated message.
105    pub fn ai(mut self, template: impl Into<String>) -> Self {
106        self.parts.push(Part::Templated {
107            role: Role::Ai,
108            template: template.into(),
109        });
110        self
111    }
112
113    /// Append a human-role multimodal message: a text template plus a
114    /// pre-built list of [`ContentPart`]s (images, audio). The text
115    /// template still receives template-variable substitution against
116    /// the call's input.
117    pub fn human_with_parts(
118        mut self,
119        template: impl Into<String>,
120        parts: Vec<ContentPart>,
121    ) -> Self {
122        self.parts.push(Part::Multimodal {
123            role: Role::Human,
124            template: template.into(),
125            parts,
126        });
127        self
128    }
129
130    /// Append an AI-role multimodal message.
131    pub fn ai_with_parts(mut self, template: impl Into<String>, parts: Vec<ContentPart>) -> Self {
132        self.parts.push(Part::Multimodal {
133            role: Role::Ai,
134            template: template.into(),
135            parts,
136        });
137        self
138    }
139
140    /// Convenience: append a human-role message with one image URL part.
141    pub fn human_with_image_url(
142        self,
143        template: impl Into<String>,
144        url: impl Into<String>,
145        mime: impl Into<String>,
146    ) -> Self {
147        self.human_with_parts(
148            template,
149            vec![ContentPart::Image {
150                source: crate::content::ImageSource::url(url),
151                mime: mime.into(),
152            }],
153        )
154    }
155
156    /// Append a `Vec<Message>` field from the input.
157    ///
158    /// Errors at render time if the field is missing.
159    pub fn placeholder(mut self, key: impl Into<String>) -> Self {
160        self.parts.push(Part::Placeholder {
161            key: key.into(),
162            optional: false,
163        });
164        self
165    }
166
167    /// Like [`placeholder`](Self::placeholder), but a missing key resolves
168    /// to an empty list instead of an error.
169    pub fn optional_placeholder(mut self, key: impl Into<String>) -> Self {
170        self.parts.push(Part::Placeholder {
171            key: key.into(),
172            optional: true,
173        });
174        self
175    }
176
177    /// Build from a list of `(role, template)` tuples.
178    pub fn from_messages(messages: Vec<(Role, String)>) -> Self {
179        let parts = messages
180            .into_iter()
181            .map(|(role, template)| Part::Templated { role, template })
182            .collect();
183        Self {
184            parts,
185            _input: PhantomData,
186        }
187    }
188
189    /// All `{name}` placeholders across templated parts (excludes
190    /// placeholder keys, which are field names not template variables).
191    pub fn input_variables(&self) -> Vec<String> {
192        let mut out = Vec::new();
193        for p in &self.parts {
194            let template = match p {
195                Part::Templated { template, .. } | Part::Multimodal { template, .. } => template,
196                Part::Placeholder { .. } => continue,
197            };
198            for v in scan_variables(template) {
199                if !out.contains(&v) {
200                    out.push(v);
201                }
202            }
203        }
204        out
205    }
206
207    /// Render to `Vec<Message>`.
208    pub fn render(&self, input: &I) -> Result<Vec<Message>> {
209        let ctx =
210            serde_json::to_value(input).map_err(|e| CognisError::Serialization(e.to_string()))?;
211        let mut out = Vec::with_capacity(self.parts.len());
212        for part in &self.parts {
213            match part {
214                Part::Templated { role, template } => {
215                    let text = render(template, &ctx)?;
216                    out.push(make_message(*role, text));
217                }
218                Part::Multimodal {
219                    role,
220                    template,
221                    parts,
222                } => {
223                    let text = render(template, &ctx)?;
224                    out.push(make_multimodal_message(*role, text, parts.clone()));
225                }
226                Part::Placeholder { key, optional } => {
227                    out.extend(pull_messages(&ctx, key, *optional)?);
228                }
229            }
230        }
231        Ok(out)
232    }
233}
234
235#[async_trait]
236impl<I> Runnable<I, Vec<Message>> for ChatPromptTemplate<I>
237where
238    I: Serialize + Send + Sync + 'static,
239{
240    async fn invoke(&self, input: I, _: RunnableConfig) -> Result<Vec<Message>> {
241        self.render(&input)
242    }
243    fn name(&self) -> &str {
244        "ChatPromptTemplate"
245    }
246}
247
248fn make_message(role: Role, text: String) -> Message {
249    match role {
250        Role::System => Message::system(text),
251        Role::Human => Message::human(text),
252        Role::Ai => Message::ai(text),
253    }
254}
255
256fn make_multimodal_message(role: Role, text: String, parts: Vec<ContentPart>) -> Message {
257    match role {
258        // System messages are text-only; if the caller asks for a
259        // multimodal system message we drop the parts and warn via tracing.
260        Role::System => {
261            if !parts.is_empty() {
262                tracing::warn!(
263                    "ChatPromptTemplate: system role doesn't support multimodal parts; dropping"
264                );
265            }
266            Message::system(text)
267        }
268        Role::Human => Message::human_with_parts(text, parts),
269        Role::Ai => Message::ai_with_parts(text, parts),
270    }
271}
272
273fn pull_messages(ctx: &Value, key: &str, optional: bool) -> Result<Vec<Message>> {
274    let v = match ctx.get(key) {
275        Some(v) => v,
276        None => {
277            return if optional {
278                Ok(Vec::new())
279            } else {
280                Err(CognisError::Configuration(format!(
281                    "missing required placeholder field `{key}`"
282                )))
283            };
284        }
285    };
286    serde_json::from_value::<Vec<Message>>(v.clone()).map_err(|e| {
287        CognisError::Serialization(format!(
288            "placeholder `{key}` did not deserialize as Vec<Message>: {e}"
289        ))
290    })
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use serde_json::json;
297
298    #[tokio::test]
299    async fn renders_simple_chat() {
300        let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
301            .system("you are {role}")
302            .human("hi {name}");
303        let out = p
304            .invoke(
305                json!({"role": "helpful", "name": "ada"}),
306                RunnableConfig::default(),
307            )
308            .await
309            .unwrap();
310        assert_eq!(out.len(), 2);
311        assert!(matches!(out[0], Message::System(_)));
312        assert_eq!(out[0].content(), "you are helpful");
313        assert!(matches!(out[1], Message::Human(_)));
314        assert_eq!(out[1].content(), "hi ada");
315    }
316
317    #[test]
318    fn placeholder_drops_in_messages() {
319        let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
320            .system("sys")
321            .placeholder("history")
322            .human("now");
323        let history = json!([
324            {"role": "human", "content": "before-1"},
325            {"role": "ai",    "content": "before-2"}
326        ]);
327        let out = p.render(&json!({"history": history})).unwrap();
328        assert_eq!(out.len(), 4);
329        assert_eq!(out[1].content(), "before-1");
330        assert_eq!(out[2].content(), "before-2");
331        assert_eq!(out[3].content(), "now");
332    }
333
334    #[test]
335    fn missing_required_placeholder_errors() {
336        let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new().placeholder("history");
337        let err = p.render(&json!({})).unwrap_err();
338        assert!(matches!(err, CognisError::Configuration(_)));
339    }
340
341    #[test]
342    fn optional_placeholder_accepts_missing() {
343        let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
344            .system("hi")
345            .optional_placeholder("history");
346        let out = p.render(&json!({})).unwrap();
347        assert_eq!(out.len(), 1);
348    }
349
350    #[test]
351    fn input_variables_collects_unique() {
352        let p: ChatPromptTemplate<Value> =
353            ChatPromptTemplate::new().system("{a} {b}").human("{a} {c}");
354        assert_eq!(p.input_variables(), vec!["a", "b", "c"]);
355    }
356
357    #[test]
358    fn from_messages_constructs_fluently() {
359        let p: ChatPromptTemplate<Value> = ChatPromptTemplate::from_messages(vec![
360            (Role::System, "sys".into()),
361            (Role::Human, "hi {name}".into()),
362        ]);
363        let out = p.render(&json!({"name": "ada"})).unwrap();
364        assert_eq!(out.len(), 2);
365        assert_eq!(out[1].content(), "hi ada");
366    }
367
368    #[test]
369    fn human_with_image_url_renders_with_part() {
370        let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
371            .system("describe images")
372            .human_with_image_url("describe {topic}", "https://x/cat.jpg", "image/jpeg");
373        let out = p.render(&json!({"topic": "this cat"})).unwrap();
374        assert_eq!(out.len(), 2);
375        assert_eq!(out[1].content(), "describe this cat");
376        let parts = out[1].parts();
377        assert_eq!(parts.len(), 1);
378        assert!(matches!(
379            parts[0],
380            crate::content::ContentPart::Image { .. }
381        ));
382    }
383
384    #[test]
385    fn input_variables_includes_multimodal_template_vars() {
386        let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
387            .human("text {a}")
388            .human_with_image_url("multimodal {b}", "https://x", "image/png");
389        let mut vars = p.input_variables();
390        vars.sort();
391        assert_eq!(vars, vec!["a", "b"]);
392    }
393}