ironflow_engine/executor/
agent.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::json;
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::{Agent, PermissionMode};
11use ironflow_core::provider::AgentProvider;
12
13use crate::config::AgentStepConfig;
14use crate::error::EngineError;
15
16use super::{StepExecutor, StepOutput};
17
18pub struct AgentExecutor<'a> {
23 config: &'a AgentStepConfig,
24}
25
26impl<'a> AgentExecutor<'a> {
27 pub fn new(config: &'a AgentStepConfig) -> Self {
29 Self { config }
30 }
31}
32
33impl StepExecutor for AgentExecutor<'_> {
34 async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
35 let start = Instant::now();
36
37 let mut agent = Agent::new().prompt(&self.config.prompt);
38
39 if let Some(ref sp) = self.config.system_prompt {
40 agent = agent.system_prompt(sp);
41 }
42 if let Some(ref model_str) = self.config.model {
43 agent = agent.model(model_str.clone());
44 }
45 if let Some(budget) = self.config.max_budget_usd {
46 agent = agent.max_budget_usd(budget);
47 }
48 if let Some(turns) = self.config.max_turns {
49 agent = agent.max_turns(turns);
50 }
51 if !self.config.allowed_tools.is_empty() {
52 let tool_refs: Vec<&str> = self
53 .config
54 .allowed_tools
55 .iter()
56 .map(|s| s.as_str())
57 .collect();
58 agent = agent.allowed_tools(&tool_refs);
59 }
60 if let Some(ref dir) = self.config.working_dir {
61 agent = agent.working_dir(dir);
62 }
63 if let Some(ref mode) = self.config.permission_mode {
64 let pm = parse_permission_mode(mode);
65 agent = agent.permission_mode(pm);
66 }
67 if let Some(ref schema) = self.config.output_schema {
68 if self.config.max_turns == Some(1) {
69 warn!(
70 "structured output (output_schema) requires max_turns >= 2; \
71 max_turns is set to 1, the agent will likely fail with error_max_turns"
72 );
73 }
74 agent = agent.output_schema_raw(schema);
75 }
76
77 let result = agent.run(provider.as_ref()).await?;
78 let duration_ms = start.elapsed().as_millis() as u64;
79 let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
80 let input_tokens = result.input_tokens();
81 let output_tokens = result.output_tokens();
82
83 info!(
84 step_kind = "agent",
85 model = ?self.config.model,
86 cost_usd = %cost,
87 input_tokens = ?input_tokens,
88 output_tokens = ?output_tokens,
89 duration_ms,
90 "agent step completed"
91 );
92
93 Ok(StepOutput {
94 output: json!({
95 "value": result.value(),
96 "model": result.model(),
97 }),
98 duration_ms,
99 cost_usd: cost,
100 input_tokens,
101 output_tokens,
102 })
103 }
104}
105
106fn parse_permission_mode(s: &str) -> PermissionMode {
110 match s.to_lowercase().as_str() {
111 "auto" => PermissionMode::Auto,
112 "dont_ask" | "dontask" => PermissionMode::DontAsk,
113 "bypass" | "bypass_permissions" => PermissionMode::BypassPermissions,
114 _ => PermissionMode::Default,
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn parse_permission_mode_auto() {
124 let result = parse_permission_mode("auto");
125 assert!(matches!(result, PermissionMode::Auto));
126 }
127
128 #[test]
129 fn parse_permission_mode_dont_ask() {
130 let result = parse_permission_mode("dont_ask");
131 assert!(matches!(result, PermissionMode::DontAsk));
132 }
133
134 #[test]
135 fn parse_permission_mode_dont_ask_alt() {
136 let result = parse_permission_mode("dontask");
137 assert!(matches!(result, PermissionMode::DontAsk));
138 }
139
140 #[test]
141 fn parse_permission_mode_bypass() {
142 let result = parse_permission_mode("bypass");
143 assert!(matches!(result, PermissionMode::BypassPermissions));
144 }
145
146 #[test]
147 fn parse_permission_mode_bypass_alt() {
148 let result = parse_permission_mode("bypass_permissions");
149 assert!(matches!(result, PermissionMode::BypassPermissions));
150 }
151
152 #[test]
153 fn parse_permission_mode_unknown_defaults() {
154 let result = parse_permission_mode("unknown");
155 assert!(matches!(result, PermissionMode::Default));
156 }
157
158 #[test]
159 fn parse_permission_mode_case_insensitive() {
160 assert!(matches!(
161 parse_permission_mode("AUTO"),
162 PermissionMode::Auto
163 ));
164 assert!(matches!(
165 parse_permission_mode("DONT_ASK"),
166 PermissionMode::DontAsk
167 ));
168 assert!(matches!(
169 parse_permission_mode("BYPASS"),
170 PermissionMode::BypassPermissions
171 ));
172 }
173}