Skip to main content

ironflow_engine/executor/
agent.rs

1//! Agent step executor.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::json;
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::{Agent, PermissionMode};
11use ironflow_core::provider::AgentProvider;
12
13use crate::config::AgentStepConfig;
14use crate::error::EngineError;
15
16use super::{StepExecutor, StepOutput};
17
18/// Executor for agent (AI) steps.
19///
20/// Runs an AI agent with the given prompt and configuration, capturing
21/// the response value, cost, and token counts.
22pub struct AgentExecutor<'a> {
23    config: &'a AgentStepConfig,
24}
25
26impl<'a> AgentExecutor<'a> {
27    /// Create a new agent executor from a config reference.
28    pub fn new(config: &'a AgentStepConfig) -> Self {
29        Self { config }
30    }
31}
32
33impl StepExecutor for AgentExecutor<'_> {
34    async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
35        let start = Instant::now();
36
37        let mut agent = Agent::new().prompt(&self.config.prompt);
38
39        if let Some(ref sp) = self.config.system_prompt {
40            agent = agent.system_prompt(sp);
41        }
42        if let Some(ref model_str) = self.config.model {
43            agent = agent.model(model_str.clone());
44        }
45        if let Some(budget) = self.config.max_budget_usd {
46            agent = agent.max_budget_usd(budget);
47        }
48        if let Some(turns) = self.config.max_turns {
49            agent = agent.max_turns(turns);
50        }
51        if !self.config.allowed_tools.is_empty() {
52            let tool_refs: Vec<&str> = self
53                .config
54                .allowed_tools
55                .iter()
56                .map(|s| s.as_str())
57                .collect();
58            agent = agent.allowed_tools(&tool_refs);
59        }
60        if let Some(ref dir) = self.config.working_dir {
61            agent = agent.working_dir(dir);
62        }
63        if let Some(ref mode) = self.config.permission_mode {
64            let pm = parse_permission_mode(mode);
65            agent = agent.permission_mode(pm);
66        }
67        if let Some(ref schema) = self.config.output_schema {
68            if self.config.max_turns == Some(1) {
69                warn!(
70                    "structured output (output_schema) requires max_turns >= 2; \
71                     max_turns is set to 1, the agent will likely fail with error_max_turns"
72                );
73            }
74            agent = agent.output_schema_raw(schema);
75        }
76
77        let result = agent.run(provider.as_ref()).await?;
78        let duration_ms = start.elapsed().as_millis() as u64;
79        let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
80        let input_tokens = result.input_tokens();
81        let output_tokens = result.output_tokens();
82
83        info!(
84            step_kind = "agent",
85            model = ?self.config.model,
86            cost_usd = %cost,
87            input_tokens = ?input_tokens,
88            output_tokens = ?output_tokens,
89            duration_ms,
90            "agent step completed"
91        );
92
93        Ok(StepOutput {
94            output: json!({
95                "value": result.value(),
96                "model": result.model(),
97            }),
98            duration_ms,
99            cost_usd: cost,
100            input_tokens,
101            output_tokens,
102        })
103    }
104}
105
106/// Parse a permission mode string into a [`PermissionMode`] enum.
107///
108/// Unknown values default to [`PermissionMode::Default`].
109fn parse_permission_mode(s: &str) -> PermissionMode {
110    match s.to_lowercase().as_str() {
111        "auto" => PermissionMode::Auto,
112        "dont_ask" | "dontask" => PermissionMode::DontAsk,
113        "bypass" | "bypass_permissions" => PermissionMode::BypassPermissions,
114        _ => PermissionMode::Default,
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn parse_permission_mode_auto() {
124        let result = parse_permission_mode("auto");
125        assert!(matches!(result, PermissionMode::Auto));
126    }
127
128    #[test]
129    fn parse_permission_mode_dont_ask() {
130        let result = parse_permission_mode("dont_ask");
131        assert!(matches!(result, PermissionMode::DontAsk));
132    }
133
134    #[test]
135    fn parse_permission_mode_dont_ask_alt() {
136        let result = parse_permission_mode("dontask");
137        assert!(matches!(result, PermissionMode::DontAsk));
138    }
139
140    #[test]
141    fn parse_permission_mode_bypass() {
142        let result = parse_permission_mode("bypass");
143        assert!(matches!(result, PermissionMode::BypassPermissions));
144    }
145
146    #[test]
147    fn parse_permission_mode_bypass_alt() {
148        let result = parse_permission_mode("bypass_permissions");
149        assert!(matches!(result, PermissionMode::BypassPermissions));
150    }
151
152    #[test]
153    fn parse_permission_mode_unknown_defaults() {
154        let result = parse_permission_mode("unknown");
155        assert!(matches!(result, PermissionMode::Default));
156    }
157
158    #[test]
159    fn parse_permission_mode_case_insensitive() {
160        assert!(matches!(
161            parse_permission_mode("AUTO"),
162            PermissionMode::Auto
163        ));
164        assert!(matches!(
165            parse_permission_mode("DONT_ASK"),
166            PermissionMode::DontAsk
167        ));
168        assert!(matches!(
169            parse_permission_mode("BYPASS"),
170            PermissionMode::BypassPermissions
171        ));
172    }
173}