Skip to main content

ironflow_engine/executor/
agent.rs

1//! Agent step executor.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::json;
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::{Agent, PermissionMode};
11use ironflow_core::provider::AgentProvider;
12
13use crate::config::AgentStepConfig;
14use crate::error::EngineError;
15
16use super::{StepExecutor, StepOutput};
17
18/// Executor for agent (AI) steps.
19///
20/// Runs an AI agent with the given prompt and configuration, capturing
21/// the response value, cost, and token counts.
22pub struct AgentExecutor<'a> {
23    config: &'a AgentStepConfig,
24}
25
26impl<'a> AgentExecutor<'a> {
27    /// Create a new agent executor from a config reference.
28    pub fn new(config: &'a AgentStepConfig) -> Self {
29        Self { config }
30    }
31}
32
33impl StepExecutor for AgentExecutor<'_> {
34    async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
35        let start = Instant::now();
36
37        let mut agent = Agent::new().prompt(&self.config.prompt);
38
39        if let Some(ref sp) = self.config.system_prompt {
40            agent = agent.system_prompt(sp);
41        }
42        if let Some(ref model_str) = self.config.model {
43            agent = agent.model(model_str.clone());
44        }
45        if let Some(budget) = self.config.max_budget_usd {
46            agent = agent.max_budget_usd(budget);
47        }
48        if let Some(turns) = self.config.max_turns {
49            agent = agent.max_turns(turns);
50        }
51        if !self.config.allowed_tools.is_empty() {
52            let tool_refs: Vec<&str> = self
53                .config
54                .allowed_tools
55                .iter()
56                .map(|s| s.as_str())
57                .collect();
58            agent = agent.allowed_tools(&tool_refs);
59        }
60        if let Some(ref dir) = self.config.working_dir {
61            agent = agent.working_dir(dir);
62        }
63        if let Some(ref mode) = self.config.permission_mode {
64            let pm = parse_permission_mode(mode);
65            agent = agent.permission_mode(pm);
66        }
67        if let Some(ref schema) = self.config.output_schema {
68            if self.config.max_turns == Some(1) {
69                warn!(
70                    "structured output (output_schema) requires max_turns >= 2; \
71                     max_turns is set to 1, the agent will likely fail with error_max_turns"
72                );
73            }
74            agent = agent.output_schema_raw(schema);
75        }
76
77        let result = agent.run(provider.as_ref()).await?;
78        let duration_ms = start.elapsed().as_millis() as u64;
79        let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
80        let input_tokens = result.input_tokens();
81        let output_tokens = result.output_tokens();
82
83        info!(
84            step_kind = "agent",
85            model = ?self.config.model,
86            cost_usd = %cost,
87            input_tokens = ?input_tokens,
88            output_tokens = ?output_tokens,
89            duration_ms,
90            "agent step completed"
91        );
92
93        #[cfg(feature = "prometheus")]
94        {
95            use ironflow_core::metric_names::{
96                AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
97                AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
98            };
99            use metrics::{counter, gauge, histogram};
100            let model_label = self
101                .config
102                .model
103                .clone()
104                .unwrap_or_else(|| "default".to_string());
105            counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
106                .increment(1);
107            histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
108                .record(duration_ms as f64 / 1000.0);
109            gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
110                .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
111            if let Some(inp) = input_tokens {
112                counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
113            }
114            if let Some(out) = output_tokens {
115                counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
116            }
117        }
118
119        Ok(StepOutput {
120            output: json!({
121                "value": result.value(),
122                "model": result.model(),
123            }),
124            duration_ms,
125            cost_usd: cost,
126            input_tokens,
127            output_tokens,
128        })
129    }
130}
131
132/// Parse a permission mode string into a [`PermissionMode`] enum.
133///
134/// Unknown values default to [`PermissionMode::Default`].
135fn parse_permission_mode(s: &str) -> PermissionMode {
136    match s.to_lowercase().as_str() {
137        "auto" => PermissionMode::Auto,
138        "dont_ask" | "dontask" => PermissionMode::DontAsk,
139        "bypass" | "bypass_permissions" => PermissionMode::BypassPermissions,
140        _ => PermissionMode::Default,
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn parse_permission_mode_auto() {
150        let result = parse_permission_mode("auto");
151        assert!(matches!(result, PermissionMode::Auto));
152    }
153
154    #[test]
155    fn parse_permission_mode_dont_ask() {
156        let result = parse_permission_mode("dont_ask");
157        assert!(matches!(result, PermissionMode::DontAsk));
158    }
159
160    #[test]
161    fn parse_permission_mode_dont_ask_alt() {
162        let result = parse_permission_mode("dontask");
163        assert!(matches!(result, PermissionMode::DontAsk));
164    }
165
166    #[test]
167    fn parse_permission_mode_bypass() {
168        let result = parse_permission_mode("bypass");
169        assert!(matches!(result, PermissionMode::BypassPermissions));
170    }
171
172    #[test]
173    fn parse_permission_mode_bypass_alt() {
174        let result = parse_permission_mode("bypass_permissions");
175        assert!(matches!(result, PermissionMode::BypassPermissions));
176    }
177
178    #[test]
179    fn parse_permission_mode_unknown_defaults() {
180        let result = parse_permission_mode("unknown");
181        assert!(matches!(result, PermissionMode::Default));
182    }
183
184    #[test]
185    fn parse_permission_mode_case_insensitive() {
186        assert!(matches!(
187            parse_permission_mode("AUTO"),
188            PermissionMode::Auto
189        ));
190        assert!(matches!(
191            parse_permission_mode("DONT_ASK"),
192            PermissionMode::DontAsk
193        ));
194        assert!(matches!(
195            parse_permission_mode("BYPASS"),
196            PermissionMode::BypassPermissions
197        ));
198    }
199}