Skip to main content

ironflow_engine/executor/
agent.rs

1//! Agent step executor.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use tracing::{info, warn};
8
9use ironflow_core::operations::agent::Agent;
10use ironflow_core::provider::{AgentConfig, AgentProvider};
11
12use crate::error::EngineError;
13use crate::log_sender::StepLogSender;
14use crate::notify::LogStream;
15
16use super::{StepExecutor, StepOutput};
17
18/// Executor for agent (AI) steps.
19///
20/// Runs an AI agent with the given prompt and configuration, capturing
21/// the response value, cost, and token counts. When a [`StepLogSender`]
22/// is attached, emits system log lines for step start/end.
23pub struct AgentExecutor<'a> {
24    config: &'a AgentConfig,
25    log_sender: Option<StepLogSender>,
26}
27
28impl<'a> AgentExecutor<'a> {
29    /// Create a new agent executor from a config reference.
30    pub fn new(config: &'a AgentConfig) -> Self {
31        Self {
32            config,
33            log_sender: None,
34        }
35    }
36
37    /// Attach a log sender for system-level log lines.
38    pub fn with_log_sender(mut self, sender: StepLogSender) -> Self {
39        self.log_sender = Some(sender);
40        self
41    }
42}
43
44impl StepExecutor for AgentExecutor<'_> {
45    async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
46        let start = Instant::now();
47
48        if let Some(ref sender) = self.log_sender {
49            sender.emit(
50                LogStream::System,
51                &format!("agent step started (model={})", self.config.model),
52            );
53        }
54
55        if self.config.json_schema.is_some() && self.config.max_turns == Some(1) {
56            warn!(
57                "structured output (json_schema) requires max_turns >= 2; \
58                 max_turns is set to 1, the agent will likely fail with error_max_turns"
59            );
60        }
61
62        let result = Agent::from_config(self.config.clone())
63            .run(provider.as_ref())
64            .await?;
65
66        let duration_ms = start.elapsed().as_millis() as u64;
67        let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
68        let input_tokens = result.input_tokens();
69        let output_tokens = result.output_tokens();
70
71        info!(
72            step_kind = "agent",
73            model = %self.config.model,
74            cost_usd = %cost,
75            input_tokens = ?input_tokens,
76            output_tokens = ?output_tokens,
77            duration_ms,
78            "agent step completed"
79        );
80
81        #[cfg(feature = "prometheus")]
82        {
83            use ironflow_core::metric_names::{
84                AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
85                AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
86            };
87            use metrics::{counter, gauge, histogram};
88            let model_label = self.config.model.clone();
89            counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
90                .increment(1);
91            histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
92                .record(duration_ms as f64 / 1000.0);
93            gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
94                .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
95            if let Some(inp) = input_tokens {
96                counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
97            }
98            if let Some(out) = output_tokens {
99                counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
100            }
101        }
102
103        if let Some(ref sender) = self.log_sender {
104            sender.emit(
105                LogStream::System,
106                &format!(
107                    "agent step completed (cost=${cost}, tokens_in={}, tokens_out={})",
108                    input_tokens.unwrap_or(0),
109                    output_tokens.unwrap_or(0),
110                ),
111            );
112        }
113
114        let debug_messages = result.debug_messages().map(|msgs| msgs.to_vec());
115
116        Ok(StepOutput {
117            output: result.value().clone(),
118            duration_ms,
119            cost_usd: cost,
120            input_tokens,
121            output_tokens,
122            model: result.model().map(String::from),
123            debug_messages,
124        })
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use ironflow_core::operations::agent::PermissionMode;
131
132    #[test]
133    fn parse_permission_mode_via_serde() {
134        let json = r#""auto""#;
135        let mode: PermissionMode = serde_json::from_str(json).unwrap();
136        assert!(matches!(mode, PermissionMode::Auto));
137    }
138
139    #[test]
140    fn parse_permission_mode_dont_ask() {
141        let json = r#""dont_ask""#;
142        let mode: PermissionMode = serde_json::from_str(json).unwrap();
143        assert!(matches!(mode, PermissionMode::DontAsk));
144    }
145
146    #[test]
147    fn parse_permission_mode_bypass() {
148        let json = r#""bypass""#;
149        let mode: PermissionMode = serde_json::from_str(json).unwrap();
150        assert!(matches!(mode, PermissionMode::BypassPermissions));
151    }
152
153    #[test]
154    fn parse_permission_mode_case_insensitive() {
155        let json = r#""AUTO""#;
156        let mode: PermissionMode = serde_json::from_str(json).unwrap();
157        assert!(matches!(mode, PermissionMode::Auto));
158
159        let json = r#""DONT_ASK""#;
160        let mode: PermissionMode = serde_json::from_str(json).unwrap();
161        assert!(matches!(mode, PermissionMode::DontAsk));
162    }
163
164    #[test]
165    fn parse_permission_mode_unknown_defaults() {
166        let json = r#""unknown""#;
167        let mode: PermissionMode = serde_json::from_str(json).unwrap();
168        assert!(matches!(mode, PermissionMode::Default));
169    }
170}