1use std::time::Duration;
4use thiserror::Error;
5
6pub type Result<T> = std::result::Result<T, AwsError>;
8
9#[derive(Debug, Error, Clone)]
11pub enum AwsError {
12 #[error("Authentication failed: {message}")]
14 Auth { message: String },
15
16 #[error("Access denied: {message}")]
18 AccessDenied { message: String },
19
20 #[error("Resource not found: {resource}")]
22 NotFound { resource: String },
23
24 #[error("Throttled (retry after {retry_after:?})")]
26 Throttled {
27 retry_after: Option<Duration>,
28 message: String,
29 },
30
31 #[error("Service error ({code}): {message}")]
33 ServiceError {
34 code: String,
35 message: String,
36 status: u16,
37 },
38
39 #[error("Network error: {0}")]
41 Network(String),
42
43 #[error("Invalid response: {message}")]
45 InvalidResponse {
46 message: String,
47 body: Option<String>,
48 },
49
50 #[error("XML parse error: {message}")]
52 XmlParse { message: String },
53}
54
55impl From<reqwest::Error> for AwsError {
56 fn from(err: reqwest::Error) -> Self {
57 Self::Network(err.to_string())
58 }
59}
60
61impl AwsError {
62 pub fn is_retryable(&self) -> bool {
64 match self {
65 Self::Throttled { .. } | Self::Network(_) => true,
66 Self::ServiceError { status, .. } => matches!(status, 500 | 502 | 503 | 504),
67 _ => false,
68 }
69 }
70
71 pub fn retry_after(&self) -> Option<Duration> {
73 match self {
74 Self::Throttled {
75 retry_after: Some(duration),
76 ..
77 } => Some(*duration),
78 _ => None,
79 }
80 }
81}
82
83#[allow(dead_code)]
85fn classify_error(status: u16, code: &str, message: &str) -> AwsError {
86 match status {
87 401 => AwsError::Auth {
88 message: format!("{code}: {message}"),
89 },
90 403 if code.contains("ExpiredToken") || code.contains("InvalidClientTokenId") => {
91 AwsError::Auth {
92 message: message.to_string(),
93 }
94 }
95 403 => AwsError::AccessDenied {
96 message: format!("{code}: {message}"),
97 },
98 404 => AwsError::NotFound {
99 resource: message.to_string(),
100 },
101 429 => AwsError::Throttled {
102 retry_after: None,
103 message: message.to_string(),
104 },
105 _ if code == "Throttling"
106 || code == "ThrottlingException"
107 || code == "TooManyRequestsException" =>
108 {
109 AwsError::Throttled {
110 retry_after: None,
111 message: message.to_string(),
112 }
113 }
114 _ => AwsError::ServiceError {
115 code: code.to_string(),
116 message: message.to_string(),
117 status,
118 },
119 }
120}
121
122#[allow(dead_code)]
134pub(crate) fn parse_xml_error(status: u16, body: &str) -> AwsError {
135 let code = extract_xml_tag(body, "Code").unwrap_or_default();
137 let message = extract_xml_tag(body, "Message").unwrap_or_default();
138
139 if code.is_empty() {
140 return AwsError::ServiceError {
141 code: format!("HttpError{status}"),
142 message: truncate_body(body),
143 status,
144 };
145 }
146
147 classify_error(status, &code, &message)
148}
149
150#[allow(dead_code)]
158pub(crate) fn parse_json_error(status: u16, body: &str) -> AwsError {
159 let parsed: std::result::Result<serde_json::Value, _> = serde_json::from_str(body);
160 let (code, message) = match parsed {
161 Ok(val) => {
162 let code = val
163 .get("__type")
164 .and_then(|v| v.as_str())
165 .map(|s| {
166 s.rsplit_once('#').map(|(_, c)| c).unwrap_or(s).to_string()
168 })
169 .or_else(|| {
170 val.get("code")
171 .and_then(|v| v.as_str())
172 .map(|s| s.to_string())
173 })
174 .unwrap_or_default();
175 let message = val
176 .get("message")
177 .or_else(|| val.get("Message"))
178 .and_then(|v| v.as_str())
179 .unwrap_or_default()
180 .to_string();
181 (code, message)
182 }
183 Err(_) => (String::new(), truncate_body(body)),
184 };
185
186 if code.is_empty() {
187 return AwsError::ServiceError {
188 code: format!("HttpError{status}"),
189 message,
190 status,
191 };
192 }
193
194 classify_error(status, &code, &message)
195}
196
197fn truncate_body(body: &str) -> String {
199 if body.len() > 200 {
200 let end = body.floor_char_boundary(200);
201 format!("{}...", &body[..end])
202 } else {
203 body.to_string()
204 }
205}
206
207#[allow(dead_code)]
209fn extract_xml_tag(xml: &str, tag: &str) -> Option<String> {
210 let open = format!("<{tag}>");
211 let close = format!("</{tag}>");
212 let start = xml.find(&open)? + open.len();
213 let end = xml[start..].find(&close)? + start;
214 Some(xml[start..end].to_string())
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn throttled_is_retryable() {
223 let err = AwsError::Throttled {
224 retry_after: None,
225 message: "slow down".into(),
226 };
227 assert!(err.is_retryable());
228 }
229
230 #[test]
231 fn network_is_retryable() {
232 let err = AwsError::Network("timeout".into());
233 assert!(err.is_retryable());
234 }
235
236 #[test]
237 fn auth_is_not_retryable() {
238 let err = AwsError::Auth {
239 message: "bad creds".into(),
240 };
241 assert!(!err.is_retryable());
242 }
243
244 #[test]
245 fn service_error_4xx_is_not_retryable() {
246 let err = AwsError::ServiceError {
247 code: "ValidationError".into(),
248 message: "bad param".into(),
249 status: 400,
250 };
251 assert!(!err.is_retryable());
252 }
253
254 #[test]
255 fn service_error_500_is_retryable() {
256 let err = AwsError::ServiceError {
257 code: "InternalError".into(),
258 message: "internal".into(),
259 status: 500,
260 };
261 assert!(err.is_retryable());
262 }
263
264 #[test]
265 fn service_error_503_is_retryable() {
266 let err = AwsError::ServiceError {
267 code: "ServiceUnavailable".into(),
268 message: "unavailable".into(),
269 status: 503,
270 };
271 assert!(err.is_retryable());
272 }
273
274 #[test]
275 fn service_error_502_504_are_retryable() {
276 for status in [502, 504] {
277 let err = AwsError::ServiceError {
278 code: "ServerError".into(),
279 message: "error".into(),
280 status,
281 };
282 assert!(err.is_retryable(), "status {status} should be retryable");
283 }
284 }
285
286 #[test]
287 fn parse_xml_error_extracts_code_and_message() {
288 let body = r#"<ErrorResponse><Error><Code>InvalidParameterValue</Code><Message>Bad param</Message></Error></ErrorResponse>"#;
289 let err = parse_xml_error(400, body);
290 match err {
291 AwsError::ServiceError {
292 code,
293 message,
294 status,
295 } => {
296 assert_eq!(code, "InvalidParameterValue");
297 assert_eq!(message, "Bad param");
298 assert_eq!(status, 400);
299 }
300 other => panic!("expected ServiceError, got: {other}"),
301 }
302 }
303
304 #[test]
305 fn parse_xml_error_access_denied() {
306 let body = r#"<ErrorResponse><Error><Code>AccessDenied</Code><Message>not allowed</Message></Error></ErrorResponse>"#;
307 let err = parse_xml_error(403, body);
308 assert!(matches!(err, AwsError::AccessDenied { .. }));
309 }
310
311 #[test]
312 fn parse_xml_error_fallback_on_invalid_xml() {
313 let err = parse_xml_error(500, "not xml at all");
314 match err {
315 AwsError::ServiceError { code, status, .. } => {
316 assert_eq!(code, "HttpError500");
317 assert_eq!(status, 500);
318 }
319 other => panic!("expected ServiceError, got: {other}"),
320 }
321 }
322
323 #[test]
324 fn parse_json_error_extracts_type_and_message() {
325 let body = r#"{"__type": "ResourceNotFoundException", "message": "Log group not found"}"#;
326 let err = parse_json_error(404, body);
327 assert!(matches!(err, AwsError::NotFound { .. }));
328 }
329
330 #[test]
331 fn parse_json_error_strips_uri_prefix() {
332 let body =
333 r#"{"__type": "com.amazonaws.logs#ResourceNotFoundException", "message": "not found"}"#;
334 let err = parse_json_error(404, body);
335 assert!(matches!(err, AwsError::NotFound { .. }));
336 }
337
338 #[test]
339 fn parse_json_error_handles_capital_message() {
340 let body = r#"{"__type": "ThrottlingException", "Message": "Rate exceeded"}"#;
341 let err = parse_json_error(429, body);
342 match err {
343 AwsError::Throttled { message, .. } => {
344 assert_eq!(message, "Rate exceeded");
345 }
346 other => panic!("expected Throttled, got: {other}"),
347 }
348 }
349
350 #[test]
351 fn parse_json_error_fallback_on_invalid_json() {
352 let err = parse_json_error(500, "not json");
353 match err {
354 AwsError::ServiceError { code, status, .. } => {
355 assert_eq!(code, "HttpError500");
356 assert_eq!(status, 500);
357 }
358 other => panic!("expected ServiceError, got: {other}"),
359 }
360 }
361
362 #[test]
363 fn parse_xml_error_throttling() {
364 let body = r#"<ErrorResponse><Error><Code>Throttling</Code><Message>Rate exceeded</Message></Error></ErrorResponse>"#;
365 let err = parse_xml_error(400, body);
366 assert!(matches!(err, AwsError::Throttled { .. }));
367 }
368
369 #[test]
370 fn classify_401_unconditionally_as_auth() {
371 let err = classify_error(401, "SignatureDoesNotMatch", "bad sig");
373 assert!(matches!(err, AwsError::Auth { .. }), "got: {err}");
374
375 let err = classify_error(401, "MissingAuthenticationToken", "no token");
376 assert!(matches!(err, AwsError::Auth { .. }), "got: {err}");
377 }
378
379 #[test]
380 fn classify_403_expired_token_as_auth() {
381 let err = classify_error(403, "ExpiredToken", "token expired");
382 assert!(matches!(err, AwsError::Auth { .. }), "got: {err}");
383
384 let err = classify_error(403, "InvalidClientTokenId", "bad token");
385 assert!(matches!(err, AwsError::Auth { .. }), "got: {err}");
386 }
387
388 #[test]
389 fn classify_403_other_as_access_denied() {
390 let err = classify_error(403, "AccessDenied", "not allowed");
391 assert!(matches!(err, AwsError::AccessDenied { .. }), "got: {err}");
392 }
393
394 #[test]
395 fn parse_xml_error_truncates_html_body() {
396 let html = "<html><body>".to_string() + &"x".repeat(500) + "</body></html>";
397 let err = parse_xml_error(502, &html);
398 match err {
399 AwsError::ServiceError { message, .. } => {
400 assert!(
401 message.len() <= 203,
402 "message should be truncated, got {} chars",
403 message.len()
404 );
405 assert!(message.ends_with("..."));
406 }
407 other => panic!("expected ServiceError, got: {other}"),
408 }
409 }
410
411 #[test]
412 fn retry_after_returns_duration_for_throttled() {
413 let err = AwsError::Throttled {
414 retry_after: Some(Duration::from_secs(5)),
415 message: "slow down".into(),
416 };
417 assert_eq!(err.retry_after(), Some(Duration::from_secs(5)));
418 }
419
420 #[test]
421 fn retry_after_returns_none_for_non_throttled() {
422 let err = AwsError::Auth {
423 message: "bad creds".into(),
424 };
425 assert_eq!(err.retry_after(), None);
426 }
427
428 #[test]
429 fn retry_after_returns_none_for_throttled_without_duration() {
430 let err = AwsError::Throttled {
431 retry_after: None,
432 message: "slow down".into(),
433 };
434 assert_eq!(err.retry_after(), None);
435 }
436
437 #[test]
438 fn truncate_body_handles_multibyte_utf8() {
439 let body = "a".repeat(199) + "é" + &"b".repeat(100);
441 let truncated = truncate_body(&body);
444 assert!(truncated.ends_with("..."));
445 assert!(truncated.len() <= 203); }
447
448 #[test]
449 fn parse_json_error_truncates_html_body() {
450 let html = "<html><body>".to_string() + &"x".repeat(500) + "</body></html>";
451 let err = parse_json_error(502, &html);
452 match err {
453 AwsError::ServiceError { message, .. } => {
454 assert!(
455 message.len() <= 203,
456 "message should be truncated, got {} chars",
457 message.len()
458 );
459 assert!(message.ends_with("..."));
460 }
461 other => panic!("expected ServiceError, got: {other}"),
462 }
463 }
464}