Skip to main content

ironflow_core/operations/
agent.rs

1//! Agent operation - build and execute AI agent calls.
2//!
3//! The [`Agent`] builder lets you configure a single agent invocation (model,
4//! prompt, tools, budget, permissions, etc.) and execute it through any
5//! [`AgentProvider`]. The result is an [`AgentResult`] that provides typed
6//! access to the agent's response, session metadata, and usage statistics.
7//!
8//! # Examples
9//!
10//! ```no_run
11//! use ironflow_core::prelude::*;
12//!
13//! # async fn example() -> Result<(), OperationError> {
14//! let provider = ClaudeCodeProvider::new();
15//!
16//! let result = Agent::new()
17//!     .prompt("Summarize the README.md file")
18//!     .model(Model::SONNET)
19//!     .max_turns(3)
20//!     .run(&provider)
21//!     .await?;
22//!
23//! println!("{}", result.text());
24//! # Ok(())
25//! # }
26//! ```
27
28use std::any;
29
30use schemars::{JsonSchema, schema_for};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, from_value, to_string};
34use tokio::time;
35use tracing::{info, warn};
36
37use crate::error::OperationError;
38#[cfg(feature = "prometheus")]
39use crate::metric_names;
40use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage};
41use crate::retry::RetryPolicy;
42
43/// Provider-agnostic model identifiers.
44///
45/// Constants are provided for well-known Claude models, but any string
46/// is accepted - custom [`AgentProvider`] implementations interpret the
47/// model identifier however they wish.
48///
49/// # Examples
50///
51/// ```no_run
52/// use ironflow_core::prelude::*;
53///
54/// # async fn example() -> Result<(), OperationError> {
55/// let provider = ClaudeCodeProvider::new();
56///
57/// // Using a built-in constant
58/// let r = Agent::new()
59///     .prompt("hi")
60///     .model(Model::SONNET)
61///     .run(&provider)
62///     .await?;
63///
64/// // Using a custom model string
65/// let r = Agent::new()
66///     .prompt("hi")
67///     .model("mistral-large-latest")
68///     .run(&provider)
69///     .await?;
70/// # Ok(())
71/// # }
72/// ```
73pub struct Model;
74
75impl Model {
76    // ── Aliases (latest version, CLI resolves to current) ───────────
77
78    /// Claude Sonnet - balanced speed and capability (default).
79    pub const SONNET: &str = "sonnet";
80    /// Claude Opus - highest capability.
81    pub const OPUS: &str = "opus";
82    /// Claude Haiku - fastest and cheapest.
83    pub const HAIKU: &str = "haiku";
84
85    // ── Claude 4.5 ─────────────────────────────────────────────────
86
87    /// Claude Haiku 4.5.
88    pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
89
90    // ── Claude 4.6 - 200K context ──────────────────────────────────
91
92    /// Claude Sonnet 4.6.
93    pub const SONNET_46: &str = "claude-sonnet-4-6";
94    /// Claude Opus 4.6.
95    pub const OPUS_46: &str = "claude-opus-4-6";
96
97    // ── Claude 4.6 - 1M context ────────────────────────────────────
98
99    /// Claude Sonnet 4.6 with 1M token context window.
100    pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
101    /// Claude Opus 4.6 with 1M token context window.
102    pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
103
104    // ── Claude 4.7 - 1M context native ─────────────────────────────
105
106    /// Claude Opus 4.7 - latest flagship, 1M token context native.
107    pub const OPUS_47: &str = "claude-opus-4-7";
108    /// Claude Opus 4.7 with 1M token context window explicit.
109    pub const OPUS_47_1M: &str = "claude-opus-4-7[1m]";
110}
111
112/// Controls how the agent handles tool-use permission prompts.
113///
114/// These map to the `--permission-mode` and `--dangerously-skip-permissions`
115/// flags in the Claude CLI.
116#[derive(Debug, Default, Clone, Copy, Serialize)]
117pub enum PermissionMode {
118    /// Use the CLI default permission behavior.
119    #[default]
120    Default,
121    /// Automatically approve tool-use requests.
122    Auto,
123    /// Suppress all permission prompts (the agent proceeds without asking).
124    DontAsk,
125    /// Skip all permission checks entirely.
126    ///
127    /// **Warning**: the agent will have unrestricted filesystem and shell access.
128    BypassPermissions,
129}
130
131impl<'de> Deserialize<'de> for PermissionMode {
132    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
133    where
134        D: serde::Deserializer<'de>,
135    {
136        let s = String::deserialize(deserializer)?;
137        Ok(match s.to_lowercase().replace('_', "").as_str() {
138            "auto" => Self::Auto,
139            "dontask" => Self::DontAsk,
140            "bypass" | "bypasspermissions" => Self::BypassPermissions,
141            _ => Self::Default,
142        })
143    }
144}
145
146/// Builder for a single agent invocation.
147///
148/// Create with [`Agent::new`], chain configuration methods, then call
149/// [`run`](Agent::run) with an [`AgentProvider`] to execute.
150///
151/// # Examples
152///
153/// ```no_run
154/// use ironflow_core::prelude::*;
155///
156/// # async fn example() -> Result<(), OperationError> {
157/// let provider = ClaudeCodeProvider::new();
158///
159/// let result = Agent::new()
160///     .system_prompt("You are a Rust expert.")
161///     .prompt("Review this code for safety issues.")
162///     .model(Model::OPUS)
163///     .allowed_tools(&["Read", "Grep"])
164///     .max_turns(5)
165///     .max_budget_usd(0.50)
166///     .working_dir("/tmp/project")
167///     .permission_mode(PermissionMode::Auto)
168///     .run(&provider)
169///     .await?;
170///
171/// println!("Cost: ${:.4}", result.cost_usd().unwrap_or(0.0));
172/// # Ok(())
173/// # }
174/// ```
175#[must_use = "an Agent does nothing until .run() is awaited"]
176pub struct Agent {
177    config: AgentConfig,
178    dry_run: Option<bool>,
179    retry_policy: Option<RetryPolicy>,
180}
181
182impl Agent {
183    /// Create a new agent builder with default settings.
184    ///
185    /// Defaults: [`Model::SONNET`], no system prompt, no tool restrictions,
186    /// no budget/turn limits, [`PermissionMode::Default`].
187    pub fn new() -> Self {
188        Self {
189            config: AgentConfig::new(""),
190            dry_run: None,
191            retry_policy: None,
192        }
193    }
194
195    /// Create an agent builder from an existing [`AgentConfig`].
196    ///
197    /// Useful when the config comes from a serialized workflow definition
198    /// rather than being built programmatically.
199    ///
200    /// # Examples
201    ///
202    /// ```no_run
203    /// use ironflow_core::prelude::*;
204    /// use ironflow_core::provider::AgentConfig;
205    ///
206    /// # async fn example() -> Result<(), OperationError> {
207    /// let provider = ClaudeCodeProvider::new();
208    /// let config = AgentConfig::new("Summarize the README");
209    /// let result = Agent::from_config(config).run(&provider).await?;
210    /// # Ok(())
211    /// # }
212    /// ```
213    pub fn from_config(config: impl Into<AgentConfig>) -> Self {
214        Self {
215            config: config.into(),
216            dry_run: None,
217            retry_policy: None,
218        }
219    }
220
221    /// Set the system prompt that defines the agent's persona or constraints.
222    pub fn system_prompt(mut self, prompt: &str) -> Self {
223        self.config.system_prompt = Some(prompt.to_string());
224        self
225    }
226
227    /// Set the user prompt - the main instruction sent to the agent.
228    pub fn prompt(mut self, prompt: &str) -> Self {
229        self.config.prompt = prompt.to_string();
230        self
231    }
232
233    /// Set the model to use for this invocation.
234    ///
235    /// Accepts any string-like value. Use [`Model`] constants for well-known
236    /// Claude models, or pass an arbitrary string for custom providers.
237    ///
238    /// Defaults to [`Model::SONNET`] if not called.
239    pub fn model(mut self, model: impl Into<String>) -> Self {
240        self.config.model = model.into();
241        self
242    }
243
244    /// Restrict which tools the agent may invoke.
245    ///
246    /// Pass an empty slice (or do not call this method) to allow the provider
247    /// default set of tools.
248    pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
249        self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
250        self
251    }
252
253    /// Set the maximum number of agentic turns.
254    ///
255    /// # Panics
256    ///
257    /// Panics if `turns` is `0`.
258    pub fn max_turns(mut self, turns: u32) -> Self {
259        assert!(turns > 0, "max_turns must be greater than 0");
260        self.config.max_turns = Some(turns);
261        self
262    }
263
264    /// Set the maximum spend in USD for this invocation.
265    ///
266    /// # Panics
267    ///
268    /// Panics if `budget` is negative, NaN, or infinity.
269    pub fn max_budget_usd(mut self, budget: f64) -> Self {
270        assert!(
271            budget.is_finite() && budget > 0.0,
272            "budget must be a positive finite number, got {budget}"
273        );
274        self.config.max_budget_usd = Some(budget);
275        self
276    }
277
278    /// Set the working directory for the agent process.
279    pub fn working_dir(mut self, dir: &str) -> Self {
280        self.config.working_dir = Some(dir.to_string());
281        self
282    }
283
284    /// Set the path to an MCP (Model Context Protocol) server configuration file.
285    pub fn mcp_config(mut self, config: &str) -> Self {
286        self.config.mcp_config = Some(config.to_string());
287        self
288    }
289
290    /// Set the permission mode controlling tool-use approval behavior.
291    ///
292    /// See [`PermissionMode`] for details on each variant.
293    pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
294        self.config.permission_mode = mode;
295        self
296    }
297
298    /// Request structured (typed) output from the agent.
299    ///
300    /// The type `T` must implement [`JsonSchema`]. The generated schema is sent
301    /// to the provider so the model returns JSON conforming to `T`, which can
302    /// then be deserialized with [`AgentResult::json`].
303    ///
304    /// # Examples
305    ///
306    /// ```no_run
307    /// use ironflow_core::prelude::*;
308    ///
309    /// #[derive(Deserialize, JsonSchema)]
310    /// struct Review {
311    ///     score: u8,
312    ///     summary: String,
313    /// }
314    ///
315    /// # async fn example() -> Result<(), OperationError> {
316    /// let provider = ClaudeCodeProvider::new();
317    /// let result = Agent::new()
318    ///     .prompt("Review the codebase")
319    ///     .output::<Review>()
320    ///     .run(&provider)
321    ///     .await?;
322    ///
323    /// let review: Review = result.json().expect("schema-validated output");
324    /// println!("Score: {}/10 - {}", review.score, review.summary);
325    /// # Ok(())
326    /// # }
327    /// ```
328    pub fn output<T: JsonSchema>(mut self) -> Self {
329        let schema = schema_for!(T);
330        self.config.json_schema = match to_string(&schema) {
331            Ok(s) => Some(s),
332            Err(e) => {
333                warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
334                None
335            }
336        };
337        self
338    }
339
340    /// Set structured output from a pre-serialized JSON Schema string.
341    ///
342    /// Use this when the schema comes from configuration or another source
343    /// rather than a Rust type. For type-safe schema generation, prefer
344    /// [`output`](Agent::output).
345    ///
346    /// **Important:** structured output requires `max_turns >= 2`. The Claude CLI
347    /// uses the first turn for reasoning and a second turn to produce the
348    /// schema-conforming JSON.
349    ///
350    /// # Examples
351    ///
352    /// ```no_run
353    /// use ironflow_core::prelude::*;
354    ///
355    /// # async fn example() -> Result<(), OperationError> {
356    /// let schema = r#"{"type":"object","properties":{"labels":{"type":"array","items":{"type":"string"}}}}"#;
357    /// let agent = Agent::new()
358    ///     .prompt("Classify this email")
359    ///     .output_schema_raw(schema);
360    /// # Ok(())
361    /// # }
362    /// ```
363    pub fn output_schema_raw(mut self, schema: &str) -> Self {
364        self.config.json_schema = Some(schema.to_string());
365        self
366    }
367
368    /// Retry the agent invocation up to `max_retries` times on transient failures.
369    ///
370    /// Uses default exponential backoff settings (200ms initial, 2x multiplier,
371    /// 30s cap). For custom backoff parameters, use [`retry_policy`](Agent::retry_policy).
372    ///
373    /// Only transient errors are retried: process failures and timeouts.
374    /// Deterministic errors (prompt too large, schema validation) are never retried.
375    ///
376    /// # Panics
377    ///
378    /// Panics if `max_retries` is `0`.
379    ///
380    /// # Examples
381    ///
382    /// ```no_run
383    /// use ironflow_core::prelude::*;
384    ///
385    /// # async fn example() -> Result<(), OperationError> {
386    /// let provider = ClaudeCodeProvider::new();
387    /// let result = Agent::new()
388    ///     .prompt("Summarize the codebase")
389    ///     .retry(2)
390    ///     .run(&provider)
391    ///     .await?;
392    /// # Ok(())
393    /// # }
394    /// ```
395    pub fn retry(mut self, max_retries: u32) -> Self {
396        self.retry_policy = Some(RetryPolicy::new(max_retries));
397        self
398    }
399
400    /// Set a custom [`RetryPolicy`] for this agent invocation.
401    ///
402    /// Allows full control over backoff duration, multiplier, and max delay.
403    /// See [`RetryPolicy`] for details.
404    ///
405    /// # Examples
406    ///
407    /// ```no_run
408    /// use std::time::Duration;
409    /// use ironflow_core::prelude::*;
410    /// use ironflow_core::retry::RetryPolicy;
411    ///
412    /// # async fn example() -> Result<(), OperationError> {
413    /// let provider = ClaudeCodeProvider::new();
414    /// let result = Agent::new()
415    ///     .prompt("Analyze the code")
416    ///     .retry_policy(
417    ///         RetryPolicy::new(3)
418    ///             .backoff(Duration::from_secs(1))
419    ///             .max_backoff(Duration::from_secs(60))
420    ///     )
421    ///     .run(&provider)
422    ///     .await?;
423    /// # Ok(())
424    /// # }
425    /// ```
426    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
427        self.retry_policy = Some(policy);
428        self
429    }
430
431    /// Enable or disable dry-run mode for this specific operation.
432    ///
433    /// When dry-run is active, the agent call is logged but not executed.
434    /// A synthetic [`AgentResult`] is returned with a placeholder text,
435    /// zero cost, and zero tokens.
436    ///
437    /// If not set, falls back to the global dry-run setting
438    /// (see [`set_dry_run`](crate::dry_run::set_dry_run)).
439    pub fn dry_run(mut self, enabled: bool) -> Self {
440        self.dry_run = Some(enabled);
441        self
442    }
443
444    /// Enable verbose/debug mode to capture the full conversation trace.
445    ///
446    /// When enabled, the provider captures every assistant message and tool
447    /// call into [`AgentResult::debug_messages`]. Useful for understanding
448    /// why the agent returned an unexpected result.
449    ///
450    /// # Examples
451    ///
452    /// ```no_run
453    /// use ironflow_core::prelude::*;
454    ///
455    /// # async fn example() -> Result<(), OperationError> {
456    /// let provider = ClaudeCodeProvider::new();
457    ///
458    /// let result = Agent::new()
459    ///     .prompt("Analyze src/")
460    ///     .verbose()
461    ///     .max_budget_usd(0.10)
462    ///     .run(&provider)
463    ///     .await?;
464    ///
465    /// if let Some(messages) = result.debug_messages() {
466    ///     for msg in messages {
467    ///         println!("{msg}");
468    ///     }
469    /// }
470    /// # Ok(())
471    /// # }
472    /// ```
473    pub fn verbose(mut self) -> Self {
474        self.config.verbose = true;
475        self
476    }
477
478    /// Resume a previous agent conversation by session ID.
479    ///
480    /// Pass the session ID from a previous [`AgentResult::session_id()`] to
481    /// continue the multi-turn conversation.
482    ///
483    /// # Examples
484    ///
485    /// ```no_run
486    /// use ironflow_core::prelude::*;
487    ///
488    /// # async fn example() -> Result<(), OperationError> {
489    /// let provider = ClaudeCodeProvider::new();
490    ///
491    /// let first = Agent::new()
492    ///     .prompt("Analyze the src/ directory")
493    ///     .max_budget_usd(0.10)
494    ///     .run(&provider)
495    ///     .await?;
496    ///
497    /// let session = first.session_id().expect("provider returned session ID");
498    ///
499    /// let followup = Agent::new()
500    ///     .prompt("Now suggest improvements")
501    ///     .resume(session)
502    ///     .max_budget_usd(0.10)
503    ///     .run(&provider)
504    ///     .await?;
505    /// # Ok(())
506    /// # }
507    /// ```
508    ///
509    /// # Panics
510    ///
511    /// Panics if `session_id` is empty or contains characters other than
512    /// alphanumerics, hyphens, and underscores.
513    pub fn resume(mut self, session_id: &str) -> Self {
514        assert!(!session_id.is_empty(), "session_id must not be empty");
515        assert!(
516            session_id
517                .chars()
518                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
519            "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
520        );
521        self.config.resume_session_id = Some(session_id.to_string());
522        self
523    }
524
525    /// Execute the agent invocation using the given [`AgentProvider`].
526    ///
527    /// If a [`retry_policy`](Agent::retry_policy) is configured, transient
528    /// failures (process crashes, timeouts) are retried with exponential
529    /// backoff. Deterministic errors (prompt too large, schema validation)
530    /// are returned immediately without retry.
531    ///
532    /// # Errors
533    ///
534    /// Returns [`OperationError::Agent`] if the provider reports a failure
535    /// (process crash, timeout, or schema validation error).
536    ///
537    /// # Panics
538    ///
539    /// Panics if [`prompt`](Agent::prompt) was never called or the prompt is
540    /// empty (whitespace-only counts as empty).
541    #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
542    pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
543        assert!(
544            !self.config.prompt.trim().is_empty(),
545            "prompt must not be empty - call .prompt(\"...\") before .run()"
546        );
547
548        if crate::dry_run::effective_dry_run(self.dry_run) {
549            info!(
550                prompt_len = self.config.prompt.len(),
551                "[dry-run] agent call skipped"
552            );
553            let mut output =
554                AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
555            output.cost_usd = Some(0.0);
556            output.input_tokens = Some(0);
557            output.output_tokens = Some(0);
558            return Ok(AgentResult { output });
559        }
560
561        let result = self.invoke_once(provider).await;
562
563        let policy = match &self.retry_policy {
564            Some(p) => p,
565            None => return result,
566        };
567
568        // Non-retryable errors are returned immediately.
569        if let Err(ref err) = result {
570            if !crate::retry::is_retryable(err) {
571                return result;
572            }
573        } else {
574            return result;
575        }
576
577        let mut last_result = result;
578
579        for attempt in 0..policy.max_retries {
580            let delay = policy.delay_for_attempt(attempt);
581            warn!(
582                attempt = attempt + 1,
583                max_retries = policy.max_retries,
584                delay_ms = delay.as_millis() as u64,
585                "retrying agent invocation"
586            );
587            time::sleep(delay).await;
588
589            last_result = self.invoke_once(provider).await;
590
591            match &last_result {
592                Ok(_) => return last_result,
593                Err(err) if !crate::retry::is_retryable(err) => return last_result,
594                _ => {}
595            }
596        }
597
598        last_result
599    }
600
601    /// Execute a single agent invocation attempt (no retry logic).
602    async fn invoke_once(
603        &self,
604        provider: &dyn AgentProvider,
605    ) -> Result<AgentResult, OperationError> {
606        #[cfg(feature = "prometheus")]
607        let model_label = self.config.model.to_string();
608
609        let output = match provider.invoke(&self.config).await {
610            Ok(output) => output,
611            Err(e) => {
612                #[cfg(feature = "prometheus")]
613                {
614                    metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
615                }
616                return Err(OperationError::Agent(e));
617            }
618        };
619
620        info!(
621            duration_ms = output.duration_ms,
622            cost_usd = output.cost_usd,
623            input_tokens = output.input_tokens,
624            output_tokens = output.output_tokens,
625            model = output.model,
626            "agent completed"
627        );
628
629        #[cfg(feature = "prometheus")]
630        {
631            metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
632            metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
633                .record(output.duration_ms as f64 / 1000.0);
634            if let Some(cost) = output.cost_usd {
635                metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
636                    .increment(cost);
637            }
638            if let Some(tokens) = output.input_tokens {
639                metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
640            }
641            if let Some(tokens) = output.output_tokens {
642                metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
643                    .increment(tokens);
644            }
645        }
646
647        Ok(AgentResult { output })
648    }
649}
650
651impl Default for Agent {
652    fn default() -> Self {
653        Self::new()
654    }
655}
656
657/// The result of a successful agent invocation.
658///
659/// Wraps the raw [`AgentOutput`] and provides convenience accessors for the
660/// response text, typed JSON deserialization, session metadata, and usage stats.
661#[derive(Debug)]
662pub struct AgentResult {
663    output: AgentOutput,
664}
665
666impl AgentResult {
667    /// Return the agent's response as a plain text string.
668    ///
669    /// If the underlying value is not a JSON string (e.g. when structured output
670    /// was requested), returns an empty string and logs a warning.
671    pub fn text(&self) -> &str {
672        match self.output.value.as_str() {
673            Some(s) => s,
674            None => {
675                warn!(
676                    value_type = self.output.value.to_string(),
677                    "agent output is not a string, returning empty"
678                );
679                ""
680            }
681        }
682    }
683
684    /// Return the raw JSON [`Value`] of the agent's response.
685    pub fn value(&self) -> &Value {
686        &self.output.value
687    }
688
689    /// Deserialize the agent's response into the given type `T`.
690    ///
691    /// This clones the underlying JSON value. If you no longer need the
692    /// `AgentResult` afterwards, use [`into_json`](AgentResult::into_json)
693    /// instead to avoid the clone.
694    ///
695    /// # Errors
696    ///
697    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
698    pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
699        from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
700    }
701
702    /// Consume the result and deserialize the response into `T` without cloning.
703    ///
704    /// # Errors
705    ///
706    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
707    pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
708        from_value(self.output.value).map_err(OperationError::deserialize::<T>)
709    }
710
711    /// Build an `AgentResult` from a raw [`AgentOutput`].
712    ///
713    /// This is available only in test builds to simplify test setup without
714    /// going through the full record/replay pipeline.
715    #[cfg(test)]
716    pub(crate) fn from_output(output: AgentOutput) -> Self {
717        Self { output }
718    }
719
720    /// Return the provider-assigned session ID, if available.
721    pub fn session_id(&self) -> Option<&str> {
722        self.output.session_id.as_deref()
723    }
724
725    /// Return the cost of this invocation in USD, if reported by the provider.
726    pub fn cost_usd(&self) -> Option<f64> {
727        self.output.cost_usd
728    }
729
730    /// Return the number of input tokens consumed, if reported.
731    pub fn input_tokens(&self) -> Option<u64> {
732        self.output.input_tokens
733    }
734
735    /// Return the number of output tokens generated, if reported.
736    pub fn output_tokens(&self) -> Option<u64> {
737        self.output.output_tokens
738    }
739
740    /// Return the wall-clock duration of the invocation in milliseconds.
741    pub fn duration_ms(&self) -> u64 {
742        self.output.duration_ms
743    }
744
745    /// Return the concrete model identifier used, if reported by the provider.
746    pub fn model(&self) -> Option<&str> {
747        self.output.model.as_deref()
748    }
749
750    /// Return the conversation trace captured during a verbose invocation.
751    ///
752    /// Returns `None` when [`Agent::verbose`] was not called. When present,
753    /// each [`DebugMessage`] contains the
754    /// assistant's text and tool calls for one conversation turn.
755    pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
756        self.output.debug_messages.as_deref()
757    }
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763    use crate::error::AgentError;
764    use crate::provider::InvokeFuture;
765    use serde_json::json;
766
767    struct TestProvider {
768        output: AgentOutput,
769    }
770
771    impl AgentProvider for TestProvider {
772        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
773            Box::pin(async move {
774                Ok(AgentOutput {
775                    value: self.output.value.clone(),
776                    session_id: self.output.session_id.clone(),
777                    cost_usd: self.output.cost_usd,
778                    input_tokens: self.output.input_tokens,
779                    output_tokens: self.output.output_tokens,
780                    model: self.output.model.clone(),
781                    duration_ms: self.output.duration_ms,
782                    debug_messages: None,
783                })
784            })
785        }
786    }
787
788    struct ConfigCapture {
789        output: AgentOutput,
790    }
791
792    impl AgentProvider for ConfigCapture {
793        fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
794            let config_json = serde_json::to_value(config).unwrap();
795            Box::pin(async move {
796                Ok(AgentOutput {
797                    value: config_json,
798                    session_id: self.output.session_id.clone(),
799                    cost_usd: self.output.cost_usd,
800                    input_tokens: self.output.input_tokens,
801                    output_tokens: self.output.output_tokens,
802                    model: self.output.model.clone(),
803                    duration_ms: self.output.duration_ms,
804                    debug_messages: None,
805                })
806            })
807        }
808    }
809
810    fn default_output() -> AgentOutput {
811        AgentOutput {
812            value: json!("test output"),
813            session_id: Some("sess-123".to_string()),
814            cost_usd: Some(0.05),
815            input_tokens: Some(100),
816            output_tokens: Some(50),
817            model: Some("sonnet".to_string()),
818            duration_ms: 1500,
819            debug_messages: None,
820        }
821    }
822
823    // --- Model constants ---
824
825    #[test]
826    fn model_constants_have_expected_values() {
827        assert_eq!(Model::SONNET, "sonnet");
828        assert_eq!(Model::OPUS, "opus");
829        assert_eq!(Model::HAIKU, "haiku");
830        assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
831        assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
832        assert_eq!(Model::OPUS_46, "claude-opus-4-6");
833        assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
834        assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
835        assert_eq!(Model::OPUS_47, "claude-opus-4-7");
836        assert_eq!(Model::OPUS_47_1M, "claude-opus-4-7[1m]");
837    }
838
839    // --- Agent::new() defaults via ConfigCapture ---
840
841    #[tokio::test]
842    async fn agent_new_default_values() {
843        let provider = ConfigCapture {
844            output: default_output(),
845        };
846        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
847
848        let config = result.value();
849        assert_eq!(config["system_prompt"], json!(null));
850        assert_eq!(config["prompt"], json!("hi"));
851        assert_eq!(config["model"], json!("sonnet"));
852        assert_eq!(config["allowed_tools"], json!([]));
853        assert_eq!(config["max_turns"], json!(null));
854        assert_eq!(config["max_budget_usd"], json!(null));
855        assert_eq!(config["working_dir"], json!(null));
856        assert_eq!(config["mcp_config"], json!(null));
857        assert_eq!(config["permission_mode"], json!("Default"));
858        assert_eq!(config["json_schema"], json!(null));
859    }
860
861    #[tokio::test]
862    async fn agent_default_matches_new() {
863        let provider = ConfigCapture {
864            output: default_output(),
865        };
866        let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
867        let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
868
869        assert_eq!(result_new.value(), result_default.value());
870    }
871
872    // --- Builder methods ---
873
874    #[tokio::test]
875    async fn builder_methods_store_values_correctly() {
876        let provider = ConfigCapture {
877            output: default_output(),
878        };
879        let result = Agent::new()
880            .system_prompt("you are a bot")
881            .prompt("do something")
882            .model(Model::OPUS)
883            .allowed_tools(&["Read", "Write"])
884            .max_turns(5)
885            .max_budget_usd(1.5)
886            .working_dir("/tmp")
887            .mcp_config("{}")
888            .permission_mode(PermissionMode::Auto)
889            .run(&provider)
890            .await
891            .unwrap();
892
893        let config = result.value();
894        assert_eq!(config["system_prompt"], json!("you are a bot"));
895        assert_eq!(config["prompt"], json!("do something"));
896        assert_eq!(config["model"], json!("opus"));
897        assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
898        assert_eq!(config["max_turns"], json!(5));
899        assert_eq!(config["max_budget_usd"], json!(1.5));
900        assert_eq!(config["working_dir"], json!("/tmp"));
901        assert_eq!(config["mcp_config"], json!("{}"));
902        assert_eq!(config["permission_mode"], json!("Auto"));
903    }
904
905    // --- Panics ---
906
907    #[test]
908    #[should_panic(expected = "max_turns must be greater than 0")]
909    fn max_turns_zero_panics() {
910        let _ = Agent::new().max_turns(0);
911    }
912
913    #[test]
914    #[should_panic(expected = "budget must be a positive finite number")]
915    fn max_budget_negative_panics() {
916        let _ = Agent::new().max_budget_usd(-1.0);
917    }
918
919    #[test]
920    #[should_panic(expected = "budget must be a positive finite number")]
921    fn max_budget_nan_panics() {
922        let _ = Agent::new().max_budget_usd(f64::NAN);
923    }
924
925    #[test]
926    #[should_panic(expected = "budget must be a positive finite number")]
927    fn max_budget_infinity_panics() {
928        let _ = Agent::new().max_budget_usd(f64::INFINITY);
929    }
930
931    // --- AgentResult accessors ---
932
933    #[tokio::test]
934    async fn agent_result_text_with_string_value() {
935        let provider = TestProvider {
936            output: AgentOutput {
937                value: json!("hello world"),
938                ..default_output()
939            },
940        };
941        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
942        assert_eq!(result.text(), "hello world");
943    }
944
945    #[tokio::test]
946    async fn agent_result_text_with_non_string_value() {
947        let provider = TestProvider {
948            output: AgentOutput {
949                value: json!(42),
950                ..default_output()
951            },
952        };
953        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
954        assert_eq!(result.text(), "");
955    }
956
957    #[tokio::test]
958    async fn agent_result_text_with_null_value() {
959        let provider = TestProvider {
960            output: AgentOutput {
961                value: json!(null),
962                ..default_output()
963            },
964        };
965        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
966        assert_eq!(result.text(), "");
967    }
968
969    #[tokio::test]
970    async fn agent_result_json_successful_deserialize() {
971        #[derive(Deserialize, PartialEq, Debug)]
972        struct MyOutput {
973            name: String,
974            count: u32,
975        }
976        let provider = TestProvider {
977            output: AgentOutput {
978                value: json!({"name": "test", "count": 7}),
979                ..default_output()
980            },
981        };
982        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
983        let parsed: MyOutput = result.json().unwrap();
984        assert_eq!(parsed.name, "test");
985        assert_eq!(parsed.count, 7);
986    }
987
988    #[tokio::test]
989    async fn agent_result_json_failed_deserialize() {
990        #[derive(Debug, Deserialize)]
991        #[allow(dead_code)]
992        struct MyOutput {
993            name: String,
994        }
995        let provider = TestProvider {
996            output: AgentOutput {
997                value: json!(42),
998                ..default_output()
999            },
1000        };
1001        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1002        let err = result.json::<MyOutput>().unwrap_err();
1003        assert!(matches!(err, OperationError::Deserialize { .. }));
1004    }
1005
1006    #[tokio::test]
1007    async fn agent_result_accessors() {
1008        let provider = TestProvider {
1009            output: AgentOutput {
1010                value: json!("v"),
1011                session_id: Some("s-1".to_string()),
1012                cost_usd: Some(0.123),
1013                input_tokens: Some(999),
1014                output_tokens: Some(456),
1015                model: Some("opus".to_string()),
1016                duration_ms: 2000,
1017                debug_messages: None,
1018            },
1019        };
1020        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1021        assert_eq!(result.session_id(), Some("s-1"));
1022        assert_eq!(result.cost_usd(), Some(0.123));
1023        assert_eq!(result.input_tokens(), Some(999));
1024        assert_eq!(result.output_tokens(), Some(456));
1025        assert_eq!(result.duration_ms(), 2000);
1026        assert_eq!(result.model(), Some("opus"));
1027    }
1028
1029    // --- Session resume ---
1030
1031    #[tokio::test]
1032    async fn resume_passes_session_id_in_config() {
1033        let provider = ConfigCapture {
1034            output: default_output(),
1035        };
1036        let result = Agent::new()
1037            .prompt("followup")
1038            .resume("sess-abc")
1039            .run(&provider)
1040            .await
1041            .unwrap();
1042
1043        let config = result.value();
1044        assert_eq!(config["resume_session_id"], json!("sess-abc"));
1045    }
1046
1047    #[tokio::test]
1048    async fn no_resume_has_null_session_id() {
1049        let provider = ConfigCapture {
1050            output: default_output(),
1051        };
1052        let result = Agent::new()
1053            .prompt("first call")
1054            .run(&provider)
1055            .await
1056            .unwrap();
1057
1058        let config = result.value();
1059        assert_eq!(config["resume_session_id"], json!(null));
1060    }
1061
1062    #[test]
1063    #[should_panic(expected = "session_id must not be empty")]
1064    fn resume_empty_session_id_panics() {
1065        let _ = Agent::new().resume("");
1066    }
1067
1068    #[test]
1069    #[should_panic(expected = "session_id must only contain")]
1070    fn resume_invalid_chars_panics() {
1071        let _ = Agent::new().resume("sess;rm -rf /");
1072    }
1073
1074    #[test]
1075    fn resume_valid_formats_accepted() {
1076        let _ = Agent::new().resume("sess-abc123");
1077        let _ = Agent::new().resume("a1b2c3d4_session");
1078        let _ = Agent::new().resume("abc-DEF-123_456");
1079    }
1080
1081    #[tokio::test]
1082    #[should_panic(expected = "prompt must not be empty")]
1083    async fn run_without_prompt_panics() {
1084        let provider = TestProvider {
1085            output: default_output(),
1086        };
1087        let _ = Agent::new().run(&provider).await;
1088    }
1089
1090    #[tokio::test]
1091    #[should_panic(expected = "prompt must not be empty")]
1092    async fn run_with_whitespace_only_prompt_panics() {
1093        let provider = TestProvider {
1094            output: default_output(),
1095        };
1096        let _ = Agent::new().prompt("   ").run(&provider).await;
1097    }
1098
1099    // --- Model accepts arbitrary strings ---
1100
1101    #[tokio::test]
1102    async fn model_accepts_custom_string() {
1103        let provider = ConfigCapture {
1104            output: default_output(),
1105        };
1106        let result = Agent::new()
1107            .prompt("hi")
1108            .model("mistral-large-latest")
1109            .run(&provider)
1110            .await
1111            .unwrap();
1112        assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1113    }
1114
1115    #[tokio::test]
1116    async fn verbose_sets_config_flag() {
1117        let provider = ConfigCapture {
1118            output: default_output(),
1119        };
1120        let result = Agent::new()
1121            .prompt("hi")
1122            .verbose()
1123            .run(&provider)
1124            .await
1125            .unwrap();
1126        assert_eq!(result.value()["verbose"], json!(true));
1127    }
1128
1129    #[tokio::test]
1130    async fn verbose_not_set_by_default() {
1131        let provider = ConfigCapture {
1132            output: default_output(),
1133        };
1134        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1135        assert_eq!(result.value()["verbose"], json!(false));
1136    }
1137
1138    #[tokio::test]
1139    async fn debug_messages_none_without_verbose() {
1140        let provider = TestProvider {
1141            output: default_output(),
1142        };
1143        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1144        assert!(result.debug_messages().is_none());
1145    }
1146
1147    #[tokio::test]
1148    async fn model_accepts_owned_string() {
1149        let provider = ConfigCapture {
1150            output: default_output(),
1151        };
1152        let model_name = String::from("gpt-4o");
1153        let result = Agent::new()
1154            .prompt("hi")
1155            .model(model_name)
1156            .run(&provider)
1157            .await
1158            .unwrap();
1159        assert_eq!(result.value()["model"], json!("gpt-4o"));
1160    }
1161
1162    #[tokio::test]
1163    async fn into_json_success() {
1164        #[derive(Deserialize, PartialEq, Debug)]
1165        struct Out {
1166            name: String,
1167        }
1168        let provider = TestProvider {
1169            output: AgentOutput {
1170                value: json!({"name": "test"}),
1171                ..default_output()
1172            },
1173        };
1174        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1175        let parsed: Out = result.into_json().unwrap();
1176        assert_eq!(parsed.name, "test");
1177    }
1178
1179    #[tokio::test]
1180    async fn into_json_failure() {
1181        #[derive(Debug, Deserialize)]
1182        #[allow(dead_code)]
1183        struct Out {
1184            name: String,
1185        }
1186        let provider = TestProvider {
1187            output: AgentOutput {
1188                value: json!(42),
1189                ..default_output()
1190            },
1191        };
1192        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1193        let err = result.into_json::<Out>().unwrap_err();
1194        assert!(matches!(err, OperationError::Deserialize { .. }));
1195    }
1196
1197    #[test]
1198    fn from_output_creates_result() {
1199        let output = AgentOutput {
1200            value: json!("hello"),
1201            ..default_output()
1202        };
1203        let result = AgentResult::from_output(output);
1204        assert_eq!(result.text(), "hello");
1205        assert_eq!(result.cost_usd(), Some(0.05));
1206    }
1207
1208    #[test]
1209    #[should_panic(expected = "budget must be a positive finite number")]
1210    fn max_budget_zero_panics() {
1211        let _ = Agent::new().max_budget_usd(0.0);
1212    }
1213
1214    #[test]
1215    fn model_constant_equality() {
1216        assert_eq!(Model::SONNET, "sonnet");
1217        assert_ne!(Model::SONNET, Model::OPUS);
1218    }
1219
1220    #[test]
1221    fn permission_mode_serialize_deserialize_roundtrip() {
1222        for mode in [
1223            PermissionMode::Default,
1224            PermissionMode::Auto,
1225            PermissionMode::DontAsk,
1226            PermissionMode::BypassPermissions,
1227        ] {
1228            let json = to_string(&mode).unwrap();
1229            let back: PermissionMode = serde_json::from_str(&json).unwrap();
1230            assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1231        }
1232    }
1233
1234    // --- Retry builder ---
1235
1236    #[test]
1237    fn retry_builder_stores_policy() {
1238        let agent = Agent::new().retry(3);
1239        assert!(agent.retry_policy.is_some());
1240        assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1241    }
1242
1243    #[test]
1244    fn retry_policy_builder_stores_custom_policy() {
1245        use crate::retry::RetryPolicy;
1246        let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1247        let agent = Agent::new().retry_policy(policy);
1248        let p = agent.retry_policy.unwrap();
1249        assert_eq!(p.max_retries(), 5);
1250    }
1251
1252    #[test]
1253    fn no_retry_by_default() {
1254        let agent = Agent::new();
1255        assert!(agent.retry_policy.is_none());
1256    }
1257
1258    // --- Retry behavior ---
1259
1260    use std::sync::Arc;
1261    use std::sync::atomic::{AtomicU32, Ordering};
1262    use std::time::Duration;
1263
1264    struct FailNTimesProvider {
1265        fail_count: AtomicU32,
1266        failures_before_success: u32,
1267        output: AgentOutput,
1268    }
1269
1270    impl AgentProvider for FailNTimesProvider {
1271        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1272            Box::pin(async move {
1273                let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1274                if current < self.failures_before_success {
1275                    Err(AgentError::ProcessFailed {
1276                        exit_code: 1,
1277                        stderr: format!("transient failure #{}", current + 1),
1278                    })
1279                } else {
1280                    Ok(AgentOutput {
1281                        value: self.output.value.clone(),
1282                        session_id: self.output.session_id.clone(),
1283                        cost_usd: self.output.cost_usd,
1284                        input_tokens: self.output.input_tokens,
1285                        output_tokens: self.output.output_tokens,
1286                        model: self.output.model.clone(),
1287                        duration_ms: self.output.duration_ms,
1288                        debug_messages: None,
1289                    })
1290                }
1291            })
1292        }
1293    }
1294
1295    #[tokio::test]
1296    async fn retry_succeeds_after_transient_failures() {
1297        let provider = FailNTimesProvider {
1298            fail_count: AtomicU32::new(0),
1299            failures_before_success: 2,
1300            output: default_output(),
1301        };
1302        let result = Agent::new()
1303            .prompt("test")
1304            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1305            .run(&provider)
1306            .await;
1307
1308        assert!(result.is_ok());
1309        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
1310    }
1311
1312    #[tokio::test]
1313    async fn retry_exhausted_returns_last_error() {
1314        let provider = FailNTimesProvider {
1315            fail_count: AtomicU32::new(0),
1316            failures_before_success: 10, // always fails
1317            output: default_output(),
1318        };
1319        let result = Agent::new()
1320            .prompt("test")
1321            .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1322            .run(&provider)
1323            .await;
1324
1325        assert!(result.is_err());
1326        // 1 initial + 2 retries = 3 total
1327        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1328    }
1329
1330    #[tokio::test]
1331    async fn retry_does_not_retry_non_retryable_errors() {
1332        let call_count = Arc::new(AtomicU32::new(0));
1333        let count = call_count.clone();
1334
1335        struct CountingNonRetryable {
1336            count: Arc<AtomicU32>,
1337        }
1338        impl AgentProvider for CountingNonRetryable {
1339            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1340                self.count.fetch_add(1, Ordering::SeqCst);
1341                Box::pin(async move {
1342                    Err(AgentError::SchemaValidation {
1343                        expected: "object".to_string(),
1344                        got: "string".to_string(),
1345                        debug_messages: Vec::new(),
1346                        partial_usage: Box::default(),
1347                    })
1348                })
1349            }
1350        }
1351
1352        let provider = CountingNonRetryable { count };
1353        let result = Agent::new()
1354            .prompt("test")
1355            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1356            .run(&provider)
1357            .await;
1358
1359        assert!(result.is_err());
1360        // Only 1 attempt, no retries for non-retryable errors
1361        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1362    }
1363
1364    #[tokio::test]
1365    async fn no_retry_without_policy() {
1366        let provider = FailNTimesProvider {
1367            fail_count: AtomicU32::new(0),
1368            failures_before_success: 1,
1369            output: default_output(),
1370        };
1371        let result = Agent::new().prompt("test").run(&provider).await;
1372
1373        assert!(result.is_err());
1374        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1375    }
1376}