Skip to main content

ironflow_core/operations/
agent.rs

1//! Agent operation - build and execute AI agent calls.
2//!
3//! The [`Agent`] builder lets you configure a single agent invocation (model,
4//! prompt, tools, budget, permissions, etc.) and execute it through any
5//! [`AgentProvider`]. The result is an [`AgentResult`] that provides typed
6//! access to the agent's response, session metadata, and usage statistics.
7//!
8//! # Examples
9//!
10//! ```no_run
11//! use ironflow_core::prelude::*;
12//!
13//! # async fn example() -> Result<(), OperationError> {
14//! let provider = ClaudeCodeProvider::new();
15//!
16//! let result = Agent::new()
17//!     .prompt("Summarize the README.md file")
18//!     .model(Model::SONNET)
19//!     .max_turns(3)
20//!     .run(&provider)
21//!     .await?;
22//!
23//! println!("{}", result.text());
24//! # Ok(())
25//! # }
26//! ```
27
28use std::any;
29use std::sync::Arc;
30
31use schemars::{JsonSchema, schema_for};
32use serde::de::DeserializeOwned;
33use serde::{Deserialize, Serialize};
34use serde_json::{Value, from_value, to_string};
35use tokio::time;
36use tracing::{info, warn};
37
38use crate::error::OperationError;
39#[cfg(feature = "prometheus")]
40use crate::metric_names;
41use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage, LogSink};
42use crate::retry::RetryPolicy;
43
44/// Provider-agnostic model identifiers.
45///
46/// Constants are provided for well-known Claude models, but any string
47/// is accepted - custom [`AgentProvider`] implementations interpret the
48/// model identifier however they wish.
49///
50/// # Examples
51///
52/// ```no_run
53/// use ironflow_core::prelude::*;
54///
55/// # async fn example() -> Result<(), OperationError> {
56/// let provider = ClaudeCodeProvider::new();
57///
58/// // Using a built-in constant
59/// let r = Agent::new()
60///     .prompt("hi")
61///     .model(Model::SONNET)
62///     .run(&provider)
63///     .await?;
64///
65/// // Using a custom model string
66/// let r = Agent::new()
67///     .prompt("hi")
68///     .model("mistral-large-latest")
69///     .run(&provider)
70///     .await?;
71/// # Ok(())
72/// # }
73/// ```
74pub struct Model;
75
76impl Model {
77    // ── Aliases (latest version, CLI resolves to current) ───────────
78
79    /// Claude Sonnet - balanced speed and capability (default).
80    pub const SONNET: &str = "sonnet";
81    /// Claude Opus - highest capability.
82    pub const OPUS: &str = "opus";
83    /// Claude Haiku - fastest and cheapest.
84    pub const HAIKU: &str = "haiku";
85
86    // ── Claude 4.5 ─────────────────────────────────────────────────
87
88    /// Claude Haiku 4.5.
89    pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
90
91    // ── Claude 4.6 - 200K context ──────────────────────────────────
92
93    /// Claude Sonnet 4.6.
94    pub const SONNET_46: &str = "claude-sonnet-4-6";
95    /// Claude Opus 4.6.
96    pub const OPUS_46: &str = "claude-opus-4-6";
97
98    // ── Claude 4.6 - 1M context ────────────────────────────────────
99
100    /// Claude Sonnet 4.6 with 1M token context window.
101    pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
102    /// Claude Opus 4.6 with 1M token context window.
103    pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
104
105    // ── Claude 4.7 - 1M context native ─────────────────────────────
106
107    /// Claude Opus 4.7 - previous flagship, 1M token context native.
108    pub const OPUS_47: &str = "claude-opus-4-7";
109    /// Claude Opus 4.7 with 1M token context window explicit.
110    pub const OPUS_47_1M: &str = "claude-opus-4-7[1m]";
111
112    // ── Claude 4.8 - 1M context native ─────────────────────────────
113
114    /// Claude Opus 4.8 - latest flagship, 1M token context native.
115    pub const OPUS_48: &str = "claude-opus-4-8";
116    /// Claude Opus 4.8 with 1M token context window explicit.
117    pub const OPUS_48_1M: &str = "claude-opus-4-8[1m]";
118}
119
120/// Controls how the agent handles tool-use permission prompts.
121///
122/// These map to the `--permission-mode` and `--dangerously-skip-permissions`
123/// flags in the Claude CLI.
124#[derive(Debug, Default, Clone, Copy, Serialize)]
125pub enum PermissionMode {
126    /// Use the CLI default permission behavior.
127    #[default]
128    Default,
129    /// Automatically approve tool-use requests.
130    Auto,
131    /// Suppress all permission prompts (the agent proceeds without asking).
132    DontAsk,
133    /// Skip all permission checks entirely.
134    ///
135    /// **Warning**: the agent will have unrestricted filesystem and shell access.
136    BypassPermissions,
137}
138
139impl<'de> Deserialize<'de> for PermissionMode {
140    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
141    where
142        D: serde::Deserializer<'de>,
143    {
144        let s = String::deserialize(deserializer)?;
145        Ok(match s.to_lowercase().replace('_', "").as_str() {
146            "auto" => Self::Auto,
147            "dontask" => Self::DontAsk,
148            "bypass" | "bypasspermissions" => Self::BypassPermissions,
149            _ => Self::Default,
150        })
151    }
152}
153
154/// Builder for a single agent invocation.
155///
156/// Create with [`Agent::new`], chain configuration methods, then call
157/// [`run`](Agent::run) with an [`AgentProvider`] to execute.
158///
159/// # Examples
160///
161/// ```no_run
162/// use ironflow_core::prelude::*;
163///
164/// # async fn example() -> Result<(), OperationError> {
165/// let provider = ClaudeCodeProvider::new();
166///
167/// let result = Agent::new()
168///     .system_prompt("You are a Rust expert.")
169///     .prompt("Review this code for safety issues.")
170///     .model(Model::OPUS)
171///     .allowed_tools(&["Read", "Grep"])
172///     .max_turns(5)
173///     .max_budget_usd(0.50)
174///     .working_dir("/tmp/project")
175///     .permission_mode(PermissionMode::Auto)
176///     .run(&provider)
177///     .await?;
178///
179/// println!("Cost: ${:.4}", result.cost_usd().unwrap_or(0.0));
180/// # Ok(())
181/// # }
182/// ```
183#[must_use = "an Agent does nothing until .run() is awaited"]
184pub struct Agent {
185    config: AgentConfig,
186    dry_run: Option<bool>,
187    retry_policy: Option<RetryPolicy>,
188    log_sink: Option<Arc<dyn LogSink>>,
189}
190
191impl Agent {
192    /// Create a new agent builder with default settings.
193    ///
194    /// Defaults: [`Model::SONNET`], no system prompt, no tool restrictions,
195    /// no budget/turn limits, [`PermissionMode::Default`].
196    pub fn new() -> Self {
197        Self {
198            config: AgentConfig::new(""),
199            dry_run: None,
200            retry_policy: None,
201            log_sink: None,
202        }
203    }
204
205    /// Create an agent builder from an existing [`AgentConfig`].
206    ///
207    /// Useful when the config comes from a serialized workflow definition
208    /// rather than being built programmatically.
209    ///
210    /// # Examples
211    ///
212    /// ```no_run
213    /// use ironflow_core::prelude::*;
214    /// use ironflow_core::provider::AgentConfig;
215    ///
216    /// # async fn example() -> Result<(), OperationError> {
217    /// let provider = ClaudeCodeProvider::new();
218    /// let config = AgentConfig::new("Summarize the README");
219    /// let result = Agent::from_config(config).run(&provider).await?;
220    /// # Ok(())
221    /// # }
222    /// ```
223    pub fn from_config(config: impl Into<AgentConfig>) -> Self {
224        Self {
225            config: config.into(),
226            dry_run: None,
227            retry_policy: None,
228            log_sink: None,
229        }
230    }
231
232    /// Set the system prompt that defines the agent's persona or constraints.
233    pub fn system_prompt(mut self, prompt: &str) -> Self {
234        self.config.system_prompt = Some(prompt.to_string());
235        self
236    }
237
238    /// Set the user prompt - the main instruction sent to the agent.
239    pub fn prompt(mut self, prompt: &str) -> Self {
240        self.config.prompt = prompt.to_string();
241        self
242    }
243
244    /// Set the model to use for this invocation.
245    ///
246    /// Accepts any string-like value. Use [`Model`] constants for well-known
247    /// Claude models, or pass an arbitrary string for custom providers.
248    ///
249    /// Defaults to [`Model::SONNET`] if not called.
250    pub fn model(mut self, model: impl Into<String>) -> Self {
251        self.config.model = model.into();
252        self
253    }
254
255    /// Restrict which tools the agent may invoke.
256    ///
257    /// Pass an empty slice (or do not call this method) to allow the provider
258    /// default set of tools.
259    pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
260        self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
261        self
262    }
263
264    /// Set the maximum number of agentic turns.
265    ///
266    /// # Panics
267    ///
268    /// Panics if `turns` is `0`.
269    pub fn max_turns(mut self, turns: u32) -> Self {
270        assert!(turns > 0, "max_turns must be greater than 0");
271        self.config.max_turns = Some(turns);
272        self
273    }
274
275    /// Set the maximum spend in USD for this invocation.
276    ///
277    /// # Panics
278    ///
279    /// Panics if `budget` is negative, NaN, or infinity.
280    pub fn max_budget_usd(mut self, budget: f64) -> Self {
281        assert!(
282            budget.is_finite() && budget > 0.0,
283            "budget must be a positive finite number, got {budget}"
284        );
285        self.config.max_budget_usd = Some(budget);
286        self
287    }
288
289    /// Set the working directory for the agent process.
290    pub fn working_dir(mut self, dir: &str) -> Self {
291        self.config.working_dir = Some(dir.to_string());
292        self
293    }
294
295    /// Set the path to an MCP (Model Context Protocol) server configuration file.
296    pub fn mcp_config(mut self, config: &str) -> Self {
297        self.config.mcp_config = Some(config.to_string());
298        self
299    }
300
301    /// Set the permission mode controlling tool-use approval behavior.
302    ///
303    /// See [`PermissionMode`] for details on each variant.
304    pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
305        self.config.permission_mode = mode;
306        self
307    }
308
309    /// Request structured (typed) output from the agent.
310    ///
311    /// The type `T` must implement [`JsonSchema`]. The generated schema is sent
312    /// to the provider so the model returns JSON conforming to `T`, which can
313    /// then be deserialized with [`AgentResult::json`].
314    ///
315    /// # Examples
316    ///
317    /// ```no_run
318    /// use ironflow_core::prelude::*;
319    ///
320    /// #[derive(Deserialize, JsonSchema)]
321    /// struct Review {
322    ///     score: u8,
323    ///     summary: String,
324    /// }
325    ///
326    /// # async fn example() -> Result<(), OperationError> {
327    /// let provider = ClaudeCodeProvider::new();
328    /// let result = Agent::new()
329    ///     .prompt("Review the codebase")
330    ///     .output::<Review>()
331    ///     .run(&provider)
332    ///     .await?;
333    ///
334    /// let review: Review = result.json().expect("schema-validated output");
335    /// println!("Score: {}/10 - {}", review.score, review.summary);
336    /// # Ok(())
337    /// # }
338    /// ```
339    pub fn output<T: JsonSchema>(mut self) -> Self {
340        let schema = schema_for!(T);
341        self.config.json_schema = match to_string(&schema) {
342            Ok(s) => Some(s),
343            Err(e) => {
344                warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
345                None
346            }
347        };
348        self
349    }
350
351    /// Set structured output from a pre-serialized JSON Schema string.
352    ///
353    /// Use this when the schema comes from configuration or another source
354    /// rather than a Rust type. For type-safe schema generation, prefer
355    /// [`output`](Agent::output).
356    ///
357    /// **Important:** structured output requires `max_turns >= 2`. The Claude CLI
358    /// uses the first turn for reasoning and a second turn to produce the
359    /// schema-conforming JSON.
360    ///
361    /// # Examples
362    ///
363    /// ```no_run
364    /// use ironflow_core::prelude::*;
365    ///
366    /// # async fn example() -> Result<(), OperationError> {
367    /// let schema = r#"{"type":"object","properties":{"labels":{"type":"array","items":{"type":"string"}}}}"#;
368    /// let agent = Agent::new()
369    ///     .prompt("Classify this email")
370    ///     .output_schema_raw(schema);
371    /// # Ok(())
372    /// # }
373    /// ```
374    pub fn output_schema_raw(mut self, schema: &str) -> Self {
375        self.config.json_schema = Some(schema.to_string());
376        self
377    }
378
379    /// Retry the agent invocation up to `max_retries` times on transient failures.
380    ///
381    /// Uses default exponential backoff settings (200ms initial, 2x multiplier,
382    /// 30s cap). For custom backoff parameters, use [`retry_policy`](Agent::retry_policy).
383    ///
384    /// Only transient errors are retried: process failures and timeouts.
385    /// Deterministic errors (prompt too large, schema validation) are never retried.
386    ///
387    /// # Panics
388    ///
389    /// Panics if `max_retries` is `0`.
390    ///
391    /// # Examples
392    ///
393    /// ```no_run
394    /// use ironflow_core::prelude::*;
395    ///
396    /// # async fn example() -> Result<(), OperationError> {
397    /// let provider = ClaudeCodeProvider::new();
398    /// let result = Agent::new()
399    ///     .prompt("Summarize the codebase")
400    ///     .retry(2)
401    ///     .run(&provider)
402    ///     .await?;
403    /// # Ok(())
404    /// # }
405    /// ```
406    pub fn retry(mut self, max_retries: u32) -> Self {
407        self.retry_policy = Some(RetryPolicy::new(max_retries));
408        self
409    }
410
411    /// Set a custom [`RetryPolicy`] for this agent invocation.
412    ///
413    /// Allows full control over backoff duration, multiplier, and max delay.
414    /// See [`RetryPolicy`] for details.
415    ///
416    /// # Examples
417    ///
418    /// ```no_run
419    /// use std::time::Duration;
420    /// use ironflow_core::prelude::*;
421    /// use ironflow_core::retry::RetryPolicy;
422    ///
423    /// # async fn example() -> Result<(), OperationError> {
424    /// let provider = ClaudeCodeProvider::new();
425    /// let result = Agent::new()
426    ///     .prompt("Analyze the code")
427    ///     .retry_policy(
428    ///         RetryPolicy::new(3)
429    ///             .backoff(Duration::from_secs(1))
430    ///             .max_backoff(Duration::from_secs(60))
431    ///     )
432    ///     .run(&provider)
433    ///     .await?;
434    /// # Ok(())
435    /// # }
436    /// ```
437    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
438        self.retry_policy = Some(policy);
439        self
440    }
441
442    /// Enable or disable dry-run mode for this specific operation.
443    ///
444    /// When dry-run is active, the agent call is logged but not executed.
445    /// A synthetic [`AgentResult`] is returned with a placeholder text,
446    /// zero cost, and zero tokens.
447    ///
448    /// If not set, falls back to the global dry-run setting
449    /// (see [`set_dry_run`](crate::dry_run::set_dry_run)).
450    pub fn dry_run(mut self, enabled: bool) -> Self {
451        self.dry_run = Some(enabled);
452        self
453    }
454
455    /// Attach a [`LogSink`] for real-time log streaming.
456    ///
457    /// When set, [`invoke_with_logs`](AgentProvider::invoke_with_logs) is called
458    /// instead of [`invoke`](AgentProvider::invoke), allowing providers that
459    /// support streaming to pipe output lines in real time.
460    ///
461    /// # Examples
462    ///
463    /// ```no_run
464    /// use std::sync::Arc;
465    /// use ironflow_core::prelude::*;
466    ///
467    /// # async fn example() -> Result<(), OperationError> {
468    /// # struct MySink;
469    /// # impl LogSink for MySink { fn log(&self, _: &str, _: &str) {} }
470    /// let provider = ClaudeCodeProvider::new();
471    /// let sink: Arc<dyn LogSink> = Arc::new(MySink);
472    ///
473    /// let result = Agent::new()
474    ///     .prompt("Analyze src/")
475    ///     .log_sink(sink)
476    ///     .run(&provider)
477    ///     .await?;
478    /// # Ok(())
479    /// # }
480    /// ```
481    pub fn log_sink(mut self, sink: Arc<dyn LogSink>) -> Self {
482        self.log_sink = Some(sink);
483        self
484    }
485
486    /// Enable verbose/debug mode to capture the full conversation trace.
487    ///
488    /// When enabled, the provider captures every assistant message and tool
489    /// call into [`AgentResult::debug_messages`]. Useful for understanding
490    /// why the agent returned an unexpected result.
491    ///
492    /// # Examples
493    ///
494    /// ```no_run
495    /// use ironflow_core::prelude::*;
496    ///
497    /// # async fn example() -> Result<(), OperationError> {
498    /// let provider = ClaudeCodeProvider::new();
499    ///
500    /// let result = Agent::new()
501    ///     .prompt("Analyze src/")
502    ///     .verbose()
503    ///     .max_budget_usd(0.10)
504    ///     .run(&provider)
505    ///     .await?;
506    ///
507    /// if let Some(messages) = result.debug_messages() {
508    ///     for msg in messages {
509    ///         println!("{msg}");
510    ///     }
511    /// }
512    /// # Ok(())
513    /// # }
514    /// ```
515    pub fn verbose(mut self) -> Self {
516        self.config.verbose = true;
517        self
518    }
519
520    /// Resume a previous agent conversation by session ID.
521    ///
522    /// Pass the session ID from a previous [`AgentResult::session_id()`] to
523    /// continue the multi-turn conversation.
524    ///
525    /// # Examples
526    ///
527    /// ```no_run
528    /// use ironflow_core::prelude::*;
529    ///
530    /// # async fn example() -> Result<(), OperationError> {
531    /// let provider = ClaudeCodeProvider::new();
532    ///
533    /// let first = Agent::new()
534    ///     .prompt("Analyze the src/ directory")
535    ///     .max_budget_usd(0.10)
536    ///     .run(&provider)
537    ///     .await?;
538    ///
539    /// let session = first.session_id().expect("provider returned session ID");
540    ///
541    /// let followup = Agent::new()
542    ///     .prompt("Now suggest improvements")
543    ///     .resume(session)
544    ///     .max_budget_usd(0.10)
545    ///     .run(&provider)
546    ///     .await?;
547    /// # Ok(())
548    /// # }
549    /// ```
550    ///
551    /// # Panics
552    ///
553    /// Panics if `session_id` is empty or contains characters other than
554    /// alphanumerics, hyphens, and underscores.
555    pub fn resume(mut self, session_id: &str) -> Self {
556        assert!(!session_id.is_empty(), "session_id must not be empty");
557        assert!(
558            session_id
559                .chars()
560                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
561            "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
562        );
563        self.config.resume_session_id = Some(session_id.to_string());
564        self
565    }
566
567    /// Execute the agent invocation using the given [`AgentProvider`].
568    ///
569    /// If a [`retry_policy`](Agent::retry_policy) is configured, transient
570    /// failures (process crashes, timeouts, schema validation) are retried
571    /// with exponential backoff. When structured output is requested
572    /// (`json_schema` is set) and no explicit retry policy is configured,
573    /// an automatic retry policy of 2 retries is applied to handle
574    /// non-deterministic `structured_output: null` responses from the CLI.
575    /// Deterministic errors (prompt too large) are returned immediately
576    /// without retry.
577    ///
578    /// # Errors
579    ///
580    /// Returns [`OperationError::Agent`] if the provider reports a failure
581    /// (process crash, timeout, or schema validation error).
582    ///
583    /// # Panics
584    ///
585    /// Panics if [`prompt`](Agent::prompt) was never called or the prompt is
586    /// empty (whitespace-only counts as empty).
587    #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
588    pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
589        assert!(
590            !self.config.prompt.trim().is_empty(),
591            "prompt must not be empty - call .prompt(\"...\") before .run()"
592        );
593
594        if crate::dry_run::effective_dry_run(self.dry_run) {
595            info!(
596                prompt_len = self.config.prompt.len(),
597                "[dry-run] agent call skipped"
598            );
599            let mut output =
600                AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
601            output.cost_usd = Some(0.0);
602            output.input_tokens = Some(0);
603            output.output_tokens = Some(0);
604            return Ok(AgentResult { output });
605        }
606
607        let result = self.invoke_once(provider).await;
608
609        let default_schema_retry = RetryPolicy::new(2);
610        let policy = match &self.retry_policy {
611            Some(p) => p,
612            None if self.config.json_schema.is_some() => &default_schema_retry,
613            None => return result,
614        };
615
616        // Non-retryable errors are returned immediately.
617        if let Err(ref err) = result {
618            if !crate::retry::is_retryable(err) {
619                return result;
620            }
621        } else {
622            return result;
623        }
624
625        let mut last_result = result;
626
627        for attempt in 0..policy.max_retries {
628            let delay = policy.delay_for_attempt(attempt);
629            let retry_reason = if matches!(
630                &last_result,
631                Err(OperationError::Agent(
632                    crate::error::AgentError::SchemaValidation { .. }
633                ))
634            ) {
635                "structured_output was null (CLI non-determinism)"
636            } else {
637                "transient failure"
638            };
639            warn!(
640                attempt = attempt + 1,
641                max_retries = policy.max_retries,
642                delay_ms = delay.as_millis() as u64,
643                reason = retry_reason,
644                "retrying agent invocation"
645            );
646            time::sleep(delay).await;
647
648            last_result = self.invoke_once(provider).await;
649
650            match &last_result {
651                Ok(_) => return last_result,
652                Err(err) if !crate::retry::is_retryable(err) => return last_result,
653                _ => {}
654            }
655        }
656
657        last_result
658    }
659
660    /// Execute a single agent invocation attempt (no retry logic).
661    async fn invoke_once(
662        &self,
663        provider: &dyn AgentProvider,
664    ) -> Result<AgentResult, OperationError> {
665        #[cfg(feature = "prometheus")]
666        let model_label = self.config.model.to_string();
667
668        let invoke_result = match self.log_sink {
669            Some(ref sink) => provider.invoke_with_logs(&self.config, sink.clone()).await,
670            None => provider.invoke(&self.config).await,
671        };
672        let output = match invoke_result {
673            Ok(output) => output,
674            Err(e) => {
675                #[cfg(feature = "prometheus")]
676                {
677                    metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
678                }
679                return Err(OperationError::Agent(e));
680            }
681        };
682
683        info!(
684            duration_ms = output.duration_ms,
685            cost_usd = output.cost_usd,
686            input_tokens = output.input_tokens,
687            output_tokens = output.output_tokens,
688            model = output.model,
689            "agent completed"
690        );
691
692        #[cfg(feature = "prometheus")]
693        {
694            metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
695            metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
696                .record(output.duration_ms as f64 / 1000.0);
697            if let Some(cost) = output.cost_usd {
698                metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
699                    .increment(cost);
700            }
701            if let Some(tokens) = output.input_tokens {
702                metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
703            }
704            if let Some(tokens) = output.output_tokens {
705                metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
706                    .increment(tokens);
707            }
708        }
709
710        Ok(AgentResult { output })
711    }
712}
713
714impl Default for Agent {
715    fn default() -> Self {
716        Self::new()
717    }
718}
719
720/// The result of a successful agent invocation.
721///
722/// Wraps the raw [`AgentOutput`] and provides convenience accessors for the
723/// response text, typed JSON deserialization, session metadata, and usage stats.
724#[derive(Debug)]
725pub struct AgentResult {
726    output: AgentOutput,
727}
728
729impl AgentResult {
730    /// Return the agent's response as a plain text string.
731    ///
732    /// If the underlying value is not a JSON string (e.g. when structured output
733    /// was requested), returns an empty string and logs a warning.
734    pub fn text(&self) -> &str {
735        match self.output.value.as_str() {
736            Some(s) => s,
737            None => {
738                warn!(
739                    value_type = self.output.value.to_string(),
740                    "agent output is not a string, returning empty"
741                );
742                ""
743            }
744        }
745    }
746
747    /// Return the raw JSON [`Value`] of the agent's response.
748    pub fn value(&self) -> &Value {
749        &self.output.value
750    }
751
752    /// Deserialize the agent's response into the given type `T`.
753    ///
754    /// This clones the underlying JSON value. If you no longer need the
755    /// `AgentResult` afterwards, use [`into_json`](AgentResult::into_json)
756    /// instead to avoid the clone.
757    ///
758    /// # Errors
759    ///
760    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
761    pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
762        from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
763    }
764
765    /// Consume the result and deserialize the response into `T` without cloning.
766    ///
767    /// # Errors
768    ///
769    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
770    pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
771        from_value(self.output.value).map_err(OperationError::deserialize::<T>)
772    }
773
774    /// Build an `AgentResult` from a raw [`AgentOutput`].
775    ///
776    /// This is available only in test builds to simplify test setup without
777    /// going through the full record/replay pipeline.
778    #[cfg(test)]
779    pub(crate) fn from_output(output: AgentOutput) -> Self {
780        Self { output }
781    }
782
783    /// Return the provider-assigned session ID, if available.
784    pub fn session_id(&self) -> Option<&str> {
785        self.output.session_id.as_deref()
786    }
787
788    /// Return the cost of this invocation in USD, if reported by the provider.
789    pub fn cost_usd(&self) -> Option<f64> {
790        self.output.cost_usd
791    }
792
793    /// Return the number of input tokens consumed, if reported.
794    pub fn input_tokens(&self) -> Option<u64> {
795        self.output.input_tokens
796    }
797
798    /// Return the number of output tokens generated, if reported.
799    pub fn output_tokens(&self) -> Option<u64> {
800        self.output.output_tokens
801    }
802
803    /// Return the wall-clock duration of the invocation in milliseconds.
804    pub fn duration_ms(&self) -> u64 {
805        self.output.duration_ms
806    }
807
808    /// Return the concrete model identifier used, if reported by the provider.
809    pub fn model(&self) -> Option<&str> {
810        self.output.model.as_deref()
811    }
812
813    /// Return the conversation trace captured during a verbose invocation.
814    ///
815    /// Returns `None` when [`Agent::verbose`] was not called. When present,
816    /// each [`DebugMessage`] contains the
817    /// assistant's text and tool calls for one conversation turn.
818    pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
819        self.output.debug_messages.as_deref()
820    }
821}
822
823#[cfg(test)]
824mod tests {
825    use super::*;
826    use crate::error::AgentError;
827    use crate::provider::InvokeFuture;
828    use serde_json::json;
829
830    struct TestProvider {
831        output: AgentOutput,
832    }
833
834    impl AgentProvider for TestProvider {
835        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
836            Box::pin(async move {
837                Ok(AgentOutput {
838                    value: self.output.value.clone(),
839                    session_id: self.output.session_id.clone(),
840                    cost_usd: self.output.cost_usd,
841                    input_tokens: self.output.input_tokens,
842                    output_tokens: self.output.output_tokens,
843                    model: self.output.model.clone(),
844                    duration_ms: self.output.duration_ms,
845                    debug_messages: None,
846                })
847            })
848        }
849    }
850
851    struct ConfigCapture {
852        output: AgentOutput,
853    }
854
855    impl AgentProvider for ConfigCapture {
856        fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
857            let config_json = serde_json::to_value(config).unwrap();
858            Box::pin(async move {
859                Ok(AgentOutput {
860                    value: config_json,
861                    session_id: self.output.session_id.clone(),
862                    cost_usd: self.output.cost_usd,
863                    input_tokens: self.output.input_tokens,
864                    output_tokens: self.output.output_tokens,
865                    model: self.output.model.clone(),
866                    duration_ms: self.output.duration_ms,
867                    debug_messages: None,
868                })
869            })
870        }
871    }
872
873    fn default_output() -> AgentOutput {
874        AgentOutput {
875            value: json!("test output"),
876            session_id: Some("sess-123".to_string()),
877            cost_usd: Some(0.05),
878            input_tokens: Some(100),
879            output_tokens: Some(50),
880            model: Some("sonnet".to_string()),
881            duration_ms: 1500,
882            debug_messages: None,
883        }
884    }
885
886    // --- Model constants ---
887
888    #[test]
889    fn model_constants_have_expected_values() {
890        assert_eq!(Model::SONNET, "sonnet");
891        assert_eq!(Model::OPUS, "opus");
892        assert_eq!(Model::HAIKU, "haiku");
893        assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
894        assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
895        assert_eq!(Model::OPUS_46, "claude-opus-4-6");
896        assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
897        assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
898        assert_eq!(Model::OPUS_47, "claude-opus-4-7");
899        assert_eq!(Model::OPUS_47_1M, "claude-opus-4-7[1m]");
900        assert_eq!(Model::OPUS_48, "claude-opus-4-8");
901        assert_eq!(Model::OPUS_48_1M, "claude-opus-4-8[1m]");
902    }
903
904    // --- Agent::new() defaults via ConfigCapture ---
905
906    #[tokio::test]
907    async fn agent_new_default_values() {
908        let provider = ConfigCapture {
909            output: default_output(),
910        };
911        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
912
913        let config = result.value();
914        assert_eq!(config["system_prompt"], json!(null));
915        assert_eq!(config["prompt"], json!("hi"));
916        assert_eq!(config["model"], json!("sonnet"));
917        assert_eq!(config["allowed_tools"], json!([]));
918        assert_eq!(config["max_turns"], json!(null));
919        assert_eq!(config["max_budget_usd"], json!(null));
920        assert_eq!(config["working_dir"], json!(null));
921        assert_eq!(config["mcp_config"], json!(null));
922        assert_eq!(config["permission_mode"], json!("Default"));
923        assert_eq!(config["json_schema"], json!(null));
924    }
925
926    #[tokio::test]
927    async fn agent_default_matches_new() {
928        let provider = ConfigCapture {
929            output: default_output(),
930        };
931        let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
932        let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
933
934        assert_eq!(result_new.value(), result_default.value());
935    }
936
937    // --- Builder methods ---
938
939    #[tokio::test]
940    async fn builder_methods_store_values_correctly() {
941        let provider = ConfigCapture {
942            output: default_output(),
943        };
944        let result = Agent::new()
945            .system_prompt("you are a bot")
946            .prompt("do something")
947            .model(Model::OPUS)
948            .allowed_tools(&["Read", "Write"])
949            .max_turns(5)
950            .max_budget_usd(1.5)
951            .working_dir("/tmp")
952            .mcp_config("{}")
953            .permission_mode(PermissionMode::Auto)
954            .run(&provider)
955            .await
956            .unwrap();
957
958        let config = result.value();
959        assert_eq!(config["system_prompt"], json!("you are a bot"));
960        assert_eq!(config["prompt"], json!("do something"));
961        assert_eq!(config["model"], json!("opus"));
962        assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
963        assert_eq!(config["max_turns"], json!(5));
964        assert_eq!(config["max_budget_usd"], json!(1.5));
965        assert_eq!(config["working_dir"], json!("/tmp"));
966        assert_eq!(config["mcp_config"], json!("{}"));
967        assert_eq!(config["permission_mode"], json!("Auto"));
968    }
969
970    // --- Panics ---
971
972    #[test]
973    #[should_panic(expected = "max_turns must be greater than 0")]
974    fn max_turns_zero_panics() {
975        let _ = Agent::new().max_turns(0);
976    }
977
978    #[test]
979    #[should_panic(expected = "budget must be a positive finite number")]
980    fn max_budget_negative_panics() {
981        let _ = Agent::new().max_budget_usd(-1.0);
982    }
983
984    #[test]
985    #[should_panic(expected = "budget must be a positive finite number")]
986    fn max_budget_nan_panics() {
987        let _ = Agent::new().max_budget_usd(f64::NAN);
988    }
989
990    #[test]
991    #[should_panic(expected = "budget must be a positive finite number")]
992    fn max_budget_infinity_panics() {
993        let _ = Agent::new().max_budget_usd(f64::INFINITY);
994    }
995
996    // --- AgentResult accessors ---
997
998    #[tokio::test]
999    async fn agent_result_text_with_string_value() {
1000        let provider = TestProvider {
1001            output: AgentOutput {
1002                value: json!("hello world"),
1003                ..default_output()
1004            },
1005        };
1006        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1007        assert_eq!(result.text(), "hello world");
1008    }
1009
1010    #[tokio::test]
1011    async fn agent_result_text_with_non_string_value() {
1012        let provider = TestProvider {
1013            output: AgentOutput {
1014                value: json!(42),
1015                ..default_output()
1016            },
1017        };
1018        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1019        assert_eq!(result.text(), "");
1020    }
1021
1022    #[tokio::test]
1023    async fn agent_result_text_with_null_value() {
1024        let provider = TestProvider {
1025            output: AgentOutput {
1026                value: json!(null),
1027                ..default_output()
1028            },
1029        };
1030        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1031        assert_eq!(result.text(), "");
1032    }
1033
1034    #[tokio::test]
1035    async fn agent_result_json_successful_deserialize() {
1036        #[derive(Deserialize, PartialEq, Debug)]
1037        struct MyOutput {
1038            name: String,
1039            count: u32,
1040        }
1041        let provider = TestProvider {
1042            output: AgentOutput {
1043                value: json!({"name": "test", "count": 7}),
1044                ..default_output()
1045            },
1046        };
1047        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1048        let parsed: MyOutput = result.json().unwrap();
1049        assert_eq!(parsed.name, "test");
1050        assert_eq!(parsed.count, 7);
1051    }
1052
1053    #[tokio::test]
1054    async fn agent_result_json_failed_deserialize() {
1055        #[derive(Debug, Deserialize)]
1056        #[allow(dead_code)]
1057        struct MyOutput {
1058            name: String,
1059        }
1060        let provider = TestProvider {
1061            output: AgentOutput {
1062                value: json!(42),
1063                ..default_output()
1064            },
1065        };
1066        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1067        let err = result.json::<MyOutput>().unwrap_err();
1068        assert!(matches!(err, OperationError::Deserialize { .. }));
1069    }
1070
1071    #[tokio::test]
1072    async fn agent_result_accessors() {
1073        let provider = TestProvider {
1074            output: AgentOutput {
1075                value: json!("v"),
1076                session_id: Some("s-1".to_string()),
1077                cost_usd: Some(0.123),
1078                input_tokens: Some(999),
1079                output_tokens: Some(456),
1080                model: Some("opus".to_string()),
1081                duration_ms: 2000,
1082                debug_messages: None,
1083            },
1084        };
1085        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1086        assert_eq!(result.session_id(), Some("s-1"));
1087        assert_eq!(result.cost_usd(), Some(0.123));
1088        assert_eq!(result.input_tokens(), Some(999));
1089        assert_eq!(result.output_tokens(), Some(456));
1090        assert_eq!(result.duration_ms(), 2000);
1091        assert_eq!(result.model(), Some("opus"));
1092    }
1093
1094    // --- Session resume ---
1095
1096    #[tokio::test]
1097    async fn resume_passes_session_id_in_config() {
1098        let provider = ConfigCapture {
1099            output: default_output(),
1100        };
1101        let result = Agent::new()
1102            .prompt("followup")
1103            .resume("sess-abc")
1104            .run(&provider)
1105            .await
1106            .unwrap();
1107
1108        let config = result.value();
1109        assert_eq!(config["resume_session_id"], json!("sess-abc"));
1110    }
1111
1112    #[tokio::test]
1113    async fn no_resume_has_null_session_id() {
1114        let provider = ConfigCapture {
1115            output: default_output(),
1116        };
1117        let result = Agent::new()
1118            .prompt("first call")
1119            .run(&provider)
1120            .await
1121            .unwrap();
1122
1123        let config = result.value();
1124        assert_eq!(config["resume_session_id"], json!(null));
1125    }
1126
1127    #[test]
1128    #[should_panic(expected = "session_id must not be empty")]
1129    fn resume_empty_session_id_panics() {
1130        let _ = Agent::new().resume("");
1131    }
1132
1133    #[test]
1134    #[should_panic(expected = "session_id must only contain")]
1135    fn resume_invalid_chars_panics() {
1136        let _ = Agent::new().resume("sess;rm -rf /");
1137    }
1138
1139    #[test]
1140    fn resume_valid_formats_accepted() {
1141        let _ = Agent::new().resume("sess-abc123");
1142        let _ = Agent::new().resume("a1b2c3d4_session");
1143        let _ = Agent::new().resume("abc-DEF-123_456");
1144    }
1145
1146    #[tokio::test]
1147    #[should_panic(expected = "prompt must not be empty")]
1148    async fn run_without_prompt_panics() {
1149        let provider = TestProvider {
1150            output: default_output(),
1151        };
1152        let _ = Agent::new().run(&provider).await;
1153    }
1154
1155    #[tokio::test]
1156    #[should_panic(expected = "prompt must not be empty")]
1157    async fn run_with_whitespace_only_prompt_panics() {
1158        let provider = TestProvider {
1159            output: default_output(),
1160        };
1161        let _ = Agent::new().prompt("   ").run(&provider).await;
1162    }
1163
1164    // --- Model accepts arbitrary strings ---
1165
1166    #[tokio::test]
1167    async fn model_accepts_custom_string() {
1168        let provider = ConfigCapture {
1169            output: default_output(),
1170        };
1171        let result = Agent::new()
1172            .prompt("hi")
1173            .model("mistral-large-latest")
1174            .run(&provider)
1175            .await
1176            .unwrap();
1177        assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1178    }
1179
1180    #[tokio::test]
1181    async fn verbose_sets_config_flag() {
1182        let provider = ConfigCapture {
1183            output: default_output(),
1184        };
1185        let result = Agent::new()
1186            .prompt("hi")
1187            .verbose()
1188            .run(&provider)
1189            .await
1190            .unwrap();
1191        assert_eq!(result.value()["verbose"], json!(true));
1192    }
1193
1194    #[tokio::test]
1195    async fn verbose_not_set_by_default() {
1196        let provider = ConfigCapture {
1197            output: default_output(),
1198        };
1199        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1200        assert_eq!(result.value()["verbose"], json!(false));
1201    }
1202
1203    #[tokio::test]
1204    async fn debug_messages_none_without_verbose() {
1205        let provider = TestProvider {
1206            output: default_output(),
1207        };
1208        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1209        assert!(result.debug_messages().is_none());
1210    }
1211
1212    #[tokio::test]
1213    async fn model_accepts_owned_string() {
1214        let provider = ConfigCapture {
1215            output: default_output(),
1216        };
1217        let model_name = String::from("gpt-4o");
1218        let result = Agent::new()
1219            .prompt("hi")
1220            .model(model_name)
1221            .run(&provider)
1222            .await
1223            .unwrap();
1224        assert_eq!(result.value()["model"], json!("gpt-4o"));
1225    }
1226
1227    #[tokio::test]
1228    async fn into_json_success() {
1229        #[derive(Deserialize, PartialEq, Debug)]
1230        struct Out {
1231            name: String,
1232        }
1233        let provider = TestProvider {
1234            output: AgentOutput {
1235                value: json!({"name": "test"}),
1236                ..default_output()
1237            },
1238        };
1239        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1240        let parsed: Out = result.into_json().unwrap();
1241        assert_eq!(parsed.name, "test");
1242    }
1243
1244    #[tokio::test]
1245    async fn into_json_failure() {
1246        #[derive(Debug, Deserialize)]
1247        #[allow(dead_code)]
1248        struct Out {
1249            name: String,
1250        }
1251        let provider = TestProvider {
1252            output: AgentOutput {
1253                value: json!(42),
1254                ..default_output()
1255            },
1256        };
1257        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1258        let err = result.into_json::<Out>().unwrap_err();
1259        assert!(matches!(err, OperationError::Deserialize { .. }));
1260    }
1261
1262    #[test]
1263    fn from_output_creates_result() {
1264        let output = AgentOutput {
1265            value: json!("hello"),
1266            ..default_output()
1267        };
1268        let result = AgentResult::from_output(output);
1269        assert_eq!(result.text(), "hello");
1270        assert_eq!(result.cost_usd(), Some(0.05));
1271    }
1272
1273    #[test]
1274    #[should_panic(expected = "budget must be a positive finite number")]
1275    fn max_budget_zero_panics() {
1276        let _ = Agent::new().max_budget_usd(0.0);
1277    }
1278
1279    #[test]
1280    fn model_constant_equality() {
1281        assert_eq!(Model::SONNET, "sonnet");
1282        assert_ne!(Model::SONNET, Model::OPUS);
1283    }
1284
1285    #[test]
1286    fn permission_mode_serialize_deserialize_roundtrip() {
1287        for mode in [
1288            PermissionMode::Default,
1289            PermissionMode::Auto,
1290            PermissionMode::DontAsk,
1291            PermissionMode::BypassPermissions,
1292        ] {
1293            let json = to_string(&mode).unwrap();
1294            let back: PermissionMode = serde_json::from_str(&json).unwrap();
1295            assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1296        }
1297    }
1298
1299    // --- Retry builder ---
1300
1301    #[test]
1302    fn retry_builder_stores_policy() {
1303        let agent = Agent::new().retry(3);
1304        assert!(agent.retry_policy.is_some());
1305        assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1306    }
1307
1308    #[test]
1309    fn retry_policy_builder_stores_custom_policy() {
1310        use crate::retry::RetryPolicy;
1311        let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1312        let agent = Agent::new().retry_policy(policy);
1313        let p = agent.retry_policy.unwrap();
1314        assert_eq!(p.max_retries(), 5);
1315    }
1316
1317    #[test]
1318    fn no_retry_by_default() {
1319        let agent = Agent::new();
1320        assert!(agent.retry_policy.is_none());
1321    }
1322
1323    // --- Retry behavior ---
1324
1325    use std::sync::Arc;
1326    use std::sync::atomic::{AtomicU32, Ordering};
1327    use std::time::Duration;
1328
1329    struct FailNTimesProvider {
1330        fail_count: AtomicU32,
1331        failures_before_success: u32,
1332        output: AgentOutput,
1333    }
1334
1335    impl AgentProvider for FailNTimesProvider {
1336        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1337            Box::pin(async move {
1338                let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1339                if current < self.failures_before_success {
1340                    Err(AgentError::ProcessFailed {
1341                        exit_code: 1,
1342                        stderr: format!("transient failure #{}", current + 1),
1343                    })
1344                } else {
1345                    Ok(AgentOutput {
1346                        value: self.output.value.clone(),
1347                        session_id: self.output.session_id.clone(),
1348                        cost_usd: self.output.cost_usd,
1349                        input_tokens: self.output.input_tokens,
1350                        output_tokens: self.output.output_tokens,
1351                        model: self.output.model.clone(),
1352                        duration_ms: self.output.duration_ms,
1353                        debug_messages: None,
1354                    })
1355                }
1356            })
1357        }
1358    }
1359
1360    #[tokio::test]
1361    async fn retry_succeeds_after_transient_failures() {
1362        let provider = FailNTimesProvider {
1363            fail_count: AtomicU32::new(0),
1364            failures_before_success: 2,
1365            output: default_output(),
1366        };
1367        let result = Agent::new()
1368            .prompt("test")
1369            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1370            .run(&provider)
1371            .await;
1372
1373        assert!(result.is_ok());
1374        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
1375    }
1376
1377    #[tokio::test]
1378    async fn retry_exhausted_returns_last_error() {
1379        let provider = FailNTimesProvider {
1380            fail_count: AtomicU32::new(0),
1381            failures_before_success: 10, // always fails
1382            output: default_output(),
1383        };
1384        let result = Agent::new()
1385            .prompt("test")
1386            .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1387            .run(&provider)
1388            .await;
1389
1390        assert!(result.is_err());
1391        // 1 initial + 2 retries = 3 total
1392        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1393    }
1394
1395    #[tokio::test]
1396    async fn retry_does_not_retry_prompt_too_large() {
1397        let call_count = Arc::new(AtomicU32::new(0));
1398        let count = call_count.clone();
1399
1400        struct CountingNonRetryable {
1401            count: Arc<AtomicU32>,
1402        }
1403        impl AgentProvider for CountingNonRetryable {
1404            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1405                self.count.fetch_add(1, Ordering::SeqCst);
1406                Box::pin(async move {
1407                    Err(AgentError::PromptTooLarge {
1408                        chars: 1_000_000,
1409                        estimated_tokens: 250_000,
1410                        model_limit: 200_000,
1411                    })
1412                })
1413            }
1414        }
1415
1416        let provider = CountingNonRetryable { count };
1417        let result = Agent::new()
1418            .prompt("test")
1419            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1420            .run(&provider)
1421            .await;
1422
1423        assert!(result.is_err());
1424        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1425    }
1426
1427    #[tokio::test]
1428    async fn retry_retries_schema_validation_errors() {
1429        let call_count = Arc::new(AtomicU32::new(0));
1430        let count = call_count.clone();
1431
1432        struct SchemaFailProvider {
1433            count: Arc<AtomicU32>,
1434        }
1435        impl AgentProvider for SchemaFailProvider {
1436            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1437                self.count.fetch_add(1, Ordering::SeqCst);
1438                Box::pin(async move {
1439                    Err(AgentError::SchemaValidation {
1440                        expected: "object".to_string(),
1441                        got: "null".to_string(),
1442                        debug_messages: Vec::new(),
1443                        partial_usage: Box::default(),
1444                        raw_response: None,
1445                    })
1446                })
1447            }
1448        }
1449
1450        let provider = SchemaFailProvider { count };
1451        let result = Agent::new()
1452            .prompt("test")
1453            .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1454            .run(&provider)
1455            .await;
1456
1457        assert!(result.is_err());
1458        // 1 initial + 2 retries = 3 total
1459        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1460    }
1461
1462    #[tokio::test]
1463    async fn schema_validation_succeeds_on_retry() {
1464        let call_count = Arc::new(AtomicU32::new(0));
1465        let count = call_count.clone();
1466
1467        struct SchemaFailThenSucceed {
1468            count: Arc<AtomicU32>,
1469            output: AgentOutput,
1470        }
1471        impl AgentProvider for SchemaFailThenSucceed {
1472            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1473                let current = self.count.fetch_add(1, Ordering::SeqCst);
1474                let output = self.output.clone();
1475                Box::pin(async move {
1476                    if current == 0 {
1477                        Err(AgentError::SchemaValidation {
1478                            expected: "structured_output field".to_string(),
1479                            got: "null".to_string(),
1480                            debug_messages: Vec::new(),
1481                            partial_usage: Box::default(),
1482                            raw_response: None,
1483                        })
1484                    } else {
1485                        Ok(output)
1486                    }
1487                })
1488            }
1489        }
1490
1491        let provider = SchemaFailThenSucceed {
1492            count,
1493            output: default_output(),
1494        };
1495        let result = Agent::new()
1496            .prompt("test")
1497            .retry_policy(crate::retry::RetryPolicy::new(1).backoff(Duration::from_millis(1)))
1498            .run(&provider)
1499            .await;
1500
1501        assert!(result.is_ok());
1502        assert_eq!(call_count.load(Ordering::SeqCst), 2);
1503    }
1504
1505    #[tokio::test]
1506    async fn auto_retry_applied_when_json_schema_set() {
1507        let call_count = Arc::new(AtomicU32::new(0));
1508        let count = call_count.clone();
1509
1510        struct AlwaysSchemaFail {
1511            count: Arc<AtomicU32>,
1512        }
1513        impl AgentProvider for AlwaysSchemaFail {
1514            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1515                self.count.fetch_add(1, Ordering::SeqCst);
1516                Box::pin(async move {
1517                    Err(AgentError::SchemaValidation {
1518                        expected: "object".to_string(),
1519                        got: "null".to_string(),
1520                        debug_messages: Vec::new(),
1521                        partial_usage: Box::default(),
1522                        raw_response: None,
1523                    })
1524                })
1525            }
1526        }
1527
1528        let provider = AlwaysSchemaFail { count };
1529        let result = Agent::new()
1530            .prompt("test")
1531            .output_schema_raw(r#"{"type":"object"}"#)
1532            .run(&provider)
1533            .await;
1534
1535        assert!(result.is_err());
1536        // auto-retry(2) : 1 initial + 2 retries = 3 total
1537        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1538    }
1539
1540    #[tokio::test]
1541    async fn no_retry_without_policy() {
1542        let provider = FailNTimesProvider {
1543            fail_count: AtomicU32::new(0),
1544            failures_before_success: 1,
1545            output: default_output(),
1546        };
1547        let result = Agent::new().prompt("test").run(&provider).await;
1548
1549        assert!(result.is_err());
1550        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1551    }
1552
1553    // ── log_sink tests ────────────────────────────────────────────
1554
1555    use crate::test_support::VecSink;
1556
1557    struct SinkCapture {
1558        output: AgentOutput,
1559        saw_logs: Arc<AtomicU32>,
1560    }
1561
1562    impl AgentProvider for SinkCapture {
1563        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1564            Box::pin(async {
1565                Ok(AgentOutput {
1566                    value: self.output.value.clone(),
1567                    session_id: self.output.session_id.clone(),
1568                    cost_usd: self.output.cost_usd,
1569                    input_tokens: self.output.input_tokens,
1570                    output_tokens: self.output.output_tokens,
1571                    model: self.output.model.clone(),
1572                    duration_ms: self.output.duration_ms,
1573                    debug_messages: None,
1574                })
1575            })
1576        }
1577
1578        fn invoke_with_logs<'a>(
1579            &'a self,
1580            config: &'a AgentConfig,
1581            log_sink: Arc<dyn LogSink>,
1582        ) -> InvokeFuture<'a> {
1583            self.saw_logs.fetch_add(1, Ordering::SeqCst);
1584            log_sink.log("stdout", "streaming line");
1585            self.invoke(config)
1586        }
1587    }
1588
1589    #[tokio::test]
1590    async fn log_sink_routes_to_invoke_with_logs() {
1591        let saw_logs = Arc::new(AtomicU32::new(0));
1592        let provider = SinkCapture {
1593            output: default_output(),
1594            saw_logs: saw_logs.clone(),
1595        };
1596        let sink: Arc<dyn LogSink> = VecSink::new();
1597
1598        let result = Agent::new()
1599            .prompt("test")
1600            .log_sink(sink)
1601            .run(&provider)
1602            .await;
1603
1604        assert!(result.is_ok());
1605        assert_eq!(saw_logs.load(Ordering::SeqCst), 1);
1606    }
1607
1608    #[tokio::test]
1609    async fn no_log_sink_routes_to_invoke() {
1610        let saw_logs = Arc::new(AtomicU32::new(0));
1611        let provider = SinkCapture {
1612            output: default_output(),
1613            saw_logs: saw_logs.clone(),
1614        };
1615
1616        let result = Agent::new().prompt("test").run(&provider).await;
1617
1618        assert!(result.is_ok());
1619        assert_eq!(saw_logs.load(Ordering::SeqCst), 0);
1620    }
1621
1622    #[tokio::test]
1623    async fn log_sink_receives_provider_lines() {
1624        let saw_logs = Arc::new(AtomicU32::new(0));
1625        let provider = SinkCapture {
1626            output: default_output(),
1627            saw_logs: saw_logs.clone(),
1628        };
1629        let sink = VecSink::new();
1630
1631        let _ = Agent::new()
1632            .prompt("test")
1633            .log_sink(sink.clone() as Arc<dyn LogSink>)
1634            .run(&provider)
1635            .await;
1636
1637        let lines = sink.0.lock().unwrap();
1638        assert_eq!(lines.len(), 1);
1639        assert_eq!(lines[0].0, "stdout");
1640        assert_eq!(lines[0].1, "streaming line");
1641    }
1642}