1use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum PromptError {
16 MissingVariable(String),
18 TemplateNotFound(String),
20 RenderError(String),
22}
23
24impl std::fmt::Display for PromptError {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 PromptError::MissingVariable(v) => write!(f, "missing required variable: {v}"),
28 PromptError::TemplateNotFound(n) => write!(f, "template not found: {n}"),
29 PromptError::RenderError(msg) => write!(f, "render error: {msg}"),
30 }
31 }
32}
33
34impl std::error::Error for PromptError {}
35
36#[derive(Debug, Clone, PartialEq)]
46pub struct PromptTemplate {
47 name: String,
48 template: String,
49 required_vars: Vec<String>,
50 optional_vars: Vec<String>,
51}
52
53impl PromptTemplate {
54 pub fn new(name: impl Into<String>, template: impl Into<String>) -> Self {
56 Self {
57 name: name.into(),
58 template: template.into(),
59 required_vars: Vec::new(),
60 optional_vars: Vec::new(),
61 }
62 }
63
64 pub fn required(mut self, var: impl Into<String>) -> Self {
66 self.required_vars.push(var.into());
67 self
68 }
69
70 pub fn optional(mut self, var: impl Into<String>) -> Self {
72 self.optional_vars.push(var.into());
73 self
74 }
75
76 pub fn name(&self) -> &str {
78 &self.name
79 }
80
81 pub fn raw(&self) -> &str {
83 &self.template
84 }
85
86 pub fn variables(&self) -> Vec<&str> {
89 let mut vars: Vec<&str> = self
90 .required_vars
91 .iter()
92 .chain(self.optional_vars.iter())
93 .map(String::as_str)
94 .collect();
95 vars.sort_unstable();
96 vars.dedup();
97 vars
98 }
99
100 pub fn validate(&self, vars: &HashMap<String, String>) -> Vec<String> {
102 self.required_vars
103 .iter()
104 .filter(|v| !vars.contains_key(v.as_str()))
105 .cloned()
106 .collect()
107 }
108
109 pub fn render(&self, vars: &HashMap<String, String>) -> Result<String, PromptError> {
115 let missing = self.validate(vars);
117 if let Some(first) = missing.into_iter().next() {
118 return Err(PromptError::MissingVariable(first));
119 }
120
121 let mut result = self.template.clone();
122 result = Self::substitute(&result, vars);
125 Ok(result)
126 }
127
128 fn substitute(template: &str, vars: &HashMap<String, String>) -> String {
131 let mut output = String::with_capacity(template.len());
132 let mut chars = template.chars().peekable();
133
134 while let Some(c) = chars.next() {
135 if c == '{' && chars.peek() == Some(&'{') {
136 chars.next(); let mut key = String::new();
139 let mut closed = false;
140 while let Some(k) = chars.next() {
141 if k == '}' && chars.peek() == Some(&'}') {
142 chars.next(); closed = true;
144 break;
145 }
146 key.push(k);
147 }
148 if closed {
149 let key = key.trim().to_owned();
150 let value = vars.get(&key).map(String::as_str).unwrap_or("");
151 output.push_str(value);
152 } else {
153 output.push_str("{{");
155 output.push_str(&key);
156 }
157 } else {
158 output.push(c);
159 }
160 }
161
162 output
163 }
164}
165
166#[derive(Debug, Default, Clone)]
175pub struct PromptBuilder {
176 templates: HashMap<String, PromptTemplate>,
177 global_vars: HashMap<String, String>,
178}
179
180impl PromptBuilder {
181 pub fn new() -> Self {
183 Self {
184 templates: HashMap::new(),
185 global_vars: HashMap::new(),
186 }
187 }
188
189 pub fn add_template(&mut self, template: PromptTemplate) {
192 self.templates.insert(template.name.clone(), template);
193 }
194
195 pub fn set_global(&mut self, key: impl Into<String>, value: impl Into<String>) {
197 self.global_vars.insert(key.into(), value.into());
198 }
199
200 pub fn build(
204 &self,
205 template_name: &str,
206 local_vars: HashMap<String, String>,
207 ) -> Result<String, PromptError> {
208 let template = self
209 .templates
210 .get(template_name)
211 .ok_or_else(|| PromptError::TemplateNotFound(template_name.to_owned()))?;
212
213 let mut merged = self.global_vars.clone();
215 merged.extend(local_vars);
216
217 template.render(&merged)
218 }
219
220 pub fn template_count(&self) -> usize {
222 self.templates.len()
223 }
224
225 pub fn list_templates(&self) -> Vec<&str> {
227 let mut names: Vec<&str> = self.templates.keys().map(String::as_str).collect();
228 names.sort_unstable();
229 names
230 }
231}
232
233#[cfg(test)]
238mod tests {
239 use super::*;
240
241 fn vars(pairs: &[(&str, &str)]) -> HashMap<String, String> {
242 pairs
243 .iter()
244 .map(|(k, v)| (k.to_string(), v.to_string()))
245 .collect()
246 }
247
248 #[test]
251 fn test_prompt_error_display_missing() {
252 let e = PromptError::MissingVariable("x".into());
253 assert!(e.to_string().contains("x"));
254 }
255
256 #[test]
257 fn test_prompt_error_display_not_found() {
258 let e = PromptError::TemplateNotFound("tmpl".into());
259 assert!(e.to_string().contains("tmpl"));
260 }
261
262 #[test]
263 fn test_prompt_error_display_render() {
264 let e = PromptError::RenderError("oops".into());
265 assert!(e.to_string().contains("oops"));
266 }
267
268 #[test]
269 fn test_prompt_error_equality() {
270 assert_eq!(
271 PromptError::MissingVariable("a".into()),
272 PromptError::MissingVariable("a".into())
273 );
274 assert_ne!(
275 PromptError::MissingVariable("a".into()),
276 PromptError::MissingVariable("b".into())
277 );
278 }
279
280 #[test]
283 fn test_template_new() {
284 let t = PromptTemplate::new("greet", "Hello, {{name}}!");
285 assert_eq!(t.name(), "greet");
286 assert_eq!(t.raw(), "Hello, {{name}}!");
287 }
288
289 #[test]
290 fn test_template_required() {
291 let t = PromptTemplate::new("t", "{{a}} {{b}}")
292 .required("a")
293 .required("b");
294 assert_eq!(t.required_vars, vec!["a", "b"]);
295 }
296
297 #[test]
298 fn test_template_optional() {
299 let t = PromptTemplate::new("t", "{{a}}{{b}}").optional("b");
300 assert_eq!(t.optional_vars, vec!["b"]);
301 }
302
303 #[test]
306 fn test_variables_combined() {
307 let t = PromptTemplate::new("t", "{{a}} {{b}} {{c}}")
308 .required("a")
309 .optional("b")
310 .optional("c");
311 let v = t.variables();
312 assert!(v.contains(&"a"));
313 assert!(v.contains(&"b"));
314 assert!(v.contains(&"c"));
315 }
316
317 #[test]
318 fn test_variables_deduplicated() {
319 let t = PromptTemplate::new("t", "{{a}}")
320 .required("a")
321 .optional("a");
322 let v = t.variables();
323 assert_eq!(v.iter().filter(|&&x| x == "a").count(), 1);
324 }
325
326 #[test]
329 fn test_validate_no_missing() {
330 let t = PromptTemplate::new("t", "{{a}}").required("a");
331 let missing = t.validate(&vars(&[("a", "value")]));
332 assert!(missing.is_empty());
333 }
334
335 #[test]
336 fn test_validate_missing_required() {
337 let t = PromptTemplate::new("t", "{{a}} {{b}}")
338 .required("a")
339 .required("b");
340 let missing = t.validate(&vars(&[("a", "hello")]));
341 assert!(missing.contains(&"b".to_string()));
342 }
343
344 #[test]
345 fn test_validate_optional_not_missing() {
346 let t = PromptTemplate::new("t", "{{a}}").optional("a");
347 let missing = t.validate(&HashMap::new());
349 assert!(missing.is_empty());
350 }
351
352 #[test]
355 fn test_render_simple_substitution() {
356 let t = PromptTemplate::new("greet", "Hello, {{name}}!").required("name");
357 let result = t
358 .render(&vars(&[("name", "World")]))
359 .expect("should succeed");
360 assert_eq!(result, "Hello, World!");
361 }
362
363 #[test]
364 fn test_render_multiple_vars() {
365 let t = PromptTemplate::new("t", "{{a}} and {{b}}")
366 .required("a")
367 .required("b");
368 let result = t
369 .render(&vars(&[("a", "foo"), ("b", "bar")]))
370 .expect("should succeed");
371 assert_eq!(result, "foo and bar");
372 }
373
374 #[test]
375 fn test_render_repeated_var() {
376 let t = PromptTemplate::new("t", "{{x}} {{x}} {{x}}").required("x");
377 let result = t.render(&vars(&[("x", "go")])).expect("should succeed");
378 assert_eq!(result, "go go go");
379 }
380
381 #[test]
382 fn test_render_optional_missing_is_empty_string() {
383 let t = PromptTemplate::new("t", "start {{opt}} end").optional("opt");
384 let result = t.render(&HashMap::new()).expect("should succeed");
385 assert_eq!(result, "start end");
386 }
387
388 #[test]
389 fn test_render_missing_required_returns_error() {
390 let t = PromptTemplate::new("t", "{{req}}").required("req");
391 let err = t.render(&HashMap::new()).unwrap_err();
392 assert!(matches!(err, PromptError::MissingVariable(_)));
393 }
394
395 #[test]
396 fn test_render_no_placeholders() {
397 let t = PromptTemplate::new("t", "Hello, World!");
398 let result = t.render(&HashMap::new()).expect("should succeed");
399 assert_eq!(result, "Hello, World!");
400 }
401
402 #[test]
403 fn test_render_whitespace_in_placeholder() {
404 let t = PromptTemplate::new("t", "{{ name }}").optional("name");
405 let result = t
406 .render(&vars(&[("name", "Alice")]))
407 .expect("should succeed");
408 assert_eq!(result, "Alice");
409 }
410
411 #[test]
412 fn test_render_empty_template() {
413 let t = PromptTemplate::new("t", "");
414 let result = t.render(&HashMap::new()).expect("should succeed");
415 assert_eq!(result, "");
416 }
417
418 #[test]
421 fn test_builder_new_empty() {
422 let b = PromptBuilder::new();
423 assert_eq!(b.template_count(), 0);
424 assert!(b.list_templates().is_empty());
425 }
426
427 #[test]
428 fn test_builder_add_template() {
429 let mut b = PromptBuilder::new();
430 b.add_template(PromptTemplate::new("t1", "hello"));
431 assert_eq!(b.template_count(), 1);
432 }
433
434 #[test]
435 fn test_builder_list_templates_sorted() {
436 let mut b = PromptBuilder::new();
437 b.add_template(PromptTemplate::new("c", "c"));
438 b.add_template(PromptTemplate::new("a", "a"));
439 b.add_template(PromptTemplate::new("b", "b"));
440 assert_eq!(b.list_templates(), vec!["a", "b", "c"]);
441 }
442
443 #[test]
444 fn test_builder_build_basic() {
445 let mut b = PromptBuilder::new();
446 b.add_template(PromptTemplate::new("hi", "Hi {{name}}!").required("name"));
447 let result = b
448 .build("hi", vars(&[("name", "Alice")]))
449 .expect("should succeed");
450 assert_eq!(result, "Hi Alice!");
451 }
452
453 #[test]
454 fn test_builder_build_not_found() {
455 let b = PromptBuilder::new();
456 let err = b.build("missing", HashMap::new()).unwrap_err();
457 assert!(matches!(err, PromptError::TemplateNotFound(_)));
458 }
459
460 #[test]
461 fn test_builder_global_vars() {
462 let mut b = PromptBuilder::new();
463 b.set_global("lang", "Rust");
464 b.add_template(PromptTemplate::new("prog", "I love {{lang}}!").optional("lang"));
465 let result = b.build("prog", HashMap::new()).expect("should succeed");
466 assert_eq!(result, "I love Rust!");
467 }
468
469 #[test]
470 fn test_builder_local_overrides_global() {
471 let mut b = PromptBuilder::new();
472 b.set_global("lang", "Rust");
473 b.add_template(PromptTemplate::new("prog", "Language: {{lang}}").optional("lang"));
474 let result = b
475 .build("prog", vars(&[("lang", "Python")]))
476 .expect("should succeed");
477 assert_eq!(result, "Language: Python");
478 }
479
480 #[test]
481 fn test_builder_replace_template() {
482 let mut b = PromptBuilder::new();
483 b.add_template(PromptTemplate::new("t", "version 1"));
484 b.add_template(PromptTemplate::new("t", "version 2"));
485 assert_eq!(b.template_count(), 1);
486 let result = b.build("t", HashMap::new()).expect("should succeed");
487 assert_eq!(result, "version 2");
488 }
489
490 #[test]
491 fn test_builder_multiple_templates() {
492 let mut b = PromptBuilder::new();
493 b.add_template(PromptTemplate::new("a", "{{x}}").required("x"));
494 b.add_template(PromptTemplate::new("b", "{{y}}").required("y"));
495
496 assert_eq!(
497 b.build("a", vars(&[("x", "1")])).expect("should succeed"),
498 "1"
499 );
500 assert_eq!(
501 b.build("b", vars(&[("y", "2")])).expect("should succeed"),
502 "2"
503 );
504 }
505
506 #[test]
507 fn test_builder_global_plus_local_mix() {
508 let mut b = PromptBuilder::new();
509 b.set_global("system", "OxiRS");
510 b.add_template(
511 PromptTemplate::new("intro", "{{system}} welcomes {{user}}")
512 .optional("system")
513 .required("user"),
514 );
515 let result = b
516 .build("intro", vars(&[("user", "Bob")]))
517 .expect("should succeed");
518 assert_eq!(result, "OxiRS welcomes Bob");
519 }
520
521 #[test]
522 fn test_builder_missing_required_error() {
523 let mut b = PromptBuilder::new();
524 b.add_template(PromptTemplate::new("t", "{{req}}").required("req"));
525 let err = b.build("t", HashMap::new()).unwrap_err();
526 assert!(matches!(err, PromptError::MissingVariable(_)));
527 }
528
529 #[test]
530 fn test_builder_build_multiline_template() {
531 let tmpl = "Line 1: {{a}}\nLine 2: {{b}}\nLine 3: {{a}}";
532 let mut b = PromptBuilder::new();
533 b.add_template(
534 PromptTemplate::new("multi", tmpl)
535 .required("a")
536 .required("b"),
537 );
538 let result = b
539 .build("multi", vars(&[("a", "hello"), ("b", "world")]))
540 .expect("should succeed");
541 assert_eq!(result, "Line 1: hello\nLine 2: world\nLine 3: hello");
542 }
543
544 #[test]
545 fn test_template_clone() {
546 let t = PromptTemplate::new("t", "{{x}}").required("x");
547 let t2 = t.clone();
548 assert_eq!(t, t2);
549 }
550
551 #[test]
552 fn test_builder_default() {
553 let b = PromptBuilder::default();
554 assert_eq!(b.template_count(), 0);
555 }
556
557 #[test]
558 fn test_builder_set_multiple_globals() {
559 let mut b = PromptBuilder::new();
560 b.set_global("a", "1");
561 b.set_global("b", "2");
562 b.set_global("a", "3"); b.add_template(
564 PromptTemplate::new("t", "{{a}} {{b}}")
565 .optional("a")
566 .optional("b"),
567 );
568 let result = b.build("t", HashMap::new()).expect("should succeed");
569 assert_eq!(result, "3 2");
570 }
571}