1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
11pub enum PromptError {
12 #[error("Template not found: {0}")]
14 TemplateNotFound(String),
15 #[error("Required variable not provided: {0}")]
17 MissingVariable(String),
18 #[error("Variable type mismatch for '{name}': expected {expected}, got {actual}")]
20 TypeMismatch {
21 name: String,
22 expected: String,
23 actual: String,
24 },
25 #[error("Validation failed for variable '{name}': {reason}")]
27 ValidationFailed { name: String, reason: String },
28 #[error("Parse error: {0}")]
30 ParseError(String),
31 #[error("IO error: {0}")]
33 IoError(#[from] std::io::Error),
34 #[error("YAML error: {0}")]
36 YamlError(String),
37}
38
39pub type PromptResult<T> = Result<T, PromptError>;
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
44#[serde(rename_all = "lowercase")]
45pub enum VariableType {
46 #[default]
48 String,
49 Integer,
51 Float,
53 Boolean,
55 List,
57 Json,
59}
60
61impl VariableType {
62 pub fn validate(&self, value: &str) -> bool {
64 match self {
65 VariableType::String => true,
66 VariableType::Integer => value.parse::<i64>().is_ok(),
67 VariableType::Float => value.parse::<f64>().is_ok(),
68 VariableType::Boolean => {
69 matches!(value.to_lowercase().as_str(), "true" | "false" | "1" | "0")
70 }
71 VariableType::List => value.starts_with('[') && value.ends_with(']'),
72 VariableType::Json => serde_json::from_str::<serde_json::Value>(value).is_ok(),
73 }
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct PromptVariable {
80 pub name: String,
82 #[serde(default)]
84 pub description: Option<String>,
85 #[serde(default)]
87 pub var_type: VariableType,
88 #[serde(default = "default_true")]
90 pub required: bool,
91 #[serde(default)]
93 pub default: Option<String>,
94 #[serde(default)]
96 pub pattern: Option<String>,
97 #[serde(default)]
99 pub enum_values: Option<Vec<String>>,
100}
101
102fn default_true() -> bool {
103 true
104}
105
106impl PromptVariable {
107 pub fn new(name: impl Into<String>) -> Self {
109 Self {
110 name: name.into(),
111 description: None,
112 var_type: VariableType::String,
113 required: true,
114 default: None,
115 pattern: None,
116 enum_values: None,
117 }
118 }
119
120 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
122 self.description = Some(desc.into());
123 self
124 }
125
126 pub fn with_type(mut self, var_type: VariableType) -> Self {
128 self.var_type = var_type;
129 self
130 }
131
132 pub fn required(mut self, required: bool) -> Self {
134 self.required = required;
135 self
136 }
137
138 pub fn with_default(mut self, default: impl Into<String>) -> Self {
140 self.default = Some(default.into());
141 self.required = false;
142 self
143 }
144
145 pub fn with_pattern(mut self, pattern: impl Into<String>) -> Self {
147 self.pattern = Some(pattern.into());
148 self
149 }
150
151 pub fn with_enum(mut self, values: Vec<String>) -> Self {
153 self.enum_values = Some(values);
154 self
155 }
156
157 pub fn validate(&self, value: &str) -> PromptResult<()> {
159 if !self.var_type.validate(value) {
161 return Err(PromptError::TypeMismatch {
162 name: self.name.clone(),
163 expected: format!("{:?}", self.var_type),
164 actual: "invalid".to_string(),
165 });
166 }
167
168 if let Some(ref pattern) = self.pattern {
170 let re =
171 regex::Regex::new(pattern).map_err(|e| PromptError::ParseError(e.to_string()))?;
172 if !re.is_match(value) {
173 return Err(PromptError::ValidationFailed {
174 name: self.name.clone(),
175 reason: format!("Value does not match pattern: {}", pattern),
176 });
177 }
178 }
179
180 if let Some(ref enum_values) = self.enum_values
182 && !enum_values.contains(&value.to_string())
183 {
184 return Err(PromptError::ValidationFailed {
185 name: self.name.clone(),
186 reason: format!("Value must be one of: {:?}", enum_values),
187 });
188 }
189
190 Ok(())
191 }
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct PromptTemplate {
197 pub id: String,
199 #[serde(default)]
201 pub name: Option<String>,
202 #[serde(default)]
204 pub description: Option<String>,
205 #[serde(default)]
207 pub content: String,
208 #[serde(default)]
210 pub variables: Vec<PromptVariable>,
211 #[serde(default)]
213 pub tags: Vec<String>,
214 #[serde(default)]
216 pub version: Option<String>,
217 #[serde(default)]
219 pub metadata: HashMap<String, String>,
220}
221
222impl PromptTemplate {
223 pub fn new(id: impl Into<String>) -> Self {
225 Self {
226 id: id.into(),
227 name: None,
228 description: None,
229 content: String::new(),
230 variables: Vec::new(),
231 tags: Vec::new(),
232 version: None,
233 metadata: HashMap::new(),
234 }
235 }
236
237 pub fn with_name(mut self, name: impl Into<String>) -> Self {
239 self.name = Some(name.into());
240 self
241 }
242
243 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
245 self.description = Some(desc.into());
246 self
247 }
248
249 pub fn with_content(mut self, content: impl Into<String>) -> Self {
251 self.content = content.into();
252 self.parse_variables();
254 self
255 }
256
257 pub fn with_variable(mut self, variable: PromptVariable) -> Self {
259 self.variables.push(variable);
260 self
261 }
262
263 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
265 self.tags.push(tag.into());
266 self
267 }
268
269 pub fn with_version(mut self, version: impl Into<String>) -> Self {
271 self.version = Some(version.into());
272 self
273 }
274
275 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
277 self.metadata.insert(key.into(), value.into());
278 self
279 }
280
281 fn parse_variables(&mut self) {
283 }
286
287 pub fn variable_names(&self) -> Vec<&str> {
289 self.variables.iter().map(|v| v.name.as_str()).collect()
290 }
291
292 pub fn extract_variables(&self) -> Vec<String> {
294 let re = regex::Regex::new(r"\{(\w+)\}").unwrap();
295 let mut vars = std::collections::HashSet::new();
296
297 for cap in re.captures_iter(&self.content) {
298 vars.insert(cap[1].to_string());
299 }
300
301 vars.into_iter().collect()
302 }
303
304 pub fn required_variables(&self) -> Vec<&PromptVariable> {
306 self.variables.iter().filter(|v| v.required).collect()
307 }
308
309 pub fn render(&self, vars: &[(&str, &str)]) -> PromptResult<String> {
326 let var_map: HashMap<&str, &str> = vars.iter().copied().collect();
327 self.render_with_map(&var_map)
328 }
329
330 pub fn render_with_map(&self, vars: &HashMap<&str, &str>) -> PromptResult<String> {
332 let mut result = self.content.clone();
333
334 for var_def in &self.variables {
336 let placeholder = format!("{{{}}}", var_def.name);
337
338 if let Some(&value) = vars.get(var_def.name.as_str()) {
339 var_def.validate(value)?;
341 result = result.replace(&placeholder, value);
342 } else if let Some(ref default) = var_def.default {
343 result = result.replace(&placeholder, default);
345 } else if var_def.required {
346 return Err(PromptError::MissingVariable(var_def.name.clone()));
348 }
349 }
350
351 let re = regex::Regex::new(r"\{(\w+)\}").unwrap();
353 let defined_vars: std::collections::HashSet<_> =
354 self.variables.iter().map(|v| v.name.as_str()).collect();
355
356 let mut missing = Vec::new();
358 for cap in re.captures_iter(&result.clone()) {
359 let var_name = &cap[1];
360 if !defined_vars.contains(var_name) {
361 if let Some(&value) = vars.get(var_name) {
362 let placeholder = format!("{{{}}}", var_name);
363 result = result.replace(&placeholder, value);
364 } else {
365 missing.push(var_name.to_string());
366 }
367 }
368 }
369
370 if !missing.is_empty() {
372 return Err(PromptError::MissingVariable(missing.join(", ")));
373 }
374
375 Ok(result)
376 }
377
378 pub fn render_with_owned_map(&self, vars: &HashMap<String, String>) -> PromptResult<String> {
380 let borrowed: HashMap<&str, &str> =
381 vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
382 self.render_with_map(&borrowed)
383 }
384
385 pub fn partial_render(&self, vars: &[(&str, &str)]) -> String {
387 let var_map: HashMap<&str, &str> = vars.iter().copied().collect();
388 let mut result = self.content.clone();
389
390 for (name, value) in var_map {
391 let placeholder = format!("{{{}}}", name);
392 result = result.replace(&placeholder, value);
393 }
394
395 result
396 }
397
398 pub fn is_valid_with(&self, vars: &[&str]) -> bool {
400 let var_set: std::collections::HashSet<_> = vars.iter().copied().collect();
401
402 for var_def in &self.variables {
404 if var_def.required
405 && var_def.default.is_none()
406 && !var_set.contains(var_def.name.as_str())
407 {
408 return false;
409 }
410 }
411
412 let re = regex::Regex::new(r"\{(\w+)\}").unwrap();
414 let defined_vars: std::collections::HashSet<_> =
415 self.variables.iter().map(|v| v.name.as_str()).collect();
416
417 for cap in re.captures_iter(&self.content) {
418 let var_name = &cap[1];
419 if !defined_vars.contains(var_name) && !var_set.contains(var_name) {
421 return false;
422 }
423 }
424
425 true
426 }
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize)]
431pub struct PromptComposition {
432 pub id: String,
434 #[serde(default)]
436 pub description: Option<String>,
437 pub template_ids: Vec<String>,
439 #[serde(default = "default_separator")]
441 pub separator: String,
442}
443
444fn default_separator() -> String {
445 "\n\n".to_string()
446}
447
448impl PromptComposition {
449 pub fn new(id: impl Into<String>) -> Self {
451 Self {
452 id: id.into(),
453 description: None,
454 template_ids: Vec::new(),
455 separator: "\n\n".to_string(),
456 }
457 }
458
459 pub fn add_template(mut self, template_id: impl Into<String>) -> Self {
461 self.template_ids.push(template_id.into());
462 self
463 }
464
465 pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
467 self.separator = sep.into();
468 self
469 }
470
471 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
473 self.description = Some(desc.into());
474 self
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
483 fn test_template_basic() {
484 let template = PromptTemplate::new("test")
485 .with_content("Hello, {name}!")
486 .with_description("A greeting template");
487
488 assert_eq!(template.id, "test");
489 assert_eq!(template.extract_variables(), vec!["name"]);
490
491 let result = template.render(&[("name", "World")]).unwrap();
492 assert_eq!(result, "Hello, World!");
493 }
494
495 #[test]
496 fn test_template_multiple_vars() {
497 let template = PromptTemplate::new("test")
498 .with_content("Hello, {name}! Welcome to {place}. Your role is {role}.");
499
500 let result = template
501 .render(&[
502 ("name", "Alice"),
503 ("place", "Wonderland"),
504 ("role", "explorer"),
505 ])
506 .unwrap();
507
508 assert_eq!(
509 result,
510 "Hello, Alice! Welcome to Wonderland. Your role is explorer."
511 );
512 }
513
514 #[test]
515 fn test_template_with_default() {
516 let template = PromptTemplate::new("test")
517 .with_content("Hello, {name}!")
518 .with_variable(PromptVariable::new("name").with_default("World"));
519
520 let result = template.render(&[]).unwrap();
522 assert_eq!(result, "Hello, World!");
523
524 let result = template.render(&[("name", "Alice")]).unwrap();
526 assert_eq!(result, "Hello, Alice!");
527 }
528
529 #[test]
530 fn test_template_missing_required() {
531 let template = PromptTemplate::new("test").with_content("Hello, {name}!");
532
533 let result = template.render(&[]);
534 assert!(result.is_err());
535 assert!(matches!(
536 result.unwrap_err(),
537 PromptError::MissingVariable(_)
538 ));
539 }
540
541 #[test]
542 fn test_variable_type_validation() {
543 assert!(VariableType::String.validate("anything"));
544 assert!(VariableType::Integer.validate("123"));
545 assert!(!VariableType::Integer.validate("abc"));
546 assert!(VariableType::Float.validate("3.14"));
547 assert!(VariableType::Boolean.validate("true"));
548 assert!(VariableType::Boolean.validate("false"));
549 assert!(VariableType::Json.validate(r#"{"key": "value"}"#));
550 }
551
552 #[test]
553 fn test_variable_enum() {
554 let var = PromptVariable::new("language")
555 .with_enum(vec!["rust".to_string(), "python".to_string()]);
556
557 assert!(var.validate("rust").is_ok());
558 assert!(var.validate("python").is_ok());
559 assert!(var.validate("java").is_err());
560 }
561
562 #[test]
563 fn test_partial_render() {
564 let template =
565 PromptTemplate::new("test").with_content("Hello, {name}! Your {item} is ready.");
566
567 let result = template.partial_render(&[("name", "Alice")]);
568 assert_eq!(result, "Hello, Alice! Your {item} is ready.");
569 }
570
571 #[test]
572 fn test_is_valid_with() {
573 let template = PromptTemplate::new("test")
574 .with_content("{required_var} and {optional_var}")
575 .with_variable(PromptVariable::new("required_var"))
576 .with_variable(PromptVariable::new("optional_var").with_default("default"));
577
578 assert!(template.is_valid_with(&["required_var"]));
579 assert!(!template.is_valid_with(&[]));
580 assert!(!template.is_valid_with(&["optional_var"]));
581 }
582}