Skip to main content

ironflow_engine/executor/
agent.rs

1//! Agent step executor.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use tracing::{info, warn};
8
9use ironflow_core::operations::agent::Agent;
10use ironflow_core::provider::{AgentConfig, AgentProvider, LogSink};
11
12use crate::error::EngineError;
13use crate::log_sender::StepLogSender;
14use crate::notify::LogStream;
15
16use super::{StepExecutor, StepOutput};
17
18/// Executor for agent (AI) steps.
19///
20/// Runs an AI agent with the given prompt and configuration, capturing
21/// the response value, cost, and token counts. When a [`StepLogSender`]
22/// is attached, emits system log lines for step start/end.
23pub struct AgentExecutor<'a> {
24    config: &'a AgentConfig,
25    log_sender: Option<StepLogSender>,
26}
27
28impl<'a> AgentExecutor<'a> {
29    /// Create a new agent executor from a config reference.
30    pub fn new(config: &'a AgentConfig) -> Self {
31        Self {
32            config,
33            log_sender: None,
34        }
35    }
36
37    /// Attach a log sender for system-level log lines.
38    pub fn with_log_sender(mut self, sender: StepLogSender) -> Self {
39        self.log_sender = Some(sender);
40        self
41    }
42}
43
44impl StepExecutor for AgentExecutor<'_> {
45    async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
46        let start = Instant::now();
47
48        if let Some(ref sender) = self.log_sender {
49            sender.emit(
50                LogStream::System,
51                &format!("agent step started (model={})", self.config.model),
52            );
53        }
54
55        if self.config.json_schema.is_some() && self.config.max_turns == Some(1) {
56            warn!(
57                "structured output (json_schema) requires max_turns >= 2; \
58                 max_turns is set to 1, the agent will likely fail with error_max_turns"
59            );
60        }
61
62        let mut agent = Agent::from_config(self.config.clone());
63        if let Some(ref sender) = self.log_sender {
64            agent = agent.log_sink(Arc::new(sender.clone()) as Arc<dyn LogSink>);
65        }
66        let result = agent.run(provider.as_ref()).await?;
67
68        let duration_ms = start.elapsed().as_millis() as u64;
69        let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
70        let input_tokens = result.input_tokens();
71        let output_tokens = result.output_tokens();
72
73        info!(
74            step_kind = "agent",
75            model = %self.config.model,
76            cost_usd = %cost,
77            input_tokens = ?input_tokens,
78            output_tokens = ?output_tokens,
79            duration_ms,
80            "agent step completed"
81        );
82
83        #[cfg(feature = "prometheus")]
84        {
85            use ironflow_core::metric_names::{
86                AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
87                AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
88            };
89            use metrics::{counter, gauge, histogram};
90            let model_label = self.config.model.clone();
91            counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
92                .increment(1);
93            histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
94                .record(duration_ms as f64 / 1000.0);
95            gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
96                .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
97            if let Some(inp) = input_tokens {
98                counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
99            }
100            if let Some(out) = output_tokens {
101                counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
102            }
103        }
104
105        if let Some(ref sender) = self.log_sender {
106            sender.emit(
107                LogStream::System,
108                &format!(
109                    "agent step completed (cost=${cost}, tokens_in={}, tokens_out={})",
110                    input_tokens.unwrap_or(0),
111                    output_tokens.unwrap_or(0),
112                ),
113            );
114        }
115
116        let debug_messages = result.debug_messages().map(|msgs| msgs.to_vec());
117
118        Ok(StepOutput {
119            output: result.value().clone(),
120            duration_ms,
121            cost_usd: cost,
122            input_tokens,
123            output_tokens,
124            model: result.model().map(String::from),
125            debug_messages,
126        })
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use ironflow_core::operations::agent::PermissionMode;
133
134    #[test]
135    fn parse_permission_mode_via_serde() {
136        let json = r#""auto""#;
137        let mode: PermissionMode = serde_json::from_str(json).unwrap();
138        assert!(matches!(mode, PermissionMode::Auto));
139    }
140
141    #[test]
142    fn parse_permission_mode_dont_ask() {
143        let json = r#""dont_ask""#;
144        let mode: PermissionMode = serde_json::from_str(json).unwrap();
145        assert!(matches!(mode, PermissionMode::DontAsk));
146    }
147
148    #[test]
149    fn parse_permission_mode_bypass() {
150        let json = r#""bypass""#;
151        let mode: PermissionMode = serde_json::from_str(json).unwrap();
152        assert!(matches!(mode, PermissionMode::BypassPermissions));
153    }
154
155    #[test]
156    fn parse_permission_mode_case_insensitive() {
157        let json = r#""AUTO""#;
158        let mode: PermissionMode = serde_json::from_str(json).unwrap();
159        assert!(matches!(mode, PermissionMode::Auto));
160
161        let json = r#""DONT_ASK""#;
162        let mode: PermissionMode = serde_json::from_str(json).unwrap();
163        assert!(matches!(mode, PermissionMode::DontAsk));
164    }
165
166    #[test]
167    fn parse_permission_mode_unknown_defaults() {
168        let json = r#""unknown""#;
169        let mode: PermissionMode = serde_json::from_str(json).unwrap();
170        assert!(matches!(mode, PermissionMode::Default));
171    }
172}