Skip to main content

oxirs_chat/
prompt_builder.rs

1//! Prompt template builder with variable substitution and validation.
2//!
3//! Templates use `{{variable_name}}` as the substitution syntax.  Required
4//! variables must be supplied at render time; optional variables use an empty
5//! string if omitted.
6
7use std::collections::HashMap;
8
9// ──────────────────────────────────────────────────────────────────────────────
10// Error type
11// ──────────────────────────────────────────────────────────────────────────────
12
13/// Errors produced by prompt rendering.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum PromptError {
16    /// A required variable was not supplied.
17    MissingVariable(String),
18    /// The requested template was not found in the builder.
19    TemplateNotFound(String),
20    /// An internal rendering error (e.g. malformed template syntax).
21    RenderError(String),
22}
23
24impl std::fmt::Display for PromptError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            PromptError::MissingVariable(v) => write!(f, "missing required variable: {v}"),
28            PromptError::TemplateNotFound(n) => write!(f, "template not found: {n}"),
29            PromptError::RenderError(msg) => write!(f, "render error: {msg}"),
30        }
31    }
32}
33
34impl std::error::Error for PromptError {}
35
36// ──────────────────────────────────────────────────────────────────────────────
37// PromptTemplate
38// ──────────────────────────────────────────────────────────────────────────────
39
40/// A reusable prompt template with named variable placeholders.
41///
42/// Placeholders use the `{{variable_name}}` syntax.  Required variables must
43/// be present in the variable map supplied to `render`; optional variables
44/// default to an empty string when absent.
45#[derive(Debug, Clone, PartialEq)]
46pub struct PromptTemplate {
47    name: String,
48    template: String,
49    required_vars: Vec<String>,
50    optional_vars: Vec<String>,
51}
52
53impl PromptTemplate {
54    /// Create a new template.
55    pub fn new(name: impl Into<String>, template: impl Into<String>) -> Self {
56        Self {
57            name: name.into(),
58            template: template.into(),
59            required_vars: Vec::new(),
60            optional_vars: Vec::new(),
61        }
62    }
63
64    /// Declare a variable as required (builder pattern).
65    pub fn required(mut self, var: impl Into<String>) -> Self {
66        self.required_vars.push(var.into());
67        self
68    }
69
70    /// Declare a variable as optional (builder pattern).
71    pub fn optional(mut self, var: impl Into<String>) -> Self {
72        self.optional_vars.push(var.into());
73        self
74    }
75
76    /// Return the template name.
77    pub fn name(&self) -> &str {
78        &self.name
79    }
80
81    /// Return the raw template string.
82    pub fn raw(&self) -> &str {
83        &self.template
84    }
85
86    /// Return all variable names (required and optional) mentioned in the
87    /// declared lists.
88    pub fn variables(&self) -> Vec<&str> {
89        let mut vars: Vec<&str> = self
90            .required_vars
91            .iter()
92            .chain(self.optional_vars.iter())
93            .map(String::as_str)
94            .collect();
95        vars.sort_unstable();
96        vars.dedup();
97        vars
98    }
99
100    /// Return a list of required variables that are missing from `vars`.
101    pub fn validate(&self, vars: &HashMap<String, String>) -> Vec<String> {
102        self.required_vars
103            .iter()
104            .filter(|v| !vars.contains_key(v.as_str()))
105            .cloned()
106            .collect()
107    }
108
109    /// Render the template by substituting `{{key}}` placeholders.
110    ///
111    /// Returns `Err(PromptError::MissingVariable)` if any required variable is
112    /// absent from `vars`.  Optional variables that are absent are replaced
113    /// with an empty string.
114    pub fn render(&self, vars: &HashMap<String, String>) -> Result<String, PromptError> {
115        // Check required variables.
116        let missing = self.validate(vars);
117        if let Some(first) = missing.into_iter().next() {
118            return Err(PromptError::MissingVariable(first));
119        }
120
121        let mut result = self.template.clone();
122        // Substitute all `{{key}}` occurrences.
123        // We do a single linear scan with a simple state machine.
124        result = Self::substitute(&result, vars);
125        Ok(result)
126    }
127
128    // ── Private helpers ───────────────────────────────────────────────────────
129
130    fn substitute(template: &str, vars: &HashMap<String, String>) -> String {
131        let mut output = String::with_capacity(template.len());
132        let mut chars = template.chars().peekable();
133
134        while let Some(c) = chars.next() {
135            if c == '{' && chars.peek() == Some(&'{') {
136                chars.next(); // consume second '{'
137                              // Collect the key until "}}"
138                let mut key = String::new();
139                let mut closed = false;
140                while let Some(k) = chars.next() {
141                    if k == '}' && chars.peek() == Some(&'}') {
142                        chars.next(); // consume second '}'
143                        closed = true;
144                        break;
145                    }
146                    key.push(k);
147                }
148                if closed {
149                    let key = key.trim().to_owned();
150                    let value = vars.get(&key).map(String::as_str).unwrap_or("");
151                    output.push_str(value);
152                } else {
153                    // Unclosed placeholder — emit as-is.
154                    output.push_str("{{");
155                    output.push_str(&key);
156                }
157            } else {
158                output.push(c);
159            }
160        }
161
162        output
163    }
164}
165
166// ──────────────────────────────────────────────────────────────────────────────
167// PromptBuilder
168// ──────────────────────────────────────────────────────────────────────────────
169
170/// A registry of named prompt templates with optional global variables.
171///
172/// Global variables are merged with local variables at render time; local
173/// variables take precedence over globals with the same key.
174#[derive(Debug, Default, Clone)]
175pub struct PromptBuilder {
176    templates: HashMap<String, PromptTemplate>,
177    global_vars: HashMap<String, String>,
178}
179
180impl PromptBuilder {
181    /// Create an empty builder.
182    pub fn new() -> Self {
183        Self {
184            templates: HashMap::new(),
185            global_vars: HashMap::new(),
186        }
187    }
188
189    /// Register a template.  If a template with the same name already exists
190    /// it is replaced.
191    pub fn add_template(&mut self, template: PromptTemplate) {
192        self.templates.insert(template.name.clone(), template);
193    }
194
195    /// Set a global variable available to all templates.
196    pub fn set_global(&mut self, key: impl Into<String>, value: impl Into<String>) {
197        self.global_vars.insert(key.into(), value.into());
198    }
199
200    /// Render a template by name with the supplied local variables.
201    ///
202    /// Local variables shadow global variables with the same key.
203    pub fn build(
204        &self,
205        template_name: &str,
206        local_vars: HashMap<String, String>,
207    ) -> Result<String, PromptError> {
208        let template = self
209            .templates
210            .get(template_name)
211            .ok_or_else(|| PromptError::TemplateNotFound(template_name.to_owned()))?;
212
213        // Merge: global vars first, then local vars override.
214        let mut merged = self.global_vars.clone();
215        merged.extend(local_vars);
216
217        template.render(&merged)
218    }
219
220    /// Return the number of registered templates.
221    pub fn template_count(&self) -> usize {
222        self.templates.len()
223    }
224
225    /// Return the names of all registered templates, sorted.
226    pub fn list_templates(&self) -> Vec<&str> {
227        let mut names: Vec<&str> = self.templates.keys().map(String::as_str).collect();
228        names.sort_unstable();
229        names
230    }
231}
232
233// ──────────────────────────────────────────────────────────────────────────────
234// Tests
235// ──────────────────────────────────────────────────────────────────────────────
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    fn vars(pairs: &[(&str, &str)]) -> HashMap<String, String> {
242        pairs
243            .iter()
244            .map(|(k, v)| (k.to_string(), v.to_string()))
245            .collect()
246    }
247
248    // ── PromptError ───────────────────────────────────────────────────────────
249
250    #[test]
251    fn test_prompt_error_display_missing() {
252        let e = PromptError::MissingVariable("x".into());
253        assert!(e.to_string().contains("x"));
254    }
255
256    #[test]
257    fn test_prompt_error_display_not_found() {
258        let e = PromptError::TemplateNotFound("tmpl".into());
259        assert!(e.to_string().contains("tmpl"));
260    }
261
262    #[test]
263    fn test_prompt_error_display_render() {
264        let e = PromptError::RenderError("oops".into());
265        assert!(e.to_string().contains("oops"));
266    }
267
268    #[test]
269    fn test_prompt_error_equality() {
270        assert_eq!(
271            PromptError::MissingVariable("a".into()),
272            PromptError::MissingVariable("a".into())
273        );
274        assert_ne!(
275            PromptError::MissingVariable("a".into()),
276            PromptError::MissingVariable("b".into())
277        );
278    }
279
280    // ── PromptTemplate construction ───────────────────────────────────────────
281
282    #[test]
283    fn test_template_new() {
284        let t = PromptTemplate::new("greet", "Hello, {{name}}!");
285        assert_eq!(t.name(), "greet");
286        assert_eq!(t.raw(), "Hello, {{name}}!");
287    }
288
289    #[test]
290    fn test_template_required() {
291        let t = PromptTemplate::new("t", "{{a}} {{b}}")
292            .required("a")
293            .required("b");
294        assert_eq!(t.required_vars, vec!["a", "b"]);
295    }
296
297    #[test]
298    fn test_template_optional() {
299        let t = PromptTemplate::new("t", "{{a}}{{b}}").optional("b");
300        assert_eq!(t.optional_vars, vec!["b"]);
301    }
302
303    // ── variables() ───────────────────────────────────────────────────────────
304
305    #[test]
306    fn test_variables_combined() {
307        let t = PromptTemplate::new("t", "{{a}} {{b}} {{c}}")
308            .required("a")
309            .optional("b")
310            .optional("c");
311        let v = t.variables();
312        assert!(v.contains(&"a"));
313        assert!(v.contains(&"b"));
314        assert!(v.contains(&"c"));
315    }
316
317    #[test]
318    fn test_variables_deduplicated() {
319        let t = PromptTemplate::new("t", "{{a}}")
320            .required("a")
321            .optional("a");
322        let v = t.variables();
323        assert_eq!(v.iter().filter(|&&x| x == "a").count(), 1);
324    }
325
326    // ── validate() ────────────────────────────────────────────────────────────
327
328    #[test]
329    fn test_validate_no_missing() {
330        let t = PromptTemplate::new("t", "{{a}}").required("a");
331        let missing = t.validate(&vars(&[("a", "value")]));
332        assert!(missing.is_empty());
333    }
334
335    #[test]
336    fn test_validate_missing_required() {
337        let t = PromptTemplate::new("t", "{{a}} {{b}}")
338            .required("a")
339            .required("b");
340        let missing = t.validate(&vars(&[("a", "hello")]));
341        assert!(missing.contains(&"b".to_string()));
342    }
343
344    #[test]
345    fn test_validate_optional_not_missing() {
346        let t = PromptTemplate::new("t", "{{a}}").optional("a");
347        // Optional vars not in `vars` should NOT appear in missing
348        let missing = t.validate(&HashMap::new());
349        assert!(missing.is_empty());
350    }
351
352    // ── render() ──────────────────────────────────────────────────────────────
353
354    #[test]
355    fn test_render_simple_substitution() {
356        let t = PromptTemplate::new("greet", "Hello, {{name}}!").required("name");
357        let result = t
358            .render(&vars(&[("name", "World")]))
359            .expect("should succeed");
360        assert_eq!(result, "Hello, World!");
361    }
362
363    #[test]
364    fn test_render_multiple_vars() {
365        let t = PromptTemplate::new("t", "{{a}} and {{b}}")
366            .required("a")
367            .required("b");
368        let result = t
369            .render(&vars(&[("a", "foo"), ("b", "bar")]))
370            .expect("should succeed");
371        assert_eq!(result, "foo and bar");
372    }
373
374    #[test]
375    fn test_render_repeated_var() {
376        let t = PromptTemplate::new("t", "{{x}} {{x}} {{x}}").required("x");
377        let result = t.render(&vars(&[("x", "go")])).expect("should succeed");
378        assert_eq!(result, "go go go");
379    }
380
381    #[test]
382    fn test_render_optional_missing_is_empty_string() {
383        let t = PromptTemplate::new("t", "start {{opt}} end").optional("opt");
384        let result = t.render(&HashMap::new()).expect("should succeed");
385        assert_eq!(result, "start  end");
386    }
387
388    #[test]
389    fn test_render_missing_required_returns_error() {
390        let t = PromptTemplate::new("t", "{{req}}").required("req");
391        let err = t.render(&HashMap::new()).unwrap_err();
392        assert!(matches!(err, PromptError::MissingVariable(_)));
393    }
394
395    #[test]
396    fn test_render_no_placeholders() {
397        let t = PromptTemplate::new("t", "Hello, World!");
398        let result = t.render(&HashMap::new()).expect("should succeed");
399        assert_eq!(result, "Hello, World!");
400    }
401
402    #[test]
403    fn test_render_whitespace_in_placeholder() {
404        let t = PromptTemplate::new("t", "{{ name }}").optional("name");
405        let result = t
406            .render(&vars(&[("name", "Alice")]))
407            .expect("should succeed");
408        assert_eq!(result, "Alice");
409    }
410
411    #[test]
412    fn test_render_empty_template() {
413        let t = PromptTemplate::new("t", "");
414        let result = t.render(&HashMap::new()).expect("should succeed");
415        assert_eq!(result, "");
416    }
417
418    // ── PromptBuilder ─────────────────────────────────────────────────────────
419
420    #[test]
421    fn test_builder_new_empty() {
422        let b = PromptBuilder::new();
423        assert_eq!(b.template_count(), 0);
424        assert!(b.list_templates().is_empty());
425    }
426
427    #[test]
428    fn test_builder_add_template() {
429        let mut b = PromptBuilder::new();
430        b.add_template(PromptTemplate::new("t1", "hello"));
431        assert_eq!(b.template_count(), 1);
432    }
433
434    #[test]
435    fn test_builder_list_templates_sorted() {
436        let mut b = PromptBuilder::new();
437        b.add_template(PromptTemplate::new("c", "c"));
438        b.add_template(PromptTemplate::new("a", "a"));
439        b.add_template(PromptTemplate::new("b", "b"));
440        assert_eq!(b.list_templates(), vec!["a", "b", "c"]);
441    }
442
443    #[test]
444    fn test_builder_build_basic() {
445        let mut b = PromptBuilder::new();
446        b.add_template(PromptTemplate::new("hi", "Hi {{name}}!").required("name"));
447        let result = b
448            .build("hi", vars(&[("name", "Alice")]))
449            .expect("should succeed");
450        assert_eq!(result, "Hi Alice!");
451    }
452
453    #[test]
454    fn test_builder_build_not_found() {
455        let b = PromptBuilder::new();
456        let err = b.build("missing", HashMap::new()).unwrap_err();
457        assert!(matches!(err, PromptError::TemplateNotFound(_)));
458    }
459
460    #[test]
461    fn test_builder_global_vars() {
462        let mut b = PromptBuilder::new();
463        b.set_global("lang", "Rust");
464        b.add_template(PromptTemplate::new("prog", "I love {{lang}}!").optional("lang"));
465        let result = b.build("prog", HashMap::new()).expect("should succeed");
466        assert_eq!(result, "I love Rust!");
467    }
468
469    #[test]
470    fn test_builder_local_overrides_global() {
471        let mut b = PromptBuilder::new();
472        b.set_global("lang", "Rust");
473        b.add_template(PromptTemplate::new("prog", "Language: {{lang}}").optional("lang"));
474        let result = b
475            .build("prog", vars(&[("lang", "Python")]))
476            .expect("should succeed");
477        assert_eq!(result, "Language: Python");
478    }
479
480    #[test]
481    fn test_builder_replace_template() {
482        let mut b = PromptBuilder::new();
483        b.add_template(PromptTemplate::new("t", "version 1"));
484        b.add_template(PromptTemplate::new("t", "version 2"));
485        assert_eq!(b.template_count(), 1);
486        let result = b.build("t", HashMap::new()).expect("should succeed");
487        assert_eq!(result, "version 2");
488    }
489
490    #[test]
491    fn test_builder_multiple_templates() {
492        let mut b = PromptBuilder::new();
493        b.add_template(PromptTemplate::new("a", "{{x}}").required("x"));
494        b.add_template(PromptTemplate::new("b", "{{y}}").required("y"));
495
496        assert_eq!(
497            b.build("a", vars(&[("x", "1")])).expect("should succeed"),
498            "1"
499        );
500        assert_eq!(
501            b.build("b", vars(&[("y", "2")])).expect("should succeed"),
502            "2"
503        );
504    }
505
506    #[test]
507    fn test_builder_global_plus_local_mix() {
508        let mut b = PromptBuilder::new();
509        b.set_global("system", "OxiRS");
510        b.add_template(
511            PromptTemplate::new("intro", "{{system}} welcomes {{user}}")
512                .optional("system")
513                .required("user"),
514        );
515        let result = b
516            .build("intro", vars(&[("user", "Bob")]))
517            .expect("should succeed");
518        assert_eq!(result, "OxiRS welcomes Bob");
519    }
520
521    #[test]
522    fn test_builder_missing_required_error() {
523        let mut b = PromptBuilder::new();
524        b.add_template(PromptTemplate::new("t", "{{req}}").required("req"));
525        let err = b.build("t", HashMap::new()).unwrap_err();
526        assert!(matches!(err, PromptError::MissingVariable(_)));
527    }
528
529    #[test]
530    fn test_builder_build_multiline_template() {
531        let tmpl = "Line 1: {{a}}\nLine 2: {{b}}\nLine 3: {{a}}";
532        let mut b = PromptBuilder::new();
533        b.add_template(
534            PromptTemplate::new("multi", tmpl)
535                .required("a")
536                .required("b"),
537        );
538        let result = b
539            .build("multi", vars(&[("a", "hello"), ("b", "world")]))
540            .expect("should succeed");
541        assert_eq!(result, "Line 1: hello\nLine 2: world\nLine 3: hello");
542    }
543
544    #[test]
545    fn test_template_clone() {
546        let t = PromptTemplate::new("t", "{{x}}").required("x");
547        let t2 = t.clone();
548        assert_eq!(t, t2);
549    }
550
551    #[test]
552    fn test_builder_default() {
553        let b = PromptBuilder::default();
554        assert_eq!(b.template_count(), 0);
555    }
556
557    #[test]
558    fn test_builder_set_multiple_globals() {
559        let mut b = PromptBuilder::new();
560        b.set_global("a", "1");
561        b.set_global("b", "2");
562        b.set_global("a", "3"); // override
563        b.add_template(
564            PromptTemplate::new("t", "{{a}} {{b}}")
565                .optional("a")
566                .optional("b"),
567        );
568        let result = b.build("t", HashMap::new()).expect("should succeed");
569        assert_eq!(result, "3 2");
570    }
571}