ironflow_engine/executor/
agent.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use tracing::{info, warn};
8
9use ironflow_core::operations::agent::Agent;
10use ironflow_core::provider::{AgentConfig, AgentProvider, LogSink};
11
12use crate::error::EngineError;
13use crate::log_sender::StepLogSender;
14use crate::notify::LogStream;
15
16use super::{StepExecutor, StepOutput};
17
18pub struct AgentExecutor<'a> {
24 config: &'a AgentConfig,
25 log_sender: Option<StepLogSender>,
26}
27
28impl<'a> AgentExecutor<'a> {
29 pub fn new(config: &'a AgentConfig) -> Self {
31 Self {
32 config,
33 log_sender: None,
34 }
35 }
36
37 pub fn with_log_sender(mut self, sender: StepLogSender) -> Self {
39 self.log_sender = Some(sender);
40 self
41 }
42}
43
44impl StepExecutor for AgentExecutor<'_> {
45 async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
46 let start = Instant::now();
47
48 if let Some(ref sender) = self.log_sender {
49 sender.emit(
50 LogStream::System,
51 &format!("agent step started (model={})", self.config.model),
52 );
53 }
54
55 if self.config.json_schema.is_some() && self.config.max_turns == Some(1) {
56 warn!(
57 "structured output (json_schema) requires max_turns >= 2; \
58 max_turns is set to 1, the agent will likely fail with error_max_turns"
59 );
60 }
61
62 let mut agent = Agent::from_config(self.config.clone());
63 if let Some(ref sender) = self.log_sender {
64 agent = agent.log_sink(Arc::new(sender.clone()) as Arc<dyn LogSink>);
65 }
66 let result = agent.run(provider.as_ref()).await?;
67
68 let duration_ms = start.elapsed().as_millis() as u64;
69 let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
70 let input_tokens = result.input_tokens();
71 let output_tokens = result.output_tokens();
72
73 info!(
74 step_kind = "agent",
75 model = %self.config.model,
76 cost_usd = %cost,
77 input_tokens = ?input_tokens,
78 output_tokens = ?output_tokens,
79 duration_ms,
80 "agent step completed"
81 );
82
83 #[cfg(feature = "prometheus")]
84 {
85 use ironflow_core::metric_names::{
86 AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
87 AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
88 };
89 use metrics::{counter, gauge, histogram};
90 let model_label = self.config.model.clone();
91 counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
92 .increment(1);
93 histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
94 .record(duration_ms as f64 / 1000.0);
95 gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
96 .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
97 if let Some(inp) = input_tokens {
98 counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
99 }
100 if let Some(out) = output_tokens {
101 counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
102 }
103 }
104
105 if let Some(ref sender) = self.log_sender {
106 sender.emit(
107 LogStream::System,
108 &format!(
109 "agent step completed (cost=${cost}, tokens_in={}, tokens_out={})",
110 input_tokens.unwrap_or(0),
111 output_tokens.unwrap_or(0),
112 ),
113 );
114 }
115
116 let debug_messages = result.debug_messages().map(|msgs| msgs.to_vec());
117
118 Ok(StepOutput {
119 output: result.value().clone(),
120 duration_ms,
121 cost_usd: cost,
122 input_tokens,
123 output_tokens,
124 model: result.model().map(String::from),
125 debug_messages,
126 })
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use ironflow_core::operations::agent::PermissionMode;
133
134 #[test]
135 fn parse_permission_mode_via_serde() {
136 let json = r#""auto""#;
137 let mode: PermissionMode = serde_json::from_str(json).unwrap();
138 assert!(matches!(mode, PermissionMode::Auto));
139 }
140
141 #[test]
142 fn parse_permission_mode_dont_ask() {
143 let json = r#""dont_ask""#;
144 let mode: PermissionMode = serde_json::from_str(json).unwrap();
145 assert!(matches!(mode, PermissionMode::DontAsk));
146 }
147
148 #[test]
149 fn parse_permission_mode_bypass() {
150 let json = r#""bypass""#;
151 let mode: PermissionMode = serde_json::from_str(json).unwrap();
152 assert!(matches!(mode, PermissionMode::BypassPermissions));
153 }
154
155 #[test]
156 fn parse_permission_mode_case_insensitive() {
157 let json = r#""AUTO""#;
158 let mode: PermissionMode = serde_json::from_str(json).unwrap();
159 assert!(matches!(mode, PermissionMode::Auto));
160
161 let json = r#""DONT_ASK""#;
162 let mode: PermissionMode = serde_json::from_str(json).unwrap();
163 assert!(matches!(mode, PermissionMode::DontAsk));
164 }
165
166 #[test]
167 fn parse_permission_mode_unknown_defaults() {
168 let json = r#""unknown""#;
169 let mode: PermissionMode = serde_json::from_str(json).unwrap();
170 assert!(matches!(mode, PermissionMode::Default));
171 }
172}