Skip to main content

matrixcode_core/workflow/executors/
ai.rs

1//! AI Executor
2//!
3//! AI 模型调用执行器,调用 Provider 执行任务节点。
4
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use super::node_executor::NodeExecutor;
10use crate::providers::{
11    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider,
12};
13use crate::workflow::context::WorkflowContext;
14use crate::workflow::def::NodeDef;
15use crate::workflow::template::TemplateRenderer;
16
17/// AI 执行器配置
18#[derive(Debug, Clone)]
19pub struct AiExecutorConfig {
20    /// 系统提示模板
21    pub system_prompt: Option<String>,
22    /// 最大输出 token 数
23    pub max_tokens: u32,
24    /// 是否启用思考模式
25    pub enable_thinking: bool,
26    /// 是否启用流式输出
27    pub enable_streaming: bool,
28}
29
30impl Default for AiExecutorConfig {
31    fn default() -> Self {
32        Self {
33            system_prompt: None,
34            max_tokens: 4096,
35            enable_thinking: false,
36            enable_streaming: false,
37        }
38    }
39}
40
41/// AI 执行器
42///
43/// 调用 AI Provider 执行任务节点。
44pub struct AiExecutor {
45    /// Provider 实例
46    provider: Arc<dyn Provider>,
47    /// 配置
48    config: AiExecutorConfig,
49    /// 模板渲染器
50    template_renderer: TemplateRenderer,
51}
52
53impl AiExecutor {
54    /// 创建新的 AI 执行器
55    pub fn new(provider: Arc<dyn Provider>) -> Self {
56        Self {
57            provider,
58            config: AiExecutorConfig::default(),
59            template_renderer: TemplateRenderer::new(),
60        }
61    }
62
63    /// 使用配置创建 AI 执行器
64    pub fn with_config(provider: Arc<dyn Provider>, config: AiExecutorConfig) -> Self {
65        Self {
66            provider,
67            config,
68            template_renderer: TemplateRenderer::new(),
69        }
70    }
71
72    /// 从响应中提取文本内容
73    pub fn extract_text_content(response: &ChatResponse) -> Result<String> {
74        let mut text_parts = Vec::new();
75        for block in &response.content {
76            match block {
77                ContentBlock::Text { text } => {
78                    text_parts.push(text.clone());
79                }
80                ContentBlock::Thinking { thinking, .. } => {
81                    text_parts.push(format!("[Thinking]\n{}", thinking));
82                }
83                _ => {}
84            }
85        }
86        Ok(text_parts.join("\n"))
87    }
88
89    /// 从响应中提取结构化输出
90    pub fn extract_structured_output(response: &ChatResponse) -> Result<serde_json::Value> {
91        // 尝试从文本中解析 JSON
92        for block in &response.content {
93            if let ContentBlock::Text { text } = block {
94                // 尝试解析为 JSON
95                if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
96                    return Ok(json);
97                }
98            }
99        }
100
101        // 如果没有找到 JSON,返回文本内容
102        let text = Self::extract_text_content(response)?;
103        let stop_reason_str = match response.stop_reason {
104            crate::providers::StopReason::EndTurn => "end_turn",
105            crate::providers::StopReason::ToolUse => "tool_use",
106            crate::providers::StopReason::MaxTokens => "max_tokens",
107        };
108        Ok(serde_json::json!({
109            "text": text,
110            "stop_reason": stop_reason_str,
111            "usage": {
112                "input_tokens": response.usage.input_tokens,
113                "output_tokens": response.usage.output_tokens,
114            }
115        }))
116    }
117}
118
119#[async_trait]
120impl NodeExecutor for AiExecutor {
121    async fn execute(
122        &self,
123        node: &NodeDef,
124        context: &mut WorkflowContext,
125    ) -> Result<serde_json::Value> {
126        // 获取任务名称
127        let task_name = node
128            .task
129            .as_ref()
130            .ok_or_else(|| anyhow::anyhow!("AI executor requires a task name"))?;
131
132        // 构建用户消息
133        let mut prompt_parts = Vec::new();
134
135        // 添加任务名称
136        prompt_parts.push(format!("Task: {}", task_name));
137
138        // 添加任务描述
139        if let Some(desc) = &node.description {
140            prompt_parts.push(format!("Description: {}", desc));
141        }
142
143        // 渲染并添加参数
144        for (key, value) in &node.params {
145            let rendered_value = if let serde_json::Value::String(s) = value {
146                self.template_renderer.render(s, &context.variables)?
147            } else {
148                value.to_string()
149            };
150            prompt_parts.push(format!("{}: {}", key, rendered_value));
151        }
152
153        // 添加上下文信息
154        if !context.variables.is_empty() {
155            prompt_parts.push("\nContext:".to_string());
156            for (key, value) in &context.variables {
157                prompt_parts.push(format!("  {}: {}", key, value));
158            }
159        }
160
161        let user_message = prompt_parts.join("\n");
162
163        // 构建聊天请求
164        let messages = vec![Message {
165            role: crate::providers::Role::User,
166            content: MessageContent::Text(user_message),
167        }];
168
169        let request = ChatRequest {
170            messages,
171            tools: Vec::new(),
172            system: self.config.system_prompt.clone(),
173            think: self.config.enable_thinking,
174            max_tokens: self.config.max_tokens,
175            server_tools: Vec::new(),
176            enable_caching: false,
177        };
178
179        // 调用 Provider
180        let response = self
181            .provider
182            .chat(request)
183            .await
184            .with_context(|| format!("AI executor failed for task '{}'", task_name))?;
185
186        // 提取输出
187        let output = Self::extract_structured_output(&response)?;
188
189        // 更新上下文
190        let output_ref = &output;
191        if let serde_json::Value::Object(map) = output_ref {
192            for (key, value) in map {
193                context.set_variable(key.clone(), value.clone());
194            }
195        }
196
197        Ok(output)
198    }
199
200    fn name(&self) -> &str {
201        "ai_executor"
202    }
203}