Skip to main content

matrixcode_core/workflow/executors/
ai.rs

1//! AI Executor
2//!
3//! AI 模型调用执行器,调用 Provider 执行任务节点。
4
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use crate::providers::{ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider};
10use crate::workflow::context::WorkflowContext;
11use crate::workflow::def::NodeDef;
12use crate::workflow::template::TemplateRenderer;
13use super::node_executor::NodeExecutor;
14
15/// AI 执行器配置
16#[derive(Debug, Clone)]
17pub struct AiExecutorConfig {
18    /// 系统提示模板
19    pub system_prompt: Option<String>,
20    /// 最大输出 token 数
21    pub max_tokens: u32,
22    /// 是否启用思考模式
23    pub enable_thinking: bool,
24    /// 是否启用流式输出
25    pub enable_streaming: bool,
26}
27
28impl Default for AiExecutorConfig {
29    fn default() -> Self {
30        Self {
31            system_prompt: None,
32            max_tokens: 4096,
33            enable_thinking: false,
34            enable_streaming: false,
35        }
36    }
37}
38
39/// AI 执行器
40///
41/// 调用 AI Provider 执行任务节点。
42pub struct AiExecutor {
43    /// Provider 实例
44    provider: Arc<dyn Provider>,
45    /// 配置
46    config: AiExecutorConfig,
47    /// 模板渲染器
48    template_renderer: TemplateRenderer,
49}
50
51impl AiExecutor {
52    /// 创建新的 AI 执行器
53    pub fn new(provider: Arc<dyn Provider>) -> Self {
54        Self {
55            provider,
56            config: AiExecutorConfig::default(),
57            template_renderer: TemplateRenderer::new(),
58        }
59    }
60
61    /// 使用配置创建 AI 执行器
62    pub fn with_config(provider: Arc<dyn Provider>, config: AiExecutorConfig) -> Self {
63        Self {
64            provider,
65            config,
66            template_renderer: TemplateRenderer::new(),
67        }
68    }
69
70    /// 从响应中提取文本内容
71    pub fn extract_text_content(response: &ChatResponse) -> Result<String> {
72        let mut text_parts = Vec::new();
73        for block in &response.content {
74            match block {
75                ContentBlock::Text { text } => {
76                    text_parts.push(text.clone());
77                }
78                ContentBlock::Thinking { thinking, .. } => {
79                    text_parts.push(format!("[Thinking]\n{}", thinking));
80                }
81                _ => {}
82            }
83        }
84        Ok(text_parts.join("\n"))
85    }
86
87    /// 从响应中提取结构化输出
88    pub fn extract_structured_output(response: &ChatResponse) -> Result<serde_json::Value> {
89        // 尝试从文本中解析 JSON
90        for block in &response.content {
91            if let ContentBlock::Text { text } = block {
92                // 尝试解析为 JSON
93                if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
94                    return Ok(json);
95                }
96            }
97        }
98
99        // 如果没有找到 JSON,返回文本内容
100        let text = Self::extract_text_content(response)?;
101        let stop_reason_str = match response.stop_reason {
102            crate::providers::StopReason::EndTurn => "end_turn",
103            crate::providers::StopReason::ToolUse => "tool_use",
104            crate::providers::StopReason::MaxTokens => "max_tokens",
105        };
106        Ok(serde_json::json!({
107            "text": text,
108            "stop_reason": stop_reason_str,
109            "usage": {
110                "input_tokens": response.usage.input_tokens,
111                "output_tokens": response.usage.output_tokens,
112            }
113        }))
114    }
115}
116
117#[async_trait]
118impl NodeExecutor for AiExecutor {
119    async fn execute(
120        &self,
121        node: &NodeDef,
122        context: &mut WorkflowContext,
123    ) -> Result<serde_json::Value> {
124        // 获取任务名称
125        let task_name = node.task.as_ref()
126            .ok_or_else(|| anyhow::anyhow!("AI executor requires a task name"))?;
127
128        // 构建用户消息
129        let mut prompt_parts = Vec::new();
130
131        // 添加任务名称
132        prompt_parts.push(format!("Task: {}", task_name));
133
134        // 添加任务描述
135        if let Some(desc) = &node.description {
136            prompt_parts.push(format!("Description: {}", desc));
137        }
138
139        // 渲染并添加参数
140        for (key, value) in &node.params {
141            let rendered_value = if let serde_json::Value::String(s) = value {
142                self.template_renderer.render(s, &context.variables)?
143            } else {
144                value.to_string()
145            };
146            prompt_parts.push(format!("{}: {}", key, rendered_value));
147        }
148
149        // 添加上下文信息
150        if !context.variables.is_empty() {
151            prompt_parts.push("\nContext:".to_string());
152            for (key, value) in &context.variables {
153                prompt_parts.push(format!("  {}: {}", key, value));
154            }
155        }
156
157        let user_message = prompt_parts.join("\n");
158
159        // 构建聊天请求
160        let messages = vec![Message {
161            role: crate::providers::Role::User,
162            content: MessageContent::Text(user_message),
163        }];
164
165        let request = ChatRequest {
166            messages,
167            tools: Vec::new(),
168            system: self.config.system_prompt.clone(),
169            think: self.config.enable_thinking,
170            max_tokens: self.config.max_tokens,
171            server_tools: Vec::new(),
172            enable_caching: false,
173        };
174
175        // 调用 Provider
176        let response = self.provider.chat(request)
177            .await
178            .with_context(|| format!("AI executor failed for task '{}'", task_name))?;
179
180        // 提取输出
181        let output = Self::extract_structured_output(&response)?;
182
183        // 更新上下文
184        let output_ref = &output;
185        if let serde_json::Value::Object(map) = output_ref {
186            for (key, value) in map {
187                context.set_variable(key.clone(), value.clone());
188            }
189        }
190
191        Ok(output)
192    }
193
194    fn name(&self) -> &str {
195        "ai_executor"
196    }
197}