1use std::path::PathBuf;
7use std::time::Duration;
8
9use reqwest::header::RETRY_AFTER;
10
11pub type SourceResult<T> = std::result::Result<T, SourceError>;
13
14#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum SourceError {
18 #[error("unauthorized while accessing {url}")]
20 Unauthorized { url: String },
21
22 #[error("rate limited while accessing {url}")]
25 RateLimited {
26 url: String,
27 retry_after: Option<Duration>,
28 },
29
30 #[error("not found while accessing {url}")]
32 NotFound { url: String },
33
34 #[error("hash verification failed: {source}")]
36 HashMismatch {
37 #[source]
38 source: modde_core::CoreError,
39 },
40
41 #[error("network error: {0}")]
43 Network(#[from] reqwest::Error),
44
45 #[error("I/O error: {0}")]
47 Io(#[from] std::io::Error),
48
49 #[error(transparent)]
51 Other(#[from] anyhow::Error),
52}
53
54impl SourceError {
55 pub fn other(error: impl Into<anyhow::Error>) -> Self {
57 Self::Other(error.into())
58 }
59
60 pub fn hash_mismatch(path: impl Into<PathBuf>, expected: u64, actual: u64) -> Self {
62 Self::HashMismatch {
63 source: modde_core::CoreError::HashMismatch {
64 path: path.into(),
65 expected: format!("{expected:016x}"),
66 actual: format!("{actual:016x}"),
67 },
68 }
69 }
70
71 pub(crate) fn is_retryable(&self) -> bool {
72 matches!(self, Self::Network(_) | Self::Other(_))
73 }
74}
75
76pub fn status_error(response: reqwest::Response) -> SourceResult<reqwest::Response> {
85 let status = response.status();
86 if status.is_success() {
87 return Ok(response);
88 }
89
90 let url = response.url().to_string();
91 match status {
92 reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
93 Err(SourceError::Unauthorized { url })
94 }
95 reqwest::StatusCode::TOO_MANY_REQUESTS => Err(SourceError::RateLimited {
96 retry_after: retry_after(response.headers()),
97 url,
98 }),
99 reqwest::StatusCode::NOT_FOUND => Err(SourceError::NotFound { url }),
100 _ => match response.error_for_status() {
101 Ok(response) => Ok(response),
102 Err(error) => Err(SourceError::Network(error)),
103 },
104 }
105}
106
107fn retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
108 let value = headers.get(RETRY_AFTER)?.to_str().ok()?.trim();
109 if let Ok(seconds) = value.parse::<u64>() {
110 return Some(Duration::from_secs(seconds));
111 }
112
113 httpdate::parse_http_date(value).ok().map(|deadline| {
114 deadline
115 .duration_since(std::time::SystemTime::now())
116 .unwrap_or(Duration::ZERO)
117 })
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use wiremock::matchers::{method, path};
124 use wiremock::{Mock, MockServer, ResponseTemplate};
125
126 async fn status_error_for(status: u16, retry_after: Option<&str>) -> SourceError {
127 let server = MockServer::start().await;
128 let mut template = ResponseTemplate::new(status);
129 if let Some(retry_after) = retry_after {
130 template = template.insert_header("Retry-After", retry_after);
131 }
132 Mock::given(method("GET"))
133 .and(path("/archive"))
134 .respond_with(template)
135 .mount(&server)
136 .await;
137
138 let response = reqwest::Client::new()
139 .get(format!("{}/archive", server.uri()))
140 .send()
141 .await
142 .unwrap();
143 status_error(response).unwrap_err()
144 }
145
146 #[tokio::test]
147 async fn maps_unauthorized_status() {
148 let error = status_error_for(401, None).await;
149 assert!(matches!(error, SourceError::Unauthorized { .. }));
150 }
151
152 #[tokio::test]
153 async fn maps_not_found_status() {
154 let error = status_error_for(404, None).await;
155 assert!(matches!(error, SourceError::NotFound { .. }));
156 }
157
158 #[tokio::test]
159 async fn maps_rate_limit_status_with_retry_after_seconds() {
160 let error = status_error_for(429, Some("17")).await;
161 assert!(matches!(
162 error,
163 SourceError::RateLimited {
164 retry_after: Some(duration),
165 ..
166 } if duration == Duration::from_secs(17)
167 ));
168 }
169
170 #[tokio::test]
171 async fn maps_rate_limit_status_with_retry_after_http_date() {
172 let error = status_error_for(429, Some("Wed, 21 Oct 2037 07:28:00 GMT")).await;
173 assert!(matches!(
174 error,
175 SourceError::RateLimited {
176 retry_after: Some(duration),
177 ..
178 } if duration > Duration::ZERO
179 ));
180 }
181}