1mod error;
20mod sparse_range;
21
22use futures::{FutureExt, Stream, StreamExt};
23use http_content_range::{ContentRange, ContentRangeBytes};
24use memmap2::MmapMut;
25use reqwest::header::HeaderMap;
26use reqwest::{Response, Url};
27use sparse_range::SparseRange;
28use std::{
29 io::{self, SeekFrom},
30 ops::Range,
31 pin::Pin,
32 sync::Arc,
33 task::{ready, Context, Poll},
34};
35use tokio::{
36 io::{AsyncRead, AsyncSeek, ReadBuf},
37 sync::watch::Sender,
38 sync::{watch, Mutex},
39};
40use tokio_stream::wrappers::WatchStream;
41use tokio_util::sync::PollSender;
42use tracing::{info_span, Instrument};
43
44pub use error::AsyncHttpRangeReaderError;
45
46#[derive(Debug)]
81pub struct AsyncHttpRangeReader {
82 inner: Mutex<Inner>,
83 len: u64,
84}
85
86#[derive(Default, Clone, Debug)]
87struct StreamerState {
88 resident_range: SparseRange,
89 requested_ranges: Vec<Range<u64>>,
90 error: Option<AsyncHttpRangeReaderError>,
91}
92
93#[derive(Debug)]
94struct Inner {
95 data: &'static [u8],
98
99 pos: u64,
101
102 requested_range: SparseRange,
104
105 streamer_state: StreamerState,
107
108 streamer_state_rx: WatchStream<StreamerState>,
111
112 request_tx: tokio::sync::mpsc::Sender<Range<u64>>,
116
117 poll_request_tx: Option<PollSender<Range<u64>>>,
120}
121
122pub enum CheckSupportMethod {
126 NegativeRangeRequest(u64),
130
131 Head,
134}
135
136fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
137 response
138 .error_for_status()
139 .map_err(reqwest_middleware::Error::Reqwest)
140}
141
142impl AsyncHttpRangeReader {
143 pub async fn new(
145 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
146 url: Url,
147 check_method: CheckSupportMethod,
148 extra_headers: HeaderMap,
149 ) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
150 let client = client.into();
151 match check_method {
152 CheckSupportMethod::NegativeRangeRequest(initial_chunk_size) => {
153 let response = Self::initial_tail_request(
154 client.clone(),
155 url.clone(),
156 initial_chunk_size,
157 HeaderMap::default(),
158 )
159 .await?;
160 let response_headers = response.headers().clone();
161 let self_ = Self::from_range_response(client, response, url, extra_headers).await?;
162 Ok((self_, response_headers))
163 }
164 CheckSupportMethod::Head => {
165 let response =
166 Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default())
167 .await?;
168 let response_headers = response.headers().clone();
169 let self_ = Self::from_head_response(client, response, url, extra_headers).await?;
170 Ok((self_, response_headers))
171 }
172 }
173 }
174
175 pub async fn initial_tail_request(
179 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
180 url: reqwest::Url,
181 initial_chunk_size: u64,
182 extra_headers: HeaderMap,
183 ) -> Result<Response, AsyncHttpRangeReaderError> {
184 let client = client.into();
185 let tail_response = client
186 .get(url)
187 .header(
188 reqwest::header::RANGE,
189 format!("bytes=-{initial_chunk_size}"),
190 )
191 .headers(extra_headers)
192 .send()
193 .await
194 .and_then(error_for_status)
195 .map_err(Arc::new)
196 .map_err(AsyncHttpRangeReaderError::HttpError)?;
197 Ok(tail_response)
198 }
199
200 #[deprecated(note = "use `from_range_response` instead")]
201 pub async fn from_tail_response(
202 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
203 tail_request_response: Response,
204 url: Url,
205 extra_headers: HeaderMap,
206 ) -> Result<Self, AsyncHttpRangeReaderError> {
207 Self::from_range_response(client, tail_request_response, url, extra_headers).await
208 }
209
210 pub async fn from_range_response(
213 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
214 response: Response,
215 url: Url,
216 extra_headers: HeaderMap,
217 ) -> Result<Self, AsyncHttpRangeReaderError> {
218 let client = client.into();
219
220 let content_range_header = response
222 .headers()
223 .get(reqwest::header::CONTENT_RANGE)
224 .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)?
225 .to_str()
226 .map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?;
227 let content_range = ContentRange::parse(content_range_header).ok_or_else(|| {
229 AsyncHttpRangeReaderError::ContentRangeParser(content_range_header.to_string())
230 })?;
231 let (start, end_inclusive, complete_length) = match content_range {
232 ContentRange::Bytes(ContentRangeBytes {
233 first_byte,
234 last_byte,
235 complete_length,
236 }) => (first_byte, last_byte, complete_length),
237 _ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported),
238 };
239
240 let memory_map = memmap2::MmapOptions::new()
242 .len(complete_length as usize)
243 .map_anon()
244 .map_err(Arc::new)
245 .map_err(AsyncHttpRangeReaderError::MemoryMapError)?;
246
247 let memory_map_slice =
250 unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) };
251
252 let requested_range = SparseRange::from_range(start..end_inclusive + 1);
253
254 let (request_tx, request_rx) = tokio::sync::mpsc::channel(10);
259 let (state_tx, state_rx) = watch::channel(StreamerState::default());
260 tokio::spawn(run_streamer(
261 client,
262 url,
263 extra_headers,
264 Some((response, start, end_inclusive + 1)),
265 memory_map,
266 state_tx,
267 request_rx,
268 ));
269
270 let mut streamer_state = StreamerState::default();
272 streamer_state
273 .requested_ranges
274 .push(start..end_inclusive + 1);
275
276 let reader = Self {
277 len: memory_map_slice.len() as u64,
278 inner: Mutex::new(Inner {
279 data: memory_map_slice,
280 pos: 0,
281 requested_range,
282 streamer_state,
283 streamer_state_rx: WatchStream::new(state_rx),
284 request_tx,
285 poll_request_tx: None,
286 }),
287 };
288 Ok(reader)
289 }
290
291 pub async fn initial_head_request(
294 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
295 url: reqwest::Url,
296 extra_headers: HeaderMap,
297 ) -> Result<Response, AsyncHttpRangeReaderError> {
298 let client = client.into();
299
300 let head_response = client
302 .head(url.clone())
303 .headers(extra_headers)
304 .send()
305 .await
306 .and_then(error_for_status)
307 .map_err(Arc::new)
308 .map_err(AsyncHttpRangeReaderError::HttpError)?;
309 Ok(head_response)
310 }
311
312 pub async fn from_head_response(
315 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
316 head_response: Response,
317 url: Url,
318 extra_headers: HeaderMap,
319 ) -> Result<Self, AsyncHttpRangeReaderError> {
320 let client = client.into();
321
322 if head_response
324 .headers()
325 .get(reqwest::header::ACCEPT_RANGES)
326 .and_then(|h| h.to_str().ok())
327 != Some("bytes")
328 {
329 return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported);
330 }
331
332 let content_length: u64 = head_response
333 .headers()
334 .get(reqwest::header::CONTENT_LENGTH)
335 .ok_or(AsyncHttpRangeReaderError::ContentLengthMissing)?
336 .to_str()
337 .map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?
338 .parse()
339 .map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?;
340
341 let memory_map = memmap2::MmapOptions::new()
343 .len(content_length as _)
344 .map_anon()
345 .map_err(Arc::new)
346 .map_err(AsyncHttpRangeReaderError::MemoryMapError)?;
347
348 let memory_map_slice =
351 unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) };
352
353 let requested_range = SparseRange::default();
354
355 let (request_tx, request_rx) = tokio::sync::mpsc::channel(10);
360 let (state_tx, state_rx) = watch::channel(StreamerState::default());
361 tokio::spawn(run_streamer(
362 client,
363 url,
364 extra_headers,
365 None,
366 memory_map,
367 state_tx,
368 request_rx,
369 ));
370
371 let streamer_state = StreamerState::default();
373
374 let reader = Self {
375 len: memory_map_slice.len() as u64,
376 inner: Mutex::new(Inner {
377 data: memory_map_slice,
378 pos: 0,
379 requested_range,
380 streamer_state,
381 streamer_state_rx: WatchStream::new(state_rx),
382 request_tx,
383 poll_request_tx: None,
384 }),
385 };
386 Ok(reader)
387 }
388
389 pub async fn requested_ranges(&self) -> Vec<Range<u64>> {
391 let mut inner = self.inner.lock().await;
392 if let Some(Some(new_state)) = inner.streamer_state_rx.next().now_or_never() {
393 inner.streamer_state = new_state;
394 }
395 inner.streamer_state.requested_ranges.clone()
396 }
397
398 pub async fn prefetch(&mut self, bytes: Range<u64>) {
401 let inner = self.inner.get_mut();
402
403 let range = bytes.start..(bytes.end.min(inner.data.len() as u64));
405 if range.start >= range.end {
406 return;
407 }
408
409 let inner = self.inner.get_mut();
411 if let Some((new_range, _)) = inner.requested_range.cover(range.clone()) {
412 let _ = inner.request_tx.send(range).await;
413 inner.requested_range = new_range;
414 }
415 }
416
417 #[allow(clippy::len_without_is_empty)]
419 pub fn len(&self) -> u64 {
420 self.len
421 }
422}
423
424#[tracing::instrument(name = "fetch_ranges", skip_all, fields(url))]
427async fn run_streamer(
428 client: reqwest_middleware::ClientWithMiddleware,
429 url: Url,
430 extra_headers: HeaderMap,
431 response: Option<(Response, u64, u64)>,
432 mut memory_map: MmapMut,
433 mut state_tx: Sender<StreamerState>,
434 mut request_rx: tokio::sync::mpsc::Receiver<Range<u64>>,
435) {
436 let mut state = StreamerState::default();
437
438 if let Some((response, start, end_exclusive)) = response {
439 state.requested_ranges.push(start..end_exclusive);
441
442 if !stream_response(
444 response,
445 start,
446 end_exclusive,
447 &mut memory_map,
448 &mut state_tx,
449 &mut state,
450 )
451 .await
452 {
453 return;
454 }
455 }
456
457 'outer: loop {
459 let range = match request_rx.recv().await {
460 Some(range) => range,
461 None => {
462 break 'outer;
463 }
464 };
465
466 let uncovered_ranges = match state.resident_range.cover(range) {
468 None => continue,
469 Some((_, uncovered_ranges)) => uncovered_ranges,
470 };
471
472 for range in uncovered_ranges {
474 state
476 .requested_ranges
477 .push(*range.start()..*range.end() + 1);
478
479 let range_string = format!("bytes={}-{}", range.start(), range.end());
481 let span = info_span!("fetch_range", range = range_string.as_str());
482 let response = match client
483 .get(url.clone())
484 .header(reqwest::header::RANGE, range_string)
485 .headers(extra_headers.clone())
486 .send()
487 .instrument(span)
488 .await
489 .and_then(error_for_status)
490 .map_err(std::io::Error::other)
491 {
492 Err(e) => {
493 state.error = Some(e.into());
494 let _ = state_tx.send(state);
495 break 'outer;
496 }
497 Ok(response) => response,
498 };
499
500 if let Err(err) =
501 validate_content_range(&response, *range.start(), *range.end(), memory_map.len())
502 {
503 state.error = Some(err);
504 let _ = state_tx.send(state);
505 break 'outer;
506 }
507
508 if response.status() != reqwest::StatusCode::PARTIAL_CONTENT {
511 state.error = Some(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported);
512 let _ = state_tx.send(state);
513 break 'outer;
514 }
515
516 if !stream_response(
517 response,
518 *range.start(),
519 *range.end() + 1,
520 &mut memory_map,
521 &mut state_tx,
522 &mut state,
523 )
524 .await
525 {
526 break 'outer;
527 }
528 }
529 }
530}
531
532fn validate_content_range(
534 response: &Response,
535 expected_start: u64,
536 expected_end_inclusive: u64,
537 expected_complete_length: usize,
538) -> Result<(), AsyncHttpRangeReaderError> {
539 let content_range_header = response
540 .headers()
541 .get(reqwest::header::CONTENT_RANGE)
542 .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)?
543 .to_str()
544 .map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?;
545 let content_range = ContentRange::parse(content_range_header).ok_or_else(|| {
546 AsyncHttpRangeReaderError::ContentRangeParser(content_range_header.to_string())
547 })?;
548 let (actual_start, actual_end_inclusive, actual_complete_length) = match content_range {
549 ContentRange::Bytes(ContentRangeBytes {
550 first_byte,
551 last_byte,
552 complete_length,
553 }) => (first_byte, last_byte, complete_length),
554 _ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported),
555 };
556 if expected_start != actual_start
557 || expected_end_inclusive != actual_end_inclusive
558 || expected_complete_length as u64 != actual_complete_length
559 {
560 return Err(AsyncHttpRangeReaderError::RangeMismatch {
561 expected_start,
562 expected_end_inclusive,
563 expected_complete_length,
564 actual_start,
565 actual_end_inclusive,
566 actual_complete_length,
567 });
568 }
569
570 Ok(())
571}
572
573async fn stream_response(
579 tail_request_response: Response,
580 start: u64,
581 end_exclusive: u64,
582 memory_map: &mut MmapMut,
583 state_tx: &mut Sender<StreamerState>,
584 state: &mut StreamerState,
585) -> bool {
586 assert!(
588 (end_exclusive as usize) <= memory_map.len(),
589 "end is outside of memory map {} > {}",
590 end_exclusive,
591 memory_map.len()
592 );
593
594 let mut offset = start;
595 let mut byte_stream = tail_request_response.bytes_stream();
596 while let Some(bytes) = byte_stream.next().await {
597 let bytes = match bytes {
598 Err(e) => {
599 state.error = Some(e.into());
600 let _ = state_tx.send(state.clone());
601 return false;
602 }
603 Ok(bytes) => bytes,
604 };
605
606 let byte_range = offset..offset + bytes.len() as u64;
608
609 offset += bytes.len() as u64;
611
612 if offset > end_exclusive {
614 state.error = Some(AsyncHttpRangeReaderError::ResponseTooLong {
615 expected: end_exclusive - start,
616 });
617 let _ = state_tx.send(state.clone());
618 return false;
619 }
620
621 memory_map[byte_range.start as usize..byte_range.end as usize]
623 .copy_from_slice(bytes.as_ref());
624
625 state.resident_range.update(byte_range);
627
628 if state_tx.send(state.clone()).is_err() {
630 return false;
633 }
634 }
635
636 if offset != end_exclusive {
638 state.error = Some(AsyncHttpRangeReaderError::ResponseTooShort {
639 expected: end_exclusive - start,
640 actual: offset - start,
641 });
642 let _ = state_tx.send(state.clone());
643 return false;
644 }
645
646 true
647}
648
649impl AsyncSeek for AsyncHttpRangeReader {
650 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
651 let me = self.get_mut();
652 let inner = me.inner.get_mut();
653
654 inner.pos = match position {
655 SeekFrom::Start(pos) => pos,
656 SeekFrom::End(relative) => (inner.data.len() as i64).saturating_add(relative) as u64,
657 SeekFrom::Current(relative) => (inner.pos as i64).saturating_add(relative) as u64,
658 };
659
660 Ok(())
661 }
662
663 fn poll_complete(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
664 let inner = self.inner.get_mut();
665 Poll::Ready(Ok(inner.pos))
666 }
667}
668
669impl AsyncRead for AsyncHttpRangeReader {
670 fn poll_read(
671 self: Pin<&mut Self>,
672 cx: &mut Context<'_>,
673 buf: &mut ReadBuf<'_>,
674 ) -> Poll<io::Result<()>> {
675 let me = self.get_mut();
676 let inner = me.inner.get_mut();
677
678 if let Some(e) = inner.streamer_state.error.as_ref() {
680 return Poll::Ready(Err(io::Error::other(e.clone())));
681 }
682
683 let range = inner.pos..(inner.pos + buf.remaining() as u64).min(inner.data.len() as u64);
685 if range.start >= range.end {
686 return Poll::Ready(Ok(()));
687 }
688
689 while !inner.requested_range.is_covered(range.clone()) {
691 if let Some(mut poll) = inner.poll_request_tx.take() {
693 match poll.poll_reserve(cx) {
694 Poll::Ready(_) => {
695 let _ = poll.send_item(range.clone());
696 inner.requested_range.update(range.clone());
697 break;
698 }
699 Poll::Pending => {
700 inner.poll_request_tx = Some(poll);
701 return Poll::Pending;
702 }
703 }
704 }
705
706 inner.poll_request_tx = Some(PollSender::new(inner.request_tx.clone()));
708 }
709
710 if let Some(mut poll) = inner.poll_request_tx.take() {
712 poll.abort_send();
713 }
714
715 loop {
716 if inner
718 .streamer_state
719 .resident_range
720 .is_covered(range.clone())
721 {
722 let len = (range.end - range.start) as usize;
723 buf.initialize_unfilled_to(len)
724 .copy_from_slice(&inner.data[range.start as usize..range.end as usize]);
725 buf.advance(len);
726 inner.pos += len as u64;
727 return Poll::Ready(Ok(()));
728 }
729
730 match ready!(Pin::new(&mut inner.streamer_state_rx).poll_next(cx)) {
732 None => unreachable!(),
733 Some(state) => {
734 inner.streamer_state = state;
735 if let Some(e) = inner.streamer_state.error.as_ref() {
736 return Poll::Ready(Err(io::Error::other(e.clone())));
737 }
738 }
739 }
740 }
741 }
742}
743
744#[cfg(test)]
745mod static_directory_server;
746
747#[cfg(test)]
748mod test {
749 use super::*;
750 use crate::static_directory_server::StaticDirectoryServer;
751 use assert_matches::assert_matches;
752 use async_zip::tokio::read::seek::ZipFileReader;
753 use axum::body::Body;
754 use axum::extract::Request;
755 use axum::response::IntoResponse;
756 use futures::AsyncReadExt;
757 use reqwest::header;
758 use reqwest::Method;
759 use reqwest::{Client, StatusCode};
760 use rstest::*;
761 use std::path::Path;
762 use tokio::io::AsyncReadExt as _;
763
764 #[rstest]
765 #[case(CheckSupportMethod::Head)]
766 #[case(CheckSupportMethod::NegativeRangeRequest(8192))]
767 #[tokio::test]
768 async fn async_range_reader_zip(#[case] check_method: CheckSupportMethod) {
769 let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
771 let server = StaticDirectoryServer::new(&path)
772 .await
773 .expect("could not initialize server");
774
775 let filepath = path.join("andes-1.8.3-pyhd8ed1ab_0.conda");
777 assert!(
778 filepath.exists(),
779 "The conda package is not there yet. Did you run `git lfs pull`?"
780 );
781 let file_size = std::fs::metadata(&filepath).unwrap().len();
782 assert_eq!(
783 file_size, 2_463_995,
784 "The conda package is not there yet. Did you run `git lfs pull`?"
785 );
786
787 let (mut range, _) = AsyncHttpRangeReader::new(
789 Client::new(),
790 server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
791 check_method,
792 HeaderMap::default(),
793 )
794 .await
795 .expect("Could not download range - did you run `git lfs pull`?");
796
797 range.prefetch(range.len() - 8192..range.len()).await;
799
800 assert_eq!(range.len(), file_size);
801
802 let mut reader = ZipFileReader::with_tokio(tokio::io::BufReader::with_capacity(0, range))
803 .await
804 .unwrap();
805
806 assert_eq!(
807 reader
808 .file()
809 .entries()
810 .iter()
811 .map(|e| e.filename().as_str().unwrap_or(""))
812 .collect::<Vec<_>>(),
813 vec![
814 "metadata.json",
815 "info-andes-1.8.3-pyhd8ed1ab_0.tar.zst",
816 "pkg-andes-1.8.3-pyhd8ed1ab_0.tar.zst",
817 ]
818 );
819
820 let request_ranges = reader
822 .inner_mut()
823 .get_mut()
824 .get_mut()
825 .requested_ranges()
826 .await;
827 assert_eq!(request_ranges.len(), 1);
828 assert_eq!(
829 request_ranges[0].end - request_ranges[0].start,
830 8192,
831 "first request should be the size of the initial chunk size"
832 );
833 assert_eq!(
834 request_ranges[0].end, file_size,
835 "first request should be at the end"
836 );
837
838 let entry = reader.file().entries().first().unwrap();
840 let offset = entry.header_offset();
841 let size = entry.compressed_size() + 30 + entry.filename().as_bytes().len() as u64;
844
845 let buffer_size = 8192;
848 let size = size.div_ceil(buffer_size) * buffer_size;
849
850 reader
852 .inner_mut()
853 .get_mut()
854 .get_mut()
855 .prefetch(offset..offset + size as u64)
856 .await;
857
858 let mut contents = String::new();
860 reader
861 .reader_with_entry(0)
862 .await
863 .unwrap()
864 .read_to_string(&mut contents)
865 .await
866 .unwrap();
867
868 let request_ranges = reader
870 .inner_mut()
871 .get_mut()
872 .get_mut()
873 .requested_ranges()
874 .await;
875
876 assert_eq!(contents, r#"{"conda_pkg_format_version": 2}"#);
877 assert_eq!(request_ranges.len(), 2);
878 assert_eq!(
879 request_ranges[1],
880 0..size,
881 "expected only two range requests"
882 );
883 }
884
885 #[rstest]
886 #[case(CheckSupportMethod::Head)]
887 #[case(CheckSupportMethod::NegativeRangeRequest(8192))]
888 #[tokio::test]
889 async fn async_range_reader(#[case] check_method: CheckSupportMethod) {
890 let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
892 let server = StaticDirectoryServer::new(&path)
893 .await
894 .expect("could not initialize server");
895
896 let (mut range, _) = AsyncHttpRangeReader::new(
898 Client::new(),
899 server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
900 check_method,
901 HeaderMap::default(),
902 )
903 .await
904 .expect("bla");
905
906 let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda"))
908 .await
909 .unwrap();
910
911 let mut range_read = vec![0; 64 * 1024];
913 let mut file_read = vec![0; 64 * 1024];
914 loop {
915 let range_read_bytes = range.read(&mut range_read).await.unwrap();
917
918 let file_read_bytes = file
920 .read_exact(&mut file_read[0..range_read_bytes])
921 .await
922 .unwrap();
923
924 assert_eq!(range_read_bytes, file_read_bytes);
925 assert_eq!(
926 range_read[0..range_read_bytes],
927 file_read[0..file_read_bytes]
928 );
929
930 if file_read_bytes == 0 && range_read_bytes == 0 {
931 break;
932 }
933 }
934 }
935
936 #[tokio::test]
937 async fn test_not_found() {
938 let server = StaticDirectoryServer::new(Path::new(env!("CARGO_MANIFEST_DIR")))
939 .await
940 .expect("could not initialize server");
941 let err = AsyncHttpRangeReader::new(
942 Client::new(),
943 server.url().join("not-found").unwrap(),
944 CheckSupportMethod::Head,
945 HeaderMap::default(),
946 )
947 .await
948 .expect_err("expected an error");
949
950 assert_matches!(
951 err, AsyncHttpRangeReaderError::HttpError(err) if err.status() == Some(StatusCode::NOT_FOUND)
952 );
953 }
954
955 async fn spawn_mismatch_server(
958 head_content_length: usize,
959 pretend_size: usize,
960 actual_size: usize,
961 ) -> Url {
962 let app =
963 axum::Router::new().fallback(async move |request: Request| match *request.method() {
964 Method::HEAD => {
965 let headers = [
966 (header::CONTENT_LENGTH, head_content_length.to_string()),
967 (header::ACCEPT_RANGES, "bytes".to_string()),
968 ];
969 (StatusCode::OK, headers).into_response()
970 }
971 Method::GET => {
972 let range_header = request
973 .headers()
974 .get(header::RANGE)
975 .unwrap()
976 .to_str()
977 .unwrap()
978 .to_string();
979
980 let range_spec = range_header.strip_prefix("bytes=").unwrap();
981 let (start_str, _end_str) = range_spec.split_once('-').unwrap();
982 let start = start_str.parse::<usize>().unwrap();
983 let end = start + pretend_size - 1;
985
986 axum::response::Response::builder()
987 .status(StatusCode::PARTIAL_CONTENT)
988 .header(
991 header::CONTENT_RANGE,
992 format!("bytes {start}-{end}/{head_content_length}"),
993 )
994 .body(Body::from(vec![1u8; actual_size]))
995 .unwrap()
996 .into_response()
997 }
998 _ => StatusCode::METHOD_NOT_ALLOWED.into_response(),
999 });
1000
1001 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1002 let local_addr = listener.local_addr().unwrap();
1003 tokio::spawn(async move {
1004 axum::serve(listener, app.into_make_service())
1005 .await
1006 .unwrap();
1007 });
1008
1009 Url::parse(&format!("http://localhost:{}/file", local_addr.port())).unwrap()
1010 }
1011
1012 #[tokio::test]
1015 async fn test_content_length_response_beyond_content_length() {
1016 fn into_range_error(err: std::io::Error) -> AsyncHttpRangeReaderError {
1018 err.into_inner()
1019 .unwrap()
1020 .downcast::<AsyncHttpRangeReaderError>()
1021 .map(|e| *e)
1022 .unwrap()
1023 }
1024
1025 let cases: Vec<(usize, usize, usize, Option<AsyncHttpRangeReaderError>)> = vec![
1026 (512, 512, 512, None),
1028 (
1030 512,
1031 512,
1032 1024,
1033 Some(AsyncHttpRangeReaderError::ResponseTooLong { expected: 512 }),
1034 ),
1035 (
1037 512,
1038 1024,
1039 1024,
1040 Some(AsyncHttpRangeReaderError::ContentRangeParser(
1041 "bytes 0-1023/512".to_string(),
1042 )),
1043 ),
1044 (
1046 512,
1047 1024,
1048 512,
1049 Some(AsyncHttpRangeReaderError::ContentRangeParser(
1050 "bytes 0-1023/512".to_string(),
1051 )),
1052 ),
1053 (1024, 512, 512, None),
1055 (
1057 1024,
1058 512,
1059 1024,
1060 Some(AsyncHttpRangeReaderError::ResponseTooLong { expected: 512 }),
1061 ),
1062 (
1064 1024,
1065 1024,
1066 1024,
1067 Some(AsyncHttpRangeReaderError::RangeMismatch {
1068 expected_start: 0,
1069 expected_end_inclusive: 511,
1070 expected_complete_length: 1024,
1071 actual_start: 0,
1072 actual_end_inclusive: 1023,
1073 actual_complete_length: 1024,
1074 }),
1075 ),
1076 (
1078 1024,
1079 1024,
1080 512,
1081 Some(AsyncHttpRangeReaderError::RangeMismatch {
1082 expected_start: 0,
1083 expected_end_inclusive: 511,
1084 expected_complete_length: 1024,
1085 actual_start: 0,
1086 actual_end_inclusive: 1023,
1087 actual_complete_length: 1024,
1088 }),
1089 ),
1090 ];
1091 for (head_content_length, range_header_length, range_actual_length, expected_error) in cases
1092 {
1093 let url = spawn_mismatch_server(
1094 head_content_length,
1095 range_header_length,
1096 range_actual_length,
1097 )
1098 .await;
1099
1100 let (mut reader, _) = AsyncHttpRangeReader::new(
1101 Client::new(),
1102 url,
1103 CheckSupportMethod::Head,
1104 HeaderMap::default(),
1105 )
1106 .await
1107 .unwrap();
1108
1109 assert_eq!(reader.len(), head_content_length as u64);
1110 reader.prefetch(0..512).await;
1111
1112 let mut buf = vec![0u8; 512];
1113 let result = reader.read(&mut buf).await;
1114 let label =
1115 format!("{head_content_length} {range_header_length} {range_actual_length}");
1116 match expected_error {
1117 None => {
1118 assert_matches!(result, Ok(_), "{label}");
1119 }
1120 Some(expected) => {
1121 assert_eq!(
1123 into_range_error(result.unwrap_err()).to_string(),
1124 expected.to_string(),
1125 "{label}"
1126 );
1127 }
1128 }
1129 }
1130 }
1131}