1use thiserror::Error;
4
5#[derive(Error, Debug)]
7pub enum ProviderError {
8 #[error("Missing API key")]
10 MissingApiKey,
11
12 #[error("Unknown provider: {0}")]
14 UnknownProvider(String),
15
16 #[error("Provider not implemented: {0}")]
18 NotImplemented(String),
19
20 #[error("HTTP error {0}: {1}")]
22 HttpError(u16, String),
23
24 #[error("Request failed: {0}")]
26 RequestFailed(#[from] reqwest::Error),
27
28 #[error("IO error: {0}")]
30 IoError(#[from] std::io::Error),
31
32 #[error("Invalid response: {0}")]
34 InvalidResponse(String),
35
36 #[error("Invalid API key format")]
38 InvalidApiKey,
39
40 #[error("JSON parse error: {0}")]
42 JsonParse(#[from] serde_json::Error),
43
44 #[error("Stream error: {0}")]
46 StreamError(String),
47
48 #[error("Network error: {0}")]
50 NetworkError(String),
51
52 #[error("Context overflow")]
54 ContextOverflow,
55
56 #[error("Request timed out")]
58 Timeout,
59
60 #[error("Rate limited")]
62 RateLimited {
63 retry_after: Option<std::time::Duration>,
65 },
66}
67
68impl ProviderError {
69 pub fn is_retryable(&self) -> bool {
71 match self {
72 Self::HttpError(status, _) => *status == 429 || *status >= 500,
73 Self::NetworkError(_) => true,
74 Self::Timeout => true,
75 Self::RateLimited { .. } => true,
76 _ => false,
77 }
78 }
79
80 pub fn retry_after(&self) -> Option<std::time::Duration> {
82 match self {
83 Self::RateLimited { retry_after } => *retry_after,
84 Self::HttpError(429, _) => Some(std::time::Duration::from_secs(5)),
85 _ => None,
86 }
87 }
88}
89
90#[derive(Error, Debug)]
92pub enum ValidationError {
93 #[error("Invalid JSON: {0}")]
94 InvalidJson(#[from] serde_json::Error),
95
96 #[error("Schema validation failed: {0}")]
97 SchemaValidation(String),
98
99 #[error("Missing required field: {0}")]
100 MissingRequiredField(String),
101}
102
103#[derive(Error, Debug)]
105pub enum Error {
106 #[error("Provider error: {0}")]
108 Provider(#[from] ProviderError),
109
110 #[error("Validation error: {0}")]
112 Validation(#[from] ValidationError),
113
114 #[error("IO error: {0}")]
116 Io(#[from] std::io::Error),
117}
118
119pub type Result<T> = std::result::Result<T, Error>;
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn provider_error_display() {
128 assert_eq!(ProviderError::MissingApiKey.to_string(), "Missing API key");
129 assert_eq!(
130 ProviderError::UnknownProvider("foo".to_string()).to_string(),
131 "Unknown provider: foo"
132 );
133 assert_eq!(
134 ProviderError::HttpError(429, "rate limited".to_string()).to_string(),
135 "HTTP error 429: rate limited"
136 );
137 assert_eq!(
138 ProviderError::InvalidResponse("bad json".to_string()).to_string(),
139 "Invalid response: bad json"
140 );
141 assert_eq!(
142 ProviderError::StreamError("disconnected".to_string()).to_string(),
143 "Stream error: disconnected"
144 );
145 assert_eq!(
146 ProviderError::NotImplemented("x".to_string()).to_string(),
147 "Provider not implemented: x"
148 );
149 }
150
151 #[test]
152 fn error_chain_from_provider_error() {
153 let inner = ProviderError::MissingApiKey;
154 let outer: Error = inner.into();
155 assert!(matches!(
156 outer,
157 Error::Provider(ProviderError::MissingApiKey)
158 ));
159 assert!(outer.to_string().contains("Missing API key"));
160 }
161
162 #[test]
163 fn validation_error_display() {
164 let err = ValidationError::MissingRequiredField("model".to_string());
165 assert_eq!(err.to_string(), "Missing required field: model");
166 }
167
168 #[test]
169 fn error_chain_from_io() {
170 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
171 let outer: Error = io_err.into();
172 assert!(matches!(outer, Error::Io(_)));
173 }
174}