use anyhow::{Context, Result};
use async_trait::async_trait;
use std::sync::Arc;
use crate::providers::{ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider};
use crate::workflow::context::WorkflowContext;
use crate::workflow::def::NodeDef;
use crate::workflow::template::TemplateRenderer;
use super::node_executor::NodeExecutor;
#[derive(Debug, Clone)]
pub struct AiExecutorConfig {
pub system_prompt: Option<String>,
pub max_tokens: u32,
pub enable_thinking: bool,
pub enable_streaming: bool,
}
impl Default for AiExecutorConfig {
fn default() -> Self {
Self {
system_prompt: None,
max_tokens: 4096,
enable_thinking: false,
enable_streaming: false,
}
}
}
pub struct AiExecutor {
provider: Arc<dyn Provider>,
config: AiExecutorConfig,
template_renderer: TemplateRenderer,
}
impl AiExecutor {
pub fn new(provider: Arc<dyn Provider>) -> Self {
Self {
provider,
config: AiExecutorConfig::default(),
template_renderer: TemplateRenderer::new(),
}
}
pub fn with_config(provider: Arc<dyn Provider>, config: AiExecutorConfig) -> Self {
Self {
provider,
config,
template_renderer: TemplateRenderer::new(),
}
}
pub fn extract_text_content(response: &ChatResponse) -> Result<String> {
let mut text_parts = Vec::new();
for block in &response.content {
match block {
ContentBlock::Text { text } => {
text_parts.push(text.clone());
}
ContentBlock::Thinking { thinking, .. } => {
text_parts.push(format!("[Thinking]\n{}", thinking));
}
_ => {}
}
}
Ok(text_parts.join("\n"))
}
pub fn extract_structured_output(response: &ChatResponse) -> Result<serde_json::Value> {
for block in &response.content {
if let ContentBlock::Text { text } = block {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(text) {
return Ok(json);
}
}
}
let text = Self::extract_text_content(response)?;
let stop_reason_str = match response.stop_reason {
crate::providers::StopReason::EndTurn => "end_turn",
crate::providers::StopReason::ToolUse => "tool_use",
crate::providers::StopReason::MaxTokens => "max_tokens",
};
Ok(serde_json::json!({
"text": text,
"stop_reason": stop_reason_str,
"usage": {
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens,
}
}))
}
}
#[async_trait]
impl NodeExecutor for AiExecutor {
async fn execute(
&self,
node: &NodeDef,
context: &mut WorkflowContext,
) -> Result<serde_json::Value> {
let task_name = node.task.as_ref()
.ok_or_else(|| anyhow::anyhow!("AI executor requires a task name"))?;
let mut prompt_parts = Vec::new();
prompt_parts.push(format!("Task: {}", task_name));
if let Some(desc) = &node.description {
prompt_parts.push(format!("Description: {}", desc));
}
for (key, value) in &node.params {
let rendered_value = if let serde_json::Value::String(s) = value {
self.template_renderer.render(s, &context.variables)?
} else {
value.to_string()
};
prompt_parts.push(format!("{}: {}", key, rendered_value));
}
if !context.variables.is_empty() {
prompt_parts.push("\nContext:".to_string());
for (key, value) in &context.variables {
prompt_parts.push(format!(" {}: {}", key, value));
}
}
let user_message = prompt_parts.join("\n");
let messages = vec![Message {
role: crate::providers::Role::User,
content: MessageContent::Text(user_message),
}];
let request = ChatRequest {
messages,
tools: Vec::new(),
system: self.config.system_prompt.clone(),
think: self.config.enable_thinking,
max_tokens: self.config.max_tokens,
server_tools: Vec::new(),
enable_caching: false,
};
let response = self.provider.chat(request)
.await
.with_context(|| format!("AI executor failed for task '{}'", task_name))?;
let output = Self::extract_structured_output(&response)?;
let output_ref = &output;
if let serde_json::Value::Object(map) = output_ref {
for (key, value) in map {
context.set_variable(key.clone(), value.clone());
}
}
Ok(output)
}
fn name(&self) -> &str {
"ai_executor"
}
}