1use std::time::Duration;
4use thiserror::Error;
5
6pub type Result<T> = std::result::Result<T, AzureError>;
8
9#[derive(Debug, Error, Clone)]
11pub enum AzureError {
12 #[error("Authentication failed: {message}")]
14 Auth { message: String },
15
16 #[error("Permission denied: {message}")]
18 PermissionDenied { message: String },
19
20 #[error("Resource not found: {resource}")]
22 NotFound { resource: String },
23
24 #[error("Throttled (retry after {retry_after:?})")]
26 Throttled {
27 retry_after: Option<Duration>,
28 message: String,
29 },
30
31 #[error("Resource conflict: {message}")]
33 ResourceConflict { message: String },
34
35 #[error("Service error ({code}): {message}")]
37 ServiceError {
38 code: String,
39 message: String,
40 status: u16,
41 },
42
43 #[error("Network error: {0}")]
45 Network(String),
46
47 #[error("Invalid response: {message}")]
49 InvalidResponse {
50 message: String,
51 body: Option<String>,
52 },
53}
54
55impl From<reqwest::Error> for AzureError {
56 fn from(err: reqwest::Error) -> Self {
57 Self::Network(err.to_string())
58 }
59}
60
61impl AzureError {
62 pub fn is_retryable(&self) -> bool {
64 match self {
65 Self::Throttled { .. } | Self::Network(_) => true,
66 Self::ServiceError { status, .. } => matches!(status, 500 | 502 | 503 | 504),
67 _ => false,
68 }
69 }
70
71 pub fn retry_after(&self) -> Option<Duration> {
73 match self {
74 Self::Throttled {
75 retry_after: Some(duration),
76 ..
77 } => Some(*duration),
78 _ => None,
79 }
80 }
81}
82
83pub(crate) fn classify_error(status: u16, code: &str, message: &str) -> AzureError {
85 match status {
86 401 => AzureError::Auth {
87 message: format!("{code}: {message}"),
88 },
89 403 => AzureError::PermissionDenied {
90 message: format!("{code}: {message}"),
91 },
92 404 => AzureError::NotFound {
93 resource: message.to_string(),
94 },
95 409 => AzureError::ResourceConflict {
96 message: message.to_string(),
97 },
98 429 => AzureError::Throttled {
99 retry_after: None,
100 message: message.to_string(),
101 },
102 _ if code == "TooManyRequests" || code == "429" => AzureError::Throttled {
103 retry_after: None,
104 message: message.to_string(),
105 },
106 _ => AzureError::ServiceError {
107 code: code.to_string(),
108 message: message.to_string(),
109 status,
110 },
111 }
112}
113
114pub(crate) fn parse_json_error(status: u16, body: &str) -> AzureError {
121 let parsed: std::result::Result<serde_json::Value, _> = serde_json::from_str(body);
122 let (code, message) = match parsed {
123 Ok(val) => {
124 let error_obj = val.get("error").unwrap_or(&val);
125 let code = error_obj
126 .get("code")
127 .and_then(|v| v.as_str())
128 .unwrap_or_default()
129 .to_string();
130 let message = error_obj
131 .get("message")
132 .and_then(|v| v.as_str())
133 .unwrap_or_default()
134 .to_string();
135 (code, message)
136 }
137 Err(_) => (String::new(), truncate_body(body)),
138 };
139
140 if code.is_empty() {
141 return AzureError::ServiceError {
142 code: format!("HttpError{status}"),
143 message,
144 status,
145 };
146 }
147
148 classify_error(status, &code, &message)
149}
150
151fn truncate_body(body: &str) -> String {
153 if body.len() > 200 {
154 let end = body.floor_char_boundary(200);
155 format!("{}...", &body[..end])
156 } else {
157 body.to_string()
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn throttled_is_retryable() {
167 let err = AzureError::Throttled {
168 retry_after: None,
169 message: "slow down".into(),
170 };
171 assert!(err.is_retryable());
172 }
173
174 #[test]
175 fn network_is_retryable() {
176 let err = AzureError::Network("timeout".into());
177 assert!(err.is_retryable());
178 }
179
180 #[test]
181 fn auth_is_not_retryable() {
182 let err = AzureError::Auth {
183 message: "bad creds".into(),
184 };
185 assert!(!err.is_retryable());
186 }
187
188 #[test]
189 fn service_error_500_is_retryable() {
190 let err = AzureError::ServiceError {
191 code: "InternalError".into(),
192 message: "internal".into(),
193 status: 500,
194 };
195 assert!(err.is_retryable());
196 }
197
198 #[test]
199 fn service_error_400_is_not_retryable() {
200 let err = AzureError::ServiceError {
201 code: "ValidationError".into(),
202 message: "bad param".into(),
203 status: 400,
204 };
205 assert!(!err.is_retryable());
206 }
207
208 #[test]
209 fn parse_json_error_arm_format() {
210 let body = r#"{"error": {"code": "ResourceNotFound", "message": "Resource not found"}}"#;
211 let err = parse_json_error(404, body);
212 assert!(matches!(err, AzureError::NotFound { .. }));
213 }
214
215 #[test]
216 fn parse_json_error_flat_format() {
217 let body = r#"{"code": "Unauthorized", "message": "Token expired"}"#;
218 let err = parse_json_error(401, body);
219 assert!(matches!(err, AzureError::Auth { .. }));
220 }
221
222 #[test]
223 fn parse_json_error_fallback_on_invalid() {
224 let err = parse_json_error(500, "not json");
225 match err {
226 AzureError::ServiceError { code, status, .. } => {
227 assert_eq!(code, "HttpError500");
228 assert_eq!(status, 500);
229 }
230 other => panic!("expected ServiceError, got: {other}"),
231 }
232 }
233
234 #[test]
235 fn classify_409_as_conflict() {
236 let err = classify_error(409, "Conflict", "already exists");
237 assert!(matches!(err, AzureError::ResourceConflict { .. }));
238 }
239
240 #[test]
241 fn retry_after_returns_duration_for_throttled() {
242 let err = AzureError::Throttled {
243 retry_after: Some(Duration::from_secs(5)),
244 message: "slow down".into(),
245 };
246 assert_eq!(err.retry_after(), Some(Duration::from_secs(5)));
247 }
248
249 #[test]
250 fn retry_after_returns_none_for_non_throttled() {
251 let err = AzureError::Auth {
252 message: "bad creds".into(),
253 };
254 assert_eq!(err.retry_after(), None);
255 }
256}