Skip to main content

ironflow_core/operations/
agent.rs

1//! Agent operation - build and execute AI agent calls.
2//!
3//! The [`Agent`] builder lets you configure a single agent invocation (model,
4//! prompt, tools, budget, permissions, etc.) and execute it through any
5//! [`AgentProvider`]. The result is an [`AgentResult`] that provides typed
6//! access to the agent's response, session metadata, and usage statistics.
7//!
8//! # Examples
9//!
10//! ```no_run
11//! use ironflow_core::prelude::*;
12//!
13//! # async fn example() -> Result<(), OperationError> {
14//! let provider = ClaudeCodeProvider::new();
15//!
16//! let result = Agent::new()
17//!     .prompt("Summarize the README.md file")
18//!     .model(Model::Sonnet)
19//!     .max_turns(3)
20//!     .run(&provider)
21//!     .await?;
22//!
23//! println!("{}", result.text());
24//! # Ok(())
25//! # }
26//! ```
27
28use std::fmt::{self, Display};
29
30use schemars::{JsonSchema, schema_for};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, from_value, to_string};
34use tracing::{info, warn};
35
36use crate::error::{AgentError, OperationError};
37#[cfg(feature = "prometheus")]
38use crate::metric_names;
39use crate::provider::{AgentConfig, AgentOutput, AgentProvider};
40use crate::retry::RetryPolicy;
41use crate::utils::estimate_tokens;
42
43/// Available Claude models.
44///
45/// Each variant maps to a model name passed to the Claude CLI.
46/// Use the alias variants (`Sonnet`, `Opus`, `Haiku`) for the latest version,
47/// or the explicit variants for a specific model ID and context window.
48#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
49#[serde(rename_all = "lowercase")]
50pub enum Model {
51    // ── Aliases (latest version, CLI resolves to current) ───────────
52    /// Claude Sonnet - balanced speed and capability (default, 200K context).
53    Sonnet,
54    /// Claude Opus - highest capability (200K context).
55    Opus,
56    /// Claude Haiku - fastest and cheapest (200K context).
57    Haiku,
58
59    // ── Claude 4.5 ─────────────────────────────────────────────────
60    /// Claude Haiku 4.5 - 200K context.
61    #[serde(rename = "claude-haiku-4-5-20251001")]
62    Haiku45,
63
64    // ── Claude 4.6 - 200K context ──────────────────────────────────
65    /// Claude Sonnet 4.6 - 200K context.
66    #[serde(rename = "claude-sonnet-4-6")]
67    Sonnet46,
68    /// Claude Opus 4.6 - 200K context.
69    #[serde(rename = "claude-opus-4-6")]
70    Opus46,
71
72    // ── Claude 4.6 - 1M context ────────────────────────────────────
73    /// Claude Sonnet 4.6 with 1M token context window.
74    #[serde(rename = "claude-sonnet-4-6[1m]")]
75    Sonnet46_1M,
76    /// Claude Opus 4.6 with 1M token context window.
77    #[serde(rename = "claude-opus-4-6[1m]")]
78    Opus46_1M,
79}
80
81impl Model {
82    /// Maximum context window size in tokens.
83    pub fn context_window(self) -> usize {
84        match self {
85            Self::Sonnet46_1M | Self::Opus46_1M => 1_000_000,
86            _ => 200_000,
87        }
88    }
89}
90
91impl Display for Model {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        match serde_json::to_value(self) {
94            Ok(Value::String(s)) => f.write_str(&s),
95            _ => f.write_str("unknown"),
96        }
97    }
98}
99
100/// Controls how the agent handles tool-use permission prompts.
101///
102/// These map to the `--permission-mode` and `--dangerously-skip-permissions`
103/// flags in the Claude CLI.
104#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
105pub enum PermissionMode {
106    /// Use the CLI default permission behavior.
107    Default,
108    /// Automatically approve tool-use requests.
109    Auto,
110    /// Suppress all permission prompts (the agent proceeds without asking).
111    DontAsk,
112    /// Skip all permission checks entirely.
113    ///
114    /// **Warning**: the agent will have unrestricted filesystem and shell access.
115    BypassPermissions,
116}
117
118/// Builder for a single agent invocation.
119///
120/// Create with [`Agent::new`], chain configuration methods, then call
121/// [`run`](Agent::run) with an [`AgentProvider`] to execute.
122///
123/// # Examples
124///
125/// ```no_run
126/// use ironflow_core::prelude::*;
127///
128/// # async fn example() -> Result<(), OperationError> {
129/// let provider = ClaudeCodeProvider::new();
130///
131/// let result = Agent::new()
132///     .system_prompt("You are a Rust expert.")
133///     .prompt("Review this code for safety issues.")
134///     .model(Model::Opus)
135///     .allowed_tools(&["Read", "Grep"])
136///     .max_turns(5)
137///     .max_budget_usd(0.50)
138///     .working_dir("/tmp/project")
139///     .permission_mode(PermissionMode::Auto)
140///     .run(&provider)
141///     .await?;
142///
143/// println!("Cost: ${:.4}", result.cost_usd().unwrap_or(0.0));
144/// # Ok(())
145/// # }
146/// ```
147#[must_use = "an Agent does nothing until .run() is awaited"]
148pub struct Agent {
149    config: AgentConfig,
150    dry_run: Option<bool>,
151    retry_policy: Option<RetryPolicy>,
152}
153
154impl Agent {
155    /// Create a new agent builder with default settings.
156    ///
157    /// Defaults: [`Model::Sonnet`], no system prompt, no tool restrictions,
158    /// no budget/turn limits, [`PermissionMode::Default`].
159    pub fn new() -> Self {
160        Self {
161            config: AgentConfig::new(""),
162            dry_run: None,
163            retry_policy: None,
164        }
165    }
166
167    /// Set the system prompt that defines the agent's persona or constraints.
168    pub fn system_prompt(mut self, prompt: &str) -> Self {
169        self.config.system_prompt = Some(prompt.to_string());
170        self
171    }
172
173    /// Set the user prompt - the main instruction sent to the agent.
174    pub fn prompt(mut self, prompt: &str) -> Self {
175        self.config.prompt = prompt.to_string();
176        self
177    }
178
179    /// Set the model tier to use for this invocation.
180    ///
181    /// Defaults to [`Model::Sonnet`] if not called.
182    pub fn model(mut self, model: Model) -> Self {
183        self.config.model = model;
184        self
185    }
186
187    /// Restrict which tools the agent may invoke.
188    ///
189    /// Pass an empty slice (or do not call this method) to allow the provider
190    /// default set of tools.
191    pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
192        self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
193        self
194    }
195
196    /// Set the maximum number of agentic turns.
197    ///
198    /// # Panics
199    ///
200    /// Panics if `turns` is `0`.
201    pub fn max_turns(mut self, turns: u32) -> Self {
202        assert!(turns > 0, "max_turns must be greater than 0");
203        self.config.max_turns = Some(turns);
204        self
205    }
206
207    /// Set the maximum spend in USD for this invocation.
208    ///
209    /// # Panics
210    ///
211    /// Panics if `budget` is negative, NaN, or infinity.
212    pub fn max_budget_usd(mut self, budget: f64) -> Self {
213        assert!(
214            budget.is_finite() && budget > 0.0,
215            "budget must be a positive finite number, got {budget}"
216        );
217        self.config.max_budget_usd = Some(budget);
218        self
219    }
220
221    /// Set the working directory for the agent process.
222    pub fn working_dir(mut self, dir: &str) -> Self {
223        self.config.working_dir = Some(dir.to_string());
224        self
225    }
226
227    /// Set the path to an MCP (Model Context Protocol) server configuration file.
228    pub fn mcp_config(mut self, config: &str) -> Self {
229        self.config.mcp_config = Some(config.to_string());
230        self
231    }
232
233    /// Set the permission mode controlling tool-use approval behavior.
234    ///
235    /// See [`PermissionMode`] for details on each variant.
236    pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
237        self.config.permission_mode = mode;
238        self
239    }
240
241    /// Request structured (typed) output from the agent.
242    ///
243    /// The type `T` must implement [`JsonSchema`]. The generated schema is sent
244    /// to the provider so the model returns JSON conforming to `T`, which can
245    /// then be deserialized with [`AgentResult::json`].
246    ///
247    /// # Examples
248    ///
249    /// ```no_run
250    /// use ironflow_core::prelude::*;
251    ///
252    /// #[derive(Deserialize, JsonSchema)]
253    /// struct Review {
254    ///     score: u8,
255    ///     summary: String,
256    /// }
257    ///
258    /// # async fn example() -> Result<(), OperationError> {
259    /// let provider = ClaudeCodeProvider::new();
260    /// let result = Agent::new()
261    ///     .prompt("Review the codebase")
262    ///     .output::<Review>()
263    ///     .run(&provider)
264    ///     .await?;
265    ///
266    /// let review: Review = result.json().expect("schema-validated output");
267    /// println!("Score: {}/10 - {}", review.score, review.summary);
268    /// # Ok(())
269    /// # }
270    /// ```
271    pub fn output<T: JsonSchema>(mut self) -> Self {
272        let schema = schema_for!(T);
273        self.config.json_schema = match to_string(&schema) {
274            Ok(s) => Some(s),
275            Err(e) => {
276                warn!(error = %e, type_name = std::any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
277                None
278            }
279        };
280        self
281    }
282
283    /// Retry the agent invocation up to `max_retries` times on transient failures.
284    ///
285    /// Uses default exponential backoff settings (200ms initial, 2x multiplier,
286    /// 30s cap). For custom backoff parameters, use [`retry_policy`](Agent::retry_policy).
287    ///
288    /// Only transient errors are retried: process failures and timeouts.
289    /// Deterministic errors (prompt too large, schema validation) are never retried.
290    ///
291    /// # Panics
292    ///
293    /// Panics if `max_retries` is `0`.
294    ///
295    /// # Examples
296    ///
297    /// ```no_run
298    /// use ironflow_core::prelude::*;
299    ///
300    /// # async fn example() -> Result<(), OperationError> {
301    /// let provider = ClaudeCodeProvider::new();
302    /// let result = Agent::new()
303    ///     .prompt("Summarize the codebase")
304    ///     .retry(2)
305    ///     .run(&provider)
306    ///     .await?;
307    /// # Ok(())
308    /// # }
309    /// ```
310    pub fn retry(mut self, max_retries: u32) -> Self {
311        self.retry_policy = Some(RetryPolicy::new(max_retries));
312        self
313    }
314
315    /// Set a custom [`RetryPolicy`] for this agent invocation.
316    ///
317    /// Allows full control over backoff duration, multiplier, and max delay.
318    /// See [`RetryPolicy`] for details.
319    ///
320    /// # Examples
321    ///
322    /// ```no_run
323    /// use std::time::Duration;
324    /// use ironflow_core::prelude::*;
325    /// use ironflow_core::retry::RetryPolicy;
326    ///
327    /// # async fn example() -> Result<(), OperationError> {
328    /// let provider = ClaudeCodeProvider::new();
329    /// let result = Agent::new()
330    ///     .prompt("Analyze the code")
331    ///     .retry_policy(
332    ///         RetryPolicy::new(3)
333    ///             .backoff(Duration::from_secs(1))
334    ///             .max_backoff(Duration::from_secs(60))
335    ///     )
336    ///     .run(&provider)
337    ///     .await?;
338    /// # Ok(())
339    /// # }
340    /// ```
341    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
342        self.retry_policy = Some(policy);
343        self
344    }
345
346    /// Enable or disable dry-run mode for this specific operation.
347    ///
348    /// When dry-run is active, the agent call is logged but not executed.
349    /// A synthetic [`AgentResult`] is returned with a placeholder text,
350    /// zero cost, and zero tokens.
351    ///
352    /// If not set, falls back to the global dry-run setting
353    /// (see [`set_dry_run`](crate::dry_run::set_dry_run)).
354    pub fn dry_run(mut self, enabled: bool) -> Self {
355        self.dry_run = Some(enabled);
356        self
357    }
358
359    /// Resume a previous agent conversation by session ID.
360    ///
361    /// Pass the session ID from a previous [`AgentResult::session_id()`] to
362    /// continue the multi-turn conversation.
363    ///
364    /// # Examples
365    ///
366    /// ```no_run
367    /// use ironflow_core::prelude::*;
368    ///
369    /// # async fn example() -> Result<(), OperationError> {
370    /// let provider = ClaudeCodeProvider::new();
371    ///
372    /// let first = Agent::new()
373    ///     .prompt("Analyze the src/ directory")
374    ///     .max_budget_usd(0.10)
375    ///     .run(&provider)
376    ///     .await?;
377    ///
378    /// let session = first.session_id().expect("provider returned session ID");
379    ///
380    /// let followup = Agent::new()
381    ///     .prompt("Now suggest improvements")
382    ///     .resume(session)
383    ///     .max_budget_usd(0.10)
384    ///     .run(&provider)
385    ///     .await?;
386    /// # Ok(())
387    /// # }
388    /// ```
389    ///
390    /// # Panics
391    ///
392    /// Panics if `session_id` is empty or contains characters other than
393    /// alphanumerics, hyphens, and underscores.
394    pub fn resume(mut self, session_id: &str) -> Self {
395        assert!(!session_id.is_empty(), "session_id must not be empty");
396        assert!(
397            session_id
398                .chars()
399                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
400            "session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
401        );
402        self.config.resume_session_id = Some(session_id.to_string());
403        self
404    }
405
406    /// Execute the agent invocation using the given [`AgentProvider`].
407    ///
408    /// If a [`retry_policy`](Agent::retry_policy) is configured, transient
409    /// failures (process crashes, timeouts) are retried with exponential
410    /// backoff. Deterministic errors (prompt too large, schema validation)
411    /// are returned immediately without retry.
412    ///
413    /// # Errors
414    ///
415    /// Returns [`OperationError::Agent`] if the provider reports a failure
416    /// (process crash, timeout, or schema validation error).
417    ///
418    /// # Panics
419    ///
420    /// Panics if [`prompt`](Agent::prompt) was never called or the prompt is
421    /// empty (whitespace-only counts as empty).
422    #[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
423    pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
424        assert!(
425            !self.config.prompt.trim().is_empty(),
426            "prompt must not be empty - call .prompt(\"...\") before .run()"
427        );
428
429        // Validate prompt size against the model's context window
430        let total_chars =
431            self.config.prompt.len() + self.config.system_prompt.as_ref().map_or(0, |s| s.len());
432        let estimated_tokens = estimate_tokens(total_chars);
433        let model_limit = self.config.model.context_window();
434        if estimated_tokens > model_limit {
435            return Err(OperationError::Agent(AgentError::PromptTooLarge {
436                chars: total_chars,
437                estimated_tokens,
438                model_limit,
439            }));
440        }
441
442        if crate::dry_run::effective_dry_run(self.dry_run) {
443            info!(
444                prompt_len = self.config.prompt.len(),
445                "[dry-run] agent call skipped"
446            );
447            let mut output =
448                AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
449            output.cost_usd = Some(0.0);
450            output.input_tokens = Some(0);
451            output.output_tokens = Some(0);
452            return Ok(AgentResult { output });
453        }
454
455        let result = self.invoke_once(provider).await;
456
457        let policy = match &self.retry_policy {
458            Some(p) => p,
459            None => return result,
460        };
461
462        // Non-retryable errors are returned immediately.
463        if let Err(ref err) = result {
464            if !crate::retry::is_retryable(err) {
465                return result;
466            }
467        } else {
468            return result;
469        }
470
471        let mut last_result = result;
472
473        for attempt in 0..policy.max_retries {
474            let delay = policy.delay_for_attempt(attempt);
475            warn!(
476                attempt = attempt + 1,
477                max_retries = policy.max_retries,
478                delay_ms = delay.as_millis() as u64,
479                "retrying agent invocation"
480            );
481            tokio::time::sleep(delay).await;
482
483            last_result = self.invoke_once(provider).await;
484
485            match &last_result {
486                Ok(_) => return last_result,
487                Err(err) if !crate::retry::is_retryable(err) => return last_result,
488                _ => {}
489            }
490        }
491
492        last_result
493    }
494
495    /// Execute a single agent invocation attempt (no retry logic).
496    async fn invoke_once(
497        &self,
498        provider: &dyn AgentProvider,
499    ) -> Result<AgentResult, OperationError> {
500        #[cfg(feature = "prometheus")]
501        let model_label = self.config.model.to_string();
502
503        let output = match provider.invoke(&self.config).await {
504            Ok(output) => output,
505            Err(e) => {
506                #[cfg(feature = "prometheus")]
507                {
508                    metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
509                }
510                return Err(OperationError::Agent(e));
511            }
512        };
513
514        info!(
515            duration_ms = output.duration_ms,
516            cost_usd = output.cost_usd,
517            input_tokens = output.input_tokens,
518            output_tokens = output.output_tokens,
519            model = output.model,
520            "agent completed"
521        );
522
523        #[cfg(feature = "prometheus")]
524        {
525            metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
526            metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
527                .record(output.duration_ms as f64 / 1000.0);
528            if let Some(cost) = output.cost_usd {
529                metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
530                    .increment(cost);
531            }
532            if let Some(tokens) = output.input_tokens {
533                metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
534            }
535            if let Some(tokens) = output.output_tokens {
536                metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
537                    .increment(tokens);
538            }
539        }
540
541        Ok(AgentResult { output })
542    }
543}
544
545impl Default for Agent {
546    fn default() -> Self {
547        Self::new()
548    }
549}
550
551/// The result of a successful agent invocation.
552///
553/// Wraps the raw [`AgentOutput`] and provides convenience accessors for the
554/// response text, typed JSON deserialization, session metadata, and usage stats.
555#[derive(Debug)]
556pub struct AgentResult {
557    output: AgentOutput,
558}
559
560impl AgentResult {
561    /// Return the agent's response as a plain text string.
562    ///
563    /// If the underlying value is not a JSON string (e.g. when structured output
564    /// was requested), returns an empty string and logs a warning.
565    pub fn text(&self) -> &str {
566        match self.output.value.as_str() {
567            Some(s) => s,
568            None => {
569                warn!(
570                    value_type = self.output.value.to_string(),
571                    "agent output is not a string, returning empty"
572                );
573                ""
574            }
575        }
576    }
577
578    /// Return the raw JSON [`Value`] of the agent's response.
579    pub fn value(&self) -> &Value {
580        &self.output.value
581    }
582
583    /// Deserialize the agent's response into the given type `T`.
584    ///
585    /// This clones the underlying JSON value. If you no longer need the
586    /// `AgentResult` afterwards, use [`into_json`](AgentResult::into_json)
587    /// instead to avoid the clone.
588    ///
589    /// # Errors
590    ///
591    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
592    pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
593        from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
594    }
595
596    /// Consume the result and deserialize the response into `T` without cloning.
597    ///
598    /// # Errors
599    ///
600    /// Returns [`OperationError::Deserialize`] if the JSON value does not match `T`.
601    pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
602        from_value(self.output.value).map_err(OperationError::deserialize::<T>)
603    }
604
605    /// Build an `AgentResult` from a raw [`AgentOutput`].
606    ///
607    /// This is available only in test builds to simplify test setup without
608    /// going through the full record/replay pipeline.
609    #[cfg(test)]
610    pub(crate) fn from_output(output: AgentOutput) -> Self {
611        Self { output }
612    }
613
614    /// Return the provider-assigned session ID, if available.
615    pub fn session_id(&self) -> Option<&str> {
616        self.output.session_id.as_deref()
617    }
618
619    /// Return the cost of this invocation in USD, if reported by the provider.
620    pub fn cost_usd(&self) -> Option<f64> {
621        self.output.cost_usd
622    }
623
624    /// Return the number of input tokens consumed, if reported.
625    pub fn input_tokens(&self) -> Option<u64> {
626        self.output.input_tokens
627    }
628
629    /// Return the number of output tokens generated, if reported.
630    pub fn output_tokens(&self) -> Option<u64> {
631        self.output.output_tokens
632    }
633
634    /// Return the wall-clock duration of the invocation in milliseconds.
635    pub fn duration_ms(&self) -> u64 {
636        self.output.duration_ms
637    }
638
639    /// Return the concrete model identifier used, if reported by the provider.
640    pub fn model(&self) -> Option<&str> {
641        self.output.model.as_deref()
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use crate::provider::InvokeFuture;
649    use serde_json::json;
650
651    struct TestProvider {
652        output: AgentOutput,
653    }
654
655    impl AgentProvider for TestProvider {
656        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
657            Box::pin(async move {
658                Ok(AgentOutput {
659                    value: self.output.value.clone(),
660                    session_id: self.output.session_id.clone(),
661                    cost_usd: self.output.cost_usd,
662                    input_tokens: self.output.input_tokens,
663                    output_tokens: self.output.output_tokens,
664                    model: self.output.model.clone(),
665                    duration_ms: self.output.duration_ms,
666                })
667            })
668        }
669    }
670
671    struct ConfigCapture {
672        output: AgentOutput,
673    }
674
675    impl AgentProvider for ConfigCapture {
676        fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
677            let config_json = serde_json::to_value(config).unwrap();
678            Box::pin(async move {
679                Ok(AgentOutput {
680                    value: config_json,
681                    session_id: self.output.session_id.clone(),
682                    cost_usd: self.output.cost_usd,
683                    input_tokens: self.output.input_tokens,
684                    output_tokens: self.output.output_tokens,
685                    model: self.output.model.clone(),
686                    duration_ms: self.output.duration_ms,
687                })
688            })
689        }
690    }
691
692    fn default_output() -> AgentOutput {
693        AgentOutput {
694            value: json!("test output"),
695            session_id: Some("sess-123".to_string()),
696            cost_usd: Some(0.05),
697            input_tokens: Some(100),
698            output_tokens: Some(50),
699            model: Some("sonnet".to_string()),
700            duration_ms: 1500,
701        }
702    }
703
704    // --- Model Display ---
705
706    #[test]
707    fn model_display_sonnet() {
708        assert_eq!(Model::Sonnet.to_string(), "sonnet");
709    }
710
711    #[test]
712    fn model_display_opus() {
713        assert_eq!(Model::Opus.to_string(), "opus");
714    }
715
716    #[test]
717    fn model_display_haiku() {
718        assert_eq!(Model::Haiku.to_string(), "haiku");
719    }
720
721    // --- Agent::new() defaults via ConfigCapture ---
722
723    #[tokio::test]
724    async fn agent_new_default_values() {
725        let provider = ConfigCapture {
726            output: default_output(),
727        };
728        let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
729
730        let config = result.value();
731        assert_eq!(config["system_prompt"], json!(null));
732        assert_eq!(config["prompt"], json!("hi"));
733        assert_eq!(config["model"], json!("sonnet"));
734        assert_eq!(config["allowed_tools"], json!([]));
735        assert_eq!(config["max_turns"], json!(null));
736        assert_eq!(config["max_budget_usd"], json!(null));
737        assert_eq!(config["working_dir"], json!(null));
738        assert_eq!(config["mcp_config"], json!(null));
739        assert_eq!(config["permission_mode"], json!("Default"));
740        assert_eq!(config["json_schema"], json!(null));
741    }
742
743    #[tokio::test]
744    async fn agent_default_matches_new() {
745        let provider = ConfigCapture {
746            output: default_output(),
747        };
748        let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
749        let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
750
751        assert_eq!(result_new.value(), result_default.value());
752    }
753
754    // --- Builder methods ---
755
756    #[tokio::test]
757    async fn builder_methods_store_values_correctly() {
758        let provider = ConfigCapture {
759            output: default_output(),
760        };
761        let result = Agent::new()
762            .system_prompt("you are a bot")
763            .prompt("do something")
764            .model(Model::Opus)
765            .allowed_tools(&["Read", "Write"])
766            .max_turns(5)
767            .max_budget_usd(1.5)
768            .working_dir("/tmp")
769            .mcp_config("{}")
770            .permission_mode(PermissionMode::Auto)
771            .run(&provider)
772            .await
773            .unwrap();
774
775        let config = result.value();
776        assert_eq!(config["system_prompt"], json!("you are a bot"));
777        assert_eq!(config["prompt"], json!("do something"));
778        assert_eq!(config["model"], json!("opus"));
779        assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
780        assert_eq!(config["max_turns"], json!(5));
781        assert_eq!(config["max_budget_usd"], json!(1.5));
782        assert_eq!(config["working_dir"], json!("/tmp"));
783        assert_eq!(config["mcp_config"], json!("{}"));
784        assert_eq!(config["permission_mode"], json!("Auto"));
785    }
786
787    // --- Panics ---
788
789    #[test]
790    #[should_panic(expected = "max_turns must be greater than 0")]
791    fn max_turns_zero_panics() {
792        let _ = Agent::new().max_turns(0);
793    }
794
795    #[test]
796    #[should_panic(expected = "budget must be a positive finite number")]
797    fn max_budget_negative_panics() {
798        let _ = Agent::new().max_budget_usd(-1.0);
799    }
800
801    #[test]
802    #[should_panic(expected = "budget must be a positive finite number")]
803    fn max_budget_nan_panics() {
804        let _ = Agent::new().max_budget_usd(f64::NAN);
805    }
806
807    #[test]
808    #[should_panic(expected = "budget must be a positive finite number")]
809    fn max_budget_infinity_panics() {
810        let _ = Agent::new().max_budget_usd(f64::INFINITY);
811    }
812
813    // --- AgentResult accessors ---
814
815    #[tokio::test]
816    async fn agent_result_text_with_string_value() {
817        let provider = TestProvider {
818            output: AgentOutput {
819                value: json!("hello world"),
820                ..default_output()
821            },
822        };
823        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
824        assert_eq!(result.text(), "hello world");
825    }
826
827    #[tokio::test]
828    async fn agent_result_text_with_non_string_value() {
829        let provider = TestProvider {
830            output: AgentOutput {
831                value: json!(42),
832                ..default_output()
833            },
834        };
835        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
836        assert_eq!(result.text(), "");
837    }
838
839    #[tokio::test]
840    async fn agent_result_text_with_null_value() {
841        let provider = TestProvider {
842            output: AgentOutput {
843                value: json!(null),
844                ..default_output()
845            },
846        };
847        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
848        assert_eq!(result.text(), "");
849    }
850
851    #[tokio::test]
852    async fn agent_result_json_successful_deserialize() {
853        #[derive(Deserialize, PartialEq, Debug)]
854        struct MyOutput {
855            name: String,
856            count: u32,
857        }
858        let provider = TestProvider {
859            output: AgentOutput {
860                value: json!({"name": "test", "count": 7}),
861                ..default_output()
862            },
863        };
864        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
865        let parsed: MyOutput = result.json().unwrap();
866        assert_eq!(parsed.name, "test");
867        assert_eq!(parsed.count, 7);
868    }
869
870    #[tokio::test]
871    async fn agent_result_json_failed_deserialize() {
872        #[derive(Debug, Deserialize)]
873        #[allow(dead_code)]
874        struct MyOutput {
875            name: String,
876        }
877        let provider = TestProvider {
878            output: AgentOutput {
879                value: json!(42),
880                ..default_output()
881            },
882        };
883        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
884        let err = result.json::<MyOutput>().unwrap_err();
885        assert!(matches!(err, OperationError::Deserialize { .. }));
886    }
887
888    #[tokio::test]
889    async fn agent_result_accessors() {
890        let provider = TestProvider {
891            output: AgentOutput {
892                value: json!("v"),
893                session_id: Some("s-1".to_string()),
894                cost_usd: Some(0.123),
895                input_tokens: Some(999),
896                output_tokens: Some(456),
897                model: Some("opus".to_string()),
898                duration_ms: 2000,
899            },
900        };
901        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
902        assert_eq!(result.session_id(), Some("s-1"));
903        assert_eq!(result.cost_usd(), Some(0.123));
904        assert_eq!(result.input_tokens(), Some(999));
905        assert_eq!(result.output_tokens(), Some(456));
906        assert_eq!(result.duration_ms(), 2000);
907        assert_eq!(result.model(), Some("opus"));
908    }
909
910    // --- Session resume ---
911
912    #[tokio::test]
913    async fn resume_passes_session_id_in_config() {
914        let provider = ConfigCapture {
915            output: default_output(),
916        };
917        let result = Agent::new()
918            .prompt("followup")
919            .resume("sess-abc")
920            .run(&provider)
921            .await
922            .unwrap();
923
924        let config = result.value();
925        assert_eq!(config["resume_session_id"], json!("sess-abc"));
926    }
927
928    #[tokio::test]
929    async fn no_resume_has_null_session_id() {
930        let provider = ConfigCapture {
931            output: default_output(),
932        };
933        let result = Agent::new()
934            .prompt("first call")
935            .run(&provider)
936            .await
937            .unwrap();
938
939        let config = result.value();
940        assert_eq!(config["resume_session_id"], json!(null));
941    }
942
943    #[test]
944    #[should_panic(expected = "session_id must not be empty")]
945    fn resume_empty_session_id_panics() {
946        let _ = Agent::new().resume("");
947    }
948
949    #[test]
950    #[should_panic(expected = "session_id must only contain")]
951    fn resume_invalid_chars_panics() {
952        let _ = Agent::new().resume("sess;rm -rf /");
953    }
954
955    #[test]
956    fn resume_valid_formats_accepted() {
957        let _ = Agent::new().resume("sess-abc123");
958        let _ = Agent::new().resume("a1b2c3d4_session");
959        let _ = Agent::new().resume("abc-DEF-123_456");
960    }
961
962    #[tokio::test]
963    #[should_panic(expected = "prompt must not be empty")]
964    async fn run_without_prompt_panics() {
965        let provider = TestProvider {
966            output: default_output(),
967        };
968        let _ = Agent::new().run(&provider).await;
969    }
970
971    #[tokio::test]
972    #[should_panic(expected = "prompt must not be empty")]
973    async fn run_with_whitespace_only_prompt_panics() {
974        let provider = TestProvider {
975            output: default_output(),
976        };
977        let _ = Agent::new().prompt("   ").run(&provider).await;
978    }
979
980    // --- Serde roundtrips ---
981
982    #[test]
983    fn model_serialize_deserialize_roundtrip() {
984        for model in [
985            Model::Sonnet,
986            Model::Opus,
987            Model::Haiku,
988            Model::Haiku45,
989            Model::Sonnet46,
990            Model::Opus46,
991            Model::Sonnet46_1M,
992            Model::Opus46_1M,
993        ] {
994            let json = serde_json::to_string(&model).unwrap();
995            let back: Model = serde_json::from_str(&json).unwrap();
996            assert_eq!(model.to_string(), back.to_string());
997        }
998    }
999
1000    #[test]
1001    fn model_display_explicit_ids() {
1002        assert_eq!(Model::Haiku45.to_string(), "claude-haiku-4-5-20251001");
1003        assert_eq!(Model::Sonnet46.to_string(), "claude-sonnet-4-6");
1004        assert_eq!(Model::Opus46.to_string(), "claude-opus-4-6");
1005        assert_eq!(Model::Sonnet46_1M.to_string(), "claude-sonnet-4-6[1m]");
1006        assert_eq!(Model::Opus46_1M.to_string(), "claude-opus-4-6[1m]");
1007    }
1008
1009    #[test]
1010    fn model_context_window() {
1011        assert_eq!(Model::Sonnet.context_window(), 200_000);
1012        assert_eq!(Model::Opus.context_window(), 200_000);
1013        assert_eq!(Model::Haiku.context_window(), 200_000);
1014        assert_eq!(Model::Sonnet46_1M.context_window(), 1_000_000);
1015        assert_eq!(Model::Opus46_1M.context_window(), 1_000_000);
1016    }
1017
1018    #[tokio::test]
1019    async fn into_json_success() {
1020        #[derive(Deserialize, PartialEq, Debug)]
1021        struct Out {
1022            name: String,
1023        }
1024        let provider = TestProvider {
1025            output: AgentOutput {
1026                value: json!({"name": "test"}),
1027                ..default_output()
1028            },
1029        };
1030        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1031        let parsed: Out = result.into_json().unwrap();
1032        assert_eq!(parsed.name, "test");
1033    }
1034
1035    #[tokio::test]
1036    async fn into_json_failure() {
1037        #[derive(Debug, Deserialize)]
1038        #[allow(dead_code)]
1039        struct Out {
1040            name: String,
1041        }
1042        let provider = TestProvider {
1043            output: AgentOutput {
1044                value: json!(42),
1045                ..default_output()
1046            },
1047        };
1048        let result = Agent::new().prompt("test").run(&provider).await.unwrap();
1049        let err = result.into_json::<Out>().unwrap_err();
1050        assert!(matches!(err, OperationError::Deserialize { .. }));
1051    }
1052
1053    #[test]
1054    fn from_output_creates_result() {
1055        let output = AgentOutput {
1056            value: json!("hello"),
1057            ..default_output()
1058        };
1059        let result = AgentResult::from_output(output);
1060        assert_eq!(result.text(), "hello");
1061        assert_eq!(result.cost_usd(), Some(0.05));
1062    }
1063
1064    #[test]
1065    #[should_panic(expected = "budget must be a positive finite number")]
1066    fn max_budget_zero_panics() {
1067        let _ = Agent::new().max_budget_usd(0.0);
1068    }
1069
1070    #[tokio::test]
1071    async fn run_with_prompt_too_large_returns_error() {
1072        let provider = TestProvider {
1073            output: default_output(),
1074        };
1075        // 200K context = 800K chars max. Create a prompt exceeding that.
1076        let large_prompt = "x".repeat(800_004);
1077        let result = Agent::new()
1078            .prompt(&large_prompt)
1079            .model(Model::Sonnet)
1080            .run(&provider)
1081            .await;
1082
1083        let err = result.unwrap_err();
1084        match err {
1085            OperationError::Agent(crate::error::AgentError::PromptTooLarge {
1086                chars,
1087                estimated_tokens,
1088                model_limit,
1089            }) => {
1090                assert_eq!(chars, 800_004);
1091                assert_eq!(estimated_tokens, 200_001);
1092                assert_eq!(model_limit, 200_000);
1093            }
1094            other => panic!("expected PromptTooLarge, got: {other}"),
1095        }
1096    }
1097
1098    #[tokio::test]
1099    async fn run_with_prompt_at_limit_succeeds() {
1100        let provider = TestProvider {
1101            output: default_output(),
1102        };
1103        // Exactly at the limit (200K tokens = 800K chars)
1104        let prompt = "x".repeat(800_000);
1105        let result = Agent::new()
1106            .prompt(&prompt)
1107            .model(Model::Sonnet)
1108            .run(&provider)
1109            .await;
1110        assert!(result.is_ok());
1111    }
1112
1113    #[tokio::test]
1114    async fn run_with_system_prompt_counts_toward_limit() {
1115        let provider = TestProvider {
1116            output: default_output(),
1117        };
1118        // System prompt + user prompt together exceed the limit
1119        let system = "s".repeat(400_004);
1120        let prompt = "p".repeat(400_000);
1121        let result = Agent::new()
1122            .system_prompt(&system)
1123            .prompt(&prompt)
1124            .model(Model::Sonnet)
1125            .run(&provider)
1126            .await;
1127
1128        let err = result.unwrap_err();
1129        assert!(matches!(
1130            err,
1131            OperationError::Agent(crate::error::AgentError::PromptTooLarge { .. })
1132        ));
1133    }
1134
1135    #[tokio::test]
1136    async fn run_with_1m_model_allows_larger_prompt() {
1137        let provider = TestProvider {
1138            output: default_output(),
1139        };
1140        // 900K chars ~ 225K tokens, exceeds 200K but fits in 1M
1141        let prompt = "x".repeat(900_000);
1142        let result = Agent::new()
1143            .prompt(&prompt)
1144            .model(Model::Sonnet46_1M)
1145            .run(&provider)
1146            .await;
1147        assert!(result.is_ok());
1148    }
1149
1150    #[test]
1151    fn model_equality() {
1152        assert_eq!(Model::Sonnet, Model::Sonnet);
1153        assert_ne!(Model::Sonnet, Model::Opus);
1154    }
1155
1156    #[test]
1157    fn permission_mode_serialize_deserialize_roundtrip() {
1158        for mode in [
1159            PermissionMode::Default,
1160            PermissionMode::Auto,
1161            PermissionMode::DontAsk,
1162            PermissionMode::BypassPermissions,
1163        ] {
1164            let json = serde_json::to_string(&mode).unwrap();
1165            let back: PermissionMode = serde_json::from_str(&json).unwrap();
1166            assert_eq!(format!("{:?}", mode), format!("{:?}", back));
1167        }
1168    }
1169
1170    // --- Retry builder ---
1171
1172    #[test]
1173    fn retry_builder_stores_policy() {
1174        let agent = Agent::new().retry(3);
1175        assert!(agent.retry_policy.is_some());
1176        assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
1177    }
1178
1179    #[test]
1180    fn retry_policy_builder_stores_custom_policy() {
1181        use crate::retry::RetryPolicy;
1182        let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
1183        let agent = Agent::new().retry_policy(policy);
1184        let p = agent.retry_policy.unwrap();
1185        assert_eq!(p.max_retries(), 5);
1186    }
1187
1188    #[test]
1189    fn no_retry_by_default() {
1190        let agent = Agent::new();
1191        assert!(agent.retry_policy.is_none());
1192    }
1193
1194    // --- Retry behavior ---
1195
1196    use std::sync::Arc;
1197    use std::sync::atomic::{AtomicU32, Ordering};
1198    use std::time::Duration;
1199
1200    struct FailNTimesProvider {
1201        fail_count: AtomicU32,
1202        failures_before_success: u32,
1203        output: AgentOutput,
1204    }
1205
1206    impl AgentProvider for FailNTimesProvider {
1207        fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1208            Box::pin(async move {
1209                let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
1210                if current < self.failures_before_success {
1211                    Err(AgentError::ProcessFailed {
1212                        exit_code: 1,
1213                        stderr: format!("transient failure #{}", current + 1),
1214                    })
1215                } else {
1216                    Ok(AgentOutput {
1217                        value: self.output.value.clone(),
1218                        session_id: self.output.session_id.clone(),
1219                        cost_usd: self.output.cost_usd,
1220                        input_tokens: self.output.input_tokens,
1221                        output_tokens: self.output.output_tokens,
1222                        model: self.output.model.clone(),
1223                        duration_ms: self.output.duration_ms,
1224                    })
1225                }
1226            })
1227        }
1228    }
1229
1230    #[tokio::test]
1231    async fn retry_succeeds_after_transient_failures() {
1232        let provider = FailNTimesProvider {
1233            fail_count: AtomicU32::new(0),
1234            failures_before_success: 2,
1235            output: default_output(),
1236        };
1237        let result = Agent::new()
1238            .prompt("test")
1239            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1240            .run(&provider)
1241            .await;
1242
1243        assert!(result.is_ok());
1244        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
1245    }
1246
1247    #[tokio::test]
1248    async fn retry_exhausted_returns_last_error() {
1249        let provider = FailNTimesProvider {
1250            fail_count: AtomicU32::new(0),
1251            failures_before_success: 10, // always fails
1252            output: default_output(),
1253        };
1254        let result = Agent::new()
1255            .prompt("test")
1256            .retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
1257            .run(&provider)
1258            .await;
1259
1260        assert!(result.is_err());
1261        // 1 initial + 2 retries = 3 total
1262        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
1263    }
1264
1265    #[tokio::test]
1266    async fn retry_does_not_retry_non_retryable_errors() {
1267        let call_count = Arc::new(AtomicU32::new(0));
1268        let count = call_count.clone();
1269
1270        struct CountingNonRetryable {
1271            count: Arc<AtomicU32>,
1272        }
1273        impl AgentProvider for CountingNonRetryable {
1274            fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
1275                self.count.fetch_add(1, Ordering::SeqCst);
1276                Box::pin(async move {
1277                    Err(AgentError::SchemaValidation {
1278                        expected: "object".to_string(),
1279                        got: "string".to_string(),
1280                    })
1281                })
1282            }
1283        }
1284
1285        let provider = CountingNonRetryable { count };
1286        let result = Agent::new()
1287            .prompt("test")
1288            .retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
1289            .run(&provider)
1290            .await;
1291
1292        assert!(result.is_err());
1293        // Only 1 attempt, no retries for non-retryable errors
1294        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1295    }
1296
1297    #[tokio::test]
1298    async fn no_retry_without_policy() {
1299        let provider = FailNTimesProvider {
1300            fail_count: AtomicU32::new(0),
1301            failures_before_success: 1,
1302            output: default_output(),
1303        };
1304        let result = Agent::new().prompt("test").run(&provider).await;
1305
1306        assert!(result.is_err());
1307        assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
1308    }
1309}