Skip to main content

cognis_core/prompts/
template.rs

1//! String template + the rendering engine shared by all prompt types.
2
3use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use serde::Serialize;
7use serde_json::Value;
8
9use crate::runnable::{Runnable, RunnableConfig};
10use crate::{CognisError, Result};
11
12/// A typed string prompt template.
13///
14/// Inputs are any `Serialize` type — fields are read via serde reflection,
15/// so plain structs, `HashMap<String, Value>`, and `serde_json::Value` all
16/// work transparently.
17///
18/// Placeholders are `{name}`. Literal braces are `{{` and `}}`. Dotted
19/// paths (`{user.name}`) descend into nested objects.
20#[derive(Debug, Clone)]
21pub struct PromptTemplate<I = Value> {
22    template: String,
23    _input: PhantomData<fn() -> I>,
24}
25
26impl<I> PromptTemplate<I>
27where
28    I: Serialize + Send + Sync + 'static,
29{
30    /// Build a template from a string. Doesn't validate placeholder
31    /// names — invalid placeholders surface at render time.
32    pub fn new(template: impl Into<String>) -> Self {
33        Self {
34            template: template.into(),
35            _input: PhantomData,
36        }
37    }
38
39    /// Render the template against the given input.
40    pub fn render(&self, input: &I) -> Result<String> {
41        let value =
42            serde_json::to_value(input).map_err(|e| CognisError::Serialization(e.to_string()))?;
43        render(&self.template, &value)
44    }
45
46    /// The raw template string.
47    pub fn template_str(&self) -> &str {
48        &self.template
49    }
50
51    /// All `{name}` placeholders in the template, in first-occurrence order.
52    pub fn input_variables(&self) -> Vec<String> {
53        scan_variables(&self.template)
54    }
55}
56
57#[async_trait]
58impl<I> Runnable<I, String> for PromptTemplate<I>
59where
60    I: Serialize + Send + Sync + 'static,
61{
62    async fn invoke(&self, input: I, _: RunnableConfig) -> Result<String> {
63        self.render(&input)
64    }
65
66    fn name(&self) -> &str {
67        "PromptTemplate"
68    }
69}
70
71/// Render `{var}` placeholders against a `serde_json::Value` context.
72///
73/// `{{` and `}}` are literal `{` and `}`. Dotted keys descend into nested
74/// objects. Returns `CognisError::Configuration` for missing variables or
75/// unclosed braces.
76pub(crate) fn render(template: &str, ctx: &Value) -> Result<String> {
77    let mut out = String::with_capacity(template.len());
78    let mut chars = template.chars().peekable();
79    while let Some(c) = chars.next() {
80        match c {
81            '{' if chars.peek() == Some(&'{') => {
82                chars.next();
83                out.push('{');
84            }
85            '}' if chars.peek() == Some(&'}') => {
86                chars.next();
87                out.push('}');
88            }
89            '{' => {
90                let mut name = String::new();
91                let mut closed = false;
92                for nc in chars.by_ref() {
93                    if nc == '}' {
94                        closed = true;
95                        break;
96                    }
97                    name.push(nc);
98                }
99                if !closed {
100                    return Err(CognisError::Configuration(format!(
101                        "unclosed `{{` in template: {template}"
102                    )));
103                }
104                let key = name.trim();
105                let resolved = lookup(ctx, key).ok_or_else(|| {
106                    CognisError::Configuration(format!("missing template variable `{key}`"))
107                })?;
108                out.push_str(&value_to_string(&resolved));
109            }
110            other => out.push(other),
111        }
112    }
113    Ok(out)
114}
115
116/// Find all `{name}` placeholders in first-occurrence order.
117pub(crate) fn scan_variables(template: &str) -> Vec<String> {
118    let mut out = Vec::new();
119    let mut chars = template.chars().peekable();
120    while let Some(c) = chars.next() {
121        match c {
122            '{' if chars.peek() == Some(&'{') => {
123                chars.next();
124            }
125            '}' if chars.peek() == Some(&'}') => {
126                chars.next();
127            }
128            '{' => {
129                let mut name = String::new();
130                for nc in chars.by_ref() {
131                    if nc == '}' {
132                        break;
133                    }
134                    name.push(nc);
135                }
136                let trimmed = name.trim().to_string();
137                if !trimmed.is_empty() && !out.contains(&trimmed) {
138                    out.push(trimmed);
139                }
140            }
141            _ => {}
142        }
143    }
144    out
145}
146
147fn lookup(ctx: &Value, key: &str) -> Option<Value> {
148    let mut cur = ctx.clone();
149    for segment in key.split('.') {
150        cur = match cur {
151            Value::Object(mut m) => m.remove(segment)?,
152            _ => return None,
153        };
154    }
155    Some(cur)
156}
157
158fn value_to_string(v: &Value) -> String {
159    match v {
160        Value::String(s) => s.clone(),
161        Value::Null => String::new(),
162        v => v.to_string(),
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use serde_json::json;
170
171    #[tokio::test]
172    async fn renders_simple() {
173        let p = PromptTemplate::<Value>::new("hello {name}");
174        let out = p
175            .invoke(json!({"name": "world"}), RunnableConfig::default())
176            .await
177            .unwrap();
178        assert_eq!(out, "hello world");
179    }
180
181    #[test]
182    fn renders_typed_struct() {
183        #[derive(Serialize)]
184        struct Ctx {
185            name: String,
186        }
187        let p: PromptTemplate<Ctx> = PromptTemplate::new("hi {name}");
188        let out = p
189            .render(&Ctx {
190                name: "rust".into(),
191            })
192            .unwrap();
193        assert_eq!(out, "hi rust");
194    }
195
196    #[test]
197    fn dotted_paths() {
198        let p: PromptTemplate<Value> = PromptTemplate::new("{user.name} aged {user.age}");
199        let out = p
200            .render(&json!({"user": {"name": "Ada", "age": 36}}))
201            .unwrap();
202        assert_eq!(out, "Ada aged 36");
203    }
204
205    #[test]
206    fn literal_braces() {
207        let p: PromptTemplate<Value> = PromptTemplate::new("{{not a var}} {x}");
208        let out = p.render(&json!({"x": 1})).unwrap();
209        assert_eq!(out, "{not a var} 1");
210    }
211
212    #[test]
213    fn missing_variable_errors() {
214        let p: PromptTemplate<Value> = PromptTemplate::new("hi {name}");
215        let err = p.render(&json!({})).unwrap_err();
216        assert!(matches!(err, CognisError::Configuration(_)));
217    }
218
219    #[test]
220    fn unclosed_brace_errors() {
221        let p: PromptTemplate<Value> = PromptTemplate::new("hi {name");
222        let err = p.render(&json!({"name": "x"})).unwrap_err();
223        assert!(matches!(err, CognisError::Configuration(_)));
224    }
225
226    #[test]
227    fn input_variables_returns_unique_in_order() {
228        let p: PromptTemplate<Value> = PromptTemplate::new("{a} {b} {a} {c}");
229        assert_eq!(p.input_variables(), vec!["a", "b", "c"]);
230    }
231}