Skip to main content

ironflow_engine/executor/
agent.rs

1//! Agent step executor.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::json;
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::Agent;
11use ironflow_core::provider::{AgentConfig, AgentProvider};
12
13use crate::error::EngineError;
14
15use super::{StepExecutor, StepOutput};
16
17/// Executor for agent (AI) steps.
18///
19/// Runs an AI agent with the given prompt and configuration, capturing
20/// the response value, cost, and token counts.
21pub struct AgentExecutor<'a> {
22    config: &'a AgentConfig,
23}
24
25impl<'a> AgentExecutor<'a> {
26    /// Create a new agent executor from a config reference.
27    pub fn new(config: &'a AgentConfig) -> Self {
28        Self { config }
29    }
30}
31
32impl StepExecutor for AgentExecutor<'_> {
33    async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
34        let start = Instant::now();
35
36        if self.config.json_schema.is_some() && self.config.max_turns == Some(1) {
37            warn!(
38                "structured output (json_schema) requires max_turns >= 2; \
39                 max_turns is set to 1, the agent will likely fail with error_max_turns"
40            );
41        }
42
43        let result = Agent::from_config(self.config.clone())
44            .run(provider.as_ref())
45            .await?;
46
47        let duration_ms = start.elapsed().as_millis() as u64;
48        let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
49        let input_tokens = result.input_tokens();
50        let output_tokens = result.output_tokens();
51
52        info!(
53            step_kind = "agent",
54            model = %self.config.model,
55            cost_usd = %cost,
56            input_tokens = ?input_tokens,
57            output_tokens = ?output_tokens,
58            duration_ms,
59            "agent step completed"
60        );
61
62        #[cfg(feature = "prometheus")]
63        {
64            use ironflow_core::metric_names::{
65                AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
66                AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
67            };
68            use metrics::{counter, gauge, histogram};
69            let model_label = self.config.model.clone();
70            counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
71                .increment(1);
72            histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
73                .record(duration_ms as f64 / 1000.0);
74            gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
75                .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
76            if let Some(inp) = input_tokens {
77                counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
78            }
79            if let Some(out) = output_tokens {
80                counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
81            }
82        }
83
84        let debug_messages = result.debug_messages().map(|msgs| msgs.to_vec());
85
86        Ok(StepOutput {
87            output: json!({
88                "value": result.value(),
89                "model": result.model(),
90            }),
91            duration_ms,
92            cost_usd: cost,
93            input_tokens,
94            output_tokens,
95            debug_messages,
96        })
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use ironflow_core::operations::agent::PermissionMode;
103
104    #[test]
105    fn parse_permission_mode_via_serde() {
106        let json = r#""auto""#;
107        let mode: PermissionMode = serde_json::from_str(json).unwrap();
108        assert!(matches!(mode, PermissionMode::Auto));
109    }
110
111    #[test]
112    fn parse_permission_mode_dont_ask() {
113        let json = r#""dont_ask""#;
114        let mode: PermissionMode = serde_json::from_str(json).unwrap();
115        assert!(matches!(mode, PermissionMode::DontAsk));
116    }
117
118    #[test]
119    fn parse_permission_mode_bypass() {
120        let json = r#""bypass""#;
121        let mode: PermissionMode = serde_json::from_str(json).unwrap();
122        assert!(matches!(mode, PermissionMode::BypassPermissions));
123    }
124
125    #[test]
126    fn parse_permission_mode_case_insensitive() {
127        let json = r#""AUTO""#;
128        let mode: PermissionMode = serde_json::from_str(json).unwrap();
129        assert!(matches!(mode, PermissionMode::Auto));
130
131        let json = r#""DONT_ASK""#;
132        let mode: PermissionMode = serde_json::from_str(json).unwrap();
133        assert!(matches!(mode, PermissionMode::DontAsk));
134    }
135
136    #[test]
137    fn parse_permission_mode_unknown_defaults() {
138        let json = r#""unknown""#;
139        let mode: PermissionMode = serde_json::from_str(json).unwrap();
140        assert!(matches!(mode, PermissionMode::Default));
141    }
142}