Skip to main content

ironflow_core/operations/
agent.rs

1//! Agent operation - build and execute AI agent calls.
2//!
3//! The [`Agent`] builder lets you configure a single agent invocation (model,
4//! prompt, tools, budget, permissions, etc.) and execute it through any
5//! [`AgentProvider`]. The result is an [`AgentResult`] that provides typed
6//! access to the agent's response, session metadata, and usage statistics.
7//!
8//! # Examples
9//!
10//! ```no_run
11//! use ironflow_core::prelude::*;
12//!
13//! # async fn example() -> Result<(), OperationError> {
14//! let provider = ClaudeCodeProvider::new();
15//!
16//! let result = Agent::new()
17//!     .prompt("Summarize the README.md file")
18//!     .model(Model::SONNET)
19//!     .max_turns(3)
20//!     .run(&provider)
21//!     .await?;
22//!
23//! println!("{}", result.text());
24//! # Ok(())
25//! # }
26//! ```
27
28use std::any;
29
30use schemars::{JsonSchema, schema_for};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, from_value, to_string};
34use tokio::time;
35use tracing::{info, warn};
36
37use crate::error::OperationError;
38#[cfg(feature = "prometheus")]
39use crate::metric_names;
40use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage};
41use crate::retry::RetryPolicy;
42
43/// Provider-agnostic model identifiers.
44///
45/// Constants are provided for well-known Claude models, but any string
46/// is accepted - custom [`AgentProvider`] implementations interpret the
47/// model identifier however they wish.
48///
49/// # Examples
50///
51/// ```no_run
52/// use ironflow_core::prelude::*;
53///
54/// # async fn example() -> Result<(), OperationError> {
55/// let provider = ClaudeCodeProvider::new();
56///
57/// // Using a built-in constant
58/// let r = Agent::new()
59///     .prompt("hi")
60///     .model(Model::SONNET)
61///     .run(&provider)
62///     .await?;
63///
64/// // Using a custom model string
65/// let r = Agent::new()
66///     .prompt("hi")
67///     .model("mistral-large-latest")
68///     .run(&provider)
69///     .await?;
70/// # Ok(())
71/// # }
72/// ```
73pub struct Model;
74
75impl Model {
76    // ── Aliases (latest version, CLI resolves to current) ───────────
77
78    /// Claude Sonnet - balanced speed and capability (default).
79    pub const SONNET: &str = "sonnet";
80    /// Claude Opus - highest capability.
81    pub const OPUS: &str = "opus";
82    /// Claude Haiku - fastest and cheapest.
83    pub const HAIKU: &str = "haiku";
84
85    // ── Claude 4.5 ─────────────────────────────────────────────────
86
87    /// Claude Haiku 4.5.
88    pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
89
90    // ── Claude 4.6 - 200K context ──────────────────────────────────
91
92    /// Claude Sonnet 4.6.
93    pub const SONNET_46: &str = "claude-sonnet-4-6";
94    /// Claude Opus 4.6.
95    pub const OPUS_46: &str = "claude-opus-4-6";
96
97    // ── Claude 4.6 - 1M context ────────────────────────────────────
98
99    /// Claude Sonnet 4.6 with 1M token context window.
100    pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
101    /// Claude Opus 4.6 with 1M token context window.
102    pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
103}
104
105/// Controls how the agent handles tool-use permission prompts.
106///
107/// These map to the `--permission-mode` and `--dangerously-skip-permissions`
108/// flags in the Claude CLI.
109#[derive(Debug, Default, Clone, Copy, Serialize)]
110pub enum PermissionMode {
111    /// Use the CLI default permission behavior.
112    #[default]
113    Default,
114    /// Automatically approve tool-use requests.
115    Auto,
116    /// Suppress all permission prompts (the agent proceeds without asking).
117    DontAsk,
118    /// Skip all permission checks entirely.
119    ///
120    /// **Warning**: the agent will have unrestricted filesystem and shell access.
121    BypassPermissions,
122}
123
124impl<'de> Deserialize<'de> for PermissionMode {
125    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126    where
127        D: serde::Deserializer<'de>,
128    {
129        let s = String::deserialize(deserializer)?;
130        Ok(match s.to_lowercase().replace('_', "").as_str() {
131            "auto" => Self::Auto,
132            "dontask" => Self::DontAsk,
133            "bypass" | "bypasspermissions" => Self::BypassPermissions,
134            _ => Self::Default,
135        })
136    }
137}
138
139/// Builder for a single agent invocation.
140///
141/// Create with [`Agent::new`], chain configuration methods, then call
142/// [`run`](Agent::run) with an [`AgentProvider`] to execute.
143///
144/// # Examples
145///
146/// ```no_run
147/// use ironflow_core::prelude::*;
148///
149/// # async fn example() -> Result<(), OperationError> {
150/// let provider = ClaudeCodeProvider::new();
151///
152/// let result = Agent::new()
153///     .system_prompt("You are a Rust expert.")
154///     .prompt("Review this code for safety issues.")
155///     .model(Model::OPUS)
156///     .allowed_tools(&["Read", "Grep"])
157///     .max_turns(5)
158///     .max_budget_usd(0.50)
159///     .working_dir("/tmp/project")
160///     .permission_mode(PermissionMode::Auto)
161///     .run(&provider)
162///     .await?;
163///
164/// println!("Cost: ${:.4}", result.cost_usd().unwrap_or(0.0));
165/// # Ok(())
166/// # }
167/// ```
168#[must_use = "an Agent does nothing until .run() is awaited"]
169pub struct Agent {
170    config: AgentConfig,
171    dry_run: Option<bool>,
172    retry_policy: Option<RetryPolicy>,
173}
174
175impl Agent {
176    /// Create a new agent builder with default settings.
177    ///
178    /// Defaults: [`Model::SONNET`], no system prompt, no tool restrictions,
179    /// no budget/turn limits, [`PermissionMode::Default`].
180    pub fn new() -> Self {
181        Self {
182            config: AgentConfig::new(""),
183            dry_run: None,
184            retry_policy: None,
185        }
186    }
187
188    /// Create an agent builder from an existing [`AgentConfig`].
189    ///
190    /// Useful when the config comes from a serialized workflow definition
191    /// rather than being built programmatically.
192    ///
193    /// # Examples
194    ///
195    /// ```no_run
196    /// use ironflow_core::prelude::*;
197    /// use ironflow_core::provider::AgentConfig;
198    ///
199    /// # async fn example() -> Result<(), OperationError> {
200    /// let provider = ClaudeCodeProvider::new();
201    /// let config = AgentConfig::new("Summarize the README");
202    /// let result = Agent::from_config(config).run(&provider).await?;
203    /// # Ok(())
204    /// # }
205    /// ```
206    pub fn from_config(config: AgentConfig) -> Self {
207        Self {
208            config,
209            dry_run: None,
210            retry_policy: None,
211        }
212    }
213
214    /// Set the system prompt that defines the agent's persona or constraints.
215    pub fn system_prompt(mut self, prompt: &str) -> Self {
216        self.config.system_prompt = Some(prompt.to_string());
217        self
218    }
219
220    /// Set the user prompt - the main instruction sent to the agent.
221    pub fn prompt(mut self, prompt: &str) -> Self {
222        self.config.prompt = prompt.to_string();
223        self
224    }
225
226    /// Set the model to use for this invocation.
227    ///
228    /// Accepts any string-like value. Use [`Model`] constants for well-known
229    /// Claude models, or pass an arbitrary string for custom providers.
230    ///
231    /// Defaults to [`Model::SONNET`] if not called.
232    pub fn model(mut self, model: impl Into<String>) -> Self {
233        self.config.model = model.into();
234        self
235    }
236
237    /// Restrict which tools the agent may invoke.
238    ///
239    /// Pass an empty slice (or do not call this method) to allow the provider
240    /// default set of tools.
241    pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
242        self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
243        self
244    }
245
246    /// Set the maximum number of agentic turns.
247    ///
248    /// # Panics
249    ///
250    /// Panics if `turns` is `0`.
251    pub fn max_turns(mut self, turns: u32) -> Self {
252        assert!(turns > 0, "max_turns must be greater than 0");
253        self.config.max_turns = Some(turns);
254        self
255    }
256
257    /// Set the maximum spend in USD for this invocation.
258    ///
259    /// # Panics
260    ///
261    /// Panics if `budget` is negative, NaN, or infinity.
262    pub fn max_budget_usd(mut self, budget: f64) -> Self {
263        assert!(
264            budget.is_finite() && budget > 0.0,
265            "budget must be a positive finite number, got {budget}"
266        );
267        self.config.max_budget_usd = Some(budget);
268        self
269    }
270
271    /// Set the working directory for the agent process.
272    pub fn working_dir(mut self, dir: &str) -> Self {
273        self.config.working_dir = Some(dir.to_string());
274        self
275    }
276
277    /// Set the path to an MCP (Model Context Protocol) server configuration file.
278    pub fn mcp_config(mut self, config: &str) -> Self {
279        self.config.mcp_config = Some(config.to_string());
280        self
281    }
282
283    /// Set the permission mode controlling tool-use approval behavior.
284    ///
285    /// See [`PermissionMode`] for details on each variant.
286    pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
287        self.config.permission_mode = mode;
288        self
289    }
290
291    /// Request structured (typed) output from the agent.
292    ///
293    /// The type `T` must implement [`JsonSchema`]. The generated schema is sent
294    /// to the provider so the model returns JSON conforming to `T`, which can
295    /// then be deserialized with [`AgentResult::json`].
296    ///
297    /// # Examples
298    ///
299    /// ```no_run
300    /// use ironflow_core::prelude::*;
301    ///
302    /// #[derive(Deserialize, JsonSchema)]
303    /// struct Review {
304    ///     score: u8,
305    ///     summary: String,
306    /// }
307    ///
308    /// # async fn example() -> Result<(), OperationError> {
309    /// let provider = ClaudeCodeProvider::new();
310    /// let result = Agent::new()
311    ///     .prompt("Review the codebase")
312    ///     .output::<Review>()
313    ///     .run(&provider)
314    ///     .await?;
315    ///
316    /// let review: Review = result.json().expect("schema-validated output");
317    /// println!("Score: {}/10 - {}", review.score, review.summary);
318    /// # Ok(())
319    /// # }
320    /// ```
321    pub fn output<T: JsonSchema>(mut self) -> Self {
322        let schema = schema_for!(T);
323        self.config.json_schema = match to_string(&schema) {
324            Ok(s) => Some(s),
325            Err(e) => {
326                warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
327                None
328            }
329        };
330        self
331    }
332
333    /// Set structured output from a pre-serialized JSON Schema string.
334    ///
335    /// Use this when the schema comes from configuration or another source
336    /// rather than a Rust type. For type-safe schema generation, prefer
337    /// [`output`](Agent::output).
338    ///
339    /// **Important:** structured output requires `max_turns >= 2`. The Claude CLI
340    /// uses the first turn for reasoning and a second turn to produce the
341    /// schema-conforming JSON.
342    ///
343    /// # Examples
344    ///
345    /// ```no_run
346    /// use ironflow_core::prelude::*;
347    ///
348    /// # async fn example() -> Result<(), OperationError> {
349    /// let schema = r#"{"type":"object","properties":{"labels":{"type":"array","items":{"type":"string"}}}}"#;
350    /// let agent = Agent::new()
351    ///     .prompt("Classify this email")
352    ///     .output_schema_raw(schema);
353    /// # Ok(())
354    /// # }
355    /// ```
356    pub fn output_schema_raw(mut self, schema: &str) -> Self {
357        self.config.json_schema = Some(schema.to_string());
358        self
359    }
360
361    /// Retry the agent invocation up to `max_retries` times on transient failures.
362    ///
363    /// Uses default exponential backoff settings (200ms initial, 2x multiplier,
364    /// 30s cap). For custom backoff parameters, use [`retry_policy`](Agent::retry_policy).
365    ///
366    /// Only transient errors are retried: process failures and timeouts.
367    /// Deterministic errors (prompt too large, schema validation) are never retried.
368    ///
369    /// # Panics
370    ///
371    /// Panics if `max_retries` is `0`.
372    ///
373    /// # Examples
374    ///
375    /// ```no_run
376    /// use ironflow_core::prelude::*;
377    ///
378    /// # async fn example() -> Result<(), OperationError> {
379    /// let provider = ClaudeCodeProvider::new();
380    /// let result = Agent::new()
381    ///     .prompt("Summarize the codebase")
382    ///     .retry(2)
383    ///     .run(&provider)
384    ///     .await?;
385    /// # Ok(())
386    /// # }
387    /// ```
388    pub fn retry(mut self, max_retries: u32) -> Self {
389        self.retry_policy = Some(RetryPolicy::new(max_retries));
390        self
391    }
392
393    /// Set a custom [`RetryPolicy`] for this agent invocation.
394    ///
395    /// Allows full control over backoff duration, multiplier, and max delay.
396    /// See [`RetryPolicy`] for details.
397    ///
398    /// # Examples
399    ///
400    /// ```no_run
401    /// use std::time::Duration;
402    /// use ironflow_core::prelude::*;
403    /// use ironflow_core::retry::RetryPolicy;
404    ///
405    /// # async fn example() -> Result<(), OperationError> {
406    /// let provider = ClaudeCodeProvider::new();
407    /// let result = Agent::new()
408    ///     .prompt("Analyze the code")
409    ///     .retry_policy(
410    ///         RetryPolicy::new(3)
411    ///             .backoff(Duration::from_secs(1))
412    ///             .max_backoff(Duration::from_secs(60))
413    ///     )
414    ///     .run(&provider)
415    ///     .await?;
416    /// # Ok(())
417    /// # }
418    /// ```
419    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
420        self.retry_policy = Some(policy);
421        self
422    }
423
424    /// Enable or disable dry-run mode for this specific operation.
425    ///
426    /// When dry-run is active, the agent call is logged but not executed.
427    /// A synthetic [`AgentResult`] is returned with a placeholder text,
428    /// zero cost, and zero tokens.
429    ///
430    /// If not set, falls back to the global dry-run setting
431    /// (see [`set_dry_run`](crate::dry_run::set_dry_run)).
432    pub fn dry_run(mut self, enabled: bool) -> Self {
433        self.dry_run = Some(enabled);
434        self
435    }
436
437    /// Enable verbose/debug mode to capture the full conversation trace.
438    ///
439    /// When enabled, the provider captures every assistant message and tool
440    /// call into [`AgentResult::debug_messages`]. Useful for understanding
441    /// why the agent returned an unexpected result.
442    ///
443    /// # Examples
444    ///
445    /// ```no_run
446    /// use ironflow_core::prelude::*;
447    ///
448    /// # async fn example() -> Result<(), OperationError> {
449    /// let provider = ClaudeCodeProvider::new();
450    ///
451    /// let result = Agent::new()
452    ///     .prompt("Analyze src/")
453    ///     .verbose()
454    ///     .max_budget_usd(0.10)
455    ///     .run(&provider)
456    ///     .await?;
457    ///
458    /// if let Some(messages) = result.debug_messages() {
459    ///     for msg in messages {
460    ///         println!("{msg}");
461    ///     }
462    /// }
463    /// # Ok(())
464    /// # }
465    /// ```
466    pub fn verbose(mut self) -> Self {
467        self.config.verbose = true;
468        self
469    }
470
471    /// Resume a previous agent conversation by session ID.
472    ///
473    /// Pass the session ID from a previous [`AgentResult::session_id()`] to
474    /// continue the multi-turn conversation.
475    ///
476    /// # Examples
477    ///
478    /// ```no_run
479    /// use ironflow_core::prelude::*;
480    ///
481    /// # async fn example() -> Result<(), OperationError> {
482    /// let provider = ClaudeCodeProvider::new();
483    ///
484    /// let first = Agent::new()
485    ///     .prompt("Analyze the src/ directory")
486    ///     .max_budget_usd(0.10)
487    ///     .run(&provider)
488    ///     .await?;
489    ///
490    /// let session = first.session_id().expect("provider returned session ID");
491    ///
492    /// let followup = Agent::new()
493    ///     .prompt("Now suggest improvements")
494    ///     .resume(session)
495    ///     .max_budget_usd(0.10)
496    ///     .run(&provider)
497    ///     .await?;
498    /// # Ok(())
499    /// # }
500    /// ```
501    ///
502    /// # Panics
503    ///
504    /// Panics if `session_id` is empty or contains characters other than
505    /// alphanumerics, hyphens, and underscores.
506    pub fn resume(mut self, session_id: &str) -> Self {
507        assert!(!session_id.is_empty(), "session_id must not be empty");
508        assert!(
509            session_id
510                .chars()
511                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
512            "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
513        );
514        self.config.resume_session_id = Some(session_id.to_string());
515        self
516    }
517
518    /// Execute the agent invocation using the given [`AgentProvider`].
519    ///
520    /// If a [`retry_policy`](Agent::retry_policy) is configured, transient
521    /// failures (process crashes, timeouts) are retried with exponential
522    /// backoff. Deterministic errors (prompt too large, schema validation)
523    /// are returned immediately without retry.
524    ///
525    /// # Errors
526    ///
527    /// Returns [`OperationError::Agent`] if the provider reports a failure
528    /// (process crash, timeout, or schema validation error).
529    ///
530    /// # Panics
531    ///
532    /// Panics if [`prompt`](Agent::prompt) was never called or the prompt is
533    /// empty (whitespace-only counts as empty).
534    #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
535    pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
536        assert!(
537            !self.config.prompt.trim().is_empty(),
538            "prompt must not be empty - call .prompt(\"...\") before .run()"
539        );
540
541        if crate::dry_run::effective_dry_run(self.dry_run) {
542            info!(
543                prompt_len = self.config.prompt.len(),
544                "[dry-run] agent call skipped"
545            );
546            let mut output =
547                AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
548            output.cost_usd = Some(0.0);
549            output.input_tokens = Some(0);
550            output.output_tokens = Some(0);
551            return Ok(AgentResult { output });
552        }
553
554        let result = self.invoke_once(provider).await;
555
556        let policy = match &self.retry_policy {
557            Some(p) => p,
558            None => return result,
559        };
560
561        // Non-retryable errors are returned immediately.
562        if let Err(ref err) = result {
563            if !crate::retry::is_retryable(err) {
564                return result;
565            }
566        } else {
567            return result;
568        }
569
570        let mut last_result = result;
571
572        for attempt in 0..policy.max_retries {
573            let delay = policy.delay_for_attempt(attempt);
574            warn!(
575                attempt = attempt + 1,
576                max_retries = policy.max_retries,
577                delay_ms = delay.as_millis() as u64,
578                "retrying agent invocation"
579            );
580            time::sleep(delay).await;
581
582            last_result = self.invoke_once(provider).await;
583
584            match &last_result {
585                Ok(_) => return last_result,
586                Err(err) if !crate::retry::is_retryable(err) => return last_result,
587                _ => {}
588            }
589        }
590
591        last_result
592    }
593
594    /// Execute a single agent invocation attempt (no retry logic).
595    async fn invoke_once(
596        &self,
597        provider: &dyn AgentProvider,
598    ) -> Result<AgentResult, OperationError> {
599        #[cfg(feature = "prometheus")]
600        let model_label = self.config.model.to_string();
601
602        let output = match provider.invoke(&self.config).await {
603            Ok(output) => output,
604            Err(e) => {
605                #[cfg(feature = "prometheus")]
606                {
607                    metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
608                }
609                return Err(OperationError::Agent(e));
610            }
611        };
612
613        info!(
614            duration_ms = output.duration_ms,
615            cost_usd = output.cost_usd,
616            input_tokens = output.input_tokens,
617            output_tokens = output.output_tokens,
618            model = output.model,
619            "agent completed"
620        );
621
622        #[cfg(feature = "prometheus")]
623        {
624            metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
625            metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
626                .record(output.duration_ms as f64 / 1000.0);
627            if let Some(cost) = output.cost_usd {
628                metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
629                    .increment(cost);
630            }
631            if let Some(tokens) = output.input_tokens {
632                metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
633            }
634            if let Some(tokens) = output.output_tokens {
635                metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
636                    .increment(tokens);
637            }
638        }
639
640        Ok(AgentResult { output })
641    }
642}
643
644impl Default for Agent {
645    fn default() -> Self {
646        Self::new()
647    }
648}
649
650/// The result of a successful agent invocation.
651///
652/// Wraps the raw [`AgentOutput`] and provides convenience accessors for the
653/// response text, typed JSON deserialization, session metadata, and usage stats.
654#[derive(Debug)]
655pub struct AgentResult {
656    output: AgentOutput,
657}
658
659impl AgentResult {
660    /// Return the agent's response as a plain text string.
661    ///
662    /// If the underlying value is not a JSON string (e.g. when structured output
663    /// was requested), returns an empty string and logs a warning.
664    pub fn text(&self) -> &str {
665        match self.output.value.as_str() {
666            Some(s) => s,
667            None => {
668                warn!(
669                    value_type = self.output.value.to_string(),
670                    "agent output is not a string, returning empty"
671                );
672                ""
673            }
674        }
675    }
676
677    /// Return the raw JSON [`Value`] of the agent's response.
678    pub fn value(&self) -> &Value {
679        &self.output.value
680    }
681
682    /// Deserialize the agent's response into the given type `T`.
683    ///
684    /// This clones the underlying JSON value. If you no longer need the
685    /// `AgentResult` afterwards, use [`into_json`](AgentResult::into_json)
686    /// instead to avoid the clone.
687    ///
688    /// # Errors
689    ///
690    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
691    pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
692        from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
693    }
694
695    /// Consume the result and deserialize the response into `T` without cloning.
696    ///
697    /// # Errors
698    ///
699    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
700    pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
701        from_value(self.output.value).map_err(OperationError::deserialize::<T>)
702    }
703
704    /// Build an `AgentResult` from a raw [`AgentOutput`].
705    ///
706    /// This is available only in test builds to simplify test setup without
707    /// going through the full record/replay pipeline.
708    #[cfg(test)]
709    pub(crate) fn from_output(output: AgentOutput) -> Self {
710        Self { output }
711    }
712
713    /// Return the provider-assigned session ID, if available.
714    pub fn session_id(&self) -> Option<&str> {
715        self.output.session_id.as_deref()
716    }
717
718    /// Return the cost of this invocation in USD, if reported by the provider.
719    pub fn cost_usd(&self) -> Option<f64> {
720        self.output.cost_usd
721    }
722
723    /// Return the number of input tokens consumed, if reported.
724    pub fn input_tokens(&self) -> Option<u64> {
725        self.output.input_tokens
726    }
727
728    /// Return the number of output tokens generated, if reported.
729    pub fn output_tokens(&self) -> Option<u64> {
730        self.output.output_tokens
731    }
732
733    /// Return the wall-clock duration of the invocation in milliseconds.
734    pub fn duration_ms(&self) -> u64 {
735        self.output.duration_ms
736    }
737
738    /// Return the concrete model identifier used, if reported by the provider.
739    pub fn model(&self) -> Option<&str> {
740        self.output.model.as_deref()
741    }
742
743    /// Return the conversation trace captured during a verbose invocation.
744    ///
745    /// Returns `None` when [`Agent::verbose`] was not called. When present,
746    /// each [`DebugMessage`] contains the
747    /// assistant's text and tool calls for one conversation turn.
748    pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
749        self.output.debug_messages.as_deref()
750    }
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756    use crate::error::AgentError;
757    use crate::provider::InvokeFuture;
758    use serde_json::json;
759
760    struct TestProvider {
761        output: AgentOutput,
762    }
763
764    impl AgentProvider for TestProvider {
765        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
766            Box::pin(async move {
767                Ok(AgentOutput {
768                    value: self.output.value.clone(),
769                    session_id: self.output.session_id.clone(),
770                    cost_usd: self.output.cost_usd,
771                    input_tokens: self.output.input_tokens,
772                    output_tokens: self.output.output_tokens,
773                    model: self.output.model.clone(),
774                    duration_ms: self.output.duration_ms,
775                    debug_messages: None,
776                })
777            })
778        }
779    }
780
781    struct ConfigCapture {
782        output: AgentOutput,
783    }
784
785    impl AgentProvider for ConfigCapture {
786        fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
787            let config_json = serde_json::to_value(config).unwrap();
788            Box::pin(async move {
789                Ok(AgentOutput {
790                    value: config_json,
791                    session_id: self.output.session_id.clone(),
792                    cost_usd: self.output.cost_usd,
793                    input_tokens: self.output.input_tokens,
794                    output_tokens: self.output.output_tokens,
795                    model: self.output.model.clone(),
796                    duration_ms: self.output.duration_ms,
797                    debug_messages: None,
798                })
799            })
800        }
801    }
802
803    fn default_output() -> AgentOutput {
804        AgentOutput {
805            value: json!("test output"),
806            session_id: Some("sess-123".to_string()),
807            cost_usd: Some(0.05),
808            input_tokens: Some(100),
809            output_tokens: Some(50),
810            model: Some("sonnet".to_string()),
811            duration_ms: 1500,
812            debug_messages: None,
813        }
814    }
815
816    // --- Model constants ---
817
818    #[test]
819    fn model_constants_have_expected_values() {
820        assert_eq!(Model::SONNET, "sonnet");
821        assert_eq!(Model::OPUS, "opus");
822        assert_eq!(Model::HAIKU, "haiku");
823        assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
824        assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
825        assert_eq!(Model::OPUS_46, "claude-opus-4-6");
826        assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
827        assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
828    }
829
830    // --- Agent::new() defaults via ConfigCapture ---
831
832    #[tokio::test]
833    async fn agent_new_default_values() {
834        let provider = ConfigCapture {
835            output: default_output(),
836        };
837        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
838
839        let config = result.value();
840        assert_eq!(config["system_prompt"], json!(null));
841        assert_eq!(config["prompt"], json!("hi"));
842        assert_eq!(config["model"], json!("sonnet"));
843        assert_eq!(config["allowed_tools"], json!([]));
844        assert_eq!(config["max_turns"], json!(null));
845        assert_eq!(config["max_budget_usd"], json!(null));
846        assert_eq!(config["working_dir"], json!(null));
847        assert_eq!(config["mcp_config"], json!(null));
848        assert_eq!(config["permission_mode"], json!("Default"));
849        assert_eq!(config["json_schema"], json!(null));
850    }
851
852    #[tokio::test]
853    async fn agent_default_matches_new() {
854        let provider = ConfigCapture {
855            output: default_output(),
856        };
857        let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
858        let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
859
860        assert_eq!(result_new.value(), result_default.value());
861    }
862
863    // --- Builder methods ---
864
865    #[tokio::test]
866    async fn builder_methods_store_values_correctly() {
867        let provider = ConfigCapture {
868            output: default_output(),
869        };
870        let result = Agent::new()
871            .system_prompt("you are a bot")
872            .prompt("do something")
873            .model(Model::OPUS)
874            .allowed_tools(&["Read", "Write"])
875            .max_turns(5)
876            .max_budget_usd(1.5)
877            .working_dir("/tmp")
878            .mcp_config("{}")
879            .permission_mode(PermissionMode::Auto)
880            .run(&provider)
881            .await
882            .unwrap();
883
884        let config = result.value();
885        assert_eq!(config["system_prompt"], json!("you are a bot"));
886        assert_eq!(config["prompt"], json!("do something"));
887        assert_eq!(config["model"], json!("opus"));
888        assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
889        assert_eq!(config["max_turns"], json!(5));
890        assert_eq!(config["max_budget_usd"], json!(1.5));
891        assert_eq!(config["working_dir"], json!("/tmp"));
892        assert_eq!(config["mcp_config"], json!("{}"));
893        assert_eq!(config["permission_mode"], json!("Auto"));
894    }
895
896    // --- Panics ---
897
898    #[test]
899    #[should_panic(expected = "max_turns must be greater than 0")]
900    fn max_turns_zero_panics() {
901        let _ = Agent::new().max_turns(0);
902    }
903
904    #[test]
905    #[should_panic(expected = "budget must be a positive finite number")]
906    fn max_budget_negative_panics() {
907        let _ = Agent::new().max_budget_usd(-1.0);
908    }
909
910    #[test]
911    #[should_panic(expected = "budget must be a positive finite number")]
912    fn max_budget_nan_panics() {
913        let _ = Agent::new().max_budget_usd(f64::NAN);
914    }
915
916    #[test]
917    #[should_panic(expected = "budget must be a positive finite number")]
918    fn max_budget_infinity_panics() {
919        let _ = Agent::new().max_budget_usd(f64::INFINITY);
920    }
921
922    // --- AgentResult accessors ---
923
924    #[tokio::test]
925    async fn agent_result_text_with_string_value() {
926        let provider = TestProvider {
927            output: AgentOutput {
928                value: json!("hello world"),
929                ..default_output()
930            },
931        };
932        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
933        assert_eq!(result.text(), "hello world");
934    }
935
936    #[tokio::test]
937    async fn agent_result_text_with_non_string_value() {
938        let provider = TestProvider {
939            output: AgentOutput {
940                value: json!(42),
941                ..default_output()
942            },
943        };
944        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
945        assert_eq!(result.text(), "");
946    }
947
948    #[tokio::test]
949    async fn agent_result_text_with_null_value() {
950        let provider = TestProvider {
951            output: AgentOutput {
952                value: json!(null),
953                ..default_output()
954            },
955        };
956        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
957        assert_eq!(result.text(), "");
958    }
959
960    #[tokio::test]
961    async fn agent_result_json_successful_deserialize() {
962        #[derive(Deserialize, PartialEq, Debug)]
963        struct MyOutput {
964            name: String,
965            count: u32,
966        }
967        let provider = TestProvider {
968            output: AgentOutput {
969                value: json!({"name": "test", "count": 7}),
970                ..default_output()
971            },
972        };
973        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
974        let parsed: MyOutput = result.json().unwrap();
975        assert_eq!(parsed.name, "test");
976        assert_eq!(parsed.count, 7);
977    }
978
979    #[tokio::test]
980    async fn agent_result_json_failed_deserialize() {
981        #[derive(Debug, Deserialize)]
982        #[allow(dead_code)]
983        struct MyOutput {
984            name: String,
985        }
986        let provider = TestProvider {
987            output: AgentOutput {
988                value: json!(42),
989                ..default_output()
990            },
991        };
992        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
993        let err = result.json::<MyOutput>().unwrap_err();
994        assert!(matches!(err, OperationError::Deserialize { .. }));
995    }
996
997    #[tokio::test]
998    async fn agent_result_accessors() {
999        let provider = TestProvider {
1000            output: AgentOutput {
1001                value: json!("v"),
1002                session_id: Some("s-1".to_string()),
1003                cost_usd: Some(0.123),
1004                input_tokens: Some(999),
1005                output_tokens: Some(456),
1006                model: Some("opus".to_string()),
1007                duration_ms: 2000,
1008                debug_messages: None,
1009            },
1010        };
1011        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1012        assert_eq!(result.session_id(), Some("s-1"));
1013        assert_eq!(result.cost_usd(), Some(0.123));
1014        assert_eq!(result.input_tokens(), Some(999));
1015        assert_eq!(result.output_tokens(), Some(456));
1016        assert_eq!(result.duration_ms(), 2000);
1017        assert_eq!(result.model(), Some("opus"));
1018    }
1019
1020    // --- Session resume ---
1021
1022    #[tokio::test]
1023    async fn resume_passes_session_id_in_config() {
1024        let provider = ConfigCapture {
1025            output: default_output(),
1026        };
1027        let result = Agent::new()
1028            .prompt("followup")
1029            .resume("sess-abc")
1030            .run(&provider)
1031            .await
1032            .unwrap();
1033
1034        let config = result.value();
1035        assert_eq!(config["resume_session_id"], json!("sess-abc"));
1036    }
1037
1038    #[tokio::test]
1039    async fn no_resume_has_null_session_id() {
1040        let provider = ConfigCapture {
1041            output: default_output(),
1042        };
1043        let result = Agent::new()
1044            .prompt("first call")
1045            .run(&provider)
1046            .await
1047            .unwrap();
1048
1049        let config = result.value();
1050        assert_eq!(config["resume_session_id"], json!(null));
1051    }
1052
1053    #[test]
1054    #[should_panic(expected = "session_id must not be empty")]
1055    fn resume_empty_session_id_panics() {
1056        let _ = Agent::new().resume("");
1057    }
1058
1059    #[test]
1060    #[should_panic(expected = "session_id must only contain")]
1061    fn resume_invalid_chars_panics() {
1062        let _ = Agent::new().resume("sess;rm -rf /");
1063    }
1064
1065    #[test]
1066    fn resume_valid_formats_accepted() {
1067        let _ = Agent::new().resume("sess-abc123");
1068        let _ = Agent::new().resume("a1b2c3d4_session");
1069        let _ = Agent::new().resume("abc-DEF-123_456");
1070    }
1071
1072    #[tokio::test]
1073    #[should_panic(expected = "prompt must not be empty")]
1074    async fn run_without_prompt_panics() {
1075        let provider = TestProvider {
1076            output: default_output(),
1077        };
1078        let _ = Agent::new().run(&provider).await;
1079    }
1080
1081    #[tokio::test]
1082    #[should_panic(expected = "prompt must not be empty")]
1083    async fn run_with_whitespace_only_prompt_panics() {
1084        let provider = TestProvider {
1085            output: default_output(),
1086        };
1087        let _ = Agent::new().prompt("   ").run(&provider).await;
1088    }
1089
1090    // --- Model accepts arbitrary strings ---
1091
1092    #[tokio::test]
1093    async fn model_accepts_custom_string() {
1094        let provider = ConfigCapture {
1095            output: default_output(),
1096        };
1097        let result = Agent::new()
1098            .prompt("hi")
1099            .model("mistral-large-latest")
1100            .run(&provider)
1101            .await
1102            .unwrap();
1103        assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1104    }
1105
1106    #[tokio::test]
1107    async fn verbose_sets_config_flag() {
1108        let provider = ConfigCapture {
1109            output: default_output(),
1110        };
1111        let result = Agent::new()
1112            .prompt("hi")
1113            .verbose()
1114            .run(&provider)
1115            .await
1116            .unwrap();
1117        assert_eq!(result.value()["verbose"], json!(true));
1118    }
1119
1120    #[tokio::test]
1121    async fn verbose_not_set_by_default() {
1122        let provider = ConfigCapture {
1123            output: default_output(),
1124        };
1125        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1126        assert_eq!(result.value()["verbose"], json!(false));
1127    }
1128
1129    #[tokio::test]
1130    async fn debug_messages_none_without_verbose() {
1131        let provider = TestProvider {
1132            output: default_output(),
1133        };
1134        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1135        assert!(result.debug_messages().is_none());
1136    }
1137
1138    #[tokio::test]
1139    async fn model_accepts_owned_string() {
1140        let provider = ConfigCapture {
1141            output: default_output(),
1142        };
1143        let model_name = String::from("gpt-4o");
1144        let result = Agent::new()
1145            .prompt("hi")
1146            .model(model_name)
1147            .run(&provider)
1148            .await
1149            .unwrap();
1150        assert_eq!(result.value()["model"], json!("gpt-4o"));
1151    }
1152
1153    #[tokio::test]
1154    async fn into_json_success() {
1155        #[derive(Deserialize, PartialEq, Debug)]
1156        struct Out {
1157            name: String,
1158        }
1159        let provider = TestProvider {
1160            output: AgentOutput {
1161                value: json!({"name": "test"}),
1162                ..default_output()
1163            },
1164        };
1165        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1166        let parsed: Out = result.into_json().unwrap();
1167        assert_eq!(parsed.name, "test");
1168    }
1169
1170    #[tokio::test]
1171    async fn into_json_failure() {
1172        #[derive(Debug, Deserialize)]
1173        #[allow(dead_code)]
1174        struct Out {
1175            name: String,
1176        }
1177        let provider = TestProvider {
1178            output: AgentOutput {
1179                value: json!(42),
1180                ..default_output()
1181            },
1182        };
1183        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1184        let err = result.into_json::<Out>().unwrap_err();
1185        assert!(matches!(err, OperationError::Deserialize { .. }));
1186    }
1187
1188    #[test]
1189    fn from_output_creates_result() {
1190        let output = AgentOutput {
1191            value: json!("hello"),
1192            ..default_output()
1193        };
1194        let result = AgentResult::from_output(output);
1195        assert_eq!(result.text(), "hello");
1196        assert_eq!(result.cost_usd(), Some(0.05));
1197    }
1198
1199    #[test]
1200    #[should_panic(expected = "budget must be a positive finite number")]
1201    fn max_budget_zero_panics() {
1202        let _ = Agent::new().max_budget_usd(0.0);
1203    }
1204
1205    #[test]
1206    fn model_constant_equality() {
1207        assert_eq!(Model::SONNET, "sonnet");
1208        assert_ne!(Model::SONNET, Model::OPUS);
1209    }
1210
1211    #[test]
1212    fn permission_mode_serialize_deserialize_roundtrip() {
1213        for mode in [
1214            PermissionMode::Default,
1215            PermissionMode::Auto,
1216            PermissionMode::DontAsk,
1217            PermissionMode::BypassPermissions,
1218        ] {
1219            let json = to_string(&mode).unwrap();
1220            let back: PermissionMode = serde_json::from_str(&json).unwrap();
1221            assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1222        }
1223    }
1224
1225    // --- Retry builder ---
1226
1227    #[test]
1228    fn retry_builder_stores_policy() {
1229        let agent = Agent::new().retry(3);
1230        assert!(agent.retry_policy.is_some());
1231        assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1232    }
1233
1234    #[test]
1235    fn retry_policy_builder_stores_custom_policy() {
1236        use crate::retry::RetryPolicy;
1237        let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1238        let agent = Agent::new().retry_policy(policy);
1239        let p = agent.retry_policy.unwrap();
1240        assert_eq!(p.max_retries(), 5);
1241    }
1242
1243    #[test]
1244    fn no_retry_by_default() {
1245        let agent = Agent::new();
1246        assert!(agent.retry_policy.is_none());
1247    }
1248
1249    // --- Retry behavior ---
1250
1251    use std::sync::Arc;
1252    use std::sync::atomic::{AtomicU32, Ordering};
1253    use std::time::Duration;
1254
1255    struct FailNTimesProvider {
1256        fail_count: AtomicU32,
1257        failures_before_success: u32,
1258        output: AgentOutput,
1259    }
1260
1261    impl AgentProvider for FailNTimesProvider {
1262        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1263            Box::pin(async move {
1264                let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1265                if current < self.failures_before_success {
1266                    Err(AgentError::ProcessFailed {
1267                        exit_code: 1,
1268                        stderr: format!("transient failure #{}", current + 1),
1269                    })
1270                } else {
1271                    Ok(AgentOutput {
1272                        value: self.output.value.clone(),
1273                        session_id: self.output.session_id.clone(),
1274                        cost_usd: self.output.cost_usd,
1275                        input_tokens: self.output.input_tokens,
1276                        output_tokens: self.output.output_tokens,
1277                        model: self.output.model.clone(),
1278                        duration_ms: self.output.duration_ms,
1279                        debug_messages: None,
1280                    })
1281                }
1282            })
1283        }
1284    }
1285
1286    #[tokio::test]
1287    async fn retry_succeeds_after_transient_failures() {
1288        let provider = FailNTimesProvider {
1289            fail_count: AtomicU32::new(0),
1290            failures_before_success: 2,
1291            output: default_output(),
1292        };
1293        let result = Agent::new()
1294            .prompt("test")
1295            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1296            .run(&provider)
1297            .await;
1298
1299        assert!(result.is_ok());
1300        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
1301    }
1302
1303    #[tokio::test]
1304    async fn retry_exhausted_returns_last_error() {
1305        let provider = FailNTimesProvider {
1306            fail_count: AtomicU32::new(0),
1307            failures_before_success: 10, // always fails
1308            output: default_output(),
1309        };
1310        let result = Agent::new()
1311            .prompt("test")
1312            .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1313            .run(&provider)
1314            .await;
1315
1316        assert!(result.is_err());
1317        // 1 initial + 2 retries = 3 total
1318        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1319    }
1320
1321    #[tokio::test]
1322    async fn retry_does_not_retry_non_retryable_errors() {
1323        let call_count = Arc::new(AtomicU32::new(0));
1324        let count = call_count.clone();
1325
1326        struct CountingNonRetryable {
1327            count: Arc<AtomicU32>,
1328        }
1329        impl AgentProvider for CountingNonRetryable {
1330            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1331                self.count.fetch_add(1, Ordering::SeqCst);
1332                Box::pin(async move {
1333                    Err(AgentError::SchemaValidation {
1334                        expected: "object".to_string(),
1335                        got: "string".to_string(),
1336                    })
1337                })
1338            }
1339        }
1340
1341        let provider = CountingNonRetryable { count };
1342        let result = Agent::new()
1343            .prompt("test")
1344            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1345            .run(&provider)
1346            .await;
1347
1348        assert!(result.is_err());
1349        // Only 1 attempt, no retries for non-retryable errors
1350        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1351    }
1352
1353    #[tokio::test]
1354    async fn no_retry_without_policy() {
1355        let provider = FailNTimesProvider {
1356            fail_count: AtomicU32::new(0),
1357            failures_before_success: 1,
1358            output: default_output(),
1359        };
1360        let result = Agent::new().prompt("test").run(&provider).await;
1361
1362        assert!(result.is_err());
1363        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1364    }
1365}