matrixcode_core/workflow/executors/
ai.rs1use anyhow::{Context, Result};
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use super::node_executor::NodeExecutor;
10use crate::providers::{
11 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider,
12};
13use crate::workflow::context::WorkflowContext;
14use crate::workflow::def::NodeDef;
15use crate::workflow::template::TemplateRenderer;
16
17#[derive(Debug, Clone)]
19pub struct AiExecutorConfig {
20 pub system_prompt: Option<String>,
22 pub max_tokens: u32,
24 pub enable_thinking: bool,
26 pub enable_streaming: bool,
28}
29
30impl Default for AiExecutorConfig {
31 fn default() -> Self {
32 Self {
33 system_prompt: None,
34 max_tokens: 4096,
35 enable_thinking: false,
36 enable_streaming: false,
37 }
38 }
39}
40
41pub struct AiExecutor {
45 provider: Arc<dyn Provider>,
47 config: AiExecutorConfig,
49 template_renderer: TemplateRenderer,
51}
52
53impl AiExecutor {
54 pub fn new(provider: Arc<dyn Provider>) -> Self {
56 Self {
57 provider,
58 config: AiExecutorConfig::default(),
59 template_renderer: TemplateRenderer::new(),
60 }
61 }
62
63 pub fn with_config(provider: Arc<dyn Provider>, config: AiExecutorConfig) -> Self {
65 Self {
66 provider,
67 config,
68 template_renderer: TemplateRenderer::new(),
69 }
70 }
71
72 pub fn extract_text_content(response: &ChatResponse) -> Result<String> {
74 let mut text_parts = Vec::new();
75 for block in &response.content {
76 match block {
77 ContentBlock::Text { text } => {
78 text_parts.push(text.clone());
79 }
80 ContentBlock::Thinking { thinking, .. } => {
81 text_parts.push(format!("[Thinking]\n{}", thinking));
82 }
83 _ => {}
84 }
85 }
86 Ok(text_parts.join("\n"))
87 }
88
89 pub fn extract_structured_output(response: &ChatResponse) -> Result<serde_json::Value> {
91 for block in &response.content {
93 if let ContentBlock::Text { text } = block {
94 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
96 return Ok(json);
97 }
98 }
99 }
100
101 let text = Self::extract_text_content(response)?;
103 let stop_reason_str = match response.stop_reason {
104 crate::providers::StopReason::EndTurn => "end_turn",
105 crate::providers::StopReason::ToolUse => "tool_use",
106 crate::providers::StopReason::MaxTokens => "max_tokens",
107 };
108 Ok(serde_json::json!({
109 "text": text,
110 "stop_reason": stop_reason_str,
111 "usage": {
112 "input_tokens": response.usage.input_tokens,
113 "output_tokens": response.usage.output_tokens,
114 }
115 }))
116 }
117}
118
119#[async_trait]
120impl NodeExecutor for AiExecutor {
121 async fn execute(
122 &self,
123 node: &NodeDef,
124 context: &mut WorkflowContext,
125 ) -> Result<serde_json::Value> {
126 let task_name = node
128 .task
129 .as_ref()
130 .ok_or_else(|| anyhow::anyhow!("AI executor requires a task name"))?;
131
132 let mut prompt_parts = Vec::new();
134
135 prompt_parts.push(format!("Task: {}", task_name));
137
138 if let Some(desc) = &node.description {
140 prompt_parts.push(format!("Description: {}", desc));
141 }
142
143 for (key, value) in &node.params {
145 let rendered_value = if let serde_json::Value::String(s) = value {
146 self.template_renderer.render(s, &context.variables)?
147 } else {
148 value.to_string()
149 };
150 prompt_parts.push(format!("{}: {}", key, rendered_value));
151 }
152
153 if !context.variables.is_empty() {
155 prompt_parts.push("\nContext:".to_string());
156 for (key, value) in &context.variables {
157 prompt_parts.push(format!(" {}: {}", key, value));
158 }
159 }
160
161 let user_message = prompt_parts.join("\n");
162
163 let messages = vec![Message {
165 role: crate::providers::Role::User,
166 content: MessageContent::Text(user_message),
167 }];
168
169 let request = ChatRequest {
170 messages,
171 tools: Vec::new(),
172 system: self.config.system_prompt.clone(),
173 think: self.config.enable_thinking,
174 max_tokens: self.config.max_tokens,
175 server_tools: Vec::new(),
176 enable_caching: false,
177 };
178
179 let response = self
181 .provider
182 .chat(request)
183 .await
184 .with_context(|| format!("AI executor failed for task '{}'", task_name))?;
185
186 let output = Self::extract_structured_output(&response)?;
188
189 let output_ref = &output;
191 if let serde_json::Value::Object(map) = output_ref {
192 for (key, value) in map {
193 context.set_variable(key.clone(), value.clone());
194 }
195 }
196
197 Ok(output)
198 }
199
200 fn name(&self) -> &str {
201 "ai_executor"
202 }
203}