Skip to main content

ironflow_engine/executor/
agent.rs

1//! Agent step executor.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use rust_decimal::Decimal;
7use serde_json::{Value, json};
8use tracing::{info, warn};
9
10use ironflow_core::operations::agent::Agent;
11use ironflow_core::provider::{AgentConfig, AgentProvider};
12
13use crate::error::EngineError;
14
15use super::{StepExecutor, StepOutput};
16
17/// Format the agent output value for [`StepOutput`].
18///
19/// When a JSON schema was requested (`has_schema = true`), the structured
20/// value is returned directly so that callers can deserialize it as `T`.
21/// Otherwise the value is wrapped in `{"value": ..., "model": ...}` for
22/// backward compatibility with text-mode consumers.
23///
24/// **Note:** the structured value passed here may not strictly conform to
25/// the requested schema. Claude CLI can non-deterministically flatten
26/// wrapper objects with a single array field, returning a bare array
27/// instead of `{"items": [...]}`. See upstream issues:
28/// - <https://github.com/anthropics/claude-agent-sdk-python/issues/502>
29/// - <https://github.com/anthropics/claude-agent-sdk-python/issues/374>
30///
31/// Callers that deserialize the output should handle both the expected
32/// wrapper and a bare array/value as a fallback.
33fn format_agent_output(value: &Value, model: Option<&str>, has_schema: bool) -> Value {
34    if has_schema {
35        value.clone()
36    } else {
37        json!({
38            "value": value,
39            "model": model,
40        })
41    }
42}
43
44/// Executor for agent (AI) steps.
45///
46/// Runs an AI agent with the given prompt and configuration, capturing
47/// the response value, cost, and token counts.
48pub struct AgentExecutor<'a> {
49    config: &'a AgentConfig,
50}
51
52impl<'a> AgentExecutor<'a> {
53    /// Create a new agent executor from a config reference.
54    pub fn new(config: &'a AgentConfig) -> Self {
55        Self { config }
56    }
57}
58
59impl StepExecutor for AgentExecutor<'_> {
60    async fn execute(&self, provider: &Arc<dyn AgentProvider>) -> Result<StepOutput, EngineError> {
61        let start = Instant::now();
62
63        if self.config.json_schema.is_some() && self.config.max_turns == Some(1) {
64            warn!(
65                "structured output (json_schema) requires max_turns >= 2; \
66                 max_turns is set to 1, the agent will likely fail with error_max_turns"
67            );
68        }
69
70        let result = Agent::from_config(self.config.clone())
71            .run(provider.as_ref())
72            .await?;
73
74        let duration_ms = start.elapsed().as_millis() as u64;
75        let cost = Decimal::try_from(result.cost_usd().unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
76        let input_tokens = result.input_tokens();
77        let output_tokens = result.output_tokens();
78
79        info!(
80            step_kind = "agent",
81            model = %self.config.model,
82            cost_usd = %cost,
83            input_tokens = ?input_tokens,
84            output_tokens = ?output_tokens,
85            duration_ms,
86            "agent step completed"
87        );
88
89        #[cfg(feature = "prometheus")]
90        {
91            use ironflow_core::metric_names::{
92                AGENT_COST_USD_TOTAL, AGENT_DURATION_SECONDS, AGENT_TOKENS_INPUT_TOTAL,
93                AGENT_TOKENS_OUTPUT_TOTAL, AGENT_TOTAL, STATUS_SUCCESS,
94            };
95            use metrics::{counter, gauge, histogram};
96            let model_label = self.config.model.clone();
97            counter!(AGENT_TOTAL, "model" => model_label.clone(), "status" => STATUS_SUCCESS)
98                .increment(1);
99            histogram!(AGENT_DURATION_SECONDS, "model" => model_label.clone())
100                .record(duration_ms as f64 / 1000.0);
101            gauge!(AGENT_COST_USD_TOTAL, "model" => model_label.clone())
102                .increment(cost.to_string().parse::<f64>().unwrap_or(0.0));
103            if let Some(inp) = input_tokens {
104                counter!(AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(inp);
105            }
106            if let Some(out) = output_tokens {
107                counter!(AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label).increment(out);
108            }
109        }
110
111        let debug_messages = result.debug_messages().map(|msgs| msgs.to_vec());
112
113        let output = format_agent_output(
114            result.value(),
115            result.model(),
116            self.config.json_schema.is_some(),
117        );
118
119        Ok(StepOutput {
120            output,
121            duration_ms,
122            cost_usd: cost,
123            input_tokens,
124            output_tokens,
125            debug_messages,
126        })
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use ironflow_core::operations::agent::PermissionMode;
133
134    #[test]
135    fn parse_permission_mode_via_serde() {
136        let json = r#""auto""#;
137        let mode: PermissionMode = serde_json::from_str(json).unwrap();
138        assert!(matches!(mode, PermissionMode::Auto));
139    }
140
141    #[test]
142    fn parse_permission_mode_dont_ask() {
143        let json = r#""dont_ask""#;
144        let mode: PermissionMode = serde_json::from_str(json).unwrap();
145        assert!(matches!(mode, PermissionMode::DontAsk));
146    }
147
148    #[test]
149    fn parse_permission_mode_bypass() {
150        let json = r#""bypass""#;
151        let mode: PermissionMode = serde_json::from_str(json).unwrap();
152        assert!(matches!(mode, PermissionMode::BypassPermissions));
153    }
154
155    #[test]
156    fn parse_permission_mode_case_insensitive() {
157        let json = r#""AUTO""#;
158        let mode: PermissionMode = serde_json::from_str(json).unwrap();
159        assert!(matches!(mode, PermissionMode::Auto));
160
161        let json = r#""DONT_ASK""#;
162        let mode: PermissionMode = serde_json::from_str(json).unwrap();
163        assert!(matches!(mode, PermissionMode::DontAsk));
164    }
165
166    #[test]
167    fn parse_permission_mode_unknown_defaults() {
168        let json = r#""unknown""#;
169        let mode: PermissionMode = serde_json::from_str(json).unwrap();
170        assert!(matches!(mode, PermissionMode::Default));
171    }
172
173    #[test]
174    fn structured_output_not_wrapped_in_value_model() {
175        use serde::Deserialize;
176        use serde_json::json;
177
178        use super::format_agent_output;
179
180        #[derive(Deserialize, Debug, PartialEq)]
181        struct TechDigest {
182            items: Vec<String>,
183        }
184
185        let structured = json!({"items": ["news1", "news2"]});
186        let output = format_agent_output(&structured, Some("sonnet"), true);
187
188        let digest: TechDigest = serde_json::from_value(output).unwrap();
189        assert_eq!(digest.items, vec!["news1", "news2"]);
190    }
191
192    #[test]
193    fn text_output_wrapped_in_value_model() {
194        use serde_json::json;
195
196        use super::format_agent_output;
197
198        let text_value = json!("Hello, world!");
199        let output = format_agent_output(&text_value, Some("sonnet"), false);
200
201        assert_eq!(output["value"], "Hello, world!");
202        assert_eq!(output["model"], "sonnet");
203    }
204}