Skip to main content

matrixcode_core/workflow/executors/
validate.rs

1//! Validate Executor
2//!
3//! 混合验证执行器,程序规则验证 + AI 验证。
4
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use super::node_executor::NodeExecutor;
10use crate::providers::{ChatRequest, ContentBlock, Message, MessageContent, Provider};
11use crate::workflow::context::WorkflowContext;
12use crate::workflow::def::NodeDef;
13use crate::workflow::rule_engine::{Rule, RuleEngine, ValidationResult};
14use crate::workflow::template::TemplateRenderer;
15
16/// 验证执行器配置
17#[derive(Debug, Clone)]
18pub struct ValidateExecutorConfig {
19    /// 是否启用 AI 验证
20    pub enable_ai_validation: bool,
21    /// AI 验证提示模板
22    pub ai_validation_prompt: String,
23    /// 是否在 AI 验证失败时中止
24    pub abort_on_ai_failure: bool,
25}
26
27impl Default for ValidateExecutorConfig {
28    fn default() -> Self {
29        Self {
30            enable_ai_validation: false,
31            ai_validation_prompt: String::new(),
32            abort_on_ai_failure: true,
33        }
34    }
35}
36
37/// 验证执行器
38///
39/// 混合验证执行器:程序规则验证 + AI 验证。
40pub struct ValidateExecutor {
41    /// AI Provider(可选)
42    provider: Option<Arc<dyn Provider>>,
43    /// 配置
44    config: ValidateExecutorConfig,
45    /// 模板渲染器
46    template_renderer: TemplateRenderer,
47}
48
49impl ValidateExecutor {
50    /// 创建新的验证执行器(仅程序规则)
51    pub fn new() -> Self {
52        Self {
53            provider: None,
54            config: ValidateExecutorConfig::default(),
55            template_renderer: TemplateRenderer::new(),
56        }
57    }
58
59    /// 创建带 AI 验证的执行器
60    pub fn with_ai(provider: Arc<dyn Provider>, config: ValidateExecutorConfig) -> Self {
61        Self {
62            provider: Some(provider),
63            config,
64            template_renderer: TemplateRenderer::new(),
65        }
66    }
67
68    /// 执行 AI 验证
69    async fn validate_with_ai(
70        &self,
71        data: &serde_json::Value,
72        context: &WorkflowContext,
73    ) -> Result<ValidationResult> {
74        if let Some(provider) = &self.provider {
75            // 构建验证提示
76            let prompt = if self.config.ai_validation_prompt.is_empty() {
77                format!(
78                    "Please validate the following data and return a JSON object with 'passed' (boolean) and 'errors' (array of strings):\n{}",
79                    serde_json::to_string_pretty(data)?
80                )
81            } else {
82                self.template_renderer
83                    .render(&self.config.ai_validation_prompt, &context.variables)?
84            };
85
86            // 构建请求
87            let messages = vec![Message {
88                role: crate::providers::Role::User,
89                content: MessageContent::Text(prompt),
90            }];
91
92            let request = ChatRequest {
93                messages,
94                tools: Vec::new(),
95                system: Some(
96                    "You are a data validator. Return JSON with 'passed' and 'errors' fields."
97                        .to_string(),
98                ),
99                think: false,
100                max_tokens: 1024,
101                server_tools: Vec::new(),
102                enable_caching: false,
103            };
104
105            // 调用 AI
106            let response = provider.chat(request).await?;
107
108            // 解析响应
109            for block in &response.content {
110                if let ContentBlock::Text { text } = block
111                    && let Ok(json) = serde_json::from_str::<serde_json::Value>(text)
112                {
113                    let passed = json
114                        .get("passed")
115                        .and_then(|v| v.as_bool())
116                        .unwrap_or(false);
117                    let errors = json
118                        .get("errors")
119                        .and_then(|v| v.as_array())
120                        .map(|arr| {
121                            arr.iter()
122                                .filter_map(|v| v.as_str().map(|s| s.to_string()))
123                                .collect()
124                        })
125                        .unwrap_or_default();
126
127                    return Ok(ValidationResult { passed, errors });
128                }
129            }
130
131            // 无法解析 AI 响应
132            Ok(ValidationResult::failure(
133                "Failed to parse AI validation response".to_string(),
134            ))
135        } else {
136            // 没有 AI Provider,直接通过
137            Ok(ValidationResult::success())
138        }
139    }
140}
141
142impl Default for ValidateExecutor {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148#[async_trait]
149impl NodeExecutor for ValidateExecutor {
150    async fn execute(
151        &self,
152        node: &NodeDef,
153        context: &mut WorkflowContext,
154    ) -> Result<serde_json::Value> {
155        // 从节点参数中提取验证规则
156        let rules_json = node
157            .params
158            .get("rules")
159            .ok_or_else(|| anyhow::anyhow!("Validate executor requires 'rules' parameter"))?;
160
161        // 解析规则
162        let rules: Vec<Rule> = serde_json::from_value(rules_json.clone())
163            .with_context(|| "Failed to parse validation rules")?;
164
165        // 创建可变副本用于规则验证
166        let mut rule_engine = RuleEngine::new();
167
168        // 执行规则验证
169        let mut result = ValidationResult::success();
170        for rule in &rules {
171            result = result.merge(rule_engine.validate(rule, &context.variables)?);
172        }
173
174        // 如果规则验证通过且有 AI Provider,执行 AI 验证
175        if result.passed && self.config.enable_ai_validation && self.provider.is_some() {
176            // 将 HashMap 转换为 serde_json::Map
177            let context_vars: serde_json::Map<String, serde_json::Value> = context
178                .variables
179                .iter()
180                .map(|(k, v)| (k.clone(), v.clone()))
181                .collect();
182
183            let data_to_validate = node
184                .params
185                .get("data")
186                .cloned()
187                .unwrap_or(serde_json::Value::Object(context_vars));
188
189            let ai_result = self.validate_with_ai(&data_to_validate, context).await?;
190            result = result.merge(ai_result);
191        }
192
193        // 构建输出
194        let output = serde_json::json!({
195            "passed": result.passed,
196            "errors": result.errors,
197            "node_id": node.id,
198        });
199
200        // 如果验证失败且配置为中止,返回错误
201        if !result.passed && self.config.abort_on_ai_failure {
202            return Err(anyhow::anyhow!(
203                "Validation failed: {}",
204                result.errors.join("; ")
205            ));
206        }
207
208        Ok(output)
209    }
210
211    fn name(&self) -> &str {
212        "validate_executor"
213    }
214}