1use thiserror::Error;
4
5#[derive(Error, Debug)]
7pub enum ProviderError {
8 #[error("Missing API key")]
10 MissingApiKey,
11
12 #[error("Unknown provider: {0}")]
14 UnknownProvider(String),
15
16 #[error("Provider not implemented: {0}")]
18 NotImplemented(String),
19
20 #[error("HTTP error {0}: {1}")]
22 HttpError(u16, String),
23
24 #[error("Request failed: {0}")]
26 RequestFailed(#[from] reqwest::Error),
27
28 #[error("IO error: {0}")]
30 IoError(#[from] std::io::Error),
31
32 #[error("Invalid response: {0}")]
34 InvalidResponse(String),
35
36 #[error("Invalid API key format")]
38 InvalidApiKey,
39
40 #[error("JSON parse error: {0}")]
42 JsonParse(#[from] serde_json::Error),
43
44 #[error("Stream error: {0}")]
46 StreamError(String),
47
48 #[error("Network error: {0}")]
50 NetworkError(String),
51
52 #[error("Request timed out")]
54 Timeout,
55
56 #[error("Rate limited")]
58 RateLimited {
59 retry_after: Option<std::time::Duration>,
61 },
62}
63
64impl ProviderError {
65 pub fn is_retryable(&self) -> bool {
67 match self {
68 Self::HttpError(status, _) => *status == 429 || *status >= 500,
69 Self::NetworkError(_) => true,
70 Self::Timeout => true,
71 Self::RateLimited { .. } => true,
72 _ => false,
73 }
74 }
75
76 pub fn retry_after(&self) -> Option<std::time::Duration> {
78 match self {
79 Self::RateLimited { retry_after } => *retry_after,
80 Self::HttpError(429, _) => Some(std::time::Duration::from_secs(5)),
81 _ => None,
82 }
83 }
84}
85
86#[derive(Error, Debug)]
88pub enum ValidationError {
89 #[error("Invalid JSON: {0}")]
90 InvalidJson(#[from] serde_json::Error),
91
92 #[error("Schema validation failed: {0}")]
93 SchemaValidation(String),
94
95 #[error("Missing required field: {0}")]
96 MissingRequiredField(String),
97}
98
99#[derive(Error, Debug)]
101pub enum Error {
102 #[error("Provider error: {0}")]
104 Provider(#[from] ProviderError),
105
106 #[error("Validation error: {0}")]
108 Validation(#[from] ValidationError),
109
110 #[error("IO error: {0}")]
112 Io(#[from] std::io::Error),
113}
114
115pub type Result<T> = std::result::Result<T, Error>;
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn provider_error_display() {
124 assert_eq!(ProviderError::MissingApiKey.to_string(), "Missing API key");
125 assert_eq!(
126 ProviderError::UnknownProvider("foo".to_string()).to_string(),
127 "Unknown provider: foo"
128 );
129 assert_eq!(
130 ProviderError::HttpError(429, "rate limited".to_string()).to_string(),
131 "HTTP error 429: rate limited"
132 );
133 assert_eq!(
134 ProviderError::InvalidResponse("bad json".to_string()).to_string(),
135 "Invalid response: bad json"
136 );
137 assert_eq!(
138 ProviderError::StreamError("disconnected".to_string()).to_string(),
139 "Stream error: disconnected"
140 );
141 assert_eq!(
142 ProviderError::NotImplemented("x".to_string()).to_string(),
143 "Provider not implemented: x"
144 );
145 }
146
147 #[test]
148 fn error_chain_from_provider_error() {
149 let inner = ProviderError::MissingApiKey;
150 let outer: Error = inner.into();
151 assert!(matches!(
152 outer,
153 Error::Provider(ProviderError::MissingApiKey)
154 ));
155 assert!(outer.to_string().contains("Missing API key"));
156 }
157
158 #[test]
159 fn validation_error_display() {
160 let err = ValidationError::MissingRequiredField("model".to_string());
161 assert_eq!(err.to_string(), "Missing required field: model");
162 }
163
164 #[test]
165 fn error_chain_from_io() {
166 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file missing");
167 let outer: Error = io_err.into();
168 assert!(matches!(outer, Error::Io(_)));
169 }
170}