1use serde::{Deserialize, Serialize};
7use std::path::PathBuf;
8
9use super::capability::Capability;
10
11#[derive(Debug, Clone, Serialize)]
13pub struct AgentLoopResult {
14 pub text: String,
16 pub usage: TokenUsage,
18 pub iterations: u32,
20 pub tool_calls: u32,
22}
23
24#[derive(Debug, Clone, Default, Serialize, Deserialize)]
26pub struct TokenUsage {
27 pub input_tokens: u64,
29 pub output_tokens: u64,
31}
32
33impl TokenUsage {
34 pub fn accumulate(&mut self, other: &Self) {
36 self.input_tokens += other.input_tokens;
37 self.output_tokens += other.output_tokens;
38 }
39
40 pub fn total(&self) -> u64 {
42 self.input_tokens + self.output_tokens
43 }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum StopReason {
49 EndTurn,
51 ToolUse,
53 MaxTokens,
55 StopSequence,
57}
58
59#[derive(Debug, thiserror::Error)]
65pub enum AgentError {
66 #[error("driver error: {0}")]
68 Driver(#[from] DriverError),
69 #[error("tool '{tool_name}' failed: {message}")]
71 ToolExecution {
72 tool_name: String,
74 message: String,
76 },
77 #[error("capability denied for tool '{tool_name}': requires {required:?}")]
79 CapabilityDenied {
80 tool_name: String,
82 required: Capability,
84 },
85 #[error("circuit break: {0}")]
87 CircuitBreak(String),
88 #[error("max iterations reached")]
90 MaxIterationsReached,
91 #[error("context overflow: required {required} tokens, available {available}")]
93 ContextOverflow {
94 required: usize,
96 available: usize,
98 },
99 #[error("manifest error: {0}")]
101 ManifestError(String),
102 #[error("memory error: {0}")]
104 Memory(String),
105}
106
107#[derive(Debug, Clone, thiserror::Error)]
109pub enum DriverError {
110 #[error("rate limited, retry after {retry_after_ms}ms")]
112 RateLimited {
113 retry_after_ms: u64,
115 },
116 #[error("overloaded, retry after {retry_after_ms}ms")]
118 Overloaded {
119 retry_after_ms: u64,
121 },
122 #[error("model not found: {0}")]
124 ModelNotFound(PathBuf),
125 #[error("inference failed: {0}")]
127 InferenceFailed(String),
128 #[error("network error: {0}")]
130 Network(String),
131}
132
133impl DriverError {
134 pub fn is_retryable(&self) -> bool {
136 matches!(self, Self::RateLimited { .. } | Self::Overloaded { .. } | Self::Network(_))
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_token_usage_accumulate() {
146 let mut total = TokenUsage::default();
147 total.accumulate(&TokenUsage { input_tokens: 100, output_tokens: 50 });
148 total.accumulate(&TokenUsage { input_tokens: 200, output_tokens: 75 });
149 assert_eq!(total.input_tokens, 300);
150 assert_eq!(total.output_tokens, 125);
151 assert_eq!(total.total(), 425);
152 }
153
154 #[test]
155 fn test_token_usage_default_zero() {
156 let usage = TokenUsage::default();
157 assert_eq!(usage.input_tokens, 0);
158 assert_eq!(usage.output_tokens, 0);
159 assert_eq!(usage.total(), 0);
160 }
161
162 #[test]
163 fn test_stop_reason_equality() {
164 assert_eq!(StopReason::EndTurn, StopReason::EndTurn);
165 assert_ne!(StopReason::EndTurn, StopReason::ToolUse);
166 }
167
168 #[test]
169 fn test_driver_error_retryable() {
170 assert!(DriverError::RateLimited { retry_after_ms: 1000 }.is_retryable());
171 assert!(DriverError::Overloaded { retry_after_ms: 500 }.is_retryable());
172 assert!(DriverError::Network("timeout".into()).is_retryable());
173 assert!(!DriverError::ModelNotFound("/tmp/missing.gguf".into()).is_retryable());
174 assert!(!DriverError::InferenceFailed("oom".into()).is_retryable());
175 }
176
177 #[test]
178 fn test_agent_error_display() {
179 let err = AgentError::CircuitBreak("cost exceeded".into());
180 assert_eq!(err.to_string(), "circuit break: cost exceeded");
181
182 let err = AgentError::MaxIterationsReached;
183 assert_eq!(err.to_string(), "max iterations reached");
184
185 let err = AgentError::ToolExecution {
186 tool_name: "rag".into(),
187 message: "index not found".into(),
188 };
189 assert!(err.to_string().contains("rag"));
190 }
191
192 #[test]
193 fn test_agent_loop_result_serialize() {
194 let result = AgentLoopResult {
195 text: "hello".into(),
196 usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
197 iterations: 2,
198 tool_calls: 1,
199 };
200 let json = serde_json::to_string(&result).expect("serialize failed");
201 assert!(json.contains("\"text\":\"hello\""));
202 assert!(json.contains("\"iterations\":2"));
203 }
204
205 #[test]
206 fn test_stop_reason_serialization() {
207 let reasons = vec![
208 StopReason::EndTurn,
209 StopReason::ToolUse,
210 StopReason::MaxTokens,
211 StopReason::StopSequence,
212 ];
213 for r in &reasons {
214 let json = serde_json::to_string(r).expect("serialize failed");
215 let back: StopReason = serde_json::from_str(&json).expect("deserialize failed");
216 assert_eq!(*r, back);
217 }
218 }
219}