Skip to main content

ironflow_core/operations/
agent.rs

1//! Agent operation - build and execute AI agent calls.
2//!
3//! The [`Agent`] builder lets you configure a single agent invocation (model,
4//! prompt, tools, budget, permissions, etc.) and execute it through any
5//! [`AgentProvider`]. The result is an [`AgentResult`] that provides typed
6//! access to the agent's response, session metadata, and usage statistics.
7//!
8//! # Examples
9//!
10//! ```no_run
11//! use ironflow_core::prelude::*;
12//!
13//! # async fn example() -> Result<(), OperationError> {
14//! let provider = ClaudeCodeProvider::new();
15//!
16//! let result = Agent::new()
17//!     .prompt("Summarize the README.md file")
18//!     .model(Model::SONNET)
19//!     .max_turns(3)
20//!     .run(&provider)
21//!     .await?;
22//!
23//! println!("{}", result.text());
24//! # Ok(())
25//! # }
26//! ```
27
28use std::any;
29use std::sync::Arc;
30
31use schemars::{JsonSchema, schema_for};
32use serde::de::DeserializeOwned;
33use serde::{Deserialize, Serialize};
34use serde_json::{Value, from_value, to_string};
35use tokio::time;
36use tracing::{info, warn};
37
38use crate::error::OperationError;
39#[cfg(feature = "prometheus")]
40use crate::metric_names;
41use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage, LogSink};
42use crate::retry::RetryPolicy;
43
44/// Provider-agnostic model identifiers.
45///
46/// Constants are provided for well-known Claude models, but any string
47/// is accepted - custom [`AgentProvider`] implementations interpret the
48/// model identifier however they wish.
49///
50/// # Examples
51///
52/// ```no_run
53/// use ironflow_core::prelude::*;
54///
55/// # async fn example() -> Result<(), OperationError> {
56/// let provider = ClaudeCodeProvider::new();
57///
58/// // Using a built-in constant
59/// let r = Agent::new()
60///     .prompt("hi")
61///     .model(Model::SONNET)
62///     .run(&provider)
63///     .await?;
64///
65/// // Using a custom model string
66/// let r = Agent::new()
67///     .prompt("hi")
68///     .model("mistral-large-latest")
69///     .run(&provider)
70///     .await?;
71/// # Ok(())
72/// # }
73/// ```
74pub struct Model;
75
76impl Model {
77    // ── Aliases (latest version, CLI resolves to current) ───────────
78
79    /// Claude Sonnet - balanced speed and capability (default).
80    pub const SONNET: &str = "sonnet";
81    /// Claude Opus - highest capability.
82    pub const OPUS: &str = "opus";
83    /// Claude Haiku - fastest and cheapest.
84    pub const HAIKU: &str = "haiku";
85
86    // ── Claude 4.5 ─────────────────────────────────────────────────
87
88    /// Claude Haiku 4.5.
89    pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
90
91    // ── Claude 4.6 - 200K context ──────────────────────────────────
92
93    /// Claude Sonnet 4.6.
94    pub const SONNET_46: &str = "claude-sonnet-4-6";
95    /// Claude Opus 4.6.
96    pub const OPUS_46: &str = "claude-opus-4-6";
97
98    // ── Claude 4.6 - 1M context ────────────────────────────────────
99
100    /// Claude Sonnet 4.6 with 1M token context window.
101    pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
102    /// Claude Opus 4.6 with 1M token context window.
103    pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
104
105    // ── Claude 4.7 - 1M context native ─────────────────────────────
106
107    /// Claude Opus 4.7 - latest flagship, 1M token context native.
108    pub const OPUS_47: &str = "claude-opus-4-7";
109    /// Claude Opus 4.7 with 1M token context window explicit.
110    pub const OPUS_47_1M: &str = "claude-opus-4-7[1m]";
111}
112
113/// Controls how the agent handles tool-use permission prompts.
114///
115/// These map to the `--permission-mode` and `--dangerously-skip-permissions`
116/// flags in the Claude CLI.
117#[derive(Debug, Default, Clone, Copy, Serialize)]
118pub enum PermissionMode {
119    /// Use the CLI default permission behavior.
120    #[default]
121    Default,
122    /// Automatically approve tool-use requests.
123    Auto,
124    /// Suppress all permission prompts (the agent proceeds without asking).
125    DontAsk,
126    /// Skip all permission checks entirely.
127    ///
128    /// **Warning**: the agent will have unrestricted filesystem and shell access.
129    BypassPermissions,
130}
131
132impl<'de> Deserialize<'de> for PermissionMode {
133    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134    where
135        D: serde::Deserializer<'de>,
136    {
137        let s = String::deserialize(deserializer)?;
138        Ok(match s.to_lowercase().replace('_', "").as_str() {
139            "auto" => Self::Auto,
140            "dontask" => Self::DontAsk,
141            "bypass" | "bypasspermissions" => Self::BypassPermissions,
142            _ => Self::Default,
143        })
144    }
145}
146
147/// Builder for a single agent invocation.
148///
149/// Create with [`Agent::new`], chain configuration methods, then call
150/// [`run`](Agent::run) with an [`AgentProvider`] to execute.
151///
152/// # Examples
153///
154/// ```no_run
155/// use ironflow_core::prelude::*;
156///
157/// # async fn example() -> Result<(), OperationError> {
158/// let provider = ClaudeCodeProvider::new();
159///
160/// let result = Agent::new()
161///     .system_prompt("You are a Rust expert.")
162///     .prompt("Review this code for safety issues.")
163///     .model(Model::OPUS)
164///     .allowed_tools(&["Read", "Grep"])
165///     .max_turns(5)
166///     .max_budget_usd(0.50)
167///     .working_dir("/tmp/project")
168///     .permission_mode(PermissionMode::Auto)
169///     .run(&provider)
170///     .await?;
171///
172/// println!("Cost: ${:.4}", result.cost_usd().unwrap_or(0.0));
173/// # Ok(())
174/// # }
175/// ```
176#[must_use = "an Agent does nothing until .run() is awaited"]
177pub struct Agent {
178    config: AgentConfig,
179    dry_run: Option<bool>,
180    retry_policy: Option<RetryPolicy>,
181    log_sink: Option<Arc<dyn LogSink>>,
182}
183
184impl Agent {
185    /// Create a new agent builder with default settings.
186    ///
187    /// Defaults: [`Model::SONNET`], no system prompt, no tool restrictions,
188    /// no budget/turn limits, [`PermissionMode::Default`].
189    pub fn new() -> Self {
190        Self {
191            config: AgentConfig::new(""),
192            dry_run: None,
193            retry_policy: None,
194            log_sink: None,
195        }
196    }
197
198    /// Create an agent builder from an existing [`AgentConfig`].
199    ///
200    /// Useful when the config comes from a serialized workflow definition
201    /// rather than being built programmatically.
202    ///
203    /// # Examples
204    ///
205    /// ```no_run
206    /// use ironflow_core::prelude::*;
207    /// use ironflow_core::provider::AgentConfig;
208    ///
209    /// # async fn example() -> Result<(), OperationError> {
210    /// let provider = ClaudeCodeProvider::new();
211    /// let config = AgentConfig::new("Summarize the README");
212    /// let result = Agent::from_config(config).run(&provider).await?;
213    /// # Ok(())
214    /// # }
215    /// ```
216    pub fn from_config(config: impl Into<AgentConfig>) -> Self {
217        Self {
218            config: config.into(),
219            dry_run: None,
220            retry_policy: None,
221            log_sink: None,
222        }
223    }
224
225    /// Set the system prompt that defines the agent's persona or constraints.
226    pub fn system_prompt(mut self, prompt: &str) -> Self {
227        self.config.system_prompt = Some(prompt.to_string());
228        self
229    }
230
231    /// Set the user prompt - the main instruction sent to the agent.
232    pub fn prompt(mut self, prompt: &str) -> Self {
233        self.config.prompt = prompt.to_string();
234        self
235    }
236
237    /// Set the model to use for this invocation.
238    ///
239    /// Accepts any string-like value. Use [`Model`] constants for well-known
240    /// Claude models, or pass an arbitrary string for custom providers.
241    ///
242    /// Defaults to [`Model::SONNET`] if not called.
243    pub fn model(mut self, model: impl Into<String>) -> Self {
244        self.config.model = model.into();
245        self
246    }
247
248    /// Restrict which tools the agent may invoke.
249    ///
250    /// Pass an empty slice (or do not call this method) to allow the provider
251    /// default set of tools.
252    pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
253        self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
254        self
255    }
256
257    /// Set the maximum number of agentic turns.
258    ///
259    /// # Panics
260    ///
261    /// Panics if `turns` is `0`.
262    pub fn max_turns(mut self, turns: u32) -> Self {
263        assert!(turns > 0, "max_turns must be greater than 0");
264        self.config.max_turns = Some(turns);
265        self
266    }
267
268    /// Set the maximum spend in USD for this invocation.
269    ///
270    /// # Panics
271    ///
272    /// Panics if `budget` is negative, NaN, or infinity.
273    pub fn max_budget_usd(mut self, budget: f64) -> Self {
274        assert!(
275            budget.is_finite() && budget > 0.0,
276            "budget must be a positive finite number, got {budget}"
277        );
278        self.config.max_budget_usd = Some(budget);
279        self
280    }
281
282    /// Set the working directory for the agent process.
283    pub fn working_dir(mut self, dir: &str) -> Self {
284        self.config.working_dir = Some(dir.to_string());
285        self
286    }
287
288    /// Set the path to an MCP (Model Context Protocol) server configuration file.
289    pub fn mcp_config(mut self, config: &str) -> Self {
290        self.config.mcp_config = Some(config.to_string());
291        self
292    }
293
294    /// Set the permission mode controlling tool-use approval behavior.
295    ///
296    /// See [`PermissionMode`] for details on each variant.
297    pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
298        self.config.permission_mode = mode;
299        self
300    }
301
302    /// Request structured (typed) output from the agent.
303    ///
304    /// The type `T` must implement [`JsonSchema`]. The generated schema is sent
305    /// to the provider so the model returns JSON conforming to `T`, which can
306    /// then be deserialized with [`AgentResult::json`].
307    ///
308    /// # Examples
309    ///
310    /// ```no_run
311    /// use ironflow_core::prelude::*;
312    ///
313    /// #[derive(Deserialize, JsonSchema)]
314    /// struct Review {
315    ///     score: u8,
316    ///     summary: String,
317    /// }
318    ///
319    /// # async fn example() -> Result<(), OperationError> {
320    /// let provider = ClaudeCodeProvider::new();
321    /// let result = Agent::new()
322    ///     .prompt("Review the codebase")
323    ///     .output::<Review>()
324    ///     .run(&provider)
325    ///     .await?;
326    ///
327    /// let review: Review = result.json().expect("schema-validated output");
328    /// println!("Score: {}/10 - {}", review.score, review.summary);
329    /// # Ok(())
330    /// # }
331    /// ```
332    pub fn output<T: JsonSchema>(mut self) -> Self {
333        let schema = schema_for!(T);
334        self.config.json_schema = match to_string(&schema) {
335            Ok(s) => Some(s),
336            Err(e) => {
337                warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
338                None
339            }
340        };
341        self
342    }
343
344    /// Set structured output from a pre-serialized JSON Schema string.
345    ///
346    /// Use this when the schema comes from configuration or another source
347    /// rather than a Rust type. For type-safe schema generation, prefer
348    /// [`output`](Agent::output).
349    ///
350    /// **Important:** structured output requires `max_turns >= 2`. The Claude CLI
351    /// uses the first turn for reasoning and a second turn to produce the
352    /// schema-conforming JSON.
353    ///
354    /// # Examples
355    ///
356    /// ```no_run
357    /// use ironflow_core::prelude::*;
358    ///
359    /// # async fn example() -> Result<(), OperationError> {
360    /// let schema = r#"{"type":"object","properties":{"labels":{"type":"array","items":{"type":"string"}}}}"#;
361    /// let agent = Agent::new()
362    ///     .prompt("Classify this email")
363    ///     .output_schema_raw(schema);
364    /// # Ok(())
365    /// # }
366    /// ```
367    pub fn output_schema_raw(mut self, schema: &str) -> Self {
368        self.config.json_schema = Some(schema.to_string());
369        self
370    }
371
372    /// Retry the agent invocation up to `max_retries` times on transient failures.
373    ///
374    /// Uses default exponential backoff settings (200ms initial, 2x multiplier,
375    /// 30s cap). For custom backoff parameters, use [`retry_policy`](Agent::retry_policy).
376    ///
377    /// Only transient errors are retried: process failures and timeouts.
378    /// Deterministic errors (prompt too large, schema validation) are never retried.
379    ///
380    /// # Panics
381    ///
382    /// Panics if `max_retries` is `0`.
383    ///
384    /// # Examples
385    ///
386    /// ```no_run
387    /// use ironflow_core::prelude::*;
388    ///
389    /// # async fn example() -> Result<(), OperationError> {
390    /// let provider = ClaudeCodeProvider::new();
391    /// let result = Agent::new()
392    ///     .prompt("Summarize the codebase")
393    ///     .retry(2)
394    ///     .run(&provider)
395    ///     .await?;
396    /// # Ok(())
397    /// # }
398    /// ```
399    pub fn retry(mut self, max_retries: u32) -> Self {
400        self.retry_policy = Some(RetryPolicy::new(max_retries));
401        self
402    }
403
404    /// Set a custom [`RetryPolicy`] for this agent invocation.
405    ///
406    /// Allows full control over backoff duration, multiplier, and max delay.
407    /// See [`RetryPolicy`] for details.
408    ///
409    /// # Examples
410    ///
411    /// ```no_run
412    /// use std::time::Duration;
413    /// use ironflow_core::prelude::*;
414    /// use ironflow_core::retry::RetryPolicy;
415    ///
416    /// # async fn example() -> Result<(), OperationError> {
417    /// let provider = ClaudeCodeProvider::new();
418    /// let result = Agent::new()
419    ///     .prompt("Analyze the code")
420    ///     .retry_policy(
421    ///         RetryPolicy::new(3)
422    ///             .backoff(Duration::from_secs(1))
423    ///             .max_backoff(Duration::from_secs(60))
424    ///     )
425    ///     .run(&provider)
426    ///     .await?;
427    /// # Ok(())
428    /// # }
429    /// ```
430    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
431        self.retry_policy = Some(policy);
432        self
433    }
434
435    /// Enable or disable dry-run mode for this specific operation.
436    ///
437    /// When dry-run is active, the agent call is logged but not executed.
438    /// A synthetic [`AgentResult`] is returned with a placeholder text,
439    /// zero cost, and zero tokens.
440    ///
441    /// If not set, falls back to the global dry-run setting
442    /// (see [`set_dry_run`](crate::dry_run::set_dry_run)).
443    pub fn dry_run(mut self, enabled: bool) -> Self {
444        self.dry_run = Some(enabled);
445        self
446    }
447
448    /// Attach a [`LogSink`] for real-time log streaming.
449    ///
450    /// When set, [`invoke_with_logs`](AgentProvider::invoke_with_logs) is called
451    /// instead of [`invoke`](AgentProvider::invoke), allowing providers that
452    /// support streaming to pipe output lines in real time.
453    ///
454    /// # Examples
455    ///
456    /// ```no_run
457    /// use std::sync::Arc;
458    /// use ironflow_core::prelude::*;
459    ///
460    /// # async fn example() -> Result<(), OperationError> {
461    /// # struct MySink;
462    /// # impl LogSink for MySink { fn log(&self, _: &str, _: &str) {} }
463    /// let provider = ClaudeCodeProvider::new();
464    /// let sink: Arc<dyn LogSink> = Arc::new(MySink);
465    ///
466    /// let result = Agent::new()
467    ///     .prompt("Analyze src/")
468    ///     .log_sink(sink)
469    ///     .run(&provider)
470    ///     .await?;
471    /// # Ok(())
472    /// # }
473    /// ```
474    pub fn log_sink(mut self, sink: Arc<dyn LogSink>) -> Self {
475        self.log_sink = Some(sink);
476        self
477    }
478
479    /// Enable verbose/debug mode to capture the full conversation trace.
480    ///
481    /// When enabled, the provider captures every assistant message and tool
482    /// call into [`AgentResult::debug_messages`]. Useful for understanding
483    /// why the agent returned an unexpected result.
484    ///
485    /// # Examples
486    ///
487    /// ```no_run
488    /// use ironflow_core::prelude::*;
489    ///
490    /// # async fn example() -> Result<(), OperationError> {
491    /// let provider = ClaudeCodeProvider::new();
492    ///
493    /// let result = Agent::new()
494    ///     .prompt("Analyze src/")
495    ///     .verbose()
496    ///     .max_budget_usd(0.10)
497    ///     .run(&provider)
498    ///     .await?;
499    ///
500    /// if let Some(messages) = result.debug_messages() {
501    ///     for msg in messages {
502    ///         println!("{msg}");
503    ///     }
504    /// }
505    /// # Ok(())
506    /// # }
507    /// ```
508    pub fn verbose(mut self) -> Self {
509        self.config.verbose = true;
510        self
511    }
512
513    /// Resume a previous agent conversation by session ID.
514    ///
515    /// Pass the session ID from a previous [`AgentResult::session_id()`] to
516    /// continue the multi-turn conversation.
517    ///
518    /// # Examples
519    ///
520    /// ```no_run
521    /// use ironflow_core::prelude::*;
522    ///
523    /// # async fn example() -> Result<(), OperationError> {
524    /// let provider = ClaudeCodeProvider::new();
525    ///
526    /// let first = Agent::new()
527    ///     .prompt("Analyze the src/ directory")
528    ///     .max_budget_usd(0.10)
529    ///     .run(&provider)
530    ///     .await?;
531    ///
532    /// let session = first.session_id().expect("provider returned session ID");
533    ///
534    /// let followup = Agent::new()
535    ///     .prompt("Now suggest improvements")
536    ///     .resume(session)
537    ///     .max_budget_usd(0.10)
538    ///     .run(&provider)
539    ///     .await?;
540    /// # Ok(())
541    /// # }
542    /// ```
543    ///
544    /// # Panics
545    ///
546    /// Panics if `session_id` is empty or contains characters other than
547    /// alphanumerics, hyphens, and underscores.
548    pub fn resume(mut self, session_id: &str) -> Self {
549        assert!(!session_id.is_empty(), "session_id must not be empty");
550        assert!(
551            session_id
552                .chars()
553                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
554            "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
555        );
556        self.config.resume_session_id = Some(session_id.to_string());
557        self
558    }
559
560    /// Execute the agent invocation using the given [`AgentProvider`].
561    ///
562    /// If a [`retry_policy`](Agent::retry_policy) is configured, transient
563    /// failures (process crashes, timeouts, schema validation) are retried
564    /// with exponential backoff. When structured output is requested
565    /// (`json_schema` is set) and no explicit retry policy is configured,
566    /// an automatic retry policy of 2 retries is applied to handle
567    /// non-deterministic `structured_output: null` responses from the CLI.
568    /// Deterministic errors (prompt too large) are returned immediately
569    /// without retry.
570    ///
571    /// # Errors
572    ///
573    /// Returns [`OperationError::Agent`] if the provider reports a failure
574    /// (process crash, timeout, or schema validation error).
575    ///
576    /// # Panics
577    ///
578    /// Panics if [`prompt`](Agent::prompt) was never called or the prompt is
579    /// empty (whitespace-only counts as empty).
580    #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
581    pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
582        assert!(
583            !self.config.prompt.trim().is_empty(),
584            "prompt must not be empty - call .prompt(\"...\") before .run()"
585        );
586
587        if crate::dry_run::effective_dry_run(self.dry_run) {
588            info!(
589                prompt_len = self.config.prompt.len(),
590                "[dry-run] agent call skipped"
591            );
592            let mut output =
593                AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
594            output.cost_usd = Some(0.0);
595            output.input_tokens = Some(0);
596            output.output_tokens = Some(0);
597            return Ok(AgentResult { output });
598        }
599
600        let result = self.invoke_once(provider).await;
601
602        let default_schema_retry = RetryPolicy::new(2);
603        let policy = match &self.retry_policy {
604            Some(p) => p,
605            None if self.config.json_schema.is_some() => &default_schema_retry,
606            None => return result,
607        };
608
609        // Non-retryable errors are returned immediately.
610        if let Err(ref err) = result {
611            if !crate::retry::is_retryable(err) {
612                return result;
613            }
614        } else {
615            return result;
616        }
617
618        let mut last_result = result;
619
620        for attempt in 0..policy.max_retries {
621            let delay = policy.delay_for_attempt(attempt);
622            let retry_reason = if matches!(
623                &last_result,
624                Err(OperationError::Agent(
625                    crate::error::AgentError::SchemaValidation { .. }
626                ))
627            ) {
628                "structured_output was null (CLI non-determinism)"
629            } else {
630                "transient failure"
631            };
632            warn!(
633                attempt = attempt + 1,
634                max_retries = policy.max_retries,
635                delay_ms = delay.as_millis() as u64,
636                reason = retry_reason,
637                "retrying agent invocation"
638            );
639            time::sleep(delay).await;
640
641            last_result = self.invoke_once(provider).await;
642
643            match &last_result {
644                Ok(_) => return last_result,
645                Err(err) if !crate::retry::is_retryable(err) => return last_result,
646                _ => {}
647            }
648        }
649
650        last_result
651    }
652
653    /// Execute a single agent invocation attempt (no retry logic).
654    async fn invoke_once(
655        &self,
656        provider: &dyn AgentProvider,
657    ) -> Result<AgentResult, OperationError> {
658        #[cfg(feature = "prometheus")]
659        let model_label = self.config.model.to_string();
660
661        let invoke_result = match self.log_sink {
662            Some(ref sink) => provider.invoke_with_logs(&self.config, sink.clone()).await,
663            None => provider.invoke(&self.config).await,
664        };
665        let output = match invoke_result {
666            Ok(output) => output,
667            Err(e) => {
668                #[cfg(feature = "prometheus")]
669                {
670                    metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
671                }
672                return Err(OperationError::Agent(e));
673            }
674        };
675
676        info!(
677            duration_ms = output.duration_ms,
678            cost_usd = output.cost_usd,
679            input_tokens = output.input_tokens,
680            output_tokens = output.output_tokens,
681            model = output.model,
682            "agent completed"
683        );
684
685        #[cfg(feature = "prometheus")]
686        {
687            metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
688            metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
689                .record(output.duration_ms as f64 / 1000.0);
690            if let Some(cost) = output.cost_usd {
691                metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
692                    .increment(cost);
693            }
694            if let Some(tokens) = output.input_tokens {
695                metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
696            }
697            if let Some(tokens) = output.output_tokens {
698                metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
699                    .increment(tokens);
700            }
701        }
702
703        Ok(AgentResult { output })
704    }
705}
706
707impl Default for Agent {
708    fn default() -> Self {
709        Self::new()
710    }
711}
712
713/// The result of a successful agent invocation.
714///
715/// Wraps the raw [`AgentOutput`] and provides convenience accessors for the
716/// response text, typed JSON deserialization, session metadata, and usage stats.
717#[derive(Debug)]
718pub struct AgentResult {
719    output: AgentOutput,
720}
721
722impl AgentResult {
723    /// Return the agent's response as a plain text string.
724    ///
725    /// If the underlying value is not a JSON string (e.g. when structured output
726    /// was requested), returns an empty string and logs a warning.
727    pub fn text(&self) -> &str {
728        match self.output.value.as_str() {
729            Some(s) => s,
730            None => {
731                warn!(
732                    value_type = self.output.value.to_string(),
733                    "agent output is not a string, returning empty"
734                );
735                ""
736            }
737        }
738    }
739
740    /// Return the raw JSON [`Value`] of the agent's response.
741    pub fn value(&self) -> &Value {
742        &self.output.value
743    }
744
745    /// Deserialize the agent's response into the given type `T`.
746    ///
747    /// This clones the underlying JSON value. If you no longer need the
748    /// `AgentResult` afterwards, use [`into_json`](AgentResult::into_json)
749    /// instead to avoid the clone.
750    ///
751    /// # Errors
752    ///
753    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
754    pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
755        from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
756    }
757
758    /// Consume the result and deserialize the response into `T` without cloning.
759    ///
760    /// # Errors
761    ///
762    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
763    pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
764        from_value(self.output.value).map_err(OperationError::deserialize::<T>)
765    }
766
767    /// Build an `AgentResult` from a raw [`AgentOutput`].
768    ///
769    /// This is available only in test builds to simplify test setup without
770    /// going through the full record/replay pipeline.
771    #[cfg(test)]
772    pub(crate) fn from_output(output: AgentOutput) -> Self {
773        Self { output }
774    }
775
776    /// Return the provider-assigned session ID, if available.
777    pub fn session_id(&self) -> Option<&str> {
778        self.output.session_id.as_deref()
779    }
780
781    /// Return the cost of this invocation in USD, if reported by the provider.
782    pub fn cost_usd(&self) -> Option<f64> {
783        self.output.cost_usd
784    }
785
786    /// Return the number of input tokens consumed, if reported.
787    pub fn input_tokens(&self) -> Option<u64> {
788        self.output.input_tokens
789    }
790
791    /// Return the number of output tokens generated, if reported.
792    pub fn output_tokens(&self) -> Option<u64> {
793        self.output.output_tokens
794    }
795
796    /// Return the wall-clock duration of the invocation in milliseconds.
797    pub fn duration_ms(&self) -> u64 {
798        self.output.duration_ms
799    }
800
801    /// Return the concrete model identifier used, if reported by the provider.
802    pub fn model(&self) -> Option<&str> {
803        self.output.model.as_deref()
804    }
805
806    /// Return the conversation trace captured during a verbose invocation.
807    ///
808    /// Returns `None` when [`Agent::verbose`] was not called. When present,
809    /// each [`DebugMessage`] contains the
810    /// assistant's text and tool calls for one conversation turn.
811    pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
812        self.output.debug_messages.as_deref()
813    }
814}
815
816#[cfg(test)]
817mod tests {
818    use super::*;
819    use crate::error::AgentError;
820    use crate::provider::InvokeFuture;
821    use serde_json::json;
822
823    struct TestProvider {
824        output: AgentOutput,
825    }
826
827    impl AgentProvider for TestProvider {
828        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
829            Box::pin(async move {
830                Ok(AgentOutput {
831                    value: self.output.value.clone(),
832                    session_id: self.output.session_id.clone(),
833                    cost_usd: self.output.cost_usd,
834                    input_tokens: self.output.input_tokens,
835                    output_tokens: self.output.output_tokens,
836                    model: self.output.model.clone(),
837                    duration_ms: self.output.duration_ms,
838                    debug_messages: None,
839                })
840            })
841        }
842    }
843
844    struct ConfigCapture {
845        output: AgentOutput,
846    }
847
848    impl AgentProvider for ConfigCapture {
849        fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
850            let config_json = serde_json::to_value(config).unwrap();
851            Box::pin(async move {
852                Ok(AgentOutput {
853                    value: config_json,
854                    session_id: self.output.session_id.clone(),
855                    cost_usd: self.output.cost_usd,
856                    input_tokens: self.output.input_tokens,
857                    output_tokens: self.output.output_tokens,
858                    model: self.output.model.clone(),
859                    duration_ms: self.output.duration_ms,
860                    debug_messages: None,
861                })
862            })
863        }
864    }
865
866    fn default_output() -> AgentOutput {
867        AgentOutput {
868            value: json!("test output"),
869            session_id: Some("sess-123".to_string()),
870            cost_usd: Some(0.05),
871            input_tokens: Some(100),
872            output_tokens: Some(50),
873            model: Some("sonnet".to_string()),
874            duration_ms: 1500,
875            debug_messages: None,
876        }
877    }
878
879    // --- Model constants ---
880
881    #[test]
882    fn model_constants_have_expected_values() {
883        assert_eq!(Model::SONNET, "sonnet");
884        assert_eq!(Model::OPUS, "opus");
885        assert_eq!(Model::HAIKU, "haiku");
886        assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
887        assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
888        assert_eq!(Model::OPUS_46, "claude-opus-4-6");
889        assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
890        assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
891        assert_eq!(Model::OPUS_47, "claude-opus-4-7");
892        assert_eq!(Model::OPUS_47_1M, "claude-opus-4-7[1m]");
893    }
894
895    // --- Agent::new() defaults via ConfigCapture ---
896
897    #[tokio::test]
898    async fn agent_new_default_values() {
899        let provider = ConfigCapture {
900            output: default_output(),
901        };
902        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
903
904        let config = result.value();
905        assert_eq!(config["system_prompt"], json!(null));
906        assert_eq!(config["prompt"], json!("hi"));
907        assert_eq!(config["model"], json!("sonnet"));
908        assert_eq!(config["allowed_tools"], json!([]));
909        assert_eq!(config["max_turns"], json!(null));
910        assert_eq!(config["max_budget_usd"], json!(null));
911        assert_eq!(config["working_dir"], json!(null));
912        assert_eq!(config["mcp_config"], json!(null));
913        assert_eq!(config["permission_mode"], json!("Default"));
914        assert_eq!(config["json_schema"], json!(null));
915    }
916
917    #[tokio::test]
918    async fn agent_default_matches_new() {
919        let provider = ConfigCapture {
920            output: default_output(),
921        };
922        let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
923        let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
924
925        assert_eq!(result_new.value(), result_default.value());
926    }
927
928    // --- Builder methods ---
929
930    #[tokio::test]
931    async fn builder_methods_store_values_correctly() {
932        let provider = ConfigCapture {
933            output: default_output(),
934        };
935        let result = Agent::new()
936            .system_prompt("you are a bot")
937            .prompt("do something")
938            .model(Model::OPUS)
939            .allowed_tools(&["Read", "Write"])
940            .max_turns(5)
941            .max_budget_usd(1.5)
942            .working_dir("/tmp")
943            .mcp_config("{}")
944            .permission_mode(PermissionMode::Auto)
945            .run(&provider)
946            .await
947            .unwrap();
948
949        let config = result.value();
950        assert_eq!(config["system_prompt"], json!("you are a bot"));
951        assert_eq!(config["prompt"], json!("do something"));
952        assert_eq!(config["model"], json!("opus"));
953        assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
954        assert_eq!(config["max_turns"], json!(5));
955        assert_eq!(config["max_budget_usd"], json!(1.5));
956        assert_eq!(config["working_dir"], json!("/tmp"));
957        assert_eq!(config["mcp_config"], json!("{}"));
958        assert_eq!(config["permission_mode"], json!("Auto"));
959    }
960
961    // --- Panics ---
962
963    #[test]
964    #[should_panic(expected = "max_turns must be greater than 0")]
965    fn max_turns_zero_panics() {
966        let _ = Agent::new().max_turns(0);
967    }
968
969    #[test]
970    #[should_panic(expected = "budget must be a positive finite number")]
971    fn max_budget_negative_panics() {
972        let _ = Agent::new().max_budget_usd(-1.0);
973    }
974
975    #[test]
976    #[should_panic(expected = "budget must be a positive finite number")]
977    fn max_budget_nan_panics() {
978        let _ = Agent::new().max_budget_usd(f64::NAN);
979    }
980
981    #[test]
982    #[should_panic(expected = "budget must be a positive finite number")]
983    fn max_budget_infinity_panics() {
984        let _ = Agent::new().max_budget_usd(f64::INFINITY);
985    }
986
987    // --- AgentResult accessors ---
988
989    #[tokio::test]
990    async fn agent_result_text_with_string_value() {
991        let provider = TestProvider {
992            output: AgentOutput {
993                value: json!("hello world"),
994                ..default_output()
995            },
996        };
997        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
998        assert_eq!(result.text(), "hello world");
999    }
1000
1001    #[tokio::test]
1002    async fn agent_result_text_with_non_string_value() {
1003        let provider = TestProvider {
1004            output: AgentOutput {
1005                value: json!(42),
1006                ..default_output()
1007            },
1008        };
1009        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1010        assert_eq!(result.text(), "");
1011    }
1012
1013    #[tokio::test]
1014    async fn agent_result_text_with_null_value() {
1015        let provider = TestProvider {
1016            output: AgentOutput {
1017                value: json!(null),
1018                ..default_output()
1019            },
1020        };
1021        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1022        assert_eq!(result.text(), "");
1023    }
1024
1025    #[tokio::test]
1026    async fn agent_result_json_successful_deserialize() {
1027        #[derive(Deserialize, PartialEq, Debug)]
1028        struct MyOutput {
1029            name: String,
1030            count: u32,
1031        }
1032        let provider = TestProvider {
1033            output: AgentOutput {
1034                value: json!({"name": "test", "count": 7}),
1035                ..default_output()
1036            },
1037        };
1038        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1039        let parsed: MyOutput = result.json().unwrap();
1040        assert_eq!(parsed.name, "test");
1041        assert_eq!(parsed.count, 7);
1042    }
1043
1044    #[tokio::test]
1045    async fn agent_result_json_failed_deserialize() {
1046        #[derive(Debug, Deserialize)]
1047        #[allow(dead_code)]
1048        struct MyOutput {
1049            name: String,
1050        }
1051        let provider = TestProvider {
1052            output: AgentOutput {
1053                value: json!(42),
1054                ..default_output()
1055            },
1056        };
1057        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1058        let err = result.json::<MyOutput>().unwrap_err();
1059        assert!(matches!(err, OperationError::Deserialize { .. }));
1060    }
1061
1062    #[tokio::test]
1063    async fn agent_result_accessors() {
1064        let provider = TestProvider {
1065            output: AgentOutput {
1066                value: json!("v"),
1067                session_id: Some("s-1".to_string()),
1068                cost_usd: Some(0.123),
1069                input_tokens: Some(999),
1070                output_tokens: Some(456),
1071                model: Some("opus".to_string()),
1072                duration_ms: 2000,
1073                debug_messages: None,
1074            },
1075        };
1076        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1077        assert_eq!(result.session_id(), Some("s-1"));
1078        assert_eq!(result.cost_usd(), Some(0.123));
1079        assert_eq!(result.input_tokens(), Some(999));
1080        assert_eq!(result.output_tokens(), Some(456));
1081        assert_eq!(result.duration_ms(), 2000);
1082        assert_eq!(result.model(), Some("opus"));
1083    }
1084
1085    // --- Session resume ---
1086
1087    #[tokio::test]
1088    async fn resume_passes_session_id_in_config() {
1089        let provider = ConfigCapture {
1090            output: default_output(),
1091        };
1092        let result = Agent::new()
1093            .prompt("followup")
1094            .resume("sess-abc")
1095            .run(&provider)
1096            .await
1097            .unwrap();
1098
1099        let config = result.value();
1100        assert_eq!(config["resume_session_id"], json!("sess-abc"));
1101    }
1102
1103    #[tokio::test]
1104    async fn no_resume_has_null_session_id() {
1105        let provider = ConfigCapture {
1106            output: default_output(),
1107        };
1108        let result = Agent::new()
1109            .prompt("first call")
1110            .run(&provider)
1111            .await
1112            .unwrap();
1113
1114        let config = result.value();
1115        assert_eq!(config["resume_session_id"], json!(null));
1116    }
1117
1118    #[test]
1119    #[should_panic(expected = "session_id must not be empty")]
1120    fn resume_empty_session_id_panics() {
1121        let _ = Agent::new().resume("");
1122    }
1123
1124    #[test]
1125    #[should_panic(expected = "session_id must only contain")]
1126    fn resume_invalid_chars_panics() {
1127        let _ = Agent::new().resume("sess;rm -rf /");
1128    }
1129
1130    #[test]
1131    fn resume_valid_formats_accepted() {
1132        let _ = Agent::new().resume("sess-abc123");
1133        let _ = Agent::new().resume("a1b2c3d4_session");
1134        let _ = Agent::new().resume("abc-DEF-123_456");
1135    }
1136
1137    #[tokio::test]
1138    #[should_panic(expected = "prompt must not be empty")]
1139    async fn run_without_prompt_panics() {
1140        let provider = TestProvider {
1141            output: default_output(),
1142        };
1143        let _ = Agent::new().run(&provider).await;
1144    }
1145
1146    #[tokio::test]
1147    #[should_panic(expected = "prompt must not be empty")]
1148    async fn run_with_whitespace_only_prompt_panics() {
1149        let provider = TestProvider {
1150            output: default_output(),
1151        };
1152        let _ = Agent::new().prompt("   ").run(&provider).await;
1153    }
1154
1155    // --- Model accepts arbitrary strings ---
1156
1157    #[tokio::test]
1158    async fn model_accepts_custom_string() {
1159        let provider = ConfigCapture {
1160            output: default_output(),
1161        };
1162        let result = Agent::new()
1163            .prompt("hi")
1164            .model("mistral-large-latest")
1165            .run(&provider)
1166            .await
1167            .unwrap();
1168        assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1169    }
1170
1171    #[tokio::test]
1172    async fn verbose_sets_config_flag() {
1173        let provider = ConfigCapture {
1174            output: default_output(),
1175        };
1176        let result = Agent::new()
1177            .prompt("hi")
1178            .verbose()
1179            .run(&provider)
1180            .await
1181            .unwrap();
1182        assert_eq!(result.value()["verbose"], json!(true));
1183    }
1184
1185    #[tokio::test]
1186    async fn verbose_not_set_by_default() {
1187        let provider = ConfigCapture {
1188            output: default_output(),
1189        };
1190        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
1191        assert_eq!(result.value()["verbose"], json!(false));
1192    }
1193
1194    #[tokio::test]
1195    async fn debug_messages_none_without_verbose() {
1196        let provider = TestProvider {
1197            output: default_output(),
1198        };
1199        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1200        assert!(result.debug_messages().is_none());
1201    }
1202
1203    #[tokio::test]
1204    async fn model_accepts_owned_string() {
1205        let provider = ConfigCapture {
1206            output: default_output(),
1207        };
1208        let model_name = String::from("gpt-4o");
1209        let result = Agent::new()
1210            .prompt("hi")
1211            .model(model_name)
1212            .run(&provider)
1213            .await
1214            .unwrap();
1215        assert_eq!(result.value()["model"], json!("gpt-4o"));
1216    }
1217
1218    #[tokio::test]
1219    async fn into_json_success() {
1220        #[derive(Deserialize, PartialEq, Debug)]
1221        struct Out {
1222            name: String,
1223        }
1224        let provider = TestProvider {
1225            output: AgentOutput {
1226                value: json!({"name": "test"}),
1227                ..default_output()
1228            },
1229        };
1230        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1231        let parsed: Out = result.into_json().unwrap();
1232        assert_eq!(parsed.name, "test");
1233    }
1234
1235    #[tokio::test]
1236    async fn into_json_failure() {
1237        #[derive(Debug, Deserialize)]
1238        #[allow(dead_code)]
1239        struct Out {
1240            name: String,
1241        }
1242        let provider = TestProvider {
1243            output: AgentOutput {
1244                value: json!(42),
1245                ..default_output()
1246            },
1247        };
1248        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1249        let err = result.into_json::<Out>().unwrap_err();
1250        assert!(matches!(err, OperationError::Deserialize { .. }));
1251    }
1252
1253    #[test]
1254    fn from_output_creates_result() {
1255        let output = AgentOutput {
1256            value: json!("hello"),
1257            ..default_output()
1258        };
1259        let result = AgentResult::from_output(output);
1260        assert_eq!(result.text(), "hello");
1261        assert_eq!(result.cost_usd(), Some(0.05));
1262    }
1263
1264    #[test]
1265    #[should_panic(expected = "budget must be a positive finite number")]
1266    fn max_budget_zero_panics() {
1267        let _ = Agent::new().max_budget_usd(0.0);
1268    }
1269
1270    #[test]
1271    fn model_constant_equality() {
1272        assert_eq!(Model::SONNET, "sonnet");
1273        assert_ne!(Model::SONNET, Model::OPUS);
1274    }
1275
1276    #[test]
1277    fn permission_mode_serialize_deserialize_roundtrip() {
1278        for mode in [
1279            PermissionMode::Default,
1280            PermissionMode::Auto,
1281            PermissionMode::DontAsk,
1282            PermissionMode::BypassPermissions,
1283        ] {
1284            let json = to_string(&mode).unwrap();
1285            let back: PermissionMode = serde_json::from_str(&json).unwrap();
1286            assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1287        }
1288    }
1289
1290    // --- Retry builder ---
1291
1292    #[test]
1293    fn retry_builder_stores_policy() {
1294        let agent = Agent::new().retry(3);
1295        assert!(agent.retry_policy.is_some());
1296        assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1297    }
1298
1299    #[test]
1300    fn retry_policy_builder_stores_custom_policy() {
1301        use crate::retry::RetryPolicy;
1302        let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1303        let agent = Agent::new().retry_policy(policy);
1304        let p = agent.retry_policy.unwrap();
1305        assert_eq!(p.max_retries(), 5);
1306    }
1307
1308    #[test]
1309    fn no_retry_by_default() {
1310        let agent = Agent::new();
1311        assert!(agent.retry_policy.is_none());
1312    }
1313
1314    // --- Retry behavior ---
1315
1316    use std::sync::Arc;
1317    use std::sync::atomic::{AtomicU32, Ordering};
1318    use std::time::Duration;
1319
1320    struct FailNTimesProvider {
1321        fail_count: AtomicU32,
1322        failures_before_success: u32,
1323        output: AgentOutput,
1324    }
1325
1326    impl AgentProvider for FailNTimesProvider {
1327        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1328            Box::pin(async move {
1329                let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1330                if current < self.failures_before_success {
1331                    Err(AgentError::ProcessFailed {
1332                        exit_code: 1,
1333                        stderr: format!("transient failure #{}", current + 1),
1334                    })
1335                } else {
1336                    Ok(AgentOutput {
1337                        value: self.output.value.clone(),
1338                        session_id: self.output.session_id.clone(),
1339                        cost_usd: self.output.cost_usd,
1340                        input_tokens: self.output.input_tokens,
1341                        output_tokens: self.output.output_tokens,
1342                        model: self.output.model.clone(),
1343                        duration_ms: self.output.duration_ms,
1344                        debug_messages: None,
1345                    })
1346                }
1347            })
1348        }
1349    }
1350
1351    #[tokio::test]
1352    async fn retry_succeeds_after_transient_failures() {
1353        let provider = FailNTimesProvider {
1354            fail_count: AtomicU32::new(0),
1355            failures_before_success: 2,
1356            output: default_output(),
1357        };
1358        let result = Agent::new()
1359            .prompt("test")
1360            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1361            .run(&provider)
1362            .await;
1363
1364        assert!(result.is_ok());
1365        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
1366    }
1367
1368    #[tokio::test]
1369    async fn retry_exhausted_returns_last_error() {
1370        let provider = FailNTimesProvider {
1371            fail_count: AtomicU32::new(0),
1372            failures_before_success: 10, // always fails
1373            output: default_output(),
1374        };
1375        let result = Agent::new()
1376            .prompt("test")
1377            .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1378            .run(&provider)
1379            .await;
1380
1381        assert!(result.is_err());
1382        // 1 initial + 2 retries = 3 total
1383        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1384    }
1385
1386    #[tokio::test]
1387    async fn retry_does_not_retry_prompt_too_large() {
1388        let call_count = Arc::new(AtomicU32::new(0));
1389        let count = call_count.clone();
1390
1391        struct CountingNonRetryable {
1392            count: Arc<AtomicU32>,
1393        }
1394        impl AgentProvider for CountingNonRetryable {
1395            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1396                self.count.fetch_add(1, Ordering::SeqCst);
1397                Box::pin(async move {
1398                    Err(AgentError::PromptTooLarge {
1399                        chars: 1_000_000,
1400                        estimated_tokens: 250_000,
1401                        model_limit: 200_000,
1402                    })
1403                })
1404            }
1405        }
1406
1407        let provider = CountingNonRetryable { count };
1408        let result = Agent::new()
1409            .prompt("test")
1410            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1411            .run(&provider)
1412            .await;
1413
1414        assert!(result.is_err());
1415        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1416    }
1417
1418    #[tokio::test]
1419    async fn retry_retries_schema_validation_errors() {
1420        let call_count = Arc::new(AtomicU32::new(0));
1421        let count = call_count.clone();
1422
1423        struct SchemaFailProvider {
1424            count: Arc<AtomicU32>,
1425        }
1426        impl AgentProvider for SchemaFailProvider {
1427            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1428                self.count.fetch_add(1, Ordering::SeqCst);
1429                Box::pin(async move {
1430                    Err(AgentError::SchemaValidation {
1431                        expected: "object".to_string(),
1432                        got: "null".to_string(),
1433                        debug_messages: Vec::new(),
1434                        partial_usage: Box::default(),
1435                        raw_response: None,
1436                    })
1437                })
1438            }
1439        }
1440
1441        let provider = SchemaFailProvider { count };
1442        let result = Agent::new()
1443            .prompt("test")
1444            .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1445            .run(&provider)
1446            .await;
1447
1448        assert!(result.is_err());
1449        // 1 initial + 2 retries = 3 total
1450        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1451    }
1452
1453    #[tokio::test]
1454    async fn schema_validation_succeeds_on_retry() {
1455        let call_count = Arc::new(AtomicU32::new(0));
1456        let count = call_count.clone();
1457
1458        struct SchemaFailThenSucceed {
1459            count: Arc<AtomicU32>,
1460            output: AgentOutput,
1461        }
1462        impl AgentProvider for SchemaFailThenSucceed {
1463            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1464                let current = self.count.fetch_add(1, Ordering::SeqCst);
1465                let output = self.output.clone();
1466                Box::pin(async move {
1467                    if current == 0 {
1468                        Err(AgentError::SchemaValidation {
1469                            expected: "structured_output field".to_string(),
1470                            got: "null".to_string(),
1471                            debug_messages: Vec::new(),
1472                            partial_usage: Box::default(),
1473                            raw_response: None,
1474                        })
1475                    } else {
1476                        Ok(output)
1477                    }
1478                })
1479            }
1480        }
1481
1482        let provider = SchemaFailThenSucceed {
1483            count,
1484            output: default_output(),
1485        };
1486        let result = Agent::new()
1487            .prompt("test")
1488            .retry_policy(crate::retry::RetryPolicy::new(1).backoff(Duration::from_millis(1)))
1489            .run(&provider)
1490            .await;
1491
1492        assert!(result.is_ok());
1493        assert_eq!(call_count.load(Ordering::SeqCst), 2);
1494    }
1495
1496    #[tokio::test]
1497    async fn auto_retry_applied_when_json_schema_set() {
1498        let call_count = Arc::new(AtomicU32::new(0));
1499        let count = call_count.clone();
1500
1501        struct AlwaysSchemaFail {
1502            count: Arc<AtomicU32>,
1503        }
1504        impl AgentProvider for AlwaysSchemaFail {
1505            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1506                self.count.fetch_add(1, Ordering::SeqCst);
1507                Box::pin(async move {
1508                    Err(AgentError::SchemaValidation {
1509                        expected: "object".to_string(),
1510                        got: "null".to_string(),
1511                        debug_messages: Vec::new(),
1512                        partial_usage: Box::default(),
1513                        raw_response: None,
1514                    })
1515                })
1516            }
1517        }
1518
1519        let provider = AlwaysSchemaFail { count };
1520        let result = Agent::new()
1521            .prompt("test")
1522            .output_schema_raw(r#"{"type":"object"}"#)
1523            .run(&provider)
1524            .await;
1525
1526        assert!(result.is_err());
1527        // auto-retry(2) : 1 initial + 2 retries = 3 total
1528        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1529    }
1530
1531    #[tokio::test]
1532    async fn no_retry_without_policy() {
1533        let provider = FailNTimesProvider {
1534            fail_count: AtomicU32::new(0),
1535            failures_before_success: 1,
1536            output: default_output(),
1537        };
1538        let result = Agent::new().prompt("test").run(&provider).await;
1539
1540        assert!(result.is_err());
1541        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1542    }
1543
1544    // ── log_sink tests ────────────────────────────────────────────
1545
1546    use crate::test_support::VecSink;
1547
1548    struct SinkCapture {
1549        output: AgentOutput,
1550        saw_logs: Arc<AtomicU32>,
1551    }
1552
1553    impl AgentProvider for SinkCapture {
1554        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1555            Box::pin(async {
1556                Ok(AgentOutput {
1557                    value: self.output.value.clone(),
1558                    session_id: self.output.session_id.clone(),
1559                    cost_usd: self.output.cost_usd,
1560                    input_tokens: self.output.input_tokens,
1561                    output_tokens: self.output.output_tokens,
1562                    model: self.output.model.clone(),
1563                    duration_ms: self.output.duration_ms,
1564                    debug_messages: None,
1565                })
1566            })
1567        }
1568
1569        fn invoke_with_logs<'a>(
1570            &'a self,
1571            config: &'a AgentConfig,
1572            log_sink: Arc<dyn LogSink>,
1573        ) -> InvokeFuture<'a> {
1574            self.saw_logs.fetch_add(1, Ordering::SeqCst);
1575            log_sink.log("stdout", "streaming line");
1576            self.invoke(config)
1577        }
1578    }
1579
1580    #[tokio::test]
1581    async fn log_sink_routes_to_invoke_with_logs() {
1582        let saw_logs = Arc::new(AtomicU32::new(0));
1583        let provider = SinkCapture {
1584            output: default_output(),
1585            saw_logs: saw_logs.clone(),
1586        };
1587        let sink: Arc<dyn LogSink> = VecSink::new();
1588
1589        let result = Agent::new()
1590            .prompt("test")
1591            .log_sink(sink)
1592            .run(&provider)
1593            .await;
1594
1595        assert!(result.is_ok());
1596        assert_eq!(saw_logs.load(Ordering::SeqCst), 1);
1597    }
1598
1599    #[tokio::test]
1600    async fn no_log_sink_routes_to_invoke() {
1601        let saw_logs = Arc::new(AtomicU32::new(0));
1602        let provider = SinkCapture {
1603            output: default_output(),
1604            saw_logs: saw_logs.clone(),
1605        };
1606
1607        let result = Agent::new().prompt("test").run(&provider).await;
1608
1609        assert!(result.is_ok());
1610        assert_eq!(saw_logs.load(Ordering::SeqCst), 0);
1611    }
1612
1613    #[tokio::test]
1614    async fn log_sink_receives_provider_lines() {
1615        let saw_logs = Arc::new(AtomicU32::new(0));
1616        let provider = SinkCapture {
1617            output: default_output(),
1618            saw_logs: saw_logs.clone(),
1619        };
1620        let sink = VecSink::new();
1621
1622        let _ = Agent::new()
1623            .prompt("test")
1624            .log_sink(sink.clone() as Arc<dyn LogSink>)
1625            .run(&provider)
1626            .await;
1627
1628        let lines = sink.0.lock().unwrap();
1629        assert_eq!(lines.len(), 1);
1630        assert_eq!(lines[0].0, "stdout");
1631        assert_eq!(lines[0].1, "streaming line");
1632    }
1633}