1use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use serde::Serialize;
7use serde_json::Value;
8
9use crate::content::ContentPart;
10use crate::message::Message;
11use crate::prompts::template::{render, scan_variables};
12use crate::runnable::{Runnable, RunnableConfig};
13use crate::{CognisError, Result};
14
15#[derive(Debug, Clone)]
17enum Part {
18 Templated { role: Role, template: String },
20 Multimodal {
24 role: Role,
25 template: String,
26 parts: Vec<ContentPart>,
27 },
28 Placeholder { key: String, optional: bool },
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum Role {
35 System,
37 Human,
39 Ai,
41}
42
43#[derive(Debug, Clone)]
63pub struct ChatPromptTemplate<I = Value> {
64 parts: Vec<Part>,
65 _input: PhantomData<fn() -> I>,
66}
67
68impl<I> Default for ChatPromptTemplate<I> {
69 fn default() -> Self {
70 Self {
71 parts: Vec::new(),
72 _input: PhantomData,
73 }
74 }
75}
76
77impl<I> ChatPromptTemplate<I>
78where
79 I: Serialize + Send + Sync + 'static,
80{
81 pub fn new() -> Self {
83 Self::default()
84 }
85
86 pub fn system(mut self, template: impl Into<String>) -> Self {
88 self.parts.push(Part::Templated {
89 role: Role::System,
90 template: template.into(),
91 });
92 self
93 }
94
95 pub fn human(mut self, template: impl Into<String>) -> Self {
97 self.parts.push(Part::Templated {
98 role: Role::Human,
99 template: template.into(),
100 });
101 self
102 }
103
104 pub fn ai(mut self, template: impl Into<String>) -> Self {
106 self.parts.push(Part::Templated {
107 role: Role::Ai,
108 template: template.into(),
109 });
110 self
111 }
112
113 pub fn human_with_parts(
118 mut self,
119 template: impl Into<String>,
120 parts: Vec<ContentPart>,
121 ) -> Self {
122 self.parts.push(Part::Multimodal {
123 role: Role::Human,
124 template: template.into(),
125 parts,
126 });
127 self
128 }
129
130 pub fn ai_with_parts(mut self, template: impl Into<String>, parts: Vec<ContentPart>) -> Self {
132 self.parts.push(Part::Multimodal {
133 role: Role::Ai,
134 template: template.into(),
135 parts,
136 });
137 self
138 }
139
140 pub fn human_with_image_url(
142 self,
143 template: impl Into<String>,
144 url: impl Into<String>,
145 mime: impl Into<String>,
146 ) -> Self {
147 self.human_with_parts(
148 template,
149 vec![ContentPart::Image {
150 source: crate::content::ImageSource::url(url),
151 mime: mime.into(),
152 }],
153 )
154 }
155
156 pub fn placeholder(mut self, key: impl Into<String>) -> Self {
160 self.parts.push(Part::Placeholder {
161 key: key.into(),
162 optional: false,
163 });
164 self
165 }
166
167 pub fn optional_placeholder(mut self, key: impl Into<String>) -> Self {
170 self.parts.push(Part::Placeholder {
171 key: key.into(),
172 optional: true,
173 });
174 self
175 }
176
177 pub fn from_messages(messages: Vec<(Role, String)>) -> Self {
179 let parts = messages
180 .into_iter()
181 .map(|(role, template)| Part::Templated { role, template })
182 .collect();
183 Self {
184 parts,
185 _input: PhantomData,
186 }
187 }
188
189 pub fn input_variables(&self) -> Vec<String> {
192 let mut out = Vec::new();
193 for p in &self.parts {
194 let template = match p {
195 Part::Templated { template, .. } | Part::Multimodal { template, .. } => template,
196 Part::Placeholder { .. } => continue,
197 };
198 for v in scan_variables(template) {
199 if !out.contains(&v) {
200 out.push(v);
201 }
202 }
203 }
204 out
205 }
206
207 pub fn render(&self, input: &I) -> Result<Vec<Message>> {
209 let ctx =
210 serde_json::to_value(input).map_err(|e| CognisError::Serialization(e.to_string()))?;
211 let mut out = Vec::with_capacity(self.parts.len());
212 for part in &self.parts {
213 match part {
214 Part::Templated { role, template } => {
215 let text = render(template, &ctx)?;
216 out.push(make_message(*role, text));
217 }
218 Part::Multimodal {
219 role,
220 template,
221 parts,
222 } => {
223 let text = render(template, &ctx)?;
224 out.push(make_multimodal_message(*role, text, parts.clone()));
225 }
226 Part::Placeholder { key, optional } => {
227 out.extend(pull_messages(&ctx, key, *optional)?);
228 }
229 }
230 }
231 Ok(out)
232 }
233}
234
235#[async_trait]
236impl<I> Runnable<I, Vec<Message>> for ChatPromptTemplate<I>
237where
238 I: Serialize + Send + Sync + 'static,
239{
240 async fn invoke(&self, input: I, _: RunnableConfig) -> Result<Vec<Message>> {
241 self.render(&input)
242 }
243 fn name(&self) -> &str {
244 "ChatPromptTemplate"
245 }
246}
247
248fn make_message(role: Role, text: String) -> Message {
249 match role {
250 Role::System => Message::system(text),
251 Role::Human => Message::human(text),
252 Role::Ai => Message::ai(text),
253 }
254}
255
256fn make_multimodal_message(role: Role, text: String, parts: Vec<ContentPart>) -> Message {
257 match role {
258 Role::System => {
261 if !parts.is_empty() {
262 tracing::warn!(
263 "ChatPromptTemplate: system role doesn't support multimodal parts; dropping"
264 );
265 }
266 Message::system(text)
267 }
268 Role::Human => Message::human_with_parts(text, parts),
269 Role::Ai => Message::ai_with_parts(text, parts),
270 }
271}
272
273fn pull_messages(ctx: &Value, key: &str, optional: bool) -> Result<Vec<Message>> {
274 let v = match ctx.get(key) {
275 Some(v) => v,
276 None => {
277 return if optional {
278 Ok(Vec::new())
279 } else {
280 Err(CognisError::Configuration(format!(
281 "missing required placeholder field `{key}`"
282 )))
283 };
284 }
285 };
286 serde_json::from_value::<Vec<Message>>(v.clone()).map_err(|e| {
287 CognisError::Serialization(format!(
288 "placeholder `{key}` did not deserialize as Vec<Message>: {e}"
289 ))
290 })
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use serde_json::json;
297
298 #[tokio::test]
299 async fn renders_simple_chat() {
300 let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
301 .system("you are {role}")
302 .human("hi {name}");
303 let out = p
304 .invoke(
305 json!({"role": "helpful", "name": "ada"}),
306 RunnableConfig::default(),
307 )
308 .await
309 .unwrap();
310 assert_eq!(out.len(), 2);
311 assert!(matches!(out[0], Message::System(_)));
312 assert_eq!(out[0].content(), "you are helpful");
313 assert!(matches!(out[1], Message::Human(_)));
314 assert_eq!(out[1].content(), "hi ada");
315 }
316
317 #[test]
318 fn placeholder_drops_in_messages() {
319 let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
320 .system("sys")
321 .placeholder("history")
322 .human("now");
323 let history = json!([
324 {"role": "human", "content": "before-1"},
325 {"role": "ai", "content": "before-2"}
326 ]);
327 let out = p.render(&json!({"history": history})).unwrap();
328 assert_eq!(out.len(), 4);
329 assert_eq!(out[1].content(), "before-1");
330 assert_eq!(out[2].content(), "before-2");
331 assert_eq!(out[3].content(), "now");
332 }
333
334 #[test]
335 fn missing_required_placeholder_errors() {
336 let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new().placeholder("history");
337 let err = p.render(&json!({})).unwrap_err();
338 assert!(matches!(err, CognisError::Configuration(_)));
339 }
340
341 #[test]
342 fn optional_placeholder_accepts_missing() {
343 let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
344 .system("hi")
345 .optional_placeholder("history");
346 let out = p.render(&json!({})).unwrap();
347 assert_eq!(out.len(), 1);
348 }
349
350 #[test]
351 fn input_variables_collects_unique() {
352 let p: ChatPromptTemplate<Value> =
353 ChatPromptTemplate::new().system("{a} {b}").human("{a} {c}");
354 assert_eq!(p.input_variables(), vec!["a", "b", "c"]);
355 }
356
357 #[test]
358 fn from_messages_constructs_fluently() {
359 let p: ChatPromptTemplate<Value> = ChatPromptTemplate::from_messages(vec![
360 (Role::System, "sys".into()),
361 (Role::Human, "hi {name}".into()),
362 ]);
363 let out = p.render(&json!({"name": "ada"})).unwrap();
364 assert_eq!(out.len(), 2);
365 assert_eq!(out[1].content(), "hi ada");
366 }
367
368 #[test]
369 fn human_with_image_url_renders_with_part() {
370 let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
371 .system("describe images")
372 .human_with_image_url("describe {topic}", "https://x/cat.jpg", "image/jpeg");
373 let out = p.render(&json!({"topic": "this cat"})).unwrap();
374 assert_eq!(out.len(), 2);
375 assert_eq!(out[1].content(), "describe this cat");
376 let parts = out[1].parts();
377 assert_eq!(parts.len(), 1);
378 assert!(matches!(
379 parts[0],
380 crate::content::ContentPart::Image { .. }
381 ));
382 }
383
384 #[test]
385 fn input_variables_includes_multimodal_template_vars() {
386 let p: ChatPromptTemplate<Value> = ChatPromptTemplate::new()
387 .human("text {a}")
388 .human_with_image_url("multimodal {b}", "https://x", "image/png");
389 let mut vars = p.input_variables();
390 vars.sort();
391 assert_eq!(vars, vec!["a", "b"]);
392 }
393}