matrixcode_core/workflow/executors/
ai.rs1use anyhow::{Context, Result};
6use async_trait::async_trait;
7use std::sync::Arc;
8
9use crate::providers::{ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider};
10use crate::workflow::context::WorkflowContext;
11use crate::workflow::def::NodeDef;
12use crate::workflow::template::TemplateRenderer;
13use super::node_executor::NodeExecutor;
14
15#[derive(Debug, Clone)]
17pub struct AiExecutorConfig {
18 pub system_prompt: Option<String>,
20 pub max_tokens: u32,
22 pub enable_thinking: bool,
24 pub enable_streaming: bool,
26}
27
28impl Default for AiExecutorConfig {
29 fn default() -> Self {
30 Self {
31 system_prompt: None,
32 max_tokens: 4096,
33 enable_thinking: false,
34 enable_streaming: false,
35 }
36 }
37}
38
39pub struct AiExecutor {
43 provider: Arc<dyn Provider>,
45 config: AiExecutorConfig,
47 template_renderer: TemplateRenderer,
49}
50
51impl AiExecutor {
52 pub fn new(provider: Arc<dyn Provider>) -> Self {
54 Self {
55 provider,
56 config: AiExecutorConfig::default(),
57 template_renderer: TemplateRenderer::new(),
58 }
59 }
60
61 pub fn with_config(provider: Arc<dyn Provider>, config: AiExecutorConfig) -> Self {
63 Self {
64 provider,
65 config,
66 template_renderer: TemplateRenderer::new(),
67 }
68 }
69
70 pub fn extract_text_content(response: &ChatResponse) -> Result<String> {
72 let mut text_parts = Vec::new();
73 for block in &response.content {
74 match block {
75 ContentBlock::Text { text } => {
76 text_parts.push(text.clone());
77 }
78 ContentBlock::Thinking { thinking, .. } => {
79 text_parts.push(format!("[Thinking]\n{}", thinking));
80 }
81 _ => {}
82 }
83 }
84 Ok(text_parts.join("\n"))
85 }
86
87 pub fn extract_structured_output(response: &ChatResponse) -> Result<serde_json::Value> {
89 for block in &response.content {
91 if let ContentBlock::Text { text } = block {
92 if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
94 return Ok(json);
95 }
96 }
97 }
98
99 let text = Self::extract_text_content(response)?;
101 let stop_reason_str = match response.stop_reason {
102 crate::providers::StopReason::EndTurn => "end_turn",
103 crate::providers::StopReason::ToolUse => "tool_use",
104 crate::providers::StopReason::MaxTokens => "max_tokens",
105 };
106 Ok(serde_json::json!({
107 "text": text,
108 "stop_reason": stop_reason_str,
109 "usage": {
110 "input_tokens": response.usage.input_tokens,
111 "output_tokens": response.usage.output_tokens,
112 }
113 }))
114 }
115}
116
117#[async_trait]
118impl NodeExecutor for AiExecutor {
119 async fn execute(
120 &self,
121 node: &NodeDef,
122 context: &mut WorkflowContext,
123 ) -> Result<serde_json::Value> {
124 let task_name = node.task.as_ref()
126 .ok_or_else(|| anyhow::anyhow!("AI executor requires a task name"))?;
127
128 let mut prompt_parts = Vec::new();
130
131 prompt_parts.push(format!("Task: {}", task_name));
133
134 if let Some(desc) = &node.description {
136 prompt_parts.push(format!("Description: {}", desc));
137 }
138
139 for (key, value) in &node.params {
141 let rendered_value = if let serde_json::Value::String(s) = value {
142 self.template_renderer.render(s, &context.variables)?
143 } else {
144 value.to_string()
145 };
146 prompt_parts.push(format!("{}: {}", key, rendered_value));
147 }
148
149 if !context.variables.is_empty() {
151 prompt_parts.push("\nContext:".to_string());
152 for (key, value) in &context.variables {
153 prompt_parts.push(format!(" {}: {}", key, value));
154 }
155 }
156
157 let user_message = prompt_parts.join("\n");
158
159 let messages = vec![Message {
161 role: crate::providers::Role::User,
162 content: MessageContent::Text(user_message),
163 }];
164
165 let request = ChatRequest {
166 messages,
167 tools: Vec::new(),
168 system: self.config.system_prompt.clone(),
169 think: self.config.enable_thinking,
170 max_tokens: self.config.max_tokens,
171 server_tools: Vec::new(),
172 enable_caching: false,
173 };
174
175 let response = self.provider.chat(request)
177 .await
178 .with_context(|| format!("AI executor failed for task '{}'", task_name))?;
179
180 let output = Self::extract_structured_output(&response)?;
182
183 let output_ref = &output;
185 if let serde_json::Value::Object(map) = output_ref {
186 for (key, value) in map {
187 context.set_variable(key.clone(), value.clone());
188 }
189 }
190
191 Ok(output)
192 }
193
194 fn name(&self) -> &str {
195 "ai_executor"
196 }
197}