ironflow_engine/executor/
agent.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::json;
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::{Agent, PermissionMode};
11use ironflow_core::provider::AgentProvider;
12
13use crate::config::AgentStepConfig;
14use crate::error::EngineError;
15
16use super::{StepExecutor, StepOutput};
17
18pub struct AgentExecutor<'a> {
23 config: &'a AgentStepConfig,
24}
25
26impl<'a> AgentExecutor<'a> {
27 pub fn new(config: &'a AgentStepConfig) -> Self {
29 Self { config }
30 }
31}
32
33impl StepExecutor for AgentExecutor<'_> {
34 async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
35 let start = Instant::now();
36
37 let mut agent = Agent::new().prompt(&self.config.prompt);
38
39 if let Some(ref sp) = self.config.system_prompt {
40 agent = agent.system_prompt(sp);
41 }
42 if let Some(ref model_str) = self.config.model {
43 agent = agent.model(model_str.clone());
44 }
45 if let Some(budget) = self.config.max_budget_usd {
46 agent = agent.max_budget_usd(budget);
47 }
48 if let Some(turns) = self.config.max_turns {
49 agent = agent.max_turns(turns);
50 }
51 if !self.config.allowed_tools.is_empty() {
52 let tool_refs: Vec<&str> = self
53 .config
54 .allowed_tools
55 .iter()
56 .map(|s| s.as_str())
57 .collect();
58 agent = agent.allowed_tools(&tool_refs);
59 }
60 if let Some(ref dir) = self.config.working_dir {
61 agent = agent.working_dir(dir);
62 }
63 if let Some(ref mode) = self.config.permission_mode {
64 let pm = parse_permission_mode(mode);
65 agent = agent.permission_mode(pm);
66 }
67 if let Some(ref schema) = self.config.output_schema {
68 if self.config.max_turns == Some(1) {
69 warn!(
70 "structured output (output_schema) requires max_turns >= 2; \
71 max_turns is set to 1, the agent will likely fail with error_max_turns"
72 );
73 }
74 agent = agent.output_schema_raw(schema);
75 }
76
77 let result = agent.run(provider.as_ref()).await?;
78 let duration_ms = start.elapsed().as_millis() as u64;
79 let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
80 let input_tokens = result.input_tokens();
81 let output_tokens = result.output_tokens();
82
83 info!(
84 step_kind = "agent",
85 model = ?self.config.model,
86 cost_usd = %cost,
87 input_tokens = ?input_tokens,
88 output_tokens = ?output_tokens,
89 duration_ms,
90 "agent step completed"
91 );
92
93 #[cfg(feature = "prometheus")]
94 {
95 use ironflow_core::metric_names::{
96 AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
97 AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
98 };
99 use metrics::{counter, gauge, histogram};
100 let model_label = self
101 .config
102 .model
103 .clone()
104 .unwrap_or_else(|| "default".to_string());
105 counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
106 .increment(1);
107 histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
108 .record(duration_ms as f64 / 1000.0);
109 gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
110 .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
111 if let Some(inp) = input_tokens {
112 counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
113 }
114 if let Some(out) = output_tokens {
115 counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
116 }
117 }
118
119 Ok(StepOutput {
120 output: json!({
121 "value": result.value(),
122 "model": result.model(),
123 }),
124 duration_ms,
125 cost_usd: cost,
126 input_tokens,
127 output_tokens,
128 })
129 }
130}
131
132fn parse_permission_mode(s: &str) -> PermissionMode {
136 match s.to_lowercase().as_str() {
137 "auto" => PermissionMode::Auto,
138 "dont_ask" | "dontask" => PermissionMode::DontAsk,
139 "bypass" | "bypass_permissions" => PermissionMode::BypassPermissions,
140 _ => PermissionMode::Default,
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn parse_permission_mode_auto() {
150 let result = parse_permission_mode("auto");
151 assert!(matches!(result, PermissionMode::Auto));
152 }
153
154 #[test]
155 fn parse_permission_mode_dont_ask() {
156 let result = parse_permission_mode("dont_ask");
157 assert!(matches!(result, PermissionMode::DontAsk));
158 }
159
160 #[test]
161 fn parse_permission_mode_dont_ask_alt() {
162 let result = parse_permission_mode("dontask");
163 assert!(matches!(result, PermissionMode::DontAsk));
164 }
165
166 #[test]
167 fn parse_permission_mode_bypass() {
168 let result = parse_permission_mode("bypass");
169 assert!(matches!(result, PermissionMode::BypassPermissions));
170 }
171
172 #[test]
173 fn parse_permission_mode_bypass_alt() {
174 let result = parse_permission_mode("bypass_permissions");
175 assert!(matches!(result, PermissionMode::BypassPermissions));
176 }
177
178 #[test]
179 fn parse_permission_mode_unknown_defaults() {
180 let result = parse_permission_mode("unknown");
181 assert!(matches!(result, PermissionMode::Default));
182 }
183
184 #[test]
185 fn parse_permission_mode_case_insensitive() {
186 assert!(matches!(
187 parse_permission_mode("AUTO"),
188 PermissionMode::Auto
189 ));
190 assert!(matches!(
191 parse_permission_mode("DONT_ASK"),
192 PermissionMode::DontAsk
193 ));
194 assert!(matches!(
195 parse_permission_mode("BYPASS"),
196 PermissionMode::BypassPermissions
197 ));
198 }
199}