1use crate::error::{EngineError, NetworkErrorKind, ProtocolErrorKind, Result};
7use reqwest::{Client, StatusCode};
8use std::path::Path;
9use tokio::fs;
10
11use super::ACCEPT_ENCODING_IDENTITY;
12
13#[derive(Debug, Clone, Copy, Default)]
14pub struct RangedResponseContext<'a> {
15 pub sent_if_range: bool,
16 pub expected_etag: Option<&'a str>,
17 pub expected_last_modified: Option<&'a str>,
18 pub response_etag: Option<&'a str>,
19 pub response_last_modified: Option<&'a str>,
20}
21
22#[derive(Debug, Clone)]
24pub struct ResumeInfo {
25 pub supports_range: bool,
27 pub etag: Option<String>,
29 pub last_modified: Option<String>,
31 pub content_length: Option<u64>,
33 pub can_resume: bool,
35 pub existing_size: u64,
37}
38
39pub async fn check_resume(
41 client: &Client,
42 url: &str,
43 user_agent: &str,
44 part_path: &Path,
45 saved_etag: Option<&str>,
46 saved_last_modified: Option<&str>,
47) -> Result<ResumeInfo> {
48 let existing_size = if part_path.exists() {
50 fs::metadata(part_path).await.map(|m| m.len()).unwrap_or(0)
51 } else {
52 0
53 };
54
55 let response = client
57 .head(url)
58 .header("User-Agent", user_agent)
59 .header("Accept-Encoding", ACCEPT_ENCODING_IDENTITY)
60 .send()
61 .await
62 .map_err(|e| {
63 EngineError::protocol(
64 ProtocolErrorKind::InvalidResponse,
65 format!("HEAD request failed: {}", e),
66 )
67 })?;
68
69 if !response.status().is_success() {
70 return Err(EngineError::protocol(
71 ProtocolErrorKind::InvalidResponse,
72 format!("HEAD request returned: {}", response.status()),
73 ));
74 }
75
76 let headers = response.headers();
77
78 let supports_range = headers
80 .get("accept-ranges")
81 .and_then(|v| v.to_str().ok())
82 .map(|v| v.contains("bytes"))
83 .unwrap_or(false);
84
85 let etag = headers
87 .get("etag")
88 .and_then(|v| v.to_str().ok())
89 .map(|s| s.to_string());
90
91 let last_modified = headers
93 .get("last-modified")
94 .and_then(|v| v.to_str().ok())
95 .map(|s| s.to_string());
96
97 let content_length = headers
99 .get("content-length")
100 .and_then(|v| v.to_str().ok())
101 .and_then(|s| s.parse::<u64>().ok());
102
103 let can_resume = if existing_size == 0 {
105 false
107 } else if !supports_range {
108 false
110 } else {
111 let etag_valid = match (saved_etag, &etag) {
113 (Some(saved), Some(current)) => saved == current,
114 (Some(_), None) => false, (None, _) => true, };
117
118 let last_modified_valid = match (saved_last_modified, &last_modified) {
119 (Some(saved), Some(current)) => saved == current,
120 (Some(_), None) => false,
121 (None, _) => true,
122 };
123
124 etag_valid && last_modified_valid
126 };
127
128 Ok(ResumeInfo {
129 supports_range,
130 etag,
131 last_modified,
132 content_length,
133 can_resume,
134 existing_size,
135 })
136}
137
138pub async fn verify_range_support(client: &Client, url: &str, user_agent: &str) -> Result<bool> {
140 let response = client
142 .get(url)
143 .header("User-Agent", user_agent)
144 .header("Accept-Encoding", ACCEPT_ENCODING_IDENTITY)
145 .header("Range", "bytes=0-0")
146 .send()
147 .await
148 .map_err(|e| {
149 EngineError::protocol(
150 ProtocolErrorKind::InvalidResponse,
151 format!("Range request failed: {}", e),
152 )
153 })?;
154
155 Ok(response.status() == reqwest::StatusCode::PARTIAL_CONTENT)
157}
158
159pub fn validate_ranged_response(
161 expected_start: u64,
162 expected_end: Option<u64>,
163 status: StatusCode,
164 content_range: Option<&str>,
165 context: RangedResponseContext<'_>,
166) -> Result<()> {
167 let restart_required = |message: String| {
168 EngineError::protocol(
169 ProtocolErrorKind::RangeNotSupported,
170 format!("{message}. Restart from byte 0 required"),
171 )
172 };
173
174 if status != StatusCode::PARTIAL_CONTENT {
175 if status == StatusCode::OK {
176 if let (Some(expected), Some(actual)) = (context.expected_etag, context.response_etag) {
177 if expected != actual {
178 return Err(restart_required(format!(
179 "Server returned 200 OK to a ranged request after ETag changed from {} to {}",
180 expected, actual
181 )));
182 }
183 }
184
185 if let (Some(expected), Some(actual)) = (
186 context.expected_last_modified,
187 context.response_last_modified,
188 ) {
189 if expected != actual {
190 return Err(restart_required(format!(
191 "Server returned 200 OK to a ranged request after Last-Modified changed from {} to {}",
192 expected, actual
193 )));
194 }
195 }
196
197 if context.sent_if_range {
198 return Err(restart_required(
199 "Server returned 200 OK to a ranged request after If-Range validation; the remote file may have changed or the server ignored Range".to_string(),
200 ));
201 }
202 }
203
204 return Err(EngineError::protocol(
205 ProtocolErrorKind::RangeNotSupported,
206 format!(
207 "Server ignored Range request starting at byte {} and returned {}. Restart from byte 0 required",
208 expected_start, status
209 ),
210 ));
211 }
212
213 let content_range = content_range.ok_or_else(|| {
214 restart_required("Missing Content-Range header on ranged response".to_string())
215 })?;
216
217 if let Err(err) = validate_resumed_position(expected_start, content_range) {
218 return Err(restart_required(format!(
219 "Server returned mismatched Content-Range for ranged request starting at byte {}: {}",
220 expected_start, err
221 )));
222 }
223
224 if let Some(expected_end) = expected_end {
225 let (_, actual_end, _) = parse_content_range(content_range).ok_or_else(|| {
226 restart_required(format!("Invalid Content-Range header: {}", content_range))
227 })?;
228
229 if actual_end != expected_end {
230 return Err(EngineError::protocol(
231 ProtocolErrorKind::RangeNotSupported,
232 format!(
233 "Range end mismatch: expected {}, got {}. Restart from byte 0 required",
234 expected_end, actual_end
235 ),
236 ));
237 }
238 }
239
240 Ok(())
241}
242
243pub fn should_restart_without_ranges(err: &EngineError) -> bool {
244 match err {
245 EngineError::Protocol {
246 kind: ProtocolErrorKind::RangeNotSupported,
247 ..
248 } => true,
249 EngineError::Network {
250 kind: NetworkErrorKind::HttpStatus(416),
251 ..
252 } => true,
253 _ => false,
254 }
255}
256
257pub fn calculate_range_header(start: u64, end: Option<u64>) -> String {
259 match end {
260 Some(end) => format!("bytes={}-{}", start, end),
261 None => format!("bytes={}-", start),
262 }
263}
264
265pub fn parse_content_range(header: &str) -> Option<(u64, u64, Option<u64>)> {
269 let header = header.strip_prefix("bytes ")?;
270 let parts: Vec<&str> = header.split('/').collect();
271 if parts.len() != 2 {
272 return None;
273 }
274
275 let range_parts: Vec<&str> = parts[0].split('-').collect();
276 if range_parts.len() != 2 {
277 return None;
278 }
279
280 let start = range_parts[0].parse::<u64>().ok()?;
281 let end = range_parts[1].parse::<u64>().ok()?;
282 let total = if parts[1] == "*" {
283 None
284 } else {
285 parts[1].parse::<u64>().ok()
286 };
287
288 Some((start, end, total))
289}
290
291pub fn validate_resumed_position(expected_start: u64, content_range: &str) -> Result<()> {
293 let (actual_start, _, _) = parse_content_range(content_range).ok_or_else(|| {
294 EngineError::protocol(
295 ProtocolErrorKind::InvalidResponse,
296 format!("Invalid Content-Range header: {}", content_range),
297 )
298 })?;
299
300 if actual_start != expected_start {
301 return Err(EngineError::protocol(
302 ProtocolErrorKind::InvalidResponse,
303 format!(
304 "Resume position mismatch: expected {}, got {}",
305 expected_start, actual_start
306 ),
307 ));
308 }
309
310 Ok(())
311}
312
313pub async fn should_restart(
315 part_path: &Path,
316 expected_size: Option<u64>,
317 saved_etag: Option<&str>,
318 current_etag: Option<&str>,
319) -> bool {
320 if !part_path.exists() {
322 return false;
323 }
324
325 if let (Some(saved), Some(current)) = (saved_etag, current_etag) {
327 if saved != current {
328 return true;
329 }
330 }
331
332 if let Some(expected) = expected_size {
334 if let Ok(metadata) = fs::metadata(part_path).await {
335 if metadata.len() > expected {
336 return true;
337 }
338 }
339 }
340
341 false
342}
343
344pub async fn cleanup_partial(part_path: &Path) -> Result<()> {
346 if part_path.exists() {
347 fs::remove_file(part_path)
348 .await
349 .map_err(|e| EngineError::Internal(format!("Failed to remove partial file: {}", e)))?;
350 }
351 Ok(())
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_calculate_range_header() {
360 assert_eq!(calculate_range_header(0, None), "bytes=0-");
361 assert_eq!(calculate_range_header(100, None), "bytes=100-");
362 assert_eq!(calculate_range_header(0, Some(99)), "bytes=0-99");
363 assert_eq!(calculate_range_header(1000, Some(1999)), "bytes=1000-1999");
364 }
365
366 #[test]
367 fn test_parse_content_range() {
368 assert_eq!(
369 parse_content_range("bytes 0-99/100"),
370 Some((0, 99, Some(100)))
371 );
372
373 assert_eq!(
374 parse_content_range("bytes 100-199/1000"),
375 Some((100, 199, Some(1000)))
376 );
377
378 assert_eq!(parse_content_range("bytes 0-99/*"), Some((0, 99, None)));
379
380 assert_eq!(parse_content_range("invalid"), None);
381 assert_eq!(parse_content_range("bytes invalid"), None);
382 }
383
384 #[test]
385 fn test_validate_resumed_position() {
386 assert!(validate_resumed_position(0, "bytes 0-99/100").is_ok());
388 assert!(validate_resumed_position(100, "bytes 100-199/1000").is_ok());
389
390 assert!(validate_resumed_position(50, "bytes 0-99/100").is_err());
392 assert!(validate_resumed_position(0, "invalid header").is_err());
393 }
394
395 #[test]
396 fn test_validate_ranged_response() {
397 assert!(validate_ranged_response(
398 100,
399 Some(199),
400 StatusCode::PARTIAL_CONTENT,
401 Some("bytes 100-199/1000"),
402 RangedResponseContext::default(),
403 )
404 .is_ok());
405
406 assert!(
407 validate_ranged_response(
408 100,
409 None,
410 StatusCode::OK,
411 None,
412 RangedResponseContext::default(),
413 )
414 .is_err(),
415 "200 OK must be rejected for ranged requests"
416 );
417
418 assert!(
419 validate_ranged_response(
420 100,
421 Some(200),
422 StatusCode::PARTIAL_CONTENT,
423 None,
424 RangedResponseContext::default(),
425 )
426 .is_err(),
427 "Missing Content-Range must be rejected"
428 );
429
430 assert!(
431 validate_ranged_response(
432 100,
433 Some(200),
434 StatusCode::PARTIAL_CONTENT,
435 Some("bytes 100-199/1000"),
436 RangedResponseContext::default(),
437 )
438 .is_err(),
439 "Mismatched end offset must be rejected"
440 );
441
442 let err = validate_ranged_response(
443 100,
444 None,
445 StatusCode::OK,
446 None,
447 RangedResponseContext {
448 sent_if_range: true,
449 expected_etag: Some("\"old\""),
450 response_etag: Some("\"new\""),
451 ..RangedResponseContext::default()
452 },
453 )
454 .expect_err("Changed validator must trigger restart classification");
455 assert!(
456 matches!(
457 err,
458 EngineError::Protocol {
459 kind: ProtocolErrorKind::RangeNotSupported,
460 ..
461 }
462 ),
463 "Expected restart-worthy RangeNotSupported error, got {err:?}"
464 );
465 }
466}