1use crate::{CdnEndpoint, NetworkConfig, NetworkError, Result, RetryConfig};
4use bytes::Bytes;
5use reqwest::{header, Client, Response, StatusCode};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tracing::{debug, warn};
10
11#[derive(Debug, Clone)]
13pub struct ClientConfig {
14 pub timeout: Duration,
16 pub connect_timeout: Duration,
18 pub compression: bool,
20 pub user_agent: String,
22 pub retry: RetryConfig,
24}
25
26impl From<&NetworkConfig> for ClientConfig {
27 fn from(config: &NetworkConfig) -> Self {
28 Self {
29 timeout: config.timeout,
30 connect_timeout: config.connect_timeout,
31 compression: config.compression,
32 user_agent: config.user_agent.clone(),
33 retry: config.retry.clone(),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct RangeRequest {
41 pub start: u64,
43 pub end: u64,
45}
46
47impl RangeRequest {
48 pub fn new(start: u64, end: u64) -> Self {
50 Self { start, end }
51 }
52
53 pub fn header_value(&self) -> String {
55 format!("bytes={}-{}", self.start, self.end)
56 }
57
58 pub fn content_length(&self) -> u64 {
60 self.end - self.start + 1
61 }
62}
63
64pub struct HttpClient {
66 client: Client,
67 config: ClientConfig,
68 endpoint: CdnEndpoint,
69 semaphore: Arc<Semaphore>,
70}
71
72impl HttpClient {
73 pub fn new(endpoint: CdnEndpoint, config: ClientConfig) -> Result<Self> {
75 let mut builder = Client::builder()
76 .timeout(config.timeout)
77 .connect_timeout(config.connect_timeout)
78 .user_agent(&config.user_agent)
79 .pool_max_idle_per_host(endpoint.max_connections);
80
81 if config.compression {
82 builder = builder.gzip(true).brotli(true);
83 }
84
85 let client = builder
86 .build()
87 .map_err(|e| NetworkError::Configuration(e.to_string()))?;
88 let semaphore = Arc::new(Semaphore::new(endpoint.max_connections));
89
90 Ok(Self {
91 client,
92 config,
93 endpoint,
94 semaphore,
95 })
96 }
97
98 pub async fn fetch(&self, path: &str) -> Result<Bytes> {
100 let _permit = self
101 .semaphore
102 .acquire()
103 .await
104 .map_err(|_| NetworkError::Cancelled)?;
105
106 let url = format!(
107 "{}/{}",
108 self.endpoint.url.trim_end_matches('/'),
109 path.trim_start_matches('/')
110 );
111 debug!("Fetching: {}", url);
112
113 self.fetch_with_retry(&url, None).await
114 }
115
116 pub async fn fetch_range(&self, path: &str, range: RangeRequest) -> Result<Bytes> {
118 if !self.endpoint.supports_range {
119 return Err(NetworkError::Configuration(
120 "Endpoint does not support range requests".into(),
121 ));
122 }
123
124 let _permit = self
125 .semaphore
126 .acquire()
127 .await
128 .map_err(|_| NetworkError::Cancelled)?;
129
130 let url = format!(
131 "{}/{}",
132 self.endpoint.url.trim_end_matches('/'),
133 path.trim_start_matches('/')
134 );
135 debug!("Fetching range {}-{}: {}", range.start, range.end, url);
136
137 self.fetch_with_retry(&url, Some(range)).await
138 }
139
140 async fn fetch_with_retry(&self, url: &str, range: Option<RangeRequest>) -> Result<Bytes> {
142 let mut last_error = NetworkError::Connection("No attempts made".into());
143 let mut backoff = self.config.retry.initial_backoff;
144
145 for attempt in 0..=self.config.retry.max_retries {
146 if attempt > 0 {
147 debug!("Retry attempt {} after {:?}", attempt, backoff);
148 tokio::time::sleep(backoff).await;
149
150 backoff = Duration::from_secs_f64(
152 (backoff.as_secs_f64() * self.config.retry.multiplier)
153 .min(self.config.retry.max_backoff.as_secs_f64()),
154 );
155
156 if self.config.retry.jitter {
157 let jitter = rand::random::<f64>() * 0.3;
158 backoff = Duration::from_secs_f64(backoff.as_secs_f64() * (1.0 + jitter));
159 }
160 }
161
162 match self.fetch_once(url, range.clone()).await {
163 Ok(bytes) => return Ok(bytes),
164 Err(e) => {
165 if !e.is_retryable() {
166 return Err(e);
167 }
168
169 if let Some(retry_after) = e.retry_after() {
171 backoff = retry_after;
172 }
173
174 warn!("Request failed (attempt {}): {:?}", attempt + 1, e);
175 last_error = e;
176 }
177 }
178 }
179
180 Err(NetworkError::RetriesExhausted(last_error.to_string()))
181 }
182
183 async fn fetch_once(&self, url: &str, range: Option<RangeRequest>) -> Result<Bytes> {
185 let mut request = self.client.get(url);
186
187 for (key, value) in &self.endpoint.headers {
189 request = request.header(key, value);
190 }
191
192 if let Some(ref range) = range {
194 request = request.header(header::RANGE, range.header_value());
195 }
196
197 let response = request.send().await?;
198 self.handle_response(response, range).await
199 }
200
201 async fn handle_response(
203 &self,
204 response: Response,
205 range: Option<RangeRequest>,
206 ) -> Result<Bytes> {
207 let status = response.status();
208
209 match status {
210 StatusCode::OK | StatusCode::PARTIAL_CONTENT => {
211 if let Some(ref range) = range {
213 if let Some(len) = response.content_length() {
214 if len != range.content_length() {
215 warn!(
216 "Content length mismatch: expected {}, got {}",
217 range.content_length(),
218 len
219 );
220 }
221 }
222 }
223
224 response.bytes().await.map_err(|e| e.into())
225 }
226
227 StatusCode::NOT_FOUND => Err(NetworkError::NotFound("Fragment not found".into())),
228
229 StatusCode::TOO_MANY_REQUESTS => {
230 let retry_after = response
231 .headers()
232 .get(header::RETRY_AFTER)
233 .and_then(|v| v.to_str().ok())
234 .and_then(|v| v.parse::<u64>().ok())
235 .unwrap_or(60)
236 * 1000;
237
238 Err(NetworkError::RateLimited {
239 retry_after_ms: retry_after,
240 })
241 }
242
243 _ => Err(NetworkError::Http {
244 status: status.as_u16(),
245 message: response.text().await.unwrap_or_default(),
246 }),
247 }
248 }
249
250 pub async fn head(&self, path: &str) -> Result<HeadInfo> {
252 let _permit = self
253 .semaphore
254 .acquire()
255 .await
256 .map_err(|_| NetworkError::Cancelled)?;
257
258 let url = format!(
259 "{}/{}",
260 self.endpoint.url.trim_end_matches('/'),
261 path.trim_start_matches('/')
262 );
263
264 let mut request = self.client.head(&url);
265 for (key, value) in &self.endpoint.headers {
266 request = request.header(key, value);
267 }
268
269 let response = request.send().await?;
270
271 if !response.status().is_success() {
272 return Err(NetworkError::Http {
273 status: response.status().as_u16(),
274 message: "HEAD request failed".into(),
275 });
276 }
277
278 let headers = response.headers();
279
280 Ok(HeadInfo {
281 content_length: response.content_length(),
282 etag: headers
283 .get(header::ETAG)
284 .and_then(|v| v.to_str().ok())
285 .map(String::from),
286 last_modified: headers
287 .get(header::LAST_MODIFIED)
288 .and_then(|v| v.to_str().ok())
289 .map(String::from),
290 accepts_ranges: headers
291 .get(header::ACCEPT_RANGES)
292 .and_then(|v| v.to_str().ok())
293 .map(|v| v == "bytes")
294 .unwrap_or(false),
295 })
296 }
297}
298
299#[derive(Debug, Clone)]
301pub struct HeadInfo {
302 pub content_length: Option<u64>,
304 pub etag: Option<String>,
306 pub last_modified: Option<String>,
308 pub accepts_ranges: bool,
310}
311
312mod rand {
314 use std::time::{SystemTime, UNIX_EPOCH};
315
316 pub fn random<T: From<f64>>() -> T {
317 let nanos = SystemTime::now()
318 .duration_since(UNIX_EPOCH)
319 .unwrap()
320 .subsec_nanos();
321 T::from(nanos as f64 / u32::MAX as f64)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_range_header() {
331 let range = RangeRequest::new(0, 1023);
332 assert_eq!(range.header_value(), "bytes=0-1023");
333 assert_eq!(range.content_length(), 1024);
334 }
335}