Skip to main content

haagenti_network/
client.rs

1//! HTTP client for CDN communication
2
3use crate::{CdnEndpoint, NetworkConfig, NetworkError, Result, RetryConfig};
4use bytes::Bytes;
5use reqwest::{header, Client, Response, StatusCode};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tracing::{debug, warn};
10
11/// Configuration for the HTTP client
12#[derive(Debug, Clone)]
13pub struct ClientConfig {
14    /// Request timeout
15    pub timeout: Duration,
16    /// Connect timeout
17    pub connect_timeout: Duration,
18    /// Enable compression
19    pub compression: bool,
20    /// User agent
21    pub user_agent: String,
22    /// Retry configuration
23    pub retry: RetryConfig,
24}
25
26impl From<&NetworkConfig> for ClientConfig {
27    fn from(config: &NetworkConfig) -> Self {
28        Self {
29            timeout: config.timeout,
30            connect_timeout: config.connect_timeout,
31            compression: config.compression,
32            user_agent: config.user_agent.clone(),
33            retry: config.retry.clone(),
34        }
35    }
36}
37
38/// Range request for partial downloads
39#[derive(Debug, Clone)]
40pub struct RangeRequest {
41    /// Start byte
42    pub start: u64,
43    /// End byte (inclusive)
44    pub end: u64,
45}
46
47impl RangeRequest {
48    /// Create a new range request
49    pub fn new(start: u64, end: u64) -> Self {
50        Self { start, end }
51    }
52
53    /// Get the range header value
54    pub fn header_value(&self) -> String {
55        format!("bytes={}-{}", self.start, self.end)
56    }
57
58    /// Get the expected content length
59    pub fn content_length(&self) -> u64 {
60        self.end - self.start + 1
61    }
62}
63
64/// HTTP client for CDN requests
65pub struct HttpClient {
66    client: Client,
67    config: ClientConfig,
68    endpoint: CdnEndpoint,
69    semaphore: Arc<Semaphore>,
70}
71
72impl HttpClient {
73    /// Create a new HTTP client
74    pub fn new(endpoint: CdnEndpoint, config: ClientConfig) -> Result<Self> {
75        let mut builder = Client::builder()
76            .timeout(config.timeout)
77            .connect_timeout(config.connect_timeout)
78            .user_agent(&config.user_agent)
79            .pool_max_idle_per_host(endpoint.max_connections);
80
81        if config.compression {
82            builder = builder.gzip(true).brotli(true);
83        }
84
85        let client = builder
86            .build()
87            .map_err(|e| NetworkError::Configuration(e.to_string()))?;
88        let semaphore = Arc::new(Semaphore::new(endpoint.max_connections));
89
90        Ok(Self {
91            client,
92            config,
93            endpoint,
94            semaphore,
95        })
96    }
97
98    /// Fetch a fragment by path
99    pub async fn fetch(&self, path: &str) -> Result<Bytes> {
100        let _permit = self
101            .semaphore
102            .acquire()
103            .await
104            .map_err(|_| NetworkError::Cancelled)?;
105
106        let url = format!(
107            "{}/{}",
108            self.endpoint.url.trim_end_matches('/'),
109            path.trim_start_matches('/')
110        );
111        debug!("Fetching: {}", url);
112
113        self.fetch_with_retry(&url, None).await
114    }
115
116    /// Fetch a range of bytes
117    pub async fn fetch_range(&self, path: &str, range: RangeRequest) -> Result<Bytes> {
118        if !self.endpoint.supports_range {
119            return Err(NetworkError::Configuration(
120                "Endpoint does not support range requests".into(),
121            ));
122        }
123
124        let _permit = self
125            .semaphore
126            .acquire()
127            .await
128            .map_err(|_| NetworkError::Cancelled)?;
129
130        let url = format!(
131            "{}/{}",
132            self.endpoint.url.trim_end_matches('/'),
133            path.trim_start_matches('/')
134        );
135        debug!("Fetching range {}-{}: {}", range.start, range.end, url);
136
137        self.fetch_with_retry(&url, Some(range)).await
138    }
139
140    /// Fetch with retry logic
141    async fn fetch_with_retry(&self, url: &str, range: Option<RangeRequest>) -> Result<Bytes> {
142        let mut last_error = NetworkError::Connection("No attempts made".into());
143        let mut backoff = self.config.retry.initial_backoff;
144
145        for attempt in 0..=self.config.retry.max_retries {
146            if attempt > 0 {
147                debug!("Retry attempt {} after {:?}", attempt, backoff);
148                tokio::time::sleep(backoff).await;
149
150                // Exponential backoff with jitter
151                backoff = Duration::from_secs_f64(
152                    (backoff.as_secs_f64() * self.config.retry.multiplier)
153                        .min(self.config.retry.max_backoff.as_secs_f64()),
154                );
155
156                if self.config.retry.jitter {
157                    let jitter = rand::random::<f64>() * 0.3;
158                    backoff = Duration::from_secs_f64(backoff.as_secs_f64() * (1.0 + jitter));
159                }
160            }
161
162            match self.fetch_once(url, range.clone()).await {
163                Ok(bytes) => return Ok(bytes),
164                Err(e) => {
165                    if !e.is_retryable() {
166                        return Err(e);
167                    }
168
169                    // Check for rate limiting
170                    if let Some(retry_after) = e.retry_after() {
171                        backoff = retry_after;
172                    }
173
174                    warn!("Request failed (attempt {}): {:?}", attempt + 1, e);
175                    last_error = e;
176                }
177            }
178        }
179
180        Err(NetworkError::RetriesExhausted(last_error.to_string()))
181    }
182
183    /// Single fetch attempt
184    async fn fetch_once(&self, url: &str, range: Option<RangeRequest>) -> Result<Bytes> {
185        let mut request = self.client.get(url);
186
187        // Add custom headers
188        for (key, value) in &self.endpoint.headers {
189            request = request.header(key, value);
190        }
191
192        // Add range header if specified
193        if let Some(ref range) = range {
194            request = request.header(header::RANGE, range.header_value());
195        }
196
197        let response = request.send().await?;
198        self.handle_response(response, range).await
199    }
200
201    /// Handle HTTP response
202    async fn handle_response(
203        &self,
204        response: Response,
205        range: Option<RangeRequest>,
206    ) -> Result<Bytes> {
207        let status = response.status();
208
209        match status {
210            StatusCode::OK | StatusCode::PARTIAL_CONTENT => {
211                // Validate content length for range requests
212                if let Some(ref range) = range {
213                    if let Some(len) = response.content_length() {
214                        if len != range.content_length() {
215                            warn!(
216                                "Content length mismatch: expected {}, got {}",
217                                range.content_length(),
218                                len
219                            );
220                        }
221                    }
222                }
223
224                response.bytes().await.map_err(|e| e.into())
225            }
226
227            StatusCode::NOT_FOUND => Err(NetworkError::NotFound("Fragment not found".into())),
228
229            StatusCode::TOO_MANY_REQUESTS => {
230                let retry_after = response
231                    .headers()
232                    .get(header::RETRY_AFTER)
233                    .and_then(|v| v.to_str().ok())
234                    .and_then(|v| v.parse::<u64>().ok())
235                    .unwrap_or(60)
236                    * 1000;
237
238                Err(NetworkError::RateLimited {
239                    retry_after_ms: retry_after,
240                })
241            }
242
243            _ => Err(NetworkError::Http {
244                status: status.as_u16(),
245                message: response.text().await.unwrap_or_default(),
246            }),
247        }
248    }
249
250    /// Get HEAD information (for cache validation)
251    pub async fn head(&self, path: &str) -> Result<HeadInfo> {
252        let _permit = self
253            .semaphore
254            .acquire()
255            .await
256            .map_err(|_| NetworkError::Cancelled)?;
257
258        let url = format!(
259            "{}/{}",
260            self.endpoint.url.trim_end_matches('/'),
261            path.trim_start_matches('/')
262        );
263
264        let mut request = self.client.head(&url);
265        for (key, value) in &self.endpoint.headers {
266            request = request.header(key, value);
267        }
268
269        let response = request.send().await?;
270
271        if !response.status().is_success() {
272            return Err(NetworkError::Http {
273                status: response.status().as_u16(),
274                message: "HEAD request failed".into(),
275            });
276        }
277
278        let headers = response.headers();
279
280        Ok(HeadInfo {
281            content_length: response.content_length(),
282            etag: headers
283                .get(header::ETAG)
284                .and_then(|v| v.to_str().ok())
285                .map(String::from),
286            last_modified: headers
287                .get(header::LAST_MODIFIED)
288                .and_then(|v| v.to_str().ok())
289                .map(String::from),
290            accepts_ranges: headers
291                .get(header::ACCEPT_RANGES)
292                .and_then(|v| v.to_str().ok())
293                .map(|v| v == "bytes")
294                .unwrap_or(false),
295        })
296    }
297}
298
299/// HEAD response information
300#[derive(Debug, Clone)]
301pub struct HeadInfo {
302    /// Content length
303    pub content_length: Option<u64>,
304    /// ETag for cache validation
305    pub etag: Option<String>,
306    /// Last modified timestamp
307    pub last_modified: Option<String>,
308    /// Whether server accepts range requests
309    pub accepts_ranges: bool,
310}
311
312// Random number helper
313mod rand {
314    use std::time::{SystemTime, UNIX_EPOCH};
315
316    pub fn random<T: From<f64>>() -> T {
317        let nanos = SystemTime::now()
318            .duration_since(UNIX_EPOCH)
319            .unwrap()
320            .subsec_nanos();
321        T::from(nanos as f64 / u32::MAX as f64)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_range_header() {
331        let range = RangeRequest::new(0, 1023);
332        assert_eq!(range.header_value(), "bytes=0-1023");
333        assert_eq!(range.content_length(), 1024);
334    }
335}