ironflow_engine/executor/
agent.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::{Value, json};
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::Agent;
11use ironflow_core::provider::{AgentConfig, AgentProvider};
12
13use crate::error::EngineError;
14use crate::log_sender::StepLogSender;
15use crate::notify::LogStream;
16
17use super::{StepExecutor, StepOutput};
18
19fn format_agent_output(value: &Value, model: Option<&str>, has_schema: bool) -> Value {
36 if has_schema {
37 value.clone()
38 } else {
39 json!({
40 "value": value,
41 "model": model,
42 })
43 }
44}
45
46pub struct AgentExecutor<'a> {
52 config: &'a AgentConfig,
53 log_sender: Option<StepLogSender>,
54}
55
56impl<'a> AgentExecutor<'a> {
57 pub fn new(config: &'a AgentConfig) -> Self {
59 Self {
60 config,
61 log_sender: None,
62 }
63 }
64
65 pub fn with_log_sender(mut self, sender: StepLogSender) -> Self {
67 self.log_sender = Some(sender);
68 self
69 }
70}
71
72impl StepExecutor for AgentExecutor<'_> {
73 async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
74 let start = Instant::now();
75
76 if let Some(ref sender) = self.log_sender {
77 sender.emit(
78 LogStream::System,
79 &format!("agent step started (model={})", self.config.model),
80 );
81 }
82
83 if self.config.json_schema.is_some() && self.config.max_turns == Some(1) {
84 warn!(
85 "structured output (json_schema) requires max_turns >= 2; \
86 max_turns is set to 1, the agent will likely fail with error_max_turns"
87 );
88 }
89
90 let result = Agent::from_config(self.config.clone())
91 .run(provider.as_ref())
92 .await?;
93
94 let duration_ms = start.elapsed().as_millis() as u64;
95 let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
96 let input_tokens = result.input_tokens();
97 let output_tokens = result.output_tokens();
98
99 info!(
100 step_kind = "agent",
101 model = %self.config.model,
102 cost_usd = %cost,
103 input_tokens = ?input_tokens,
104 output_tokens = ?output_tokens,
105 duration_ms,
106 "agent step completed"
107 );
108
109 #[cfg(feature = "prometheus")]
110 {
111 use ironflow_core::metric_names::{
112 AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
113 AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
114 };
115 use metrics::{counter, gauge, histogram};
116 let model_label = self.config.model.clone();
117 counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
118 .increment(1);
119 histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
120 .record(duration_ms as f64 / 1000.0);
121 gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
122 .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
123 if let Some(inp) = input_tokens {
124 counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
125 }
126 if let Some(out) = output_tokens {
127 counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
128 }
129 }
130
131 if let Some(ref sender) = self.log_sender {
132 sender.emit(
133 LogStream::System,
134 &format!(
135 "agent step completed (cost=${cost}, tokens_in={}, tokens_out={})",
136 input_tokens.unwrap_or(0),
137 output_tokens.unwrap_or(0),
138 ),
139 );
140 }
141
142 let debug_messages = result.debug_messages().map(|msgs| msgs.to_vec());
143
144 let output = format_agent_output(
145 result.value(),
146 result.model(),
147 self.config.json_schema.is_some(),
148 );
149
150 Ok(StepOutput {
151 output,
152 duration_ms,
153 cost_usd: cost,
154 input_tokens,
155 output_tokens,
156 debug_messages,
157 })
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use ironflow_core::operations::agent::PermissionMode;
164
165 #[test]
166 fn parse_permission_mode_via_serde() {
167 let json = r#""auto""#;
168 let mode: PermissionMode = serde_json::from_str(json).unwrap();
169 assert!(matches!(mode, PermissionMode::Auto));
170 }
171
172 #[test]
173 fn parse_permission_mode_dont_ask() {
174 let json = r#""dont_ask""#;
175 let mode: PermissionMode = serde_json::from_str(json).unwrap();
176 assert!(matches!(mode, PermissionMode::DontAsk));
177 }
178
179 #[test]
180 fn parse_permission_mode_bypass() {
181 let json = r#""bypass""#;
182 let mode: PermissionMode = serde_json::from_str(json).unwrap();
183 assert!(matches!(mode, PermissionMode::BypassPermissions));
184 }
185
186 #[test]
187 fn parse_permission_mode_case_insensitive() {
188 let json = r#""AUTO""#;
189 let mode: PermissionMode = serde_json::from_str(json).unwrap();
190 assert!(matches!(mode, PermissionMode::Auto));
191
192 let json = r#""DONT_ASK""#;
193 let mode: PermissionMode = serde_json::from_str(json).unwrap();
194 assert!(matches!(mode, PermissionMode::DontAsk));
195 }
196
197 #[test]
198 fn parse_permission_mode_unknown_defaults() {
199 let json = r#""unknown""#;
200 let mode: PermissionMode = serde_json::from_str(json).unwrap();
201 assert!(matches!(mode, PermissionMode::Default));
202 }
203
204 #[test]
205 fn structured_output_not_wrapped_in_value_model() {
206 use serde::Deserialize;
207 use serde_json::json;
208
209 use super::format_agent_output;
210
211 #[derive(Deserialize, Debug, PartialEq)]
212 struct TechDigest {
213 items: Vec<String>,
214 }
215
216 let structured = json!({"items": ["news1", "news2"]});
217 let output = format_agent_output(&structured, Some("sonnet"), true);
218
219 let digest: TechDigest = serde_json::from_value(output).unwrap();
220 assert_eq!(digest.items, vec!["news1", "news2"]);
221 }
222
223 #[test]
224 fn text_output_wrapped_in_value_model() {
225 use serde_json::json;
226
227 use super::format_agent_output;
228
229 let text_value = json!("Hello, world!");
230 let output = format_agent_output(&text_value, Some("sonnet"), false);
231
232 assert_eq!(output["value"], "Hello, world!");
233 assert_eq!(output["model"], "sonnet");
234 }
235}