1mod error;
20mod sparse_range;
21
22use futures::{FutureExt, Stream, StreamExt};
23use http_content_range::{ContentRange, ContentRangeBytes};
24use memmap2::MmapMut;
25use reqwest::header::HeaderMap;
26use reqwest::{Response, Url};
27use sparse_range::SparseRange;
28use std::{
29 io::{self, SeekFrom},
30 ops::Range,
31 pin::Pin,
32 sync::Arc,
33 task::{ready, Context, Poll},
34};
35use tokio::{
36 io::{AsyncRead, AsyncSeek, ReadBuf},
37 sync::watch::Sender,
38 sync::{watch, Mutex},
39};
40use tokio_stream::wrappers::WatchStream;
41use tokio_util::sync::PollSender;
42use tracing::{info_span, Instrument};
43
44pub use error::AsyncHttpRangeReaderError;
45
46#[derive(Debug)]
81pub struct AsyncHttpRangeReader {
82 inner: Mutex<Inner>,
83 len: u64,
84}
85
86#[derive(Default, Clone, Debug)]
87struct StreamerState {
88 resident_range: SparseRange,
89 requested_ranges: Vec<Range<u64>>,
90 error: Option<AsyncHttpRangeReaderError>,
91}
92
93#[derive(Debug)]
94struct Inner {
95 data: &'static [u8],
98
99 pos: u64,
101
102 requested_range: SparseRange,
104
105 streamer_state: StreamerState,
107
108 streamer_state_rx: WatchStream<StreamerState>,
111
112 request_tx: tokio::sync::mpsc::Sender<Range<u64>>,
114
115 poll_request_tx: Option<PollSender<Range<u64>>>,
118}
119
120pub enum CheckSupportMethod {
124 NegativeRangeRequest(u64),
128
129 Head,
132}
133
134fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
135 response
136 .error_for_status()
137 .map_err(reqwest_middleware::Error::Reqwest)
138}
139
140impl AsyncHttpRangeReader {
141 pub async fn new(
143 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
144 url: Url,
145 check_method: CheckSupportMethod,
146 extra_headers: HeaderMap,
147 ) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
148 let client = client.into();
149 match check_method {
150 CheckSupportMethod::NegativeRangeRequest(initial_chunk_size) => {
151 let response = Self::initial_tail_request(
152 client.clone(),
153 url.clone(),
154 initial_chunk_size,
155 HeaderMap::default(),
156 )
157 .await?;
158 let response_headers = response.headers().clone();
159 let self_ = Self::from_tail_response(client, response, url, extra_headers).await?;
160 Ok((self_, response_headers))
161 }
162 CheckSupportMethod::Head => {
163 let response =
164 Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default())
165 .await?;
166 let response_headers = response.headers().clone();
167 let self_ = Self::from_head_response(client, response, url, extra_headers).await?;
168 Ok((self_, response_headers))
169 }
170 }
171 }
172
173 pub async fn initial_tail_request(
177 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
178 url: reqwest::Url,
179 initial_chunk_size: u64,
180 extra_headers: HeaderMap,
181 ) -> Result<Response, AsyncHttpRangeReaderError> {
182 let client = client.into();
183 let tail_response = client
184 .get(url)
185 .header(
186 reqwest::header::RANGE,
187 format!("bytes=-{initial_chunk_size}"),
188 )
189 .headers(extra_headers)
190 .send()
191 .await
192 .and_then(error_for_status)
193 .map_err(Arc::new)
194 .map_err(AsyncHttpRangeReaderError::HttpError)?;
195 Ok(tail_response)
196 }
197
198 pub async fn from_tail_response(
201 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
202 tail_request_response: Response,
203 url: Url,
204 extra_headers: HeaderMap,
205 ) -> Result<Self, AsyncHttpRangeReaderError> {
206 let client = client.into();
207
208 let content_range_header = tail_request_response
210 .headers()
211 .get(reqwest::header::CONTENT_RANGE)
212 .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)?
213 .to_str()
214 .map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?;
215 let content_range = ContentRange::parse(content_range_header).ok_or_else(|| {
216 AsyncHttpRangeReaderError::ContentRangeParser(content_range_header.to_string())
217 })?;
218 let (start, finish, complete_length) = match content_range {
219 ContentRange::Bytes(ContentRangeBytes {
220 first_byte,
221 last_byte,
222 complete_length,
223 }) => (first_byte, last_byte, complete_length),
224 _ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported),
225 };
226
227 let memory_map = memmap2::MmapOptions::new()
229 .len(complete_length as usize)
230 .map_anon()
231 .map_err(Arc::new)
232 .map_err(AsyncHttpRangeReaderError::MemoryMapError)?;
233
234 let memory_map_slice =
237 unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) };
238
239 let requested_range =
240 SparseRange::from_range(complete_length - (finish - start)..complete_length);
241
242 let (request_tx, request_rx) = tokio::sync::mpsc::channel(10);
247 let (state_tx, state_rx) = watch::channel(StreamerState::default());
248 tokio::spawn(run_streamer(
249 client,
250 url,
251 extra_headers,
252 Some((tail_request_response, start)),
253 memory_map,
254 state_tx,
255 request_rx,
256 ));
257
258 let mut streamer_state = StreamerState::default();
260 streamer_state
261 .requested_ranges
262 .push(complete_length - (finish - start)..complete_length);
263
264 let reader = Self {
265 len: memory_map_slice.len() as u64,
266 inner: Mutex::new(Inner {
267 data: memory_map_slice,
268 pos: 0,
269 requested_range,
270 streamer_state,
271 streamer_state_rx: WatchStream::new(state_rx),
272 request_tx,
273 poll_request_tx: None,
274 }),
275 };
276 Ok(reader)
277 }
278
279 pub async fn initial_head_request(
282 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
283 url: reqwest::Url,
284 extra_headers: HeaderMap,
285 ) -> Result<Response, AsyncHttpRangeReaderError> {
286 let client = client.into();
287
288 let head_response = client
290 .head(url.clone())
291 .headers(extra_headers)
292 .send()
293 .await
294 .and_then(error_for_status)
295 .map_err(Arc::new)
296 .map_err(AsyncHttpRangeReaderError::HttpError)?;
297 Ok(head_response)
298 }
299
300 pub async fn from_head_response(
303 client: impl Into<reqwest_middleware::ClientWithMiddleware>,
304 head_response: Response,
305 url: Url,
306 extra_headers: HeaderMap,
307 ) -> Result<Self, AsyncHttpRangeReaderError> {
308 let client = client.into();
309
310 if head_response
312 .headers()
313 .get(reqwest::header::ACCEPT_RANGES)
314 .and_then(|h| h.to_str().ok())
315 != Some("bytes")
316 {
317 return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported);
318 }
319
320 let content_length: u64 = head_response
321 .headers()
322 .get(reqwest::header::CONTENT_LENGTH)
323 .ok_or(AsyncHttpRangeReaderError::ContentLengthMissing)?
324 .to_str()
325 .map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?
326 .parse()
327 .map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?;
328
329 let memory_map = memmap2::MmapOptions::new()
331 .len(content_length as _)
332 .map_anon()
333 .map_err(Arc::new)
334 .map_err(AsyncHttpRangeReaderError::MemoryMapError)?;
335
336 let memory_map_slice =
339 unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) };
340
341 let requested_range = SparseRange::default();
342
343 let (request_tx, request_rx) = tokio::sync::mpsc::channel(10);
348 let (state_tx, state_rx) = watch::channel(StreamerState::default());
349 tokio::spawn(run_streamer(
350 client,
351 url,
352 extra_headers,
353 None,
354 memory_map,
355 state_tx,
356 request_rx,
357 ));
358
359 let streamer_state = StreamerState::default();
361
362 let reader = Self {
363 len: memory_map_slice.len() as u64,
364 inner: Mutex::new(Inner {
365 data: memory_map_slice,
366 pos: 0,
367 requested_range,
368 streamer_state,
369 streamer_state_rx: WatchStream::new(state_rx),
370 request_tx,
371 poll_request_tx: None,
372 }),
373 };
374 Ok(reader)
375 }
376
377 pub async fn requested_ranges(&self) -> Vec<Range<u64>> {
379 let mut inner = self.inner.lock().await;
380 if let Some(Some(new_state)) = inner.streamer_state_rx.next().now_or_never() {
381 inner.streamer_state = new_state;
382 }
383 inner.streamer_state.requested_ranges.clone()
384 }
385
386 pub async fn prefetch(&mut self, bytes: Range<u64>) {
389 let inner = self.inner.get_mut();
390
391 let range = bytes.start..(bytes.end.min(inner.data.len() as u64));
393 if range.start >= range.end {
394 return;
395 }
396
397 let inner = self.inner.get_mut();
399 if let Some((new_range, _)) = inner.requested_range.cover(range.clone()) {
400 let _ = inner.request_tx.send(range).await;
401 inner.requested_range = new_range;
402 }
403 }
404
405 #[allow(clippy::len_without_is_empty)]
407 pub fn len(&self) -> u64 {
408 self.len
409 }
410}
411
412#[tracing::instrument(name = "fetch_ranges", skip_all, fields(url))]
415async fn run_streamer(
416 client: reqwest_middleware::ClientWithMiddleware,
417 url: Url,
418 extra_headers: HeaderMap,
419 initial_tail_response: Option<(Response, u64)>,
420 mut memory_map: MmapMut,
421 mut state_tx: Sender<StreamerState>,
422 mut request_rx: tokio::sync::mpsc::Receiver<Range<u64>>,
423) {
424 let mut state = StreamerState::default();
425
426 if let Some((response, response_start)) = initial_tail_response {
427 state
429 .requested_ranges
430 .push(response_start..memory_map.len() as u64);
431
432 if !stream_response(
434 response,
435 response_start,
436 &mut memory_map,
437 &mut state_tx,
438 &mut state,
439 )
440 .await
441 {
442 return;
443 }
444 }
445
446 'outer: loop {
448 let range = match request_rx.recv().await {
449 Some(range) => range,
450 None => {
451 break 'outer;
452 }
453 };
454
455 let uncovered_ranges = match state.resident_range.cover(range) {
457 None => continue,
458 Some((_, uncovered_ranges)) => uncovered_ranges,
459 };
460
461 for range in uncovered_ranges {
463 state
465 .requested_ranges
466 .push(*range.start()..*range.end() + 1);
467
468 let range_string = format!("bytes={}-{}", range.start(), range.end());
470 let span = info_span!("fetch_range", range = range_string.as_str());
471 let response = match client
472 .get(url.clone())
473 .header(reqwest::header::RANGE, range_string)
474 .headers(extra_headers.clone())
475 .send()
476 .instrument(span)
477 .await
478 .and_then(error_for_status)
479 .map_err(std::io::Error::other)
480 {
481 Err(e) => {
482 state.error = Some(e.into());
483 let _ = state_tx.send(state);
484 break 'outer;
485 }
486 Ok(response) => response,
487 };
488
489 if response.status() != reqwest::StatusCode::PARTIAL_CONTENT {
492 state.error = Some(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported);
493 let _ = state_tx.send(state);
494 break 'outer;
495 }
496
497 if !stream_response(
498 response,
499 *range.start(),
500 &mut memory_map,
501 &mut state_tx,
502 &mut state,
503 )
504 .await
505 {
506 break 'outer;
507 }
508 }
509 }
510}
511
512async fn stream_response(
516 tail_request_response: Response,
517 mut offset: u64,
518 memory_map: &mut MmapMut,
519 state_tx: &mut Sender<StreamerState>,
520 state: &mut StreamerState,
521) -> bool {
522 let mut byte_stream = tail_request_response.bytes_stream();
523 while let Some(bytes) = byte_stream.next().await {
524 let bytes = match bytes {
525 Err(e) => {
526 state.error = Some(e.into());
527 let _ = state_tx.send(state.clone());
528 return false;
529 }
530 Ok(bytes) => bytes,
531 };
532
533 let byte_range = offset..offset + bytes.len() as u64;
535
536 offset = byte_range.end;
538
539 memory_map[byte_range.start as usize..byte_range.end as usize]
541 .copy_from_slice(bytes.as_ref());
542
543 state.resident_range.update(byte_range);
545
546 if state_tx.send(state.clone()).is_err() {
548 return false;
551 }
552 }
553
554 true
555}
556
557impl AsyncSeek for AsyncHttpRangeReader {
558 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
559 let me = self.get_mut();
560 let inner = me.inner.get_mut();
561
562 inner.pos = match position {
563 SeekFrom::Start(pos) => pos,
564 SeekFrom::End(relative) => (inner.data.len() as i64).saturating_add(relative) as u64,
565 SeekFrom::Current(relative) => (inner.pos as i64).saturating_add(relative) as u64,
566 };
567
568 Ok(())
569 }
570
571 fn poll_complete(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
572 let inner = self.inner.get_mut();
573 Poll::Ready(Ok(inner.pos))
574 }
575}
576
577impl AsyncRead for AsyncHttpRangeReader {
578 fn poll_read(
579 self: Pin<&mut Self>,
580 cx: &mut Context<'_>,
581 buf: &mut ReadBuf<'_>,
582 ) -> Poll<io::Result<()>> {
583 let me = self.get_mut();
584 let inner = me.inner.get_mut();
585
586 if let Some(e) = inner.streamer_state.error.as_ref() {
588 return Poll::Ready(Err(io::Error::other(e.clone())));
589 }
590
591 let range = inner.pos..(inner.pos + buf.remaining() as u64).min(inner.data.len() as u64);
593 if range.start >= range.end {
594 return Poll::Ready(Ok(()));
595 }
596
597 while !inner.requested_range.is_covered(range.clone()) {
599 if let Some(mut poll) = inner.poll_request_tx.take() {
601 match poll.poll_reserve(cx) {
602 Poll::Ready(_) => {
603 let _ = poll.send_item(range.clone());
604 inner.requested_range.update(range.clone());
605 break;
606 }
607 Poll::Pending => {
608 inner.poll_request_tx = Some(poll);
609 return Poll::Pending;
610 }
611 }
612 }
613
614 inner.poll_request_tx = Some(PollSender::new(inner.request_tx.clone()));
616 }
617
618 if let Some(mut poll) = inner.poll_request_tx.take() {
620 poll.abort_send();
621 }
622
623 loop {
624 if inner
626 .streamer_state
627 .resident_range
628 .is_covered(range.clone())
629 {
630 let len = (range.end - range.start) as usize;
631 buf.initialize_unfilled_to(len)
632 .copy_from_slice(&inner.data[range.start as usize..range.end as usize]);
633 buf.advance(len);
634 inner.pos += len as u64;
635 return Poll::Ready(Ok(()));
636 }
637
638 match ready!(Pin::new(&mut inner.streamer_state_rx).poll_next(cx)) {
640 None => unreachable!(),
641 Some(state) => {
642 inner.streamer_state = state;
643 if let Some(e) = inner.streamer_state.error.as_ref() {
644 return Poll::Ready(Err(io::Error::other(e.clone())));
645 }
646 }
647 }
648 }
649 }
650}
651
652#[cfg(test)]
653mod static_directory_server;
654
655#[cfg(test)]
656mod test {
657 use super::*;
658 use crate::static_directory_server::StaticDirectoryServer;
659 use assert_matches::assert_matches;
660 use async_zip::tokio::read::seek::ZipFileReader;
661 use futures::AsyncReadExt;
662 use reqwest::{Client, StatusCode};
663 use rstest::*;
664 use std::path::Path;
665 use tokio::io::AsyncReadExt as _;
666
667 #[rstest]
668 #[case(CheckSupportMethod::Head)]
669 #[case(CheckSupportMethod::NegativeRangeRequest(8192))]
670 #[tokio::test]
671 async fn async_range_reader_zip(#[case] check_method: CheckSupportMethod) {
672 let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
674 let server = StaticDirectoryServer::new(&path)
675 .await
676 .expect("could not initialize server");
677
678 let filepath = path.join("andes-1.8.3-pyhd8ed1ab_0.conda");
680 assert!(
681 filepath.exists(),
682 "The conda package is not there yet. Did you run `git lfs pull`?"
683 );
684 let file_size = std::fs::metadata(&filepath).unwrap().len();
685 assert_eq!(
686 file_size, 2_463_995,
687 "The conda package is not there yet. Did you run `git lfs pull`?"
688 );
689
690 let (mut range, _) = AsyncHttpRangeReader::new(
692 Client::new(),
693 server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
694 check_method,
695 HeaderMap::default(),
696 )
697 .await
698 .expect("Could not download range - did you run `git lfs pull`?");
699
700 range.prefetch(range.len() - 8192..range.len()).await;
702
703 assert_eq!(range.len(), file_size);
704
705 let mut reader = ZipFileReader::with_tokio(tokio::io::BufReader::with_capacity(0, range))
706 .await
707 .unwrap();
708
709 assert_eq!(
710 reader
711 .file()
712 .entries()
713 .iter()
714 .map(|e| e.filename().as_str().unwrap_or(""))
715 .collect::<Vec<_>>(),
716 vec![
717 "metadata.json",
718 "info-andes-1.8.3-pyhd8ed1ab_0.tar.zst",
719 "pkg-andes-1.8.3-pyhd8ed1ab_0.tar.zst",
720 ]
721 );
722
723 let request_ranges = reader
725 .inner_mut()
726 .get_mut()
727 .get_mut()
728 .requested_ranges()
729 .await;
730 assert_eq!(request_ranges.len(), 1);
731 assert_eq!(
732 request_ranges[0].end - request_ranges[0].start,
733 8192,
734 "first request should be the size of the initial chunk size"
735 );
736 assert_eq!(
737 request_ranges[0].end, file_size,
738 "first request should be at the end"
739 );
740
741 let entry = reader.file().entries().first().unwrap();
743 let offset = entry.header_offset();
744 let size = entry.compressed_size() + 30 + entry.filename().as_bytes().len() as u64;
747
748 let buffer_size = 8192;
751 let size = ((size + buffer_size - 1) / buffer_size) * buffer_size;
752
753 reader
755 .inner_mut()
756 .get_mut()
757 .get_mut()
758 .prefetch(offset..offset + size as u64)
759 .await;
760
761 let mut contents = String::new();
763 reader
764 .reader_with_entry(0)
765 .await
766 .unwrap()
767 .read_to_string(&mut contents)
768 .await
769 .unwrap();
770
771 let request_ranges = reader
773 .inner_mut()
774 .get_mut()
775 .get_mut()
776 .requested_ranges()
777 .await;
778
779 assert_eq!(contents, r#"{"conda_pkg_format_version": 2}"#);
780 assert_eq!(request_ranges.len(), 2);
781 assert_eq!(
782 request_ranges[1],
783 0..size,
784 "expected only two range requests"
785 );
786 }
787
788 #[rstest]
789 #[case(CheckSupportMethod::Head)]
790 #[case(CheckSupportMethod::NegativeRangeRequest(8192))]
791 #[tokio::test]
792 async fn async_range_reader(#[case] check_method: CheckSupportMethod) {
793 let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
795 let server = StaticDirectoryServer::new(&path)
796 .await
797 .expect("could not initialize server");
798
799 let (mut range, _) = AsyncHttpRangeReader::new(
801 Client::new(),
802 server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
803 check_method,
804 HeaderMap::default(),
805 )
806 .await
807 .expect("bla");
808
809 let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda"))
811 .await
812 .unwrap();
813
814 let mut range_read = vec![0; 64 * 1024];
816 let mut file_read = vec![0; 64 * 1024];
817 loop {
818 let range_read_bytes = range.read(&mut range_read).await.unwrap();
820
821 let file_read_bytes = file
823 .read_exact(&mut file_read[0..range_read_bytes])
824 .await
825 .unwrap();
826
827 assert_eq!(range_read_bytes, file_read_bytes);
828 assert_eq!(
829 range_read[0..range_read_bytes],
830 file_read[0..file_read_bytes]
831 );
832
833 if file_read_bytes == 0 && range_read_bytes == 0 {
834 break;
835 }
836 }
837 }
838
839 #[tokio::test]
840 async fn test_not_found() {
841 let server = StaticDirectoryServer::new(Path::new(env!("CARGO_MANIFEST_DIR")))
842 .await
843 .expect("could not initialize server");
844 let err = AsyncHttpRangeReader::new(
845 Client::new(),
846 server.url().join("not-found").unwrap(),
847 CheckSupportMethod::Head,
848 HeaderMap::default(),
849 )
850 .await
851 .expect_err("expected an error");
852
853 assert_matches!(
854 err, AsyncHttpRangeReaderError::HttpError(err) if err.status() == Some(StatusCode::NOT_FOUND)
855 );
856 }
857}