use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use super::capability::Capability;
#[derive(Debug, Clone, Serialize)]
pub struct AgentLoopResult {
pub text: String,
pub usage: TokenUsage,
pub iterations: u32,
pub tool_calls: u32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
}
impl TokenUsage {
pub fn accumulate(&mut self, other: &Self) {
self.input_tokens += other.input_tokens;
self.output_tokens += other.output_tokens;
}
pub fn total(&self) -> u64 {
self.input_tokens + self.output_tokens
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum StopReason {
EndTurn,
ToolUse,
MaxTokens,
StopSequence,
}
#[derive(Debug, thiserror::Error)]
pub enum AgentError {
#[error("driver error: {0}")]
Driver(#[from] DriverError),
#[error("tool '{tool_name}' failed: {message}")]
ToolExecution {
tool_name: String,
message: String,
},
#[error("capability denied for tool '{tool_name}': requires {required:?}")]
CapabilityDenied {
tool_name: String,
required: Capability,
},
#[error("circuit break: {0}")]
CircuitBreak(String),
#[error("max iterations reached")]
MaxIterationsReached,
#[error("context overflow: required {required} tokens, available {available}")]
ContextOverflow {
required: usize,
available: usize,
},
#[error("manifest error: {0}")]
ManifestError(String),
#[error("memory error: {0}")]
Memory(String),
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum DriverError {
#[error("rate limited, retry after {retry_after_ms}ms")]
RateLimited {
retry_after_ms: u64,
},
#[error("overloaded, retry after {retry_after_ms}ms")]
Overloaded {
retry_after_ms: u64,
},
#[error("model not found: {0}")]
ModelNotFound(PathBuf),
#[error("inference failed: {0}")]
InferenceFailed(String),
#[error("network error: {0}")]
Network(String),
}
impl DriverError {
pub fn is_retryable(&self) -> bool {
matches!(self, Self::RateLimited { .. } | Self::Overloaded { .. } | Self::Network(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_usage_accumulate() {
let mut total = TokenUsage::default();
total.accumulate(&TokenUsage { input_tokens: 100, output_tokens: 50 });
total.accumulate(&TokenUsage { input_tokens: 200, output_tokens: 75 });
assert_eq!(total.input_tokens, 300);
assert_eq!(total.output_tokens, 125);
assert_eq!(total.total(), 425);
}
#[test]
fn test_token_usage_default_zero() {
let usage = TokenUsage::default();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
assert_eq!(usage.total(), 0);
}
#[test]
fn test_stop_reason_equality() {
assert_eq!(StopReason::EndTurn, StopReason::EndTurn);
assert_ne!(StopReason::EndTurn, StopReason::ToolUse);
}
#[test]
fn test_driver_error_retryable() {
assert!(DriverError::RateLimited { retry_after_ms: 1000 }.is_retryable());
assert!(DriverError::Overloaded { retry_after_ms: 500 }.is_retryable());
assert!(DriverError::Network("timeout".into()).is_retryable());
assert!(!DriverError::ModelNotFound("/tmp/missing.gguf".into()).is_retryable());
assert!(!DriverError::InferenceFailed("oom".into()).is_retryable());
}
#[test]
fn test_agent_error_display() {
let err = AgentError::CircuitBreak("cost exceeded".into());
assert_eq!(err.to_string(), "circuit break: cost exceeded");
let err = AgentError::MaxIterationsReached;
assert_eq!(err.to_string(), "max iterations reached");
let err = AgentError::ToolExecution {
tool_name: "rag".into(),
message: "index not found".into(),
};
assert!(err.to_string().contains("rag"));
}
#[test]
fn test_agent_loop_result_serialize() {
let result = AgentLoopResult {
text: "hello".into(),
usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
iterations: 2,
tool_calls: 1,
};
let json = serde_json::to_string(&result).expect("serialize failed");
assert!(json.contains("\"text\":\"hello\""));
assert!(json.contains("\"iterations\":2"));
}
#[test]
fn test_stop_reason_serialization() {
let reasons = vec![
StopReason::EndTurn,
StopReason::ToolUse,
StopReason::MaxTokens,
StopReason::StopSequence,
];
for r in &reasons {
let json = serde_json::to_string(r).expect("serialize failed");
let back: StopReason = serde_json::from_str(&json).expect("deserialize failed");
assert_eq!(*r, back);
}
}
}