Skip to main content

ironflow_engine/executor/
agent.rs

1//! Agent step executor.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::json;
8use tracing::info;
9
10use ironflow_core::operations::agent::{Agent, Model, PermissionMode};
11use ironflow_core::provider::AgentProvider;
12
13use crate::config::AgentStepConfig;
14use crate::error::EngineError;
15
16use super::{StepExecutor, StepOutput};
17
18/// Executor for agent (AI) steps.
19///
20/// Runs an AI agent with the given prompt and configuration, capturing
21/// the response value, cost, and token counts.
22pub struct AgentExecutor<'a> {
23    config: &'a AgentStepConfig,
24}
25
26impl<'a> AgentExecutor<'a> {
27    /// Create a new agent executor from a config reference.
28    pub fn new(config: &'a AgentStepConfig) -> Self {
29        Self { config }
30    }
31}
32
33impl StepExecutor for AgentExecutor<'_> {
34    async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
35        let start = Instant::now();
36
37        let mut agent = Agent::new().prompt(&self.config.prompt);
38
39        if let Some(ref sp) = self.config.system_prompt {
40            agent = agent.system_prompt(sp);
41        }
42        if let Some(ref model_str) = self.config.model {
43            let model = parse_model(model_str)?;
44            agent = agent.model(model);
45        }
46        if let Some(budget) = self.config.max_budget_usd {
47            agent = agent.max_budget_usd(budget);
48        }
49        if let Some(turns) = self.config.max_turns {
50            agent = agent.max_turns(turns);
51        }
52        if !self.config.allowed_tools.is_empty() {
53            let tool_refs: Vec<&str> = self
54                .config
55                .allowed_tools
56                .iter()
57                .map(|s| s.as_str())
58                .collect();
59            agent = agent.allowed_tools(&tool_refs);
60        }
61        if let Some(ref dir) = self.config.working_dir {
62            agent = agent.working_dir(dir);
63        }
64        if let Some(ref mode) = self.config.permission_mode {
65            let pm = parse_permission_mode(mode);
66            agent = agent.permission_mode(pm);
67        }
68
69        let result = agent.run(provider.as_ref()).await?;
70        let duration_ms = start.elapsed().as_millis() as u64;
71        let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
72        let input_tokens = result.input_tokens();
73        let output_tokens = result.output_tokens();
74
75        info!(
76            step_kind = "agent",
77            model = ?self.config.model,
78            cost_usd = %cost,
79            input_tokens = ?input_tokens,
80            output_tokens = ?output_tokens,
81            duration_ms,
82            "agent step completed"
83        );
84
85        Ok(StepOutput {
86            output: json!({
87                "value": result.value(),
88                "model": result.model(),
89            }),
90            duration_ms,
91            cost_usd: cost,
92            input_tokens,
93            output_tokens,
94        })
95    }
96}
97
98/// Parse a model string into a [`Model`] enum.
99///
100/// Supports multiple formats for backward compatibility:
101/// - "sonnet", "opus", "haiku"
102/// - "haiku45", "haiku-4.5"
103/// - "sonnet46", "sonnet-4.6"
104/// - "opus46", "opus-4.6"
105fn parse_model(s: &str) -> Result<Model, EngineError> {
106    match s.to_lowercase().as_str() {
107        "sonnet" => Ok(Model::Sonnet),
108        "opus" => Ok(Model::Opus),
109        "haiku" => Ok(Model::Haiku),
110        "haiku45" | "haiku-4.5" => Ok(Model::Haiku45),
111        "sonnet46" | "sonnet-4.6" => Ok(Model::Sonnet46),
112        "opus46" | "opus-4.6" => Ok(Model::Opus46),
113        other => Err(EngineError::StepConfig(format!("unknown model: {other}"))),
114    }
115}
116
117/// Parse a permission mode string into a [`PermissionMode`] enum.
118///
119/// Unknown values default to [`PermissionMode::Default`].
120fn parse_permission_mode(s: &str) -> PermissionMode {
121    match s.to_lowercase().as_str() {
122        "auto" => PermissionMode::Auto,
123        "dont_ask" | "dontask" => PermissionMode::DontAsk,
124        "bypass" | "bypass_permissions" => PermissionMode::BypassPermissions,
125        _ => PermissionMode::Default,
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    #[test]
133    fn parse_model_sonnet() {
134        let result = parse_model("sonnet");
135        assert!(result.is_ok());
136        assert_eq!(result.unwrap(), Model::Sonnet);
137    }
138
139    #[test]
140    fn parse_model_opus() {
141        let result = parse_model("opus");
142        assert!(result.is_ok());
143        assert_eq!(result.unwrap(), Model::Opus);
144    }
145
146    #[test]
147    fn parse_model_haiku() {
148        let result = parse_model("haiku");
149        assert!(result.is_ok());
150        assert_eq!(result.unwrap(), Model::Haiku);
151    }
152
153    #[test]
154    fn parse_model_haiku45() {
155        let result = parse_model("haiku45");
156        assert!(result.is_ok());
157        assert_eq!(result.unwrap(), Model::Haiku45);
158    }
159
160    #[test]
161    fn parse_model_haiku_with_dash() {
162        let result = parse_model("haiku-4.5");
163        assert!(result.is_ok());
164        assert_eq!(result.unwrap(), Model::Haiku45);
165    }
166
167    #[test]
168    fn parse_model_sonnet46() {
169        let result = parse_model("sonnet46");
170        assert!(result.is_ok());
171        assert_eq!(result.unwrap(), Model::Sonnet46);
172    }
173
174    #[test]
175    fn parse_model_sonnet_with_dash() {
176        let result = parse_model("sonnet-4.6");
177        assert!(result.is_ok());
178        assert_eq!(result.unwrap(), Model::Sonnet46);
179    }
180
181    #[test]
182    fn parse_model_opus46() {
183        let result = parse_model("opus46");
184        assert!(result.is_ok());
185        assert_eq!(result.unwrap(), Model::Opus46);
186    }
187
188    #[test]
189    fn parse_model_opus_with_dash() {
190        let result = parse_model("opus-4.6");
191        assert!(result.is_ok());
192        assert_eq!(result.unwrap(), Model::Opus46);
193    }
194
195    #[test]
196    fn parse_model_unknown_returns_error() {
197        let result = parse_model("invalid-model");
198        assert!(result.is_err());
199        match result {
200            Err(EngineError::StepConfig(msg)) => {
201                assert!(msg.contains("unknown model"));
202            }
203            _ => panic!("expected StepConfig error"),
204        }
205    }
206
207    #[test]
208    fn parse_model_case_insensitive() {
209        assert!(parse_model("SONNET").is_ok());
210        assert!(parse_model("OpUs").is_ok());
211        assert!(parse_model("HAIKU").is_ok());
212    }
213
214    #[test]
215    fn parse_permission_mode_auto() {
216        let result = parse_permission_mode("auto");
217        assert!(matches!(result, PermissionMode::Auto));
218    }
219
220    #[test]
221    fn parse_permission_mode_dont_ask() {
222        let result = parse_permission_mode("dont_ask");
223        assert!(matches!(result, PermissionMode::DontAsk));
224    }
225
226    #[test]
227    fn parse_permission_mode_dont_ask_alt() {
228        let result = parse_permission_mode("dontask");
229        assert!(matches!(result, PermissionMode::DontAsk));
230    }
231
232    #[test]
233    fn parse_permission_mode_bypass() {
234        let result = parse_permission_mode("bypass");
235        assert!(matches!(result, PermissionMode::BypassPermissions));
236    }
237
238    #[test]
239    fn parse_permission_mode_bypass_alt() {
240        let result = parse_permission_mode("bypass_permissions");
241        assert!(matches!(result, PermissionMode::BypassPermissions));
242    }
243
244    #[test]
245    fn parse_permission_mode_unknown_defaults() {
246        let result = parse_permission_mode("unknown");
247        assert!(matches!(result, PermissionMode::Default));
248    }
249
250    #[test]
251    fn parse_permission_mode_case_insensitive() {
252        assert!(matches!(
253            parse_permission_mode("AUTO"),
254            PermissionMode::Auto
255        ));
256        assert!(matches!(
257            parse_permission_mode("DONT_ASK"),
258            PermissionMode::DontAsk
259        ));
260        assert!(matches!(
261            parse_permission_mode("BYPASS"),
262            PermissionMode::BypassPermissions
263        ));
264    }
265}