Skip to main content

ironflow_core/operations/
agent.rs

1//! Agent operation - build and execute AI agent calls.
2//!
3//! The [`Agent`] builder lets you configure a single agent invocation (model,
4//! prompt, tools, budget, permissions, etc.) and execute it through any
5//! [`AgentProvider`]. The result is an [`AgentResult`] that provides typed
6//! access to the agent's response, session metadata, and usage statistics.
7//!
8//! # Examples
9//!
10//! ```no_run
11//! use ironflow_core::prelude::*;
12//!
13//! # async fn example() -> Result<(), OperationError> {
14//! let provider = ClaudeCodeProvider::new();
15//!
16//! let result = Agent::new()
17//!     .prompt("Summarize the README.md file")
18//!     .model(Model::SONNET)
19//!     .max_turns(3)
20//!     .run(&provider)
21//!     .await?;
22//!
23//! println!("{}", result.text());
24//! # Ok(())
25//! # }
26//! ```
27
28use std::any;
29
30use schemars::{JsonSchema, schema_for};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, from_value, to_string};
34use tokio::time;
35use tracing::{info, warn};
36
37use crate::error::OperationError;
38#[cfg(feature = "prometheus")]
39use crate::metric_names;
40use crate::provider::{AgentConfig, AgentOutput, AgentProvider};
41use crate::retry::RetryPolicy;
42
43/// Provider-agnostic model identifiers.
44///
45/// Constants are provided for well-known Claude models, but any string
46/// is accepted - custom [`AgentProvider`] implementations interpret the
47/// model identifier however they wish.
48///
49/// # Examples
50///
51/// ```no_run
52/// use ironflow_core::prelude::*;
53///
54/// # async fn example() -> Result<(), OperationError> {
55/// let provider = ClaudeCodeProvider::new();
56///
57/// // Using a built-in constant
58/// let r = Agent::new()
59///     .prompt("hi")
60///     .model(Model::SONNET)
61///     .run(&provider)
62///     .await?;
63///
64/// // Using a custom model string
65/// let r = Agent::new()
66///     .prompt("hi")
67///     .model("mistral-large-latest")
68///     .run(&provider)
69///     .await?;
70/// # Ok(())
71/// # }
72/// ```
73pub struct Model;
74
75impl Model {
76    // ── Aliases (latest version, CLI resolves to current) ───────────
77
78    /// Claude Sonnet - balanced speed and capability (default).
79    pub const SONNET: &str = "sonnet";
80    /// Claude Opus - highest capability.
81    pub const OPUS: &str = "opus";
82    /// Claude Haiku - fastest and cheapest.
83    pub const HAIKU: &str = "haiku";
84
85    // ── Claude 4.5 ─────────────────────────────────────────────────
86
87    /// Claude Haiku 4.5.
88    pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
89
90    // ── Claude 4.6 - 200K context ──────────────────────────────────
91
92    /// Claude Sonnet 4.6.
93    pub const SONNET_46: &str = "claude-sonnet-4-6";
94    /// Claude Opus 4.6.
95    pub const OPUS_46: &str = "claude-opus-4-6";
96
97    // ── Claude 4.6 - 1M context ────────────────────────────────────
98
99    /// Claude Sonnet 4.6 with 1M token context window.
100    pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
101    /// Claude Opus 4.6 with 1M token context window.
102    pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
103}
104
105/// Controls how the agent handles tool-use permission prompts.
106///
107/// These map to the `--permission-mode` and `--dangerously-skip-permissions`
108/// flags in the Claude CLI.
109#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
110pub enum PermissionMode {
111    /// Use the CLI default permission behavior.
112    Default,
113    /// Automatically approve tool-use requests.
114    Auto,
115    /// Suppress all permission prompts (the agent proceeds without asking).
116    DontAsk,
117    /// Skip all permission checks entirely.
118    ///
119    /// **Warning**: the agent will have unrestricted filesystem and shell access.
120    BypassPermissions,
121}
122
123/// Builder for a single agent invocation.
124///
125/// Create with [`Agent::new`], chain configuration methods, then call
126/// [`run`](Agent::run) with an [`AgentProvider`] to execute.
127///
128/// # Examples
129///
130/// ```no_run
131/// use ironflow_core::prelude::*;
132///
133/// # async fn example() -> Result<(), OperationError> {
134/// let provider = ClaudeCodeProvider::new();
135///
136/// let result = Agent::new()
137///     .system_prompt("You are a Rust expert.")
138///     .prompt("Review this code for safety issues.")
139///     .model(Model::OPUS)
140///     .allowed_tools(&["Read", "Grep"])
141///     .max_turns(5)
142///     .max_budget_usd(0.50)
143///     .working_dir("/tmp/project")
144///     .permission_mode(PermissionMode::Auto)
145///     .run(&provider)
146///     .await?;
147///
148/// println!("Cost: ${:.4}", result.cost_usd().unwrap_or(0.0));
149/// # Ok(())
150/// # }
151/// ```
152#[must_use = "an Agent does nothing until .run() is awaited"]
153pub struct Agent {
154    config: AgentConfig,
155    dry_run: Option<bool>,
156    retry_policy: Option<RetryPolicy>,
157}
158
159impl Agent {
160    /// Create a new agent builder with default settings.
161    ///
162    /// Defaults: [`Model::SONNET`], no system prompt, no tool restrictions,
163    /// no budget/turn limits, [`PermissionMode::Default`].
164    pub fn new() -> Self {
165        Self {
166            config: AgentConfig::new(""),
167            dry_run: None,
168            retry_policy: None,
169        }
170    }
171
172    /// Set the system prompt that defines the agent's persona or constraints.
173    pub fn system_prompt(mut self, prompt: &str) -> Self {
174        self.config.system_prompt = Some(prompt.to_string());
175        self
176    }
177
178    /// Set the user prompt - the main instruction sent to the agent.
179    pub fn prompt(mut self, prompt: &str) -> Self {
180        self.config.prompt = prompt.to_string();
181        self
182    }
183
184    /// Set the model to use for this invocation.
185    ///
186    /// Accepts any string-like value. Use [`Model`] constants for well-known
187    /// Claude models, or pass an arbitrary string for custom providers.
188    ///
189    /// Defaults to [`Model::SONNET`] if not called.
190    pub fn model(mut self, model: impl Into<String>) -> Self {
191        self.config.model = model.into();
192        self
193    }
194
195    /// Restrict which tools the agent may invoke.
196    ///
197    /// Pass an empty slice (or do not call this method) to allow the provider
198    /// default set of tools.
199    pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
200        self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
201        self
202    }
203
204    /// Set the maximum number of agentic turns.
205    ///
206    /// # Panics
207    ///
208    /// Panics if `turns` is `0`.
209    pub fn max_turns(mut self, turns: u32) -> Self {
210        assert!(turns > 0, "max_turns must be greater than 0");
211        self.config.max_turns = Some(turns);
212        self
213    }
214
215    /// Set the maximum spend in USD for this invocation.
216    ///
217    /// # Panics
218    ///
219    /// Panics if `budget` is negative, NaN, or infinity.
220    pub fn max_budget_usd(mut self, budget: f64) -> Self {
221        assert!(
222            budget.is_finite() && budget > 0.0,
223            "budget must be a positive finite number, got {budget}"
224        );
225        self.config.max_budget_usd = Some(budget);
226        self
227    }
228
229    /// Set the working directory for the agent process.
230    pub fn working_dir(mut self, dir: &str) -> Self {
231        self.config.working_dir = Some(dir.to_string());
232        self
233    }
234
235    /// Set the path to an MCP (Model Context Protocol) server configuration file.
236    pub fn mcp_config(mut self, config: &str) -> Self {
237        self.config.mcp_config = Some(config.to_string());
238        self
239    }
240
241    /// Set the permission mode controlling tool-use approval behavior.
242    ///
243    /// See [`PermissionMode`] for details on each variant.
244    pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
245        self.config.permission_mode = mode;
246        self
247    }
248
249    /// Request structured (typed) output from the agent.
250    ///
251    /// The type `T` must implement [`JsonSchema`]. The generated schema is sent
252    /// to the provider so the model returns JSON conforming to `T`, which can
253    /// then be deserialized with [`AgentResult::json`].
254    ///
255    /// # Examples
256    ///
257    /// ```no_run
258    /// use ironflow_core::prelude::*;
259    ///
260    /// #[derive(Deserialize, JsonSchema)]
261    /// struct Review {
262    ///     score: u8,
263    ///     summary: String,
264    /// }
265    ///
266    /// # async fn example() -> Result<(), OperationError> {
267    /// let provider = ClaudeCodeProvider::new();
268    /// let result = Agent::new()
269    ///     .prompt("Review the codebase")
270    ///     .output::<Review>()
271    ///     .run(&provider)
272    ///     .await?;
273    ///
274    /// let review: Review = result.json().expect("schema-validated output");
275    /// println!("Score: {}/10 - {}", review.score, review.summary);
276    /// # Ok(())
277    /// # }
278    /// ```
279    pub fn output<T: JsonSchema>(mut self) -> Self {
280        let schema = schema_for!(T);
281        self.config.json_schema = match to_string(&schema) {
282            Ok(s) => Some(s),
283            Err(e) => {
284                warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
285                None
286            }
287        };
288        self
289    }
290
291    /// Set structured output from a pre-serialized JSON Schema string.
292    ///
293    /// Use this when the schema comes from configuration or another source
294    /// rather than a Rust type. For type-safe schema generation, prefer
295    /// [`output`](Agent::output).
296    ///
297    /// **Important:** structured output requires `max_turns >= 2`. The Claude CLI
298    /// uses the first turn for reasoning and a second turn to produce the
299    /// schema-conforming JSON.
300    ///
301    /// # Examples
302    ///
303    /// ```no_run
304    /// use ironflow_core::prelude::*;
305    ///
306    /// # async fn example() -> Result<(), OperationError> {
307    /// let schema = r#"{"type":"object","properties":{"labels":{"type":"array","items":{"type":"string"}}}}"#;
308    /// let agent = Agent::new()
309    ///     .prompt("Classify this email")
310    ///     .output_schema_raw(schema);
311    /// # Ok(())
312    /// # }
313    /// ```
314    pub fn output_schema_raw(mut self, schema: &str) -> Self {
315        self.config.json_schema = Some(schema.to_string());
316        self
317    }
318
319    /// Retry the agent invocation up to `max_retries` times on transient failures.
320    ///
321    /// Uses default exponential backoff settings (200ms initial, 2x multiplier,
322    /// 30s cap). For custom backoff parameters, use [`retry_policy`](Agent::retry_policy).
323    ///
324    /// Only transient errors are retried: process failures and timeouts.
325    /// Deterministic errors (prompt too large, schema validation) are never retried.
326    ///
327    /// # Panics
328    ///
329    /// Panics if `max_retries` is `0`.
330    ///
331    /// # Examples
332    ///
333    /// ```no_run
334    /// use ironflow_core::prelude::*;
335    ///
336    /// # async fn example() -> Result<(), OperationError> {
337    /// let provider = ClaudeCodeProvider::new();
338    /// let result = Agent::new()
339    ///     .prompt("Summarize the codebase")
340    ///     .retry(2)
341    ///     .run(&provider)
342    ///     .await?;
343    /// # Ok(())
344    /// # }
345    /// ```
346    pub fn retry(mut self, max_retries: u32) -> Self {
347        self.retry_policy = Some(RetryPolicy::new(max_retries));
348        self
349    }
350
351    /// Set a custom [`RetryPolicy`] for this agent invocation.
352    ///
353    /// Allows full control over backoff duration, multiplier, and max delay.
354    /// See [`RetryPolicy`] for details.
355    ///
356    /// # Examples
357    ///
358    /// ```no_run
359    /// use std::time::Duration;
360    /// use ironflow_core::prelude::*;
361    /// use ironflow_core::retry::RetryPolicy;
362    ///
363    /// # async fn example() -> Result<(), OperationError> {
364    /// let provider = ClaudeCodeProvider::new();
365    /// let result = Agent::new()
366    ///     .prompt("Analyze the code")
367    ///     .retry_policy(
368    ///         RetryPolicy::new(3)
369    ///             .backoff(Duration::from_secs(1))
370    ///             .max_backoff(Duration::from_secs(60))
371    ///     )
372    ///     .run(&provider)
373    ///     .await?;
374    /// # Ok(())
375    /// # }
376    /// ```
377    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
378        self.retry_policy = Some(policy);
379        self
380    }
381
382    /// Enable or disable dry-run mode for this specific operation.
383    ///
384    /// When dry-run is active, the agent call is logged but not executed.
385    /// A synthetic [`AgentResult`] is returned with a placeholder text,
386    /// zero cost, and zero tokens.
387    ///
388    /// If not set, falls back to the global dry-run setting
389    /// (see [`set_dry_run`](crate::dry_run::set_dry_run)).
390    pub fn dry_run(mut self, enabled: bool) -> Self {
391        self.dry_run = Some(enabled);
392        self
393    }
394
395    /// Resume a previous agent conversation by session ID.
396    ///
397    /// Pass the session ID from a previous [`AgentResult::session_id()`] to
398    /// continue the multi-turn conversation.
399    ///
400    /// # Examples
401    ///
402    /// ```no_run
403    /// use ironflow_core::prelude::*;
404    ///
405    /// # async fn example() -> Result<(), OperationError> {
406    /// let provider = ClaudeCodeProvider::new();
407    ///
408    /// let first = Agent::new()
409    ///     .prompt("Analyze the src/ directory")
410    ///     .max_budget_usd(0.10)
411    ///     .run(&provider)
412    ///     .await?;
413    ///
414    /// let session = first.session_id().expect("provider returned session ID");
415    ///
416    /// let followup = Agent::new()
417    ///     .prompt("Now suggest improvements")
418    ///     .resume(session)
419    ///     .max_budget_usd(0.10)
420    ///     .run(&provider)
421    ///     .await?;
422    /// # Ok(())
423    /// # }
424    /// ```
425    ///
426    /// # Panics
427    ///
428    /// Panics if `session_id` is empty or contains characters other than
429    /// alphanumerics, hyphens, and underscores.
430    pub fn resume(mut self, session_id: &str) -> Self {
431        assert!(!session_id.is_empty(), "session_id must not be empty");
432        assert!(
433            session_id
434                .chars()
435                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
436            "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
437        );
438        self.config.resume_session_id = Some(session_id.to_string());
439        self
440    }
441
442    /// Execute the agent invocation using the given [`AgentProvider`].
443    ///
444    /// If a [`retry_policy`](Agent::retry_policy) is configured, transient
445    /// failures (process crashes, timeouts) are retried with exponential
446    /// backoff. Deterministic errors (prompt too large, schema validation)
447    /// are returned immediately without retry.
448    ///
449    /// # Errors
450    ///
451    /// Returns [`OperationError::Agent`] if the provider reports a failure
452    /// (process crash, timeout, or schema validation error).
453    ///
454    /// # Panics
455    ///
456    /// Panics if [`prompt`](Agent::prompt) was never called or the prompt is
457    /// empty (whitespace-only counts as empty).
458    #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
459    pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
460        assert!(
461            !self.config.prompt.trim().is_empty(),
462            "prompt must not be empty - call .prompt(\"...\") before .run()"
463        );
464
465        if crate::dry_run::effective_dry_run(self.dry_run) {
466            info!(
467                prompt_len = self.config.prompt.len(),
468                "[dry-run] agent call skipped"
469            );
470            let mut output =
471                AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
472            output.cost_usd = Some(0.0);
473            output.input_tokens = Some(0);
474            output.output_tokens = Some(0);
475            return Ok(AgentResult { output });
476        }
477
478        let result = self.invoke_once(provider).await;
479
480        let policy = match &self.retry_policy {
481            Some(p) => p,
482            None => return result,
483        };
484
485        // Non-retryable errors are returned immediately.
486        if let Err(ref err) = result {
487            if !crate::retry::is_retryable(err) {
488                return result;
489            }
490        } else {
491            return result;
492        }
493
494        let mut last_result = result;
495
496        for attempt in 0..policy.max_retries {
497            let delay = policy.delay_for_attempt(attempt);
498            warn!(
499                attempt = attempt + 1,
500                max_retries = policy.max_retries,
501                delay_ms = delay.as_millis() as u64,
502                "retrying agent invocation"
503            );
504            time::sleep(delay).await;
505
506            last_result = self.invoke_once(provider).await;
507
508            match &last_result {
509                Ok(_) => return last_result,
510                Err(err) if !crate::retry::is_retryable(err) => return last_result,
511                _ => {}
512            }
513        }
514
515        last_result
516    }
517
518    /// Execute a single agent invocation attempt (no retry logic).
519    async fn invoke_once(
520        &self,
521        provider: &dyn AgentProvider,
522    ) -> Result<AgentResult, OperationError> {
523        #[cfg(feature = "prometheus")]
524        let model_label = self.config.model.to_string();
525
526        let output = match provider.invoke(&self.config).await {
527            Ok(output) => output,
528            Err(e) => {
529                #[cfg(feature = "prometheus")]
530                {
531                    metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
532                }
533                return Err(OperationError::Agent(e));
534            }
535        };
536
537        info!(
538            duration_ms = output.duration_ms,
539            cost_usd = output.cost_usd,
540            input_tokens = output.input_tokens,
541            output_tokens = output.output_tokens,
542            model = output.model,
543            "agent completed"
544        );
545
546        #[cfg(feature = "prometheus")]
547        {
548            metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
549            metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
550                .record(output.duration_ms as f64 / 1000.0);
551            if let Some(cost) = output.cost_usd {
552                metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
553                    .increment(cost);
554            }
555            if let Some(tokens) = output.input_tokens {
556                metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
557            }
558            if let Some(tokens) = output.output_tokens {
559                metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
560                    .increment(tokens);
561            }
562        }
563
564        Ok(AgentResult { output })
565    }
566}
567
568impl Default for Agent {
569    fn default() -> Self {
570        Self::new()
571    }
572}
573
574/// The result of a successful agent invocation.
575///
576/// Wraps the raw [`AgentOutput`] and provides convenience accessors for the
577/// response text, typed JSON deserialization, session metadata, and usage stats.
578#[derive(Debug)]
579pub struct AgentResult {
580    output: AgentOutput,
581}
582
583impl AgentResult {
584    /// Return the agent's response as a plain text string.
585    ///
586    /// If the underlying value is not a JSON string (e.g. when structured output
587    /// was requested), returns an empty string and logs a warning.
588    pub fn text(&self) -> &str {
589        match self.output.value.as_str() {
590            Some(s) => s,
591            None => {
592                warn!(
593                    value_type = self.output.value.to_string(),
594                    "agent output is not a string, returning empty"
595                );
596                ""
597            }
598        }
599    }
600
601    /// Return the raw JSON [`Value`] of the agent's response.
602    pub fn value(&self) -> &Value {
603        &self.output.value
604    }
605
606    /// Deserialize the agent's response into the given type `T`.
607    ///
608    /// This clones the underlying JSON value. If you no longer need the
609    /// `AgentResult` afterwards, use [`into_json`](AgentResult::into_json)
610    /// instead to avoid the clone.
611    ///
612    /// # Errors
613    ///
614    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
615    pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
616        from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
617    }
618
619    /// Consume the result and deserialize the response into `T` without cloning.
620    ///
621    /// # Errors
622    ///
623    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
624    pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
625        from_value(self.output.value).map_err(OperationError::deserialize::<T>)
626    }
627
628    /// Build an `AgentResult` from a raw [`AgentOutput`].
629    ///
630    /// This is available only in test builds to simplify test setup without
631    /// going through the full record/replay pipeline.
632    #[cfg(test)]
633    pub(crate) fn from_output(output: AgentOutput) -> Self {
634        Self { output }
635    }
636
637    /// Return the provider-assigned session ID, if available.
638    pub fn session_id(&self) -> Option<&str> {
639        self.output.session_id.as_deref()
640    }
641
642    /// Return the cost of this invocation in USD, if reported by the provider.
643    pub fn cost_usd(&self) -> Option<f64> {
644        self.output.cost_usd
645    }
646
647    /// Return the number of input tokens consumed, if reported.
648    pub fn input_tokens(&self) -> Option<u64> {
649        self.output.input_tokens
650    }
651
652    /// Return the number of output tokens generated, if reported.
653    pub fn output_tokens(&self) -> Option<u64> {
654        self.output.output_tokens
655    }
656
657    /// Return the wall-clock duration of the invocation in milliseconds.
658    pub fn duration_ms(&self) -> u64 {
659        self.output.duration_ms
660    }
661
662    /// Return the concrete model identifier used, if reported by the provider.
663    pub fn model(&self) -> Option<&str> {
664        self.output.model.as_deref()
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use crate::error::AgentError;
672    use crate::provider::InvokeFuture;
673    use serde_json::json;
674
675    struct TestProvider {
676        output: AgentOutput,
677    }
678
679    impl AgentProvider for TestProvider {
680        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
681            Box::pin(async move {
682                Ok(AgentOutput {
683                    value: self.output.value.clone(),
684                    session_id: self.output.session_id.clone(),
685                    cost_usd: self.output.cost_usd,
686                    input_tokens: self.output.input_tokens,
687                    output_tokens: self.output.output_tokens,
688                    model: self.output.model.clone(),
689                    duration_ms: self.output.duration_ms,
690                })
691            })
692        }
693    }
694
695    struct ConfigCapture {
696        output: AgentOutput,
697    }
698
699    impl AgentProvider for ConfigCapture {
700        fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
701            let config_json = serde_json::to_value(config).unwrap();
702            Box::pin(async move {
703                Ok(AgentOutput {
704                    value: config_json,
705                    session_id: self.output.session_id.clone(),
706                    cost_usd: self.output.cost_usd,
707                    input_tokens: self.output.input_tokens,
708                    output_tokens: self.output.output_tokens,
709                    model: self.output.model.clone(),
710                    duration_ms: self.output.duration_ms,
711                })
712            })
713        }
714    }
715
716    fn default_output() -> AgentOutput {
717        AgentOutput {
718            value: json!("test output"),
719            session_id: Some("sess-123".to_string()),
720            cost_usd: Some(0.05),
721            input_tokens: Some(100),
722            output_tokens: Some(50),
723            model: Some("sonnet".to_string()),
724            duration_ms: 1500,
725        }
726    }
727
728    // --- Model constants ---
729
730    #[test]
731    fn model_constants_have_expected_values() {
732        assert_eq!(Model::SONNET, "sonnet");
733        assert_eq!(Model::OPUS, "opus");
734        assert_eq!(Model::HAIKU, "haiku");
735        assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
736        assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
737        assert_eq!(Model::OPUS_46, "claude-opus-4-6");
738        assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
739        assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
740    }
741
742    // --- Agent::new() defaults via ConfigCapture ---
743
744    #[tokio::test]
745    async fn agent_new_default_values() {
746        let provider = ConfigCapture {
747            output: default_output(),
748        };
749        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
750
751        let config = result.value();
752        assert_eq!(config["system_prompt"], json!(null));
753        assert_eq!(config["prompt"], json!("hi"));
754        assert_eq!(config["model"], json!("sonnet"));
755        assert_eq!(config["allowed_tools"], json!([]));
756        assert_eq!(config["max_turns"], json!(null));
757        assert_eq!(config["max_budget_usd"], json!(null));
758        assert_eq!(config["working_dir"], json!(null));
759        assert_eq!(config["mcp_config"], json!(null));
760        assert_eq!(config["permission_mode"], json!("Default"));
761        assert_eq!(config["json_schema"], json!(null));
762    }
763
764    #[tokio::test]
765    async fn agent_default_matches_new() {
766        let provider = ConfigCapture {
767            output: default_output(),
768        };
769        let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
770        let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
771
772        assert_eq!(result_new.value(), result_default.value());
773    }
774
775    // --- Builder methods ---
776
777    #[tokio::test]
778    async fn builder_methods_store_values_correctly() {
779        let provider = ConfigCapture {
780            output: default_output(),
781        };
782        let result = Agent::new()
783            .system_prompt("you are a bot")
784            .prompt("do something")
785            .model(Model::OPUS)
786            .allowed_tools(&["Read", "Write"])
787            .max_turns(5)
788            .max_budget_usd(1.5)
789            .working_dir("/tmp")
790            .mcp_config("{}")
791            .permission_mode(PermissionMode::Auto)
792            .run(&provider)
793            .await
794            .unwrap();
795
796        let config = result.value();
797        assert_eq!(config["system_prompt"], json!("you are a bot"));
798        assert_eq!(config["prompt"], json!("do something"));
799        assert_eq!(config["model"], json!("opus"));
800        assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
801        assert_eq!(config["max_turns"], json!(5));
802        assert_eq!(config["max_budget_usd"], json!(1.5));
803        assert_eq!(config["working_dir"], json!("/tmp"));
804        assert_eq!(config["mcp_config"], json!("{}"));
805        assert_eq!(config["permission_mode"], json!("Auto"));
806    }
807
808    // --- Panics ---
809
810    #[test]
811    #[should_panic(expected = "max_turns must be greater than 0")]
812    fn max_turns_zero_panics() {
813        let _ = Agent::new().max_turns(0);
814    }
815
816    #[test]
817    #[should_panic(expected = "budget must be a positive finite number")]
818    fn max_budget_negative_panics() {
819        let _ = Agent::new().max_budget_usd(-1.0);
820    }
821
822    #[test]
823    #[should_panic(expected = "budget must be a positive finite number")]
824    fn max_budget_nan_panics() {
825        let _ = Agent::new().max_budget_usd(f64::NAN);
826    }
827
828    #[test]
829    #[should_panic(expected = "budget must be a positive finite number")]
830    fn max_budget_infinity_panics() {
831        let _ = Agent::new().max_budget_usd(f64::INFINITY);
832    }
833
834    // --- AgentResult accessors ---
835
836    #[tokio::test]
837    async fn agent_result_text_with_string_value() {
838        let provider = TestProvider {
839            output: AgentOutput {
840                value: json!("hello world"),
841                ..default_output()
842            },
843        };
844        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
845        assert_eq!(result.text(), "hello world");
846    }
847
848    #[tokio::test]
849    async fn agent_result_text_with_non_string_value() {
850        let provider = TestProvider {
851            output: AgentOutput {
852                value: json!(42),
853                ..default_output()
854            },
855        };
856        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
857        assert_eq!(result.text(), "");
858    }
859
860    #[tokio::test]
861    async fn agent_result_text_with_null_value() {
862        let provider = TestProvider {
863            output: AgentOutput {
864                value: json!(null),
865                ..default_output()
866            },
867        };
868        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
869        assert_eq!(result.text(), "");
870    }
871
872    #[tokio::test]
873    async fn agent_result_json_successful_deserialize() {
874        #[derive(Deserialize, PartialEq, Debug)]
875        struct MyOutput {
876            name: String,
877            count: u32,
878        }
879        let provider = TestProvider {
880            output: AgentOutput {
881                value: json!({"name": "test", "count": 7}),
882                ..default_output()
883            },
884        };
885        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
886        let parsed: MyOutput = result.json().unwrap();
887        assert_eq!(parsed.name, "test");
888        assert_eq!(parsed.count, 7);
889    }
890
891    #[tokio::test]
892    async fn agent_result_json_failed_deserialize() {
893        #[derive(Debug, Deserialize)]
894        #[allow(dead_code)]
895        struct MyOutput {
896            name: String,
897        }
898        let provider = TestProvider {
899            output: AgentOutput {
900                value: json!(42),
901                ..default_output()
902            },
903        };
904        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
905        let err = result.json::<MyOutput>().unwrap_err();
906        assert!(matches!(err, OperationError::Deserialize { .. }));
907    }
908
909    #[tokio::test]
910    async fn agent_result_accessors() {
911        let provider = TestProvider {
912            output: AgentOutput {
913                value: json!("v"),
914                session_id: Some("s-1".to_string()),
915                cost_usd: Some(0.123),
916                input_tokens: Some(999),
917                output_tokens: Some(456),
918                model: Some("opus".to_string()),
919                duration_ms: 2000,
920            },
921        };
922        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
923        assert_eq!(result.session_id(), Some("s-1"));
924        assert_eq!(result.cost_usd(), Some(0.123));
925        assert_eq!(result.input_tokens(), Some(999));
926        assert_eq!(result.output_tokens(), Some(456));
927        assert_eq!(result.duration_ms(), 2000);
928        assert_eq!(result.model(), Some("opus"));
929    }
930
931    // --- Session resume ---
932
933    #[tokio::test]
934    async fn resume_passes_session_id_in_config() {
935        let provider = ConfigCapture {
936            output: default_output(),
937        };
938        let result = Agent::new()
939            .prompt("followup")
940            .resume("sess-abc")
941            .run(&provider)
942            .await
943            .unwrap();
944
945        let config = result.value();
946        assert_eq!(config["resume_session_id"], json!("sess-abc"));
947    }
948
949    #[tokio::test]
950    async fn no_resume_has_null_session_id() {
951        let provider = ConfigCapture {
952            output: default_output(),
953        };
954        let result = Agent::new()
955            .prompt("first call")
956            .run(&provider)
957            .await
958            .unwrap();
959
960        let config = result.value();
961        assert_eq!(config["resume_session_id"], json!(null));
962    }
963
964    #[test]
965    #[should_panic(expected = "session_id must not be empty")]
966    fn resume_empty_session_id_panics() {
967        let _ = Agent::new().resume("");
968    }
969
970    #[test]
971    #[should_panic(expected = "session_id must only contain")]
972    fn resume_invalid_chars_panics() {
973        let _ = Agent::new().resume("sess;rm -rf /");
974    }
975
976    #[test]
977    fn resume_valid_formats_accepted() {
978        let _ = Agent::new().resume("sess-abc123");
979        let _ = Agent::new().resume("a1b2c3d4_session");
980        let _ = Agent::new().resume("abc-DEF-123_456");
981    }
982
983    #[tokio::test]
984    #[should_panic(expected = "prompt must not be empty")]
985    async fn run_without_prompt_panics() {
986        let provider = TestProvider {
987            output: default_output(),
988        };
989        let _ = Agent::new().run(&provider).await;
990    }
991
992    #[tokio::test]
993    #[should_panic(expected = "prompt must not be empty")]
994    async fn run_with_whitespace_only_prompt_panics() {
995        let provider = TestProvider {
996            output: default_output(),
997        };
998        let _ = Agent::new().prompt("   ").run(&provider).await;
999    }
1000
1001    // --- Model accepts arbitrary strings ---
1002
1003    #[tokio::test]
1004    async fn model_accepts_custom_string() {
1005        let provider = ConfigCapture {
1006            output: default_output(),
1007        };
1008        let result = Agent::new()
1009            .prompt("hi")
1010            .model("mistral-large-latest")
1011            .run(&provider)
1012            .await
1013            .unwrap();
1014        assert_eq!(result.value()["model"], json!("mistral-large-latest"));
1015    }
1016
1017    #[tokio::test]
1018    async fn model_accepts_owned_string() {
1019        let provider = ConfigCapture {
1020            output: default_output(),
1021        };
1022        let model_name = String::from("gpt-4o");
1023        let result = Agent::new()
1024            .prompt("hi")
1025            .model(model_name)
1026            .run(&provider)
1027            .await
1028            .unwrap();
1029        assert_eq!(result.value()["model"], json!("gpt-4o"));
1030    }
1031
1032    #[tokio::test]
1033    async fn into_json_success() {
1034        #[derive(Deserialize, PartialEq, Debug)]
1035        struct Out {
1036            name: String,
1037        }
1038        let provider = TestProvider {
1039            output: AgentOutput {
1040                value: json!({"name": "test"}),
1041                ..default_output()
1042            },
1043        };
1044        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1045        let parsed: Out = result.into_json().unwrap();
1046        assert_eq!(parsed.name, "test");
1047    }
1048
1049    #[tokio::test]
1050    async fn into_json_failure() {
1051        #[derive(Debug, Deserialize)]
1052        #[allow(dead_code)]
1053        struct Out {
1054            name: String,
1055        }
1056        let provider = TestProvider {
1057            output: AgentOutput {
1058                value: json!(42),
1059                ..default_output()
1060            },
1061        };
1062        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1063        let err = result.into_json::<Out>().unwrap_err();
1064        assert!(matches!(err, OperationError::Deserialize { .. }));
1065    }
1066
1067    #[test]
1068    fn from_output_creates_result() {
1069        let output = AgentOutput {
1070            value: json!("hello"),
1071            ..default_output()
1072        };
1073        let result = AgentResult::from_output(output);
1074        assert_eq!(result.text(), "hello");
1075        assert_eq!(result.cost_usd(), Some(0.05));
1076    }
1077
1078    #[test]
1079    #[should_panic(expected = "budget must be a positive finite number")]
1080    fn max_budget_zero_panics() {
1081        let _ = Agent::new().max_budget_usd(0.0);
1082    }
1083
1084    #[test]
1085    fn model_constant_equality() {
1086        assert_eq!(Model::SONNET, "sonnet");
1087        assert_ne!(Model::SONNET, Model::OPUS);
1088    }
1089
1090    #[test]
1091    fn permission_mode_serialize_deserialize_roundtrip() {
1092        for mode in [
1093            PermissionMode::Default,
1094            PermissionMode::Auto,
1095            PermissionMode::DontAsk,
1096            PermissionMode::BypassPermissions,
1097        ] {
1098            let json = to_string(&mode).unwrap();
1099            let back: PermissionMode = serde_json::from_str(&json).unwrap();
1100            assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1101        }
1102    }
1103
1104    // --- Retry builder ---
1105
1106    #[test]
1107    fn retry_builder_stores_policy() {
1108        let agent = Agent::new().retry(3);
1109        assert!(agent.retry_policy.is_some());
1110        assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1111    }
1112
1113    #[test]
1114    fn retry_policy_builder_stores_custom_policy() {
1115        use crate::retry::RetryPolicy;
1116        let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1117        let agent = Agent::new().retry_policy(policy);
1118        let p = agent.retry_policy.unwrap();
1119        assert_eq!(p.max_retries(), 5);
1120    }
1121
1122    #[test]
1123    fn no_retry_by_default() {
1124        let agent = Agent::new();
1125        assert!(agent.retry_policy.is_none());
1126    }
1127
1128    // --- Retry behavior ---
1129
1130    use std::sync::Arc;
1131    use std::sync::atomic::{AtomicU32, Ordering};
1132    use std::time::Duration;
1133
1134    struct FailNTimesProvider {
1135        fail_count: AtomicU32,
1136        failures_before_success: u32,
1137        output: AgentOutput,
1138    }
1139
1140    impl AgentProvider for FailNTimesProvider {
1141        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1142            Box::pin(async move {
1143                let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1144                if current < self.failures_before_success {
1145                    Err(AgentError::ProcessFailed {
1146                        exit_code: 1,
1147                        stderr: format!("transient failure #{}", current + 1),
1148                    })
1149                } else {
1150                    Ok(AgentOutput {
1151                        value: self.output.value.clone(),
1152                        session_id: self.output.session_id.clone(),
1153                        cost_usd: self.output.cost_usd,
1154                        input_tokens: self.output.input_tokens,
1155                        output_tokens: self.output.output_tokens,
1156                        model: self.output.model.clone(),
1157                        duration_ms: self.output.duration_ms,
1158                    })
1159                }
1160            })
1161        }
1162    }
1163
1164    #[tokio::test]
1165    async fn retry_succeeds_after_transient_failures() {
1166        let provider = FailNTimesProvider {
1167            fail_count: AtomicU32::new(0),
1168            failures_before_success: 2,
1169            output: default_output(),
1170        };
1171        let result = Agent::new()
1172            .prompt("test")
1173            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1174            .run(&provider)
1175            .await;
1176
1177        assert!(result.is_ok());
1178        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
1179    }
1180
1181    #[tokio::test]
1182    async fn retry_exhausted_returns_last_error() {
1183        let provider = FailNTimesProvider {
1184            fail_count: AtomicU32::new(0),
1185            failures_before_success: 10, // always fails
1186            output: default_output(),
1187        };
1188        let result = Agent::new()
1189            .prompt("test")
1190            .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1191            .run(&provider)
1192            .await;
1193
1194        assert!(result.is_err());
1195        // 1 initial + 2 retries = 3 total
1196        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1197    }
1198
1199    #[tokio::test]
1200    async fn retry_does_not_retry_non_retryable_errors() {
1201        let call_count = Arc::new(AtomicU32::new(0));
1202        let count = call_count.clone();
1203
1204        struct CountingNonRetryable {
1205            count: Arc<AtomicU32>,
1206        }
1207        impl AgentProvider for CountingNonRetryable {
1208            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1209                self.count.fetch_add(1, Ordering::SeqCst);
1210                Box::pin(async move {
1211                    Err(AgentError::SchemaValidation {
1212                        expected: "object".to_string(),
1213                        got: "string".to_string(),
1214                    })
1215                })
1216            }
1217        }
1218
1219        let provider = CountingNonRetryable { count };
1220        let result = Agent::new()
1221            .prompt("test")
1222            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1223            .run(&provider)
1224            .await;
1225
1226        assert!(result.is_err());
1227        // Only 1 attempt, no retries for non-retryable errors
1228        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1229    }
1230
1231    #[tokio::test]
1232    async fn no_retry_without_policy() {
1233        let provider = FailNTimesProvider {
1234            fail_count: AtomicU32::new(0),
1235            failures_before_success: 1,
1236            output: default_output(),
1237        };
1238        let result = Agent::new().prompt("test").run(&provider).await;
1239
1240        assert!(result.is_err());
1241        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1242    }
1243}