ironflow_engine/executor/
agent.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::{Value, json};
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::Agent;
11use ironflow_core::provider::{AgentConfig, AgentProvider};
12
13use crate::error::EngineError;
14
15use super::{StepExecutor, StepOutput};
16
17fn format_agent_output(value: &Value, model: Option<&str>, has_schema: bool) -> Value {
34 if has_schema {
35 value.clone()
36 } else {
37 json!({
38 "value": value,
39 "model": model,
40 })
41 }
42}
43
44pub struct AgentExecutor<'a> {
49 config: &'a AgentConfig,
50}
51
52impl<'a> AgentExecutor<'a> {
53 pub fn new(config: &'a AgentConfig) -> Self {
55 Self { config }
56 }
57}
58
59impl StepExecutor for AgentExecutor<'_> {
60 async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
61 let start = Instant::now();
62
63 if self.config.json_schema.is_some() && self.config.max_turns == Some(1) {
64 warn!(
65 "structured output (json_schema) requires max_turns >= 2; \
66 max_turns is set to 1, the agent will likely fail with error_max_turns"
67 );
68 }
69
70 let result = Agent::from_config(self.config.clone())
71 .run(provider.as_ref())
72 .await?;
73
74 let duration_ms = start.elapsed().as_millis() as u64;
75 let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
76 let input_tokens = result.input_tokens();
77 let output_tokens = result.output_tokens();
78
79 info!(
80 step_kind = "agent",
81 model = %self.config.model,
82 cost_usd = %cost,
83 input_tokens = ?input_tokens,
84 output_tokens = ?output_tokens,
85 duration_ms,
86 "agent step completed"
87 );
88
89 #[cfg(feature = "prometheus")]
90 {
91 use ironflow_core::metric_names::{
92 AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
93 AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
94 };
95 use metrics::{counter, gauge, histogram};
96 let model_label = self.config.model.clone();
97 counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
98 .increment(1);
99 histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
100 .record(duration_ms as f64 / 1000.0);
101 gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
102 .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
103 if let Some(inp) = input_tokens {
104 counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
105 }
106 if let Some(out) = output_tokens {
107 counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
108 }
109 }
110
111 let debug_messages = result.debug_messages().map(|msgs| msgs.to_vec());
112
113 let output = format_agent_output(
114 result.value(),
115 result.model(),
116 self.config.json_schema.is_some(),
117 );
118
119 Ok(StepOutput {
120 output,
121 duration_ms,
122 cost_usd: cost,
123 input_tokens,
124 output_tokens,
125 debug_messages,
126 })
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use ironflow_core::operations::agent::PermissionMode;
133
134 #[test]
135 fn parse_permission_mode_via_serde() {
136 let json = r#""auto""#;
137 let mode: PermissionMode = serde_json::from_str(json).unwrap();
138 assert!(matches!(mode, PermissionMode::Auto));
139 }
140
141 #[test]
142 fn parse_permission_mode_dont_ask() {
143 let json = r#""dont_ask""#;
144 let mode: PermissionMode = serde_json::from_str(json).unwrap();
145 assert!(matches!(mode, PermissionMode::DontAsk));
146 }
147
148 #[test]
149 fn parse_permission_mode_bypass() {
150 let json = r#""bypass""#;
151 let mode: PermissionMode = serde_json::from_str(json).unwrap();
152 assert!(matches!(mode, PermissionMode::BypassPermissions));
153 }
154
155 #[test]
156 fn parse_permission_mode_case_insensitive() {
157 let json = r#""AUTO""#;
158 let mode: PermissionMode = serde_json::from_str(json).unwrap();
159 assert!(matches!(mode, PermissionMode::Auto));
160
161 let json = r#""DONT_ASK""#;
162 let mode: PermissionMode = serde_json::from_str(json).unwrap();
163 assert!(matches!(mode, PermissionMode::DontAsk));
164 }
165
166 #[test]
167 fn parse_permission_mode_unknown_defaults() {
168 let json = r#""unknown""#;
169 let mode: PermissionMode = serde_json::from_str(json).unwrap();
170 assert!(matches!(mode, PermissionMode::Default));
171 }
172
173 #[test]
174 fn structured_output_not_wrapped_in_value_model() {
175 use serde::Deserialize;
176 use serde_json::json;
177
178 use super::format_agent_output;
179
180 #[derive(Deserialize, Debug, PartialEq)]
181 struct TechDigest {
182 items: Vec<String>,
183 }
184
185 let structured = json!({"items": ["news1", "news2"]});
186 let output = format_agent_output(&structured, Some("sonnet"), true);
187
188 let digest: TechDigest = serde_json::from_value(output).unwrap();
189 assert_eq!(digest.items, vec!["news1", "news2"]);
190 }
191
192 #[test]
193 fn text_output_wrapped_in_value_model() {
194 use serde_json::json;
195
196 use super::format_agent_output;
197
198 let text_value = json!("Hello, world!");
199 let output = format_agent_output(&text_value, Some("sonnet"), false);
200
201 assert_eq!(output["value"], "Hello, world!");
202 assert_eq!(output["model"], "sonnet");
203 }
204}