Skip to main content

gosh_dl/http/
resume.rs

1//! Resume Detection and Validation
2//!
3//! This module handles detecting resume capability and validating
4//! that a partially downloaded file can be safely resumed.
5
6use crate::error::{EngineError, NetworkErrorKind, ProtocolErrorKind, Result};
7use reqwest::{Client, StatusCode};
8use std::path::Path;
9use tokio::fs;
10
11use super::ACCEPT_ENCODING_IDENTITY;
12
13#[derive(Debug, Clone, Copy, Default)]
14pub struct RangedResponseContext<'a> {
15    pub sent_if_range: bool,
16    pub expected_etag: Option<&'a str>,
17    pub expected_last_modified: Option<&'a str>,
18    pub response_etag: Option<&'a str>,
19    pub response_last_modified: Option<&'a str>,
20}
21
22/// Information about resume capability
23#[derive(Debug, Clone)]
24pub struct ResumeInfo {
25    /// Whether the server supports Range requests
26    pub supports_range: bool,
27    /// ETag for validation
28    pub etag: Option<String>,
29    /// Last-Modified for validation
30    pub last_modified: Option<String>,
31    /// Content-Length
32    pub content_length: Option<u64>,
33    /// Can safely resume from existing partial file
34    pub can_resume: bool,
35    /// Size of existing partial file
36    pub existing_size: u64,
37}
38
39/// Check if a download can be resumed
40pub async fn check_resume(
41    client: &Client,
42    url: &str,
43    user_agent: &str,
44    part_path: &Path,
45    saved_etag: Option<&str>,
46    saved_last_modified: Option<&str>,
47) -> Result<ResumeInfo> {
48    // Check if partial file exists
49    let existing_size = if part_path.exists() {
50        fs::metadata(part_path).await.map(|m| m.len()).unwrap_or(0)
51    } else {
52        0
53    };
54
55    // Send HEAD request to check server capabilities
56    let response = client
57        .head(url)
58        .header("User-Agent", user_agent)
59        .header("Accept-Encoding", ACCEPT_ENCODING_IDENTITY)
60        .send()
61        .await
62        .map_err(|e| {
63            EngineError::protocol(
64                ProtocolErrorKind::InvalidResponse,
65                format!("HEAD request failed: {}", e),
66            )
67        })?;
68
69    if !response.status().is_success() {
70        return Err(EngineError::protocol(
71            ProtocolErrorKind::InvalidResponse,
72            format!("HEAD request returned: {}", response.status()),
73        ));
74    }
75
76    let headers = response.headers();
77
78    // Check Accept-Ranges header
79    let supports_range = headers
80        .get("accept-ranges")
81        .and_then(|v| v.to_str().ok())
82        .map(|v| v.contains("bytes"))
83        .unwrap_or(false);
84
85    // Get ETag
86    let etag = headers
87        .get("etag")
88        .and_then(|v| v.to_str().ok())
89        .map(|s| s.to_string());
90
91    // Get Last-Modified
92    let last_modified = headers
93        .get("last-modified")
94        .and_then(|v| v.to_str().ok())
95        .map(|s| s.to_string());
96
97    // Get Content-Length
98    let content_length = headers
99        .get("content-length")
100        .and_then(|v| v.to_str().ok())
101        .and_then(|s| s.parse::<u64>().ok());
102
103    // Determine if we can resume
104    let can_resume = if existing_size == 0 {
105        // No partial file, nothing to resume
106        false
107    } else if !supports_range {
108        // Server doesn't support ranges
109        false
110    } else {
111        // Validate ETag or Last-Modified if we have saved values
112        let etag_valid = match (saved_etag, &etag) {
113            (Some(saved), Some(current)) => saved == current,
114            (Some(_), None) => false, // Had ETag, now missing
115            (None, _) => true,        // Didn't have ETag, can't validate
116        };
117
118        let last_modified_valid = match (saved_last_modified, &last_modified) {
119            (Some(saved), Some(current)) => saved == current,
120            (Some(_), None) => false,
121            (None, _) => true,
122        };
123
124        // Must pass both validations
125        etag_valid && last_modified_valid
126    };
127
128    Ok(ResumeInfo {
129        supports_range,
130        etag,
131        last_modified,
132        content_length,
133        can_resume,
134        existing_size,
135    })
136}
137
138/// Verify that a Range request returns the expected response
139pub async fn verify_range_support(client: &Client, url: &str, user_agent: &str) -> Result<bool> {
140    // Request just the first byte
141    let response = client
142        .get(url)
143        .header("User-Agent", user_agent)
144        .header("Accept-Encoding", ACCEPT_ENCODING_IDENTITY)
145        .header("Range", "bytes=0-0")
146        .send()
147        .await
148        .map_err(|e| {
149            EngineError::protocol(
150                ProtocolErrorKind::InvalidResponse,
151                format!("Range request failed: {}", e),
152            )
153        })?;
154
155    // Should get 206 Partial Content
156    Ok(response.status() == reqwest::StatusCode::PARTIAL_CONTENT)
157}
158
159/// Validate that a response to a ranged request honors the requested byte span.
160pub fn validate_ranged_response(
161    expected_start: u64,
162    expected_end: Option<u64>,
163    status: StatusCode,
164    content_range: Option<&str>,
165    context: RangedResponseContext<'_>,
166) -> Result<()> {
167    let restart_required = |message: String| {
168        EngineError::protocol(
169            ProtocolErrorKind::RangeNotSupported,
170            format!("{message}. Restart from byte 0 required"),
171        )
172    };
173
174    if status != StatusCode::PARTIAL_CONTENT {
175        if status == StatusCode::OK {
176            if let (Some(expected), Some(actual)) = (context.expected_etag, context.response_etag) {
177                if expected != actual {
178                    return Err(restart_required(format!(
179                        "Server returned 200 OK to a ranged request after ETag changed from {} to {}",
180                        expected, actual
181                    )));
182                }
183            }
184
185            if let (Some(expected), Some(actual)) = (
186                context.expected_last_modified,
187                context.response_last_modified,
188            ) {
189                if expected != actual {
190                    return Err(restart_required(format!(
191                        "Server returned 200 OK to a ranged request after Last-Modified changed from {} to {}",
192                        expected, actual
193                    )));
194                }
195            }
196
197            if context.sent_if_range {
198                return Err(restart_required(
199                    "Server returned 200 OK to a ranged request after If-Range validation; the remote file may have changed or the server ignored Range".to_string(),
200                ));
201            }
202        }
203
204        return Err(EngineError::protocol(
205            ProtocolErrorKind::RangeNotSupported,
206            format!(
207                "Server ignored Range request starting at byte {} and returned {}. Restart from byte 0 required",
208                expected_start, status
209            ),
210        ));
211    }
212
213    let content_range = content_range.ok_or_else(|| {
214        restart_required("Missing Content-Range header on ranged response".to_string())
215    })?;
216
217    if let Err(err) = validate_resumed_position(expected_start, content_range) {
218        return Err(restart_required(format!(
219            "Server returned mismatched Content-Range for ranged request starting at byte {}: {}",
220            expected_start, err
221        )));
222    }
223
224    if let Some(expected_end) = expected_end {
225        let (_, actual_end, _) = parse_content_range(content_range).ok_or_else(|| {
226            restart_required(format!("Invalid Content-Range header: {}", content_range))
227        })?;
228
229        if actual_end != expected_end {
230            return Err(EngineError::protocol(
231                ProtocolErrorKind::RangeNotSupported,
232                format!(
233                    "Range end mismatch: expected {}, got {}. Restart from byte 0 required",
234                    expected_end, actual_end
235                ),
236            ));
237        }
238    }
239
240    Ok(())
241}
242
243pub fn should_restart_without_ranges(err: &EngineError) -> bool {
244    match err {
245        EngineError::Protocol {
246            kind: ProtocolErrorKind::RangeNotSupported,
247            ..
248        } => true,
249        EngineError::Network {
250            kind: NetworkErrorKind::HttpStatus(416),
251            ..
252        } => true,
253        _ => false,
254    }
255}
256
257/// Calculate the range header value for resuming
258pub fn calculate_range_header(start: u64, end: Option<u64>) -> String {
259    match end {
260        Some(end) => format!("bytes={}-{}", start, end),
261        None => format!("bytes={}-", start),
262    }
263}
264
265/// Parse Content-Range header to extract byte positions
266///
267/// Format: "bytes start-end/total" or "bytes start-end/*"
268pub fn parse_content_range(header: &str) -> Option<(u64, u64, Option<u64>)> {
269    let header = header.strip_prefix("bytes ")?;
270    let parts: Vec<&str> = header.split('/').collect();
271    if parts.len() != 2 {
272        return None;
273    }
274
275    let range_parts: Vec<&str> = parts[0].split('-').collect();
276    if range_parts.len() != 2 {
277        return None;
278    }
279
280    let start = range_parts[0].parse::<u64>().ok()?;
281    let end = range_parts[1].parse::<u64>().ok()?;
282    let total = if parts[1] == "*" {
283        None
284    } else {
285        parts[1].parse::<u64>().ok()
286    };
287
288    Some((start, end, total))
289}
290
291/// Validate that a resumed download starts at the expected position
292pub fn validate_resumed_position(expected_start: u64, content_range: &str) -> Result<()> {
293    let (actual_start, _, _) = parse_content_range(content_range).ok_or_else(|| {
294        EngineError::protocol(
295            ProtocolErrorKind::InvalidResponse,
296            format!("Invalid Content-Range header: {}", content_range),
297        )
298    })?;
299
300    if actual_start != expected_start {
301        return Err(EngineError::protocol(
302            ProtocolErrorKind::InvalidResponse,
303            format!(
304                "Resume position mismatch: expected {}, got {}",
305                expected_start, actual_start
306            ),
307        ));
308    }
309
310    Ok(())
311}
312
313/// Determine if a partial file should be deleted and restarted
314pub async fn should_restart(
315    part_path: &Path,
316    expected_size: Option<u64>,
317    saved_etag: Option<&str>,
318    current_etag: Option<&str>,
319) -> bool {
320    // If file doesn't exist, no need to restart
321    if !part_path.exists() {
322        return false;
323    }
324
325    // If ETag changed, must restart
326    if let (Some(saved), Some(current)) = (saved_etag, current_etag) {
327        if saved != current {
328            return true;
329        }
330    }
331
332    // If we have expected size and partial is larger, restart
333    if let Some(expected) = expected_size {
334        if let Ok(metadata) = fs::metadata(part_path).await {
335            if metadata.len() > expected {
336                return true;
337            }
338        }
339    }
340
341    false
342}
343
344/// Clean up a partial file that can't be resumed
345pub async fn cleanup_partial(part_path: &Path) -> Result<()> {
346    if part_path.exists() {
347        fs::remove_file(part_path)
348            .await
349            .map_err(|e| EngineError::Internal(format!("Failed to remove partial file: {}", e)))?;
350    }
351    Ok(())
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_calculate_range_header() {
360        assert_eq!(calculate_range_header(0, None), "bytes=0-");
361        assert_eq!(calculate_range_header(100, None), "bytes=100-");
362        assert_eq!(calculate_range_header(0, Some(99)), "bytes=0-99");
363        assert_eq!(calculate_range_header(1000, Some(1999)), "bytes=1000-1999");
364    }
365
366    #[test]
367    fn test_parse_content_range() {
368        assert_eq!(
369            parse_content_range("bytes 0-99/100"),
370            Some((0, 99, Some(100)))
371        );
372
373        assert_eq!(
374            parse_content_range("bytes 100-199/1000"),
375            Some((100, 199, Some(1000)))
376        );
377
378        assert_eq!(parse_content_range("bytes 0-99/*"), Some((0, 99, None)));
379
380        assert_eq!(parse_content_range("invalid"), None);
381        assert_eq!(parse_content_range("bytes invalid"), None);
382    }
383
384    #[test]
385    fn test_validate_resumed_position() {
386        // Valid cases
387        assert!(validate_resumed_position(0, "bytes 0-99/100").is_ok());
388        assert!(validate_resumed_position(100, "bytes 100-199/1000").is_ok());
389
390        // Invalid cases
391        assert!(validate_resumed_position(50, "bytes 0-99/100").is_err());
392        assert!(validate_resumed_position(0, "invalid header").is_err());
393    }
394
395    #[test]
396    fn test_validate_ranged_response() {
397        assert!(validate_ranged_response(
398            100,
399            Some(199),
400            StatusCode::PARTIAL_CONTENT,
401            Some("bytes 100-199/1000"),
402            RangedResponseContext::default(),
403        )
404        .is_ok());
405
406        assert!(
407            validate_ranged_response(
408                100,
409                None,
410                StatusCode::OK,
411                None,
412                RangedResponseContext::default(),
413            )
414            .is_err(),
415            "200 OK must be rejected for ranged requests"
416        );
417
418        assert!(
419            validate_ranged_response(
420                100,
421                Some(200),
422                StatusCode::PARTIAL_CONTENT,
423                None,
424                RangedResponseContext::default(),
425            )
426            .is_err(),
427            "Missing Content-Range must be rejected"
428        );
429
430        assert!(
431            validate_ranged_response(
432                100,
433                Some(200),
434                StatusCode::PARTIAL_CONTENT,
435                Some("bytes 100-199/1000"),
436                RangedResponseContext::default(),
437            )
438            .is_err(),
439            "Mismatched end offset must be rejected"
440        );
441
442        let err = validate_ranged_response(
443            100,
444            None,
445            StatusCode::OK,
446            None,
447            RangedResponseContext {
448                sent_if_range: true,
449                expected_etag: Some("\"old\""),
450                response_etag: Some("\"new\""),
451                ..RangedResponseContext::default()
452            },
453        )
454        .expect_err("Changed validator must trigger restart classification");
455        assert!(
456            matches!(
457                err,
458                EngineError::Protocol {
459                    kind: ProtocolErrorKind::RangeNotSupported,
460                    ..
461                }
462            ),
463            "Expected restart-worthy RangeNotSupported error, got {err:?}"
464        );
465    }
466}