ironflow_engine/executor/
agent.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::json;
8use tracing::info;
9
10use ironflow_core::operations::agent::{Agent, Model, PermissionMode};
11use ironflow_core::provider::AgentProvider;
12
13use crate::config::AgentStepConfig;
14use crate::error::EngineError;
15
16use super::{StepExecutor, StepOutput};
17
18pub struct AgentExecutor<'a> {
23 config: &'a AgentStepConfig,
24}
25
26impl<'a> AgentExecutor<'a> {
27 pub fn new(config: &'a AgentStepConfig) -> Self {
29 Self { config }
30 }
31}
32
33impl StepExecutor for AgentExecutor<'_> {
34 async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
35 let start = Instant::now();
36
37 let mut agent = Agent::new().prompt(&self.config.prompt);
38
39 if let Some(ref sp) = self.config.system_prompt {
40 agent = agent.system_prompt(sp);
41 }
42 if let Some(ref model_str) = self.config.model {
43 let model = parse_model(model_str)?;
44 agent = agent.model(model);
45 }
46 if let Some(budget) = self.config.max_budget_usd {
47 agent = agent.max_budget_usd(budget);
48 }
49 if let Some(turns) = self.config.max_turns {
50 agent = agent.max_turns(turns);
51 }
52 if !self.config.allowed_tools.is_empty() {
53 let tool_refs: Vec<&str> = self
54 .config
55 .allowed_tools
56 .iter()
57 .map(|s| s.as_str())
58 .collect();
59 agent = agent.allowed_tools(&tool_refs);
60 }
61 if let Some(ref dir) = self.config.working_dir {
62 agent = agent.working_dir(dir);
63 }
64 if let Some(ref mode) = self.config.permission_mode {
65 let pm = parse_permission_mode(mode);
66 agent = agent.permission_mode(pm);
67 }
68
69 let result = agent.run(provider.as_ref()).await?;
70 let duration_ms = start.elapsed().as_millis() as u64;
71 let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
72 let input_tokens = result.input_tokens();
73 let output_tokens = result.output_tokens();
74
75 info!(
76 step_kind = "agent",
77 model = ?self.config.model,
78 cost_usd = %cost,
79 input_tokens = ?input_tokens,
80 output_tokens = ?output_tokens,
81 duration_ms,
82 "agent step completed"
83 );
84
85 Ok(StepOutput {
86 output: json!({
87 "value": result.value(),
88 "model": result.model(),
89 }),
90 duration_ms,
91 cost_usd: cost,
92 input_tokens,
93 output_tokens,
94 })
95 }
96}
97
98fn parse_model(s: &str) -> Result<Model, EngineError> {
106 match s.to_lowercase().as_str() {
107 "sonnet" => Ok(Model::Sonnet),
108 "opus" => Ok(Model::Opus),
109 "haiku" => Ok(Model::Haiku),
110 "haiku45" | "haiku-4.5" => Ok(Model::Haiku45),
111 "sonnet46" | "sonnet-4.6" => Ok(Model::Sonnet46),
112 "opus46" | "opus-4.6" => Ok(Model::Opus46),
113 other => Err(EngineError::StepConfig(format!("unknown model: {other}"))),
114 }
115}
116
117fn parse_permission_mode(s: &str) -> PermissionMode {
121 match s.to_lowercase().as_str() {
122 "auto" => PermissionMode::Auto,
123 "dont_ask" | "dontask" => PermissionMode::DontAsk,
124 "bypass" | "bypass_permissions" => PermissionMode::BypassPermissions,
125 _ => PermissionMode::Default,
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 #[test]
133 fn parse_model_sonnet() {
134 let result = parse_model("sonnet");
135 assert!(result.is_ok());
136 assert_eq!(result.unwrap(), Model::Sonnet);
137 }
138
139 #[test]
140 fn parse_model_opus() {
141 let result = parse_model("opus");
142 assert!(result.is_ok());
143 assert_eq!(result.unwrap(), Model::Opus);
144 }
145
146 #[test]
147 fn parse_model_haiku() {
148 let result = parse_model("haiku");
149 assert!(result.is_ok());
150 assert_eq!(result.unwrap(), Model::Haiku);
151 }
152
153 #[test]
154 fn parse_model_haiku45() {
155 let result = parse_model("haiku45");
156 assert!(result.is_ok());
157 assert_eq!(result.unwrap(), Model::Haiku45);
158 }
159
160 #[test]
161 fn parse_model_haiku_with_dash() {
162 let result = parse_model("haiku-4.5");
163 assert!(result.is_ok());
164 assert_eq!(result.unwrap(), Model::Haiku45);
165 }
166
167 #[test]
168 fn parse_model_sonnet46() {
169 let result = parse_model("sonnet46");
170 assert!(result.is_ok());
171 assert_eq!(result.unwrap(), Model::Sonnet46);
172 }
173
174 #[test]
175 fn parse_model_sonnet_with_dash() {
176 let result = parse_model("sonnet-4.6");
177 assert!(result.is_ok());
178 assert_eq!(result.unwrap(), Model::Sonnet46);
179 }
180
181 #[test]
182 fn parse_model_opus46() {
183 let result = parse_model("opus46");
184 assert!(result.is_ok());
185 assert_eq!(result.unwrap(), Model::Opus46);
186 }
187
188 #[test]
189 fn parse_model_opus_with_dash() {
190 let result = parse_model("opus-4.6");
191 assert!(result.is_ok());
192 assert_eq!(result.unwrap(), Model::Opus46);
193 }
194
195 #[test]
196 fn parse_model_unknown_returns_error() {
197 let result = parse_model("invalid-model");
198 assert!(result.is_err());
199 match result {
200 Err(EngineError::StepConfig(msg)) => {
201 assert!(msg.contains("unknown model"));
202 }
203 _ => panic!("expected StepConfig error"),
204 }
205 }
206
207 #[test]
208 fn parse_model_case_insensitive() {
209 assert!(parse_model("SONNET").is_ok());
210 assert!(parse_model("OpUs").is_ok());
211 assert!(parse_model("HAIKU").is_ok());
212 }
213
214 #[test]
215 fn parse_permission_mode_auto() {
216 let result = parse_permission_mode("auto");
217 assert!(matches!(result, PermissionMode::Auto));
218 }
219
220 #[test]
221 fn parse_permission_mode_dont_ask() {
222 let result = parse_permission_mode("dont_ask");
223 assert!(matches!(result, PermissionMode::DontAsk));
224 }
225
226 #[test]
227 fn parse_permission_mode_dont_ask_alt() {
228 let result = parse_permission_mode("dontask");
229 assert!(matches!(result, PermissionMode::DontAsk));
230 }
231
232 #[test]
233 fn parse_permission_mode_bypass() {
234 let result = parse_permission_mode("bypass");
235 assert!(matches!(result, PermissionMode::BypassPermissions));
236 }
237
238 #[test]
239 fn parse_permission_mode_bypass_alt() {
240 let result = parse_permission_mode("bypass_permissions");
241 assert!(matches!(result, PermissionMode::BypassPermissions));
242 }
243
244 #[test]
245 fn parse_permission_mode_unknown_defaults() {
246 let result = parse_permission_mode("unknown");
247 assert!(matches!(result, PermissionMode::Default));
248 }
249
250 #[test]
251 fn parse_permission_mode_case_insensitive() {
252 assert!(matches!(
253 parse_permission_mode("AUTO"),
254 PermissionMode::Auto
255 ));
256 assert!(matches!(
257 parse_permission_mode("DONT_ASK"),
258 PermissionMode::DontAsk
259 ));
260 assert!(matches!(
261 parse_permission_mode("BYPASS"),
262 PermissionMode::BypassPermissions
263 ));
264 }
265}