Skip to main content

batuta/agent/
result.rs

1//! Agent loop result and error types.
2//!
3//! Defines the outcome of a complete agent loop invocation
4//! and the error taxonomy for agent failures.
5
6use serde::{Deserialize, Serialize};
7use std::path::PathBuf;
8
9use super::capability::Capability;
10
11/// Outcome of a complete agent loop invocation.
12#[derive(Debug, Clone, Serialize)]
13pub struct AgentLoopResult {
14    /// Final text response from the agent.
15    pub text: String,
16    /// Token usage across all iterations.
17    pub usage: TokenUsage,
18    /// Number of loop iterations executed.
19    pub iterations: u32,
20    /// Number of tool calls made.
21    pub tool_calls: u32,
22}
23
24/// Token usage counters.
25#[derive(Debug, Clone, Default, Serialize, Deserialize)]
26pub struct TokenUsage {
27    /// Total input tokens across all completions.
28    pub input_tokens: u64,
29    /// Total output tokens across all completions.
30    pub output_tokens: u64,
31}
32
33impl TokenUsage {
34    /// Accumulate usage from another completion.
35    pub fn accumulate(&mut self, other: &Self) {
36        self.input_tokens += other.input_tokens;
37        self.output_tokens += other.output_tokens;
38    }
39
40    /// Total tokens (input + output).
41    pub fn total(&self) -> u64 {
42        self.input_tokens + self.output_tokens
43    }
44}
45
46/// Stop reason from a single LLM completion.
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum StopReason {
49    /// Model finished naturally.
50    EndTurn,
51    /// Model wants to use a tool.
52    ToolUse,
53    /// Output truncated at `max_tokens` limit.
54    MaxTokens,
55    /// Hit a stop sequence.
56    StopSequence,
57}
58
59/// Agent error taxonomy.
60///
61/// Classified by recoverability: some errors are retryable,
62/// others are fatal. The agent loop uses this to decide whether
63/// to retry or terminate (Jidoka: stop on defect).
64#[derive(Debug, thiserror::Error)]
65pub enum AgentError {
66    /// LLM driver error (may be retryable).
67    #[error("driver error: {0}")]
68    Driver(#[from] DriverError),
69    /// Tool execution failed.
70    #[error("tool '{tool_name}' failed: {message}")]
71    ToolExecution {
72        /// Name of the failed tool.
73        tool_name: String,
74        /// Error message.
75        message: String,
76    },
77    /// Capability denied (Poka-Yoke).
78    #[error("capability denied for tool '{tool_name}': requires {required:?}")]
79    CapabilityDenied {
80        /// Name of the denied tool.
81        tool_name: String,
82        /// Required capability that was not granted.
83        required: Capability,
84    },
85    /// Loop guard triggered (Jidoka).
86    #[error("circuit break: {0}")]
87    CircuitBreak(String),
88    /// Max iterations reached.
89    #[error("max iterations reached")]
90    MaxIterationsReached,
91    /// Context overflow after truncation.
92    #[error("context overflow: required {required} tokens, available {available}")]
93    ContextOverflow {
94        /// Tokens required.
95        required: usize,
96        /// Tokens available.
97        available: usize,
98    },
99    /// Manifest parsing error.
100    #[error("manifest error: {0}")]
101    ManifestError(String),
102    /// Memory substrate error.
103    #[error("memory error: {0}")]
104    Memory(String),
105}
106
107/// LLM driver-specific errors.
108#[derive(Debug, Clone, thiserror::Error)]
109pub enum DriverError {
110    /// Remote API rate limited. Retryable with backoff.
111    #[error("rate limited, retry after {retry_after_ms}ms")]
112    RateLimited {
113        /// Suggested wait time in milliseconds.
114        retry_after_ms: u64,
115    },
116    /// Remote API overloaded. Retryable with backoff.
117    #[error("overloaded, retry after {retry_after_ms}ms")]
118    Overloaded {
119        /// Suggested wait time in milliseconds.
120        retry_after_ms: u64,
121    },
122    /// Model file not found. Not retryable.
123    #[error("model not found: {0}")]
124    ModelNotFound(PathBuf),
125    /// Inference failed. Not retryable.
126    #[error("inference failed: {0}")]
127    InferenceFailed(String),
128    /// Network error (remote driver). Retryable.
129    #[error("network error: {0}")]
130    Network(String),
131}
132
133impl DriverError {
134    /// Whether this error is retryable with backoff.
135    pub fn is_retryable(&self) -> bool {
136        matches!(self, Self::RateLimited { .. } | Self::Overloaded { .. } | Self::Network(_))
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_token_usage_accumulate() {
146        let mut total = TokenUsage::default();
147        total.accumulate(&TokenUsage { input_tokens: 100, output_tokens: 50 });
148        total.accumulate(&TokenUsage { input_tokens: 200, output_tokens: 75 });
149        assert_eq!(total.input_tokens, 300);
150        assert_eq!(total.output_tokens, 125);
151        assert_eq!(total.total(), 425);
152    }
153
154    #[test]
155    fn test_token_usage_default_zero() {
156        let usage = TokenUsage::default();
157        assert_eq!(usage.input_tokens, 0);
158        assert_eq!(usage.output_tokens, 0);
159        assert_eq!(usage.total(), 0);
160    }
161
162    #[test]
163    fn test_stop_reason_equality() {
164        assert_eq!(StopReason::EndTurn, StopReason::EndTurn);
165        assert_ne!(StopReason::EndTurn, StopReason::ToolUse);
166    }
167
168    #[test]
169    fn test_driver_error_retryable() {
170        assert!(DriverError::RateLimited { retry_after_ms: 1000 }.is_retryable());
171        assert!(DriverError::Overloaded { retry_after_ms: 500 }.is_retryable());
172        assert!(DriverError::Network("timeout".into()).is_retryable());
173        assert!(!DriverError::ModelNotFound("/tmp/missing.gguf".into()).is_retryable());
174        assert!(!DriverError::InferenceFailed("oom".into()).is_retryable());
175    }
176
177    #[test]
178    fn test_agent_error_display() {
179        let err = AgentError::CircuitBreak("cost exceeded".into());
180        assert_eq!(err.to_string(), "circuit break: cost exceeded");
181
182        let err = AgentError::MaxIterationsReached;
183        assert_eq!(err.to_string(), "max iterations reached");
184
185        let err = AgentError::ToolExecution {
186            tool_name: "rag".into(),
187            message: "index not found".into(),
188        };
189        assert!(err.to_string().contains("rag"));
190    }
191
192    #[test]
193    fn test_agent_loop_result_serialize() {
194        let result = AgentLoopResult {
195            text: "hello".into(),
196            usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
197            iterations: 2,
198            tool_calls: 1,
199        };
200        let json = serde_json::to_string(&result).expect("serialize failed");
201        assert!(json.contains("\"text\":\"hello\""));
202        assert!(json.contains("\"iterations\":2"));
203    }
204
205    #[test]
206    fn test_stop_reason_serialization() {
207        let reasons = vec![
208            StopReason::EndTurn,
209            StopReason::ToolUse,
210            StopReason::MaxTokens,
211            StopReason::StopSequence,
212        ];
213        for r in &reasons {
214            let json = serde_json::to_string(r).expect("serialize failed");
215            let back: StopReason = serde_json::from_str(&json).expect("deserialize failed");
216            assert_eq!(*r, back);
217        }
218    }
219}