Skip to main content

aws_lite_rs/
error.rs

1//! Error types for AWS HTTP client operations.
2
3use std::time::Duration;
4use thiserror::Error;
5
6/// Result type alias using AwsError.
7pub type Result<T> = std::result::Result<T, AwsError>;
8
9/// Errors that can occur during AWS API operations.
10#[derive(Debug, Error, Clone)]
11pub enum AwsError {
12    /// Authentication failed (invalid credentials, expired token).
13    #[error("Authentication failed: {message}")]
14    Auth { message: String },
15
16    /// Access denied (insufficient IAM permissions).
17    #[error("Access denied: {message}")]
18    AccessDenied { message: String },
19
20    /// Resource not found.
21    #[error("Resource not found: {resource}")]
22    NotFound { resource: String },
23
24    /// Request throttled.
25    #[error("Throttled (retry after {retry_after:?})")]
26    Throttled {
27        retry_after: Option<Duration>,
28        message: String,
29    },
30
31    /// AWS service error.
32    #[error("Service error ({code}): {message}")]
33    ServiceError {
34        code: String,
35        message: String,
36        status: u16,
37    },
38
39    /// Network error.
40    #[error("Network error: {0}")]
41    Network(String),
42
43    /// Invalid response.
44    #[error("Invalid response: {message}")]
45    InvalidResponse {
46        message: String,
47        body: Option<String>,
48    },
49
50    /// XML parsing error.
51    #[error("XML parse error: {message}")]
52    XmlParse { message: String },
53}
54
55impl From<reqwest::Error> for AwsError {
56    fn from(err: reqwest::Error) -> Self {
57        Self::Network(err.to_string())
58    }
59}
60
61impl AwsError {
62    /// Returns true if this error is retryable.
63    pub fn is_retryable(&self) -> bool {
64        match self {
65            Self::Throttled { .. } | Self::Network(_) => true,
66            Self::ServiceError { status, .. } => matches!(status, 500 | 502 | 503 | 504),
67            _ => false,
68        }
69    }
70
71    /// Extract retry-after duration if present.
72    pub fn retry_after(&self) -> Option<Duration> {
73        match self {
74            Self::Throttled {
75                retry_after: Some(duration),
76                ..
77            } => Some(*duration),
78            _ => None,
79        }
80    }
81}
82
83/// Map an AWS error code + HTTP status to a typed `AwsError`.
84#[allow(dead_code)]
85fn classify_error(status: u16, code: &str, message: &str) -> AwsError {
86    match status {
87        401 => AwsError::Auth {
88            message: format!("{code}: {message}"),
89        },
90        403 if code.contains("ExpiredToken") || code.contains("InvalidClientTokenId") => {
91            AwsError::Auth {
92                message: message.to_string(),
93            }
94        }
95        403 => AwsError::AccessDenied {
96            message: format!("{code}: {message}"),
97        },
98        404 => AwsError::NotFound {
99            resource: message.to_string(),
100        },
101        429 => AwsError::Throttled {
102            retry_after: None,
103            message: message.to_string(),
104        },
105        _ if code == "Throttling"
106            || code == "ThrottlingException"
107            || code == "TooManyRequestsException" =>
108        {
109            AwsError::Throttled {
110                retry_after: None,
111                message: message.to_string(),
112            }
113        }
114        _ => AwsError::ServiceError {
115            code: code.to_string(),
116            message: message.to_string(),
117            status,
118        },
119    }
120}
121
122/// Parse an AWS XML error response (Query/XML protocol).
123///
124/// Expected format:
125/// ```xml
126/// <ErrorResponse>
127///   <Error>
128///     <Code>InvalidParameterValue</Code>
129///     <Message>The value supplied is not valid.</Message>
130///   </Error>
131/// </ErrorResponse>
132/// ```
133#[allow(dead_code)]
134pub(crate) fn parse_xml_error(status: u16, body: &str) -> AwsError {
135    // Try to extract <Code> and <Message> from the XML body
136    let code = extract_xml_tag(body, "Code").unwrap_or_default();
137    let message = extract_xml_tag(body, "Message").unwrap_or_default();
138
139    if code.is_empty() {
140        return AwsError::ServiceError {
141            code: format!("HttpError{status}"),
142            message: truncate_body(body),
143            status,
144        };
145    }
146
147    classify_error(status, &code, &message)
148}
149
150/// Parse an AWS JSON error response (JSON protocol).
151///
152/// Expected format:
153/// ```json
154/// {"__type": "ResourceNotFoundException", "message": "The specified log group does not exist."}
155/// ```
156/// Some services use `Message` (capital M) instead of `message`.
157#[allow(dead_code)]
158pub(crate) fn parse_json_error(status: u16, body: &str) -> AwsError {
159    let parsed: std::result::Result<serde_json::Value, _> = serde_json::from_str(body);
160    let (code, message) = match parsed {
161        Ok(val) => {
162            let code = val
163                .get("__type")
164                .and_then(|v| v.as_str())
165                .map(|s| {
166                    // __type can be a full URI like "com.amazonaws.logs#ResourceNotFoundException"
167                    s.rsplit_once('#').map(|(_, c)| c).unwrap_or(s).to_string()
168                })
169                .or_else(|| {
170                    val.get("code")
171                        .and_then(|v| v.as_str())
172                        .map(|s| s.to_string())
173                })
174                .unwrap_or_default();
175            let message = val
176                .get("message")
177                .or_else(|| val.get("Message"))
178                .and_then(|v| v.as_str())
179                .unwrap_or_default()
180                .to_string();
181            (code, message)
182        }
183        Err(_) => (String::new(), truncate_body(body)),
184    };
185
186    if code.is_empty() {
187        return AwsError::ServiceError {
188            code: format!("HttpError{status}"),
189            message,
190            status,
191        };
192    }
193
194    classify_error(status, &code, &message)
195}
196
197/// Truncate a body string for error messages, avoiding unreadable HTML pages.
198fn truncate_body(body: &str) -> String {
199    if body.len() > 200 {
200        let end = body.floor_char_boundary(200);
201        format!("{}...", &body[..end])
202    } else {
203        body.to_string()
204    }
205}
206
207/// Extract text content from a simple XML tag (no attributes, no nesting).
208#[allow(dead_code)]
209fn extract_xml_tag(xml: &str, tag: &str) -> Option<String> {
210    let open = format!("<{tag}>");
211    let close = format!("</{tag}>");
212    let start = xml.find(&open)? + open.len();
213    let end = xml[start..].find(&close)? + start;
214    Some(xml[start..end].to_string())
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn throttled_is_retryable() {
223        let err = AwsError::Throttled {
224            retry_after: None,
225            message: "slow down".into(),
226        };
227        assert!(err.is_retryable());
228    }
229
230    #[test]
231    fn network_is_retryable() {
232        let err = AwsError::Network("timeout".into());
233        assert!(err.is_retryable());
234    }
235
236    #[test]
237    fn auth_is_not_retryable() {
238        let err = AwsError::Auth {
239            message: "bad creds".into(),
240        };
241        assert!(!err.is_retryable());
242    }
243
244    #[test]
245    fn service_error_4xx_is_not_retryable() {
246        let err = AwsError::ServiceError {
247            code: "ValidationError".into(),
248            message: "bad param".into(),
249            status: 400,
250        };
251        assert!(!err.is_retryable());
252    }
253
254    #[test]
255    fn service_error_500_is_retryable() {
256        let err = AwsError::ServiceError {
257            code: "InternalError".into(),
258            message: "internal".into(),
259            status: 500,
260        };
261        assert!(err.is_retryable());
262    }
263
264    #[test]
265    fn service_error_503_is_retryable() {
266        let err = AwsError::ServiceError {
267            code: "ServiceUnavailable".into(),
268            message: "unavailable".into(),
269            status: 503,
270        };
271        assert!(err.is_retryable());
272    }
273
274    #[test]
275    fn service_error_502_504_are_retryable() {
276        for status in [502, 504] {
277            let err = AwsError::ServiceError {
278                code: "ServerError".into(),
279                message: "error".into(),
280                status,
281            };
282            assert!(err.is_retryable(), "status {status} should be retryable");
283        }
284    }
285
286    #[test]
287    fn parse_xml_error_extracts_code_and_message() {
288        let body = r#"<ErrorResponse><Error><Code>InvalidParameterValue</Code><Message>Bad param</Message></Error></ErrorResponse>"#;
289        let err = parse_xml_error(400, body);
290        match err {
291            AwsError::ServiceError {
292                code,
293                message,
294                status,
295            } => {
296                assert_eq!(code, "InvalidParameterValue");
297                assert_eq!(message, "Bad param");
298                assert_eq!(status, 400);
299            }
300            other => panic!("expected ServiceError, got: {other}"),
301        }
302    }
303
304    #[test]
305    fn parse_xml_error_access_denied() {
306        let body = r#"<ErrorResponse><Error><Code>AccessDenied</Code><Message>not allowed</Message></Error></ErrorResponse>"#;
307        let err = parse_xml_error(403, body);
308        assert!(matches!(err, AwsError::AccessDenied { .. }));
309    }
310
311    #[test]
312    fn parse_xml_error_fallback_on_invalid_xml() {
313        let err = parse_xml_error(500, "not xml at all");
314        match err {
315            AwsError::ServiceError { code, status, .. } => {
316                assert_eq!(code, "HttpError500");
317                assert_eq!(status, 500);
318            }
319            other => panic!("expected ServiceError, got: {other}"),
320        }
321    }
322
323    #[test]
324    fn parse_json_error_extracts_type_and_message() {
325        let body = r#"{"__type": "ResourceNotFoundException", "message": "Log group not found"}"#;
326        let err = parse_json_error(404, body);
327        assert!(matches!(err, AwsError::NotFound { .. }));
328    }
329
330    #[test]
331    fn parse_json_error_strips_uri_prefix() {
332        let body =
333            r#"{"__type": "com.amazonaws.logs#ResourceNotFoundException", "message": "not found"}"#;
334        let err = parse_json_error(404, body);
335        assert!(matches!(err, AwsError::NotFound { .. }));
336    }
337
338    #[test]
339    fn parse_json_error_handles_capital_message() {
340        let body = r#"{"__type": "ThrottlingException", "Message": "Rate exceeded"}"#;
341        let err = parse_json_error(429, body);
342        match err {
343            AwsError::Throttled { message, .. } => {
344                assert_eq!(message, "Rate exceeded");
345            }
346            other => panic!("expected Throttled, got: {other}"),
347        }
348    }
349
350    #[test]
351    fn parse_json_error_fallback_on_invalid_json() {
352        let err = parse_json_error(500, "not json");
353        match err {
354            AwsError::ServiceError { code, status, .. } => {
355                assert_eq!(code, "HttpError500");
356                assert_eq!(status, 500);
357            }
358            other => panic!("expected ServiceError, got: {other}"),
359        }
360    }
361
362    #[test]
363    fn parse_xml_error_throttling() {
364        let body = r#"<ErrorResponse><Error><Code>Throttling</Code><Message>Rate exceeded</Message></Error></ErrorResponse>"#;
365        let err = parse_xml_error(400, body);
366        assert!(matches!(err, AwsError::Throttled { .. }));
367    }
368
369    #[test]
370    fn classify_401_unconditionally_as_auth() {
371        // 401 with any code should map to Auth, not ServiceError
372        let err = classify_error(401, "SignatureDoesNotMatch", "bad sig");
373        assert!(matches!(err, AwsError::Auth { .. }), "got: {err}");
374
375        let err = classify_error(401, "MissingAuthenticationToken", "no token");
376        assert!(matches!(err, AwsError::Auth { .. }), "got: {err}");
377    }
378
379    #[test]
380    fn classify_403_expired_token_as_auth() {
381        let err = classify_error(403, "ExpiredToken", "token expired");
382        assert!(matches!(err, AwsError::Auth { .. }), "got: {err}");
383
384        let err = classify_error(403, "InvalidClientTokenId", "bad token");
385        assert!(matches!(err, AwsError::Auth { .. }), "got: {err}");
386    }
387
388    #[test]
389    fn classify_403_other_as_access_denied() {
390        let err = classify_error(403, "AccessDenied", "not allowed");
391        assert!(matches!(err, AwsError::AccessDenied { .. }), "got: {err}");
392    }
393
394    #[test]
395    fn parse_xml_error_truncates_html_body() {
396        let html = "<html><body>".to_string() + &"x".repeat(500) + "</body></html>";
397        let err = parse_xml_error(502, &html);
398        match err {
399            AwsError::ServiceError { message, .. } => {
400                assert!(
401                    message.len() <= 203,
402                    "message should be truncated, got {} chars",
403                    message.len()
404                );
405                assert!(message.ends_with("..."));
406            }
407            other => panic!("expected ServiceError, got: {other}"),
408        }
409    }
410
411    #[test]
412    fn retry_after_returns_duration_for_throttled() {
413        let err = AwsError::Throttled {
414            retry_after: Some(Duration::from_secs(5)),
415            message: "slow down".into(),
416        };
417        assert_eq!(err.retry_after(), Some(Duration::from_secs(5)));
418    }
419
420    #[test]
421    fn retry_after_returns_none_for_non_throttled() {
422        let err = AwsError::Auth {
423            message: "bad creds".into(),
424        };
425        assert_eq!(err.retry_after(), None);
426    }
427
428    #[test]
429    fn retry_after_returns_none_for_throttled_without_duration() {
430        let err = AwsError::Throttled {
431            retry_after: None,
432            message: "slow down".into(),
433        };
434        assert_eq!(err.retry_after(), None);
435    }
436
437    #[test]
438    fn truncate_body_handles_multibyte_utf8() {
439        // 'é' is 2 bytes in UTF-8; build a string where byte 200 falls mid-character
440        let body = "a".repeat(199) + "é" + &"b".repeat(100);
441        // byte len: 199 + 2 + 100 = 301, so truncation triggers
442        // byte 200 is inside 'é', floor_char_boundary should back up to 199
443        let truncated = truncate_body(&body);
444        assert!(truncated.ends_with("..."));
445        assert!(truncated.len() <= 203); // 200 + "..."
446    }
447
448    #[test]
449    fn parse_json_error_truncates_html_body() {
450        let html = "<html><body>".to_string() + &"x".repeat(500) + "</body></html>";
451        let err = parse_json_error(502, &html);
452        match err {
453            AwsError::ServiceError { message, .. } => {
454                assert!(
455                    message.len() <= 203,
456                    "message should be truncated, got {} chars",
457                    message.len()
458                );
459                assert!(message.ends_with("..."));
460            }
461            other => panic!("expected ServiceError, got: {other}"),
462        }
463    }
464}