ironflow_engine/executor/
agent.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::json;
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::Agent;
11use ironflow_core::provider::{AgentConfig, AgentProvider};
12
13use crate::error::EngineError;
14
15use super::{StepExecutor, StepOutput};
16
17pub struct AgentExecutor<'a> {
22 config: &'a AgentConfig,
23}
24
25impl<'a> AgentExecutor<'a> {
26 pub fn new(config: &'a AgentConfig) -> Self {
28 Self { config }
29 }
30}
31
32impl StepExecutor for AgentExecutor<'_> {
33 async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
34 let start = Instant::now();
35
36 if self.config.json_schema.is_some() && self.config.max_turns == Some(1) {
37 warn!(
38 "structured output (json_schema) requires max_turns >= 2; \
39 max_turns is set to 1, the agent will likely fail with error_max_turns"
40 );
41 }
42
43 let result = Agent::from_config(self.config.clone())
44 .run(provider.as_ref())
45 .await?;
46
47 let duration_ms = start.elapsed().as_millis() as u64;
48 let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
49 let input_tokens = result.input_tokens();
50 let output_tokens = result.output_tokens();
51
52 info!(
53 step_kind = "agent",
54 model = %self.config.model,
55 cost_usd = %cost,
56 input_tokens = ?input_tokens,
57 output_tokens = ?output_tokens,
58 duration_ms,
59 "agent step completed"
60 );
61
62 #[cfg(feature = "prometheus")]
63 {
64 use ironflow_core::metric_names::{
65 AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
66 AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
67 };
68 use metrics::{counter, gauge, histogram};
69 let model_label = self.config.model.clone();
70 counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
71 .increment(1);
72 histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
73 .record(duration_ms as f64 / 1000.0);
74 gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
75 .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
76 if let Some(inp) = input_tokens {
77 counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
78 }
79 if let Some(out) = output_tokens {
80 counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
81 }
82 }
83
84 let debug_messages = result.debug_messages().map(|msgs| msgs.to_vec());
85
86 Ok(StepOutput {
87 output: json!({
88 "value": result.value(),
89 "model": result.model(),
90 }),
91 duration_ms,
92 cost_usd: cost,
93 input_tokens,
94 output_tokens,
95 debug_messages,
96 })
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use ironflow_core::operations::agent::PermissionMode;
103
104 #[test]
105 fn parse_permission_mode_via_serde() {
106 let json = r#""auto""#;
107 let mode: PermissionMode = serde_json::from_str(json).unwrap();
108 assert!(matches!(mode, PermissionMode::Auto));
109 }
110
111 #[test]
112 fn parse_permission_mode_dont_ask() {
113 let json = r#""dont_ask""#;
114 let mode: PermissionMode = serde_json::from_str(json).unwrap();
115 assert!(matches!(mode, PermissionMode::DontAsk));
116 }
117
118 #[test]
119 fn parse_permission_mode_bypass() {
120 let json = r#""bypass""#;
121 let mode: PermissionMode = serde_json::from_str(json).unwrap();
122 assert!(matches!(mode, PermissionMode::BypassPermissions));
123 }
124
125 #[test]
126 fn parse_permission_mode_case_insensitive() {
127 let json = r#""AUTO""#;
128 let mode: PermissionMode = serde_json::from_str(json).unwrap();
129 assert!(matches!(mode, PermissionMode::Auto));
130
131 let json = r#""DONT_ASK""#;
132 let mode: PermissionMode = serde_json::from_str(json).unwrap();
133 assert!(matches!(mode, PermissionMode::DontAsk));
134 }
135
136 #[test]
137 fn parse_permission_mode_unknown_defaults() {
138 let json = r#""unknown""#;
139 let mode: PermissionMode = serde_json::from_str(json).unwrap();
140 assert!(matches!(mode, PermissionMode::Default));
141 }
142}