cognis_core/prompts/
template.rs1use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use serde::Serialize;
7use serde_json::Value;
8
9use crate::runnable::{Runnable, RunnableConfig};
10use crate::{CognisError, Result};
11
12#[derive(Debug, Clone)]
21pub struct PromptTemplate<I = Value> {
22 template: String,
23 _input: PhantomData<fn() -> I>,
24}
25
26impl<I> PromptTemplate<I>
27where
28 I: Serialize + Send + Sync + 'static,
29{
30 pub fn new(template: impl Into<String>) -> Self {
33 Self {
34 template: template.into(),
35 _input: PhantomData,
36 }
37 }
38
39 pub fn render(&self, input: &I) -> Result<String> {
41 let value =
42 serde_json::to_value(input).map_err(|e| CognisError::Serialization(e.to_string()))?;
43 render(&self.template, &value)
44 }
45
46 pub fn template_str(&self) -> &str {
48 &self.template
49 }
50
51 pub fn input_variables(&self) -> Vec<String> {
53 scan_variables(&self.template)
54 }
55}
56
57#[async_trait]
58impl<I> Runnable<I, String> for PromptTemplate<I>
59where
60 I: Serialize + Send + Sync + 'static,
61{
62 async fn invoke(&self, input: I, _: RunnableConfig) -> Result<String> {
63 self.render(&input)
64 }
65
66 fn name(&self) -> &str {
67 "PromptTemplate"
68 }
69}
70
71pub(crate) fn render(template: &str, ctx: &Value) -> Result<String> {
77 let mut out = String::with_capacity(template.len());
78 let mut chars = template.chars().peekable();
79 while let Some(c) = chars.next() {
80 match c {
81 '{' if chars.peek() == Some(&'{') => {
82 chars.next();
83 out.push('{');
84 }
85 '}' if chars.peek() == Some(&'}') => {
86 chars.next();
87 out.push('}');
88 }
89 '{' => {
90 let mut name = String::new();
91 let mut closed = false;
92 for nc in chars.by_ref() {
93 if nc == '}' {
94 closed = true;
95 break;
96 }
97 name.push(nc);
98 }
99 if !closed {
100 return Err(CognisError::Configuration(format!(
101 "unclosed `{{` in template: {template}"
102 )));
103 }
104 let key = name.trim();
105 let resolved = lookup(ctx, key).ok_or_else(|| {
106 CognisError::Configuration(format!("missing template variable `{key}`"))
107 })?;
108 out.push_str(&value_to_string(&resolved));
109 }
110 other => out.push(other),
111 }
112 }
113 Ok(out)
114}
115
116pub(crate) fn scan_variables(template: &str) -> Vec<String> {
118 let mut out = Vec::new();
119 let mut chars = template.chars().peekable();
120 while let Some(c) = chars.next() {
121 match c {
122 '{' if chars.peek() == Some(&'{') => {
123 chars.next();
124 }
125 '}' if chars.peek() == Some(&'}') => {
126 chars.next();
127 }
128 '{' => {
129 let mut name = String::new();
130 for nc in chars.by_ref() {
131 if nc == '}' {
132 break;
133 }
134 name.push(nc);
135 }
136 let trimmed = name.trim().to_string();
137 if !trimmed.is_empty() && !out.contains(&trimmed) {
138 out.push(trimmed);
139 }
140 }
141 _ => {}
142 }
143 }
144 out
145}
146
147fn lookup(ctx: &Value, key: &str) -> Option<Value> {
148 let mut cur = ctx.clone();
149 for segment in key.split('.') {
150 cur = match cur {
151 Value::Object(mut m) => m.remove(segment)?,
152 _ => return None,
153 };
154 }
155 Some(cur)
156}
157
158fn value_to_string(v: &Value) -> String {
159 match v {
160 Value::String(s) => s.clone(),
161 Value::Null => String::new(),
162 v => v.to_string(),
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use serde_json::json;
170
171 #[tokio::test]
172 async fn renders_simple() {
173 let p = PromptTemplate::<Value>::new("hello {name}");
174 let out = p
175 .invoke(json!({"name": "world"}), RunnableConfig::default())
176 .await
177 .unwrap();
178 assert_eq!(out, "hello world");
179 }
180
181 #[test]
182 fn renders_typed_struct() {
183 #[derive(Serialize)]
184 struct Ctx {
185 name: String,
186 }
187 let p: PromptTemplate<Ctx> = PromptTemplate::new("hi {name}");
188 let out = p
189 .render(&Ctx {
190 name: "rust".into(),
191 })
192 .unwrap();
193 assert_eq!(out, "hi rust");
194 }
195
196 #[test]
197 fn dotted_paths() {
198 let p: PromptTemplate<Value> = PromptTemplate::new("{user.name} aged {user.age}");
199 let out = p
200 .render(&json!({"user": {"name": "Ada", "age": 36}}))
201 .unwrap();
202 assert_eq!(out, "Ada aged 36");
203 }
204
205 #[test]
206 fn literal_braces() {
207 let p: PromptTemplate<Value> = PromptTemplate::new("{{not a var}} {x}");
208 let out = p.render(&json!({"x": 1})).unwrap();
209 assert_eq!(out, "{not a var} 1");
210 }
211
212 #[test]
213 fn missing_variable_errors() {
214 let p: PromptTemplate<Value> = PromptTemplate::new("hi {name}");
215 let err = p.render(&json!({})).unwrap_err();
216 assert!(matches!(err, CognisError::Configuration(_)));
217 }
218
219 #[test]
220 fn unclosed_brace_errors() {
221 let p: PromptTemplate<Value> = PromptTemplate::new("hi {name");
222 let err = p.render(&json!({"name": "x"})).unwrap_err();
223 assert!(matches!(err, CognisError::Configuration(_)));
224 }
225
226 #[test]
227 fn input_variables_returns_unique_in_order() {
228 let p: PromptTemplate<Value> = PromptTemplate::new("{a} {b} {a} {c}");
229 assert_eq!(p.input_variables(), vec!["a", "b", "c"]);
230 }
231}