Skip to main content

mofa_foundation/prompt/
template.rs

1//! Prompt 模板引擎
2//!
3//! 提供强大的模板变量替换和验证功能
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use thiserror::Error;
8
9/// Prompt 模板错误
10#[derive(Debug, Error)]
11pub enum PromptError {
12    /// 模板未找到
13    #[error("Template not found: {0}")]
14    TemplateNotFound(String),
15    /// 变量未提供
16    #[error("Required variable not provided: {0}")]
17    MissingVariable(String),
18    /// 变量类型错误
19    #[error("Variable type mismatch for '{name}': expected {expected}, got {actual}")]
20    TypeMismatch {
21        name: String,
22        expected: String,
23        actual: String,
24    },
25    /// 验证失败
26    #[error("Validation failed for variable '{name}': {reason}")]
27    ValidationFailed { name: String, reason: String },
28    /// 解析错误
29    #[error("Parse error: {0}")]
30    ParseError(String),
31    /// IO 错误
32    #[error("IO error: {0}")]
33    IoError(#[from] std::io::Error),
34    /// YAML 解析错误
35    #[error("YAML error: {0}")]
36    YamlError(String),
37}
38
39/// Prompt 结果类型
40pub type PromptResult<T> = Result<T, PromptError>;
41
42/// 变量类型
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
44#[serde(rename_all = "lowercase")]
45pub enum VariableType {
46    /// 字符串类型
47    #[default]
48    String,
49    /// 整数类型
50    Integer,
51    /// 浮点类型
52    Float,
53    /// 布尔类型
54    Boolean,
55    /// 列表类型
56    List,
57    /// JSON 对象类型
58    Json,
59}
60
61impl VariableType {
62    /// 验证值是否符合类型
63    pub fn validate(&self, value: &str) -> bool {
64        match self {
65            VariableType::String => true,
66            VariableType::Integer => value.parse::<i64>().is_ok(),
67            VariableType::Float => value.parse::<f64>().is_ok(),
68            VariableType::Boolean => {
69                matches!(value.to_lowercase().as_str(), "true" | "false" | "1" | "0")
70            }
71            VariableType::List => value.starts_with('[') && value.ends_with(']'),
72            VariableType::Json => serde_json::from_str::<serde_json::Value>(value).is_ok(),
73        }
74    }
75}
76
77/// Prompt 变量定义
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct PromptVariable {
80    /// 变量名称
81    pub name: String,
82    /// 变量描述
83    #[serde(default)]
84    pub description: Option<String>,
85    /// 变量类型
86    #[serde(default)]
87    pub var_type: VariableType,
88    /// 是否必需
89    #[serde(default = "default_true")]
90    pub required: bool,
91    /// 默认值
92    #[serde(default)]
93    pub default: Option<String>,
94    /// 验证正则表达式
95    #[serde(default)]
96    pub pattern: Option<String>,
97    /// 枚举选项
98    #[serde(default)]
99    pub enum_values: Option<Vec<String>>,
100}
101
102fn default_true() -> bool {
103    true
104}
105
106impl PromptVariable {
107    /// 创建新的变量定义
108    pub fn new(name: impl Into<String>) -> Self {
109        Self {
110            name: name.into(),
111            description: None,
112            var_type: VariableType::String,
113            required: true,
114            default: None,
115            pattern: None,
116            enum_values: None,
117        }
118    }
119
120    /// 设置描述
121    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
122        self.description = Some(desc.into());
123        self
124    }
125
126    /// 设置类型
127    pub fn with_type(mut self, var_type: VariableType) -> Self {
128        self.var_type = var_type;
129        self
130    }
131
132    /// 设置是否必需
133    pub fn required(mut self, required: bool) -> Self {
134        self.required = required;
135        self
136    }
137
138    /// 设置默认值
139    pub fn with_default(mut self, default: impl Into<String>) -> Self {
140        self.default = Some(default.into());
141        self.required = false;
142        self
143    }
144
145    /// 设置验证正则
146    pub fn with_pattern(mut self, pattern: impl Into<String>) -> Self {
147        self.pattern = Some(pattern.into());
148        self
149    }
150
151    /// 设置枚举值
152    pub fn with_enum(mut self, values: Vec<String>) -> Self {
153        self.enum_values = Some(values);
154        self
155    }
156
157    /// 验证值
158    pub fn validate(&self, value: &str) -> PromptResult<()> {
159        // 类型验证
160        if !self.var_type.validate(value) {
161            return Err(PromptError::TypeMismatch {
162                name: self.name.clone(),
163                expected: format!("{:?}", self.var_type),
164                actual: "invalid".to_string(),
165            });
166        }
167
168        // 正则验证
169        if let Some(ref pattern) = self.pattern {
170            let re =
171                regex::Regex::new(pattern).map_err(|e| PromptError::ParseError(e.to_string()))?;
172            if !re.is_match(value) {
173                return Err(PromptError::ValidationFailed {
174                    name: self.name.clone(),
175                    reason: format!("Value does not match pattern: {}", pattern),
176                });
177            }
178        }
179
180        // 枚举验证
181        if let Some(ref enum_values) = self.enum_values
182            && !enum_values.contains(&value.to_string())
183        {
184            return Err(PromptError::ValidationFailed {
185                name: self.name.clone(),
186                reason: format!("Value must be one of: {:?}", enum_values),
187            });
188        }
189
190        Ok(())
191    }
192}
193
194/// Prompt 模板
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct PromptTemplate {
197    /// 模板 ID
198    pub id: String,
199    /// 模板名称
200    #[serde(default)]
201    pub name: Option<String>,
202    /// 模板描述
203    #[serde(default)]
204    pub description: Option<String>,
205    /// 模板内容
206    #[serde(default)]
207    pub content: String,
208    /// 变量定义
209    #[serde(default)]
210    pub variables: Vec<PromptVariable>,
211    /// 标签
212    #[serde(default)]
213    pub tags: Vec<String>,
214    /// 版本
215    #[serde(default)]
216    pub version: Option<String>,
217    /// 元数据
218    #[serde(default)]
219    pub metadata: HashMap<String, String>,
220}
221
222impl PromptTemplate {
223    /// 创建新模板
224    pub fn new(id: impl Into<String>) -> Self {
225        Self {
226            id: id.into(),
227            name: None,
228            description: None,
229            content: String::new(),
230            variables: Vec::new(),
231            tags: Vec::new(),
232            version: None,
233            metadata: HashMap::new(),
234        }
235    }
236
237    /// 设置名称
238    pub fn with_name(mut self, name: impl Into<String>) -> Self {
239        self.name = Some(name.into());
240        self
241    }
242
243    /// 设置描述
244    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
245        self.description = Some(desc.into());
246        self
247    }
248
249    /// 设置内容
250    pub fn with_content(mut self, content: impl Into<String>) -> Self {
251        self.content = content.into();
252        // 自动解析变量
253        self.parse_variables();
254        self
255    }
256
257    /// 添加变量定义
258    pub fn with_variable(mut self, variable: PromptVariable) -> Self {
259        self.variables.push(variable);
260        self
261    }
262
263    /// 添加标签
264    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
265        self.tags.push(tag.into());
266        self
267    }
268
269    /// 设置版本
270    pub fn with_version(mut self, version: impl Into<String>) -> Self {
271        self.version = Some(version.into());
272        self
273    }
274
275    /// 添加元数据
276    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
277        self.metadata.insert(key.into(), value.into());
278        self
279    }
280
281    /// 解析模板中的变量(不覆盖已有定义)
282    fn parse_variables(&mut self) {
283        // 不自动解析,让用户手动定义变量
284        // 这样可以保留用户设置的默认值和验证规则
285    }
286
287    /// 获取所有预定义变量名
288    pub fn variable_names(&self) -> Vec<&str> {
289        self.variables.iter().map(|v| v.name.as_str()).collect()
290    }
291
292    /// 获取模板中所有变量名(从内容中解析)
293    pub fn extract_variables(&self) -> Vec<String> {
294        let re = regex::Regex::new(r"\{(\w+)\}").unwrap();
295        let mut vars = std::collections::HashSet::new();
296
297        for cap in re.captures_iter(&self.content) {
298            vars.insert(cap[1].to_string());
299        }
300
301        vars.into_iter().collect()
302    }
303
304    /// 获取必需变量
305    pub fn required_variables(&self) -> Vec<&PromptVariable> {
306        self.variables.iter().filter(|v| v.required).collect()
307    }
308
309    /// 渲染模板
310    ///
311    /// # 参数
312    /// - `vars`: 变量名和值的列表
313    ///
314    /// # 示例
315    /// ```rust,ignore
316    /// let template = PromptTemplate::new("greeting")
317    ///     .with_content("Hello, {name}! Welcome to {place}.");
318    ///
319    /// let result = template.render(&[
320    ///     ("name", "Alice"),
321    ///     ("place", "Wonderland"),
322    /// ])?;
323    /// assert_eq!(result, "Hello, Alice! Welcome to Wonderland.");
324    /// ```
325    pub fn render(&self, vars: &[(&str, &str)]) -> PromptResult<String> {
326        let var_map: HashMap<&str, &str> = vars.iter().copied().collect();
327        self.render_with_map(&var_map)
328    }
329
330    /// 使用 HashMap 渲染模板
331    pub fn render_with_map(&self, vars: &HashMap<&str, &str>) -> PromptResult<String> {
332        let mut result = self.content.clone();
333
334        // 首先处理预定义的变量(带验证和默认值)
335        for var_def in &self.variables {
336            let placeholder = format!("{{{}}}", var_def.name);
337
338            if let Some(&value) = vars.get(var_def.name.as_str()) {
339                // 验证值
340                var_def.validate(value)?;
341                result = result.replace(&placeholder, value);
342            } else if let Some(ref default) = var_def.default {
343                // 使用默认值
344                result = result.replace(&placeholder, default);
345            } else if var_def.required {
346                // 缺少必需变量
347                return Err(PromptError::MissingVariable(var_def.name.clone()));
348            }
349        }
350
351        // 然后处理模板中存在但未在 variables 中预定义的变量
352        let re = regex::Regex::new(r"\{(\w+)\}").unwrap();
353        let defined_vars: std::collections::HashSet<_> =
354            self.variables.iter().map(|v| v.name.as_str()).collect();
355
356        // 收集所有未定义但在模板中出现的变量
357        let mut missing = Vec::new();
358        for cap in re.captures_iter(&result.clone()) {
359            let var_name = &cap[1];
360            if !defined_vars.contains(var_name) {
361                if let Some(&value) = vars.get(var_name) {
362                    let placeholder = format!("{{{}}}", var_name);
363                    result = result.replace(&placeholder, value);
364                } else {
365                    missing.push(var_name.to_string());
366                }
367            }
368        }
369
370        // 如果还有未替换的变量,报错
371        if !missing.is_empty() {
372            return Err(PromptError::MissingVariable(missing.join(", ")));
373        }
374
375        Ok(result)
376    }
377
378    /// 使用 owned HashMap 渲染模板
379    pub fn render_with_owned_map(&self, vars: &HashMap<String, String>) -> PromptResult<String> {
380        let borrowed: HashMap<&str, &str> =
381            vars.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
382        self.render_with_map(&borrowed)
383    }
384
385    /// 部分渲染(只替换提供的变量)
386    pub fn partial_render(&self, vars: &[(&str, &str)]) -> String {
387        let var_map: HashMap<&str, &str> = vars.iter().copied().collect();
388        let mut result = self.content.clone();
389
390        for (name, value) in var_map {
391            let placeholder = format!("{{{}}}", name);
392            result = result.replace(&placeholder, value);
393        }
394
395        result
396    }
397
398    /// 检查模板是否有效(所有必需变量都有默认值或在提供的变量中)
399    pub fn is_valid_with(&self, vars: &[&str]) -> bool {
400        let var_set: std::collections::HashSet<_> = vars.iter().copied().collect();
401
402        // 检查预定义的必需变量
403        for var_def in &self.variables {
404            if var_def.required
405                && var_def.default.is_none()
406                && !var_set.contains(var_def.name.as_str())
407            {
408                return false;
409            }
410        }
411
412        // 检查模板中的未定义变量
413        let re = regex::Regex::new(r"\{(\w+)\}").unwrap();
414        let defined_vars: std::collections::HashSet<_> =
415            self.variables.iter().map(|v| v.name.as_str()).collect();
416
417        for cap in re.captures_iter(&self.content) {
418            let var_name = &cap[1];
419            // 如果变量未在预定义列表中,且未在提供的变量中
420            if !defined_vars.contains(var_name) && !var_set.contains(var_name) {
421                return false;
422            }
423        }
424
425        true
426    }
427}
428
429/// Prompt 组合(多个模板的组合)
430#[derive(Debug, Clone, Serialize, Deserialize)]
431pub struct PromptComposition {
432    /// 组合 ID
433    pub id: String,
434    /// 组合描述
435    #[serde(default)]
436    pub description: Option<String>,
437    /// 模板 ID 列表(按顺序组合)
438    pub template_ids: Vec<String>,
439    /// 分隔符
440    #[serde(default = "default_separator")]
441    pub separator: String,
442}
443
444fn default_separator() -> String {
445    "\n\n".to_string()
446}
447
448impl PromptComposition {
449    /// 创建新的组合
450    pub fn new(id: impl Into<String>) -> Self {
451        Self {
452            id: id.into(),
453            description: None,
454            template_ids: Vec::new(),
455            separator: "\n\n".to_string(),
456        }
457    }
458
459    /// 添加模板
460    pub fn add_template(mut self, template_id: impl Into<String>) -> Self {
461        self.template_ids.push(template_id.into());
462        self
463    }
464
465    /// 设置分隔符
466    pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
467        self.separator = sep.into();
468        self
469    }
470
471    /// 设置描述
472    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
473        self.description = Some(desc.into());
474        self
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_template_basic() {
484        let template = PromptTemplate::new("test")
485            .with_content("Hello, {name}!")
486            .with_description("A greeting template");
487
488        assert_eq!(template.id, "test");
489        assert_eq!(template.extract_variables(), vec!["name"]);
490
491        let result = template.render(&[("name", "World")]).unwrap();
492        assert_eq!(result, "Hello, World!");
493    }
494
495    #[test]
496    fn test_template_multiple_vars() {
497        let template = PromptTemplate::new("test")
498            .with_content("Hello, {name}! Welcome to {place}. Your role is {role}.");
499
500        let result = template
501            .render(&[
502                ("name", "Alice"),
503                ("place", "Wonderland"),
504                ("role", "explorer"),
505            ])
506            .unwrap();
507
508        assert_eq!(
509            result,
510            "Hello, Alice! Welcome to Wonderland. Your role is explorer."
511        );
512    }
513
514    #[test]
515    fn test_template_with_default() {
516        let template = PromptTemplate::new("test")
517            .with_content("Hello, {name}!")
518            .with_variable(PromptVariable::new("name").with_default("World"));
519
520        // 不提供变量时使用默认值
521        let result = template.render(&[]).unwrap();
522        assert_eq!(result, "Hello, World!");
523
524        // 提供变量时使用提供的值
525        let result = template.render(&[("name", "Alice")]).unwrap();
526        assert_eq!(result, "Hello, Alice!");
527    }
528
529    #[test]
530    fn test_template_missing_required() {
531        let template = PromptTemplate::new("test").with_content("Hello, {name}!");
532
533        let result = template.render(&[]);
534        assert!(result.is_err());
535        assert!(matches!(
536            result.unwrap_err(),
537            PromptError::MissingVariable(_)
538        ));
539    }
540
541    #[test]
542    fn test_variable_type_validation() {
543        assert!(VariableType::String.validate("anything"));
544        assert!(VariableType::Integer.validate("123"));
545        assert!(!VariableType::Integer.validate("abc"));
546        assert!(VariableType::Float.validate("3.14"));
547        assert!(VariableType::Boolean.validate("true"));
548        assert!(VariableType::Boolean.validate("false"));
549        assert!(VariableType::Json.validate(r#"{"key": "value"}"#));
550    }
551
552    #[test]
553    fn test_variable_enum() {
554        let var = PromptVariable::new("language")
555            .with_enum(vec!["rust".to_string(), "python".to_string()]);
556
557        assert!(var.validate("rust").is_ok());
558        assert!(var.validate("python").is_ok());
559        assert!(var.validate("java").is_err());
560    }
561
562    #[test]
563    fn test_partial_render() {
564        let template =
565            PromptTemplate::new("test").with_content("Hello, {name}! Your {item} is ready.");
566
567        let result = template.partial_render(&[("name", "Alice")]);
568        assert_eq!(result, "Hello, Alice! Your {item} is ready.");
569    }
570
571    #[test]
572    fn test_is_valid_with() {
573        let template = PromptTemplate::new("test")
574            .with_content("{required_var} and {optional_var}")
575            .with_variable(PromptVariable::new("required_var"))
576            .with_variable(PromptVariable::new("optional_var").with_default("default"));
577
578        assert!(template.is_valid_with(&["required_var"]));
579        assert!(!template.is_valid_with(&[]));
580        assert!(!template.is_valid_with(&["optional_var"]));
581    }
582}