Skip to main content

async_http_range_reader/
lib.rs

1//! This library provides the [`AsyncHttpRangeReader`] type.
2//!
3//! It allows streaming a file over HTTP while also allow random access. The type implements both
4//! [`AsyncRead`] as well as [`AsyncSeek`]. This is supported through the use of range requests.
5//! Each individual read will request a portion of the file using an HTTP range request.
6//!
7//! Requesting numerous small reads might turn out to be relatively slow because each reads needs to
8//! perform an HTTP request. To alleviate this issue [`AsyncHttpRangeReader::prefetch`] is provided.
9//! Using this method you can *prefect* a number of bytes which will be streamed in on the
10//! background. If a read operation is reading from already (pre)fetched ranges it will stream from
11//! the internal cache instead.
12//!
13//! Internally the [`AsyncHttpRangeReader`] stores a memory map which allows sparsely reading the
14//! data into memory without actually requiring all memory for file to be resident in memory.
15//!
16//! The primary use-case for this library is to be able to sparsely stream a zip archive over HTTP
17//! but its designed in a generic fashion.
18
19mod error;
20mod sparse_range;
21
22use futures::{FutureExt, Stream, StreamExt};
23use http_content_range::{ContentRange, ContentRangeBytes};
24use memmap2::MmapMut;
25use reqwest::header::HeaderMap;
26use reqwest::{Response, Url};
27use sparse_range::SparseRange;
28use std::{
29    io::{self, SeekFrom},
30    ops::Range,
31    pin::Pin,
32    sync::Arc,
33    task::{ready, Context, Poll},
34};
35use tokio::{
36    io::{AsyncRead, AsyncSeek, ReadBuf},
37    sync::watch::Sender,
38    sync::{watch, Mutex},
39};
40use tokio_stream::wrappers::WatchStream;
41use tokio_util::sync::PollSender;
42use tracing::{info_span, Instrument};
43
44pub use error::AsyncHttpRangeReaderError;
45
46/// An `AsyncRangeReader` enables reading from a file over HTTP using range requests.
47///
48/// See the [`crate`] level documentation for more information.
49///
50/// The general entrypoint is [`AsyncHttpRangeReader::new`]. Depending on the
51/// [`CheckSupportMethod`], this will either call [`AsyncHttpRangeReader::initial_tail_request`] or
52/// [`AsyncHttpRangeReader::initial_head_request`] to send the initial request and then
53/// [`AsyncHttpRangeReader::from_range_response`] or [`AsyncHttpRangeReader::from_head_response`] to
54/// initialize the async reader. If you want to apply a caching layer, you can send the initial head
55/// (or tail) request yourself with your cache headers (e.g. through the
56/// [http-cache-semantics](https://docs.rs/http-cache-semantics) crate):
57///
58/// ```rust
59/// # use url::Url;
60/// # use async_http_range_reader::{AsyncHttpRangeReader, AsyncHttpRangeReaderError};
61/// # use reqwest::header::HeaderMap;
62/// async fn get_reader_cached(
63///     url: Url,
64/// ) -> Result<Option<AsyncHttpRangeReader>, AsyncHttpRangeReaderError> {
65///     let etag = "63c550e8-5ae";
66///     let client = reqwest::Client::new();
67///     let response = client
68///         .head(url.clone())
69///         .header(reqwest::header::IF_NONE_MATCH, etag)
70///         .send()
71///         .await?;
72///     if response.status() == reqwest::StatusCode::NOT_MODIFIED {
73///         Ok(None)
74///     } else {
75///         let reader = AsyncHttpRangeReader::from_head_response(client, response, url, HeaderMap::default()).await?;
76///         Ok(Some(reader))
77///     }
78/// }
79/// ```
80#[derive(Debug)]
81pub struct AsyncHttpRangeReader {
82    inner: Mutex<Inner>,
83    len: u64,
84}
85
86#[derive(Default, Clone, Debug)]
87struct StreamerState {
88    resident_range: SparseRange,
89    requested_ranges: Vec<Range<u64>>,
90    error: Option<AsyncHttpRangeReaderError>,
91}
92
93#[derive(Debug)]
94struct Inner {
95    /// A read-only view on the memory mapped data. The `downloaded_range` indicates the regions of
96    /// memory that contain bytes that have been downloaded.
97    data: &'static [u8],
98
99    /// The current read position in the stream
100    pos: u64,
101
102    /// The range of bytes that have been requested for download
103    requested_range: SparseRange,
104
105    /// The range of bytes that have actually been downloaded to `data`.
106    streamer_state: StreamerState,
107
108    /// A channel receiver that holds the last downloaded range (or an error) from the background
109    /// task.
110    streamer_state_rx: WatchStream<StreamerState>,
111
112    /// A channel sender to send range requests to the background task
113    ///
114    /// Contract: All ranges sent must be inside the range of the memory map
115    request_tx: tokio::sync::mpsc::Sender<Range<u64>>,
116
117    /// An optional object to reserve a slot in the `request_tx` sender. When in the process of
118    /// sending a requests this contains an actual value.
119    poll_request_tx: Option<PollSender<Range<u64>>>,
120}
121
122/// For the initial request, we support either directly requesting N bytes from the end for file
123/// or, if you the server doesn't support negative byte offsets, starting with a HEAD request
124/// instead
125pub enum CheckSupportMethod {
126    /// Perform a range request with a negative byte range. This will return the N bytes from the
127    /// *end* of the file as well as the file-size. This is especially useful to also immediately
128    /// get some bytes from the end of the file.
129    NegativeRangeRequest(u64),
130
131    /// Perform a head request to get the length of the file and check if the server supports range
132    /// requests.
133    Head,
134}
135
136fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
137    response
138        .error_for_status()
139        .map_err(reqwest_middleware::Error::Reqwest)
140}
141
142impl AsyncHttpRangeReader {
143    /// Construct a new `AsyncHttpRangeReader`.
144    pub async fn new(
145        client: impl Into<reqwest_middleware::ClientWithMiddleware>,
146        url: Url,
147        check_method: CheckSupportMethod,
148        extra_headers: HeaderMap,
149    ) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
150        let client = client.into();
151        match check_method {
152            CheckSupportMethod::NegativeRangeRequest(initial_chunk_size) => {
153                let response = Self::initial_tail_request(
154                    client.clone(),
155                    url.clone(),
156                    initial_chunk_size,
157                    HeaderMap::default(),
158                )
159                .await?;
160                let response_headers = response.headers().clone();
161                let self_ = Self::from_range_response(client, response, url, extra_headers).await?;
162                Ok((self_, response_headers))
163            }
164            CheckSupportMethod::Head => {
165                let response =
166                    Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default())
167                        .await?;
168                let response_headers = response.headers().clone();
169                let self_ = Self::from_head_response(client, response, url, extra_headers).await?;
170                Ok((self_, response_headers))
171            }
172        }
173    }
174
175    /// Send an initial range request to determine if the remote accepts range
176    /// requests. This will return a number of bytes from the end of the stream. Use the
177    /// `initial_chunk_size` parameter to define how many bytes should be requested from the end.
178    pub async fn initial_tail_request(
179        client: impl Into<reqwest_middleware::ClientWithMiddleware>,
180        url: reqwest::Url,
181        initial_chunk_size: u64,
182        extra_headers: HeaderMap,
183    ) -> Result<Response, AsyncHttpRangeReaderError> {
184        let client = client.into();
185        let tail_response = client
186            .get(url)
187            .header(
188                reqwest::header::RANGE,
189                format!("bytes=-{initial_chunk_size}"),
190            )
191            .headers(extra_headers)
192            .send()
193            .await
194            .and_then(error_for_status)
195            .map_err(Arc::new)
196            .map_err(AsyncHttpRangeReaderError::HttpError)?;
197        Ok(tail_response)
198    }
199
200    #[deprecated(note = "use `from_range_response` instead")]
201    pub async fn from_tail_response(
202        client: impl Into<reqwest_middleware::ClientWithMiddleware>,
203        tail_request_response: Response,
204        url: Url,
205        extra_headers: HeaderMap,
206    ) -> Result<Self, AsyncHttpRangeReaderError> {
207        Self::from_range_response(client, tail_request_response, url, extra_headers).await
208    }
209
210    /// Initialize the reader from [`AsyncHttpRangeReader::initial_tail_request`] (or a user
211    /// provided range response)
212    pub async fn from_range_response(
213        client: impl Into<reqwest_middleware::ClientWithMiddleware>,
214        response: Response,
215        url: Url,
216        extra_headers: HeaderMap,
217    ) -> Result<Self, AsyncHttpRangeReaderError> {
218        let client = client.into();
219
220        // Get the size of the file from this initial request
221        let content_range_header = response
222            .headers()
223            .get(reqwest::header::CONTENT_RANGE)
224            .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)?
225            .to_str()
226            .map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?;
227        // The parser ensures finish < complete_length
228        let content_range = ContentRange::parse(content_range_header).ok_or_else(|| {
229            AsyncHttpRangeReaderError::ContentRangeParser(content_range_header.to_string())
230        })?;
231        let (start, end_inclusive, complete_length) = match content_range {
232            ContentRange::Bytes(ContentRangeBytes {
233                first_byte,
234                last_byte,
235                complete_length,
236            }) => (first_byte, last_byte, complete_length),
237            _ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported),
238        };
239
240        // Allocate a memory map to hold the data
241        let memory_map = memmap2::MmapOptions::new()
242            .len(complete_length as usize)
243            .map_anon()
244            .map_err(Arc::new)
245            .map_err(AsyncHttpRangeReaderError::MemoryMapError)?;
246
247        // SAFETY: Get a read-only slice to the memory. This is safe because the memory map is never
248        // reallocated and we keep track of the initialized part.
249        let memory_map_slice =
250            unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) };
251
252        let requested_range = SparseRange::from_range(start..end_inclusive + 1);
253
254        // adding more than 2 entries to the channel would block the sender. I assumed two would
255        // suffice because I would want to 1) prefetch a certain range and 2) read stuff via the
256        // AsyncRead implementation. Any extra would simply have to wait for one of these to
257        // succeed. I eventually used 10 because who cares.
258        let (request_tx, request_rx) = tokio::sync::mpsc::channel(10);
259        let (state_tx, state_rx) = watch::channel(StreamerState::default());
260        tokio::spawn(run_streamer(
261            client,
262            url,
263            extra_headers,
264            Some((response, start, end_inclusive + 1)),
265            memory_map,
266            state_tx,
267            request_rx,
268        ));
269
270        // Configure the initial state of the streamer.
271        let mut streamer_state = StreamerState::default();
272        streamer_state
273            .requested_ranges
274            .push(start..end_inclusive + 1);
275
276        let reader = Self {
277            len: memory_map_slice.len() as u64,
278            inner: Mutex::new(Inner {
279                data: memory_map_slice,
280                pos: 0,
281                requested_range,
282                streamer_state,
283                streamer_state_rx: WatchStream::new(state_rx),
284                request_tx,
285                poll_request_tx: None,
286            }),
287        };
288        Ok(reader)
289    }
290
291    /// Send an initial range request to determine if the remote accepts range
292    /// requests and get the content length
293    pub async fn initial_head_request(
294        client: impl Into<reqwest_middleware::ClientWithMiddleware>,
295        url: reqwest::Url,
296        extra_headers: HeaderMap,
297    ) -> Result<Response, AsyncHttpRangeReaderError> {
298        let client = client.into();
299
300        // Perform a HEAD request to get the content-length.
301        let head_response = client
302            .head(url.clone())
303            .headers(extra_headers)
304            .send()
305            .await
306            .and_then(error_for_status)
307            .map_err(Arc::new)
308            .map_err(AsyncHttpRangeReaderError::HttpError)?;
309        Ok(head_response)
310    }
311
312    /// Initialize the reader from [`AsyncHttpRangeReader::initial_head_request`] (or a user
313    /// provided response)
314    pub async fn from_head_response(
315        client: impl Into<reqwest_middleware::ClientWithMiddleware>,
316        head_response: Response,
317        url: Url,
318        extra_headers: HeaderMap,
319    ) -> Result<Self, AsyncHttpRangeReaderError> {
320        let client = client.into();
321
322        // Are range requests supported?
323        if head_response
324            .headers()
325            .get(reqwest::header::ACCEPT_RANGES)
326            .and_then(|h| h.to_str().ok())
327            != Some("bytes")
328        {
329            return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported);
330        }
331
332        let content_length: u64 = head_response
333            .headers()
334            .get(reqwest::header::CONTENT_LENGTH)
335            .ok_or(AsyncHttpRangeReaderError::ContentLengthMissing)?
336            .to_str()
337            .map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?
338            .parse()
339            .map_err(|_err| AsyncHttpRangeReaderError::ContentLengthMissing)?;
340
341        // Allocate a memory map to hold the data
342        let memory_map = memmap2::MmapOptions::new()
343            .len(content_length as _)
344            .map_anon()
345            .map_err(Arc::new)
346            .map_err(AsyncHttpRangeReaderError::MemoryMapError)?;
347
348        // SAFETY: Get a read-only slice to the memory. This is safe because the memory map is never
349        // reallocated and we keep track of the initialized part.
350        let memory_map_slice =
351            unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) };
352
353        let requested_range = SparseRange::default();
354
355        // adding more than 2 entries to the channel would block the sender. I assumed two would
356        // suffice because I would want to 1) prefetch a certain range and 2) read stuff via the
357        // AsyncRead implementation. Any extra would simply have to wait for one of these to
358        // succeed. I eventually used 10 because who cares.
359        let (request_tx, request_rx) = tokio::sync::mpsc::channel(10);
360        let (state_tx, state_rx) = watch::channel(StreamerState::default());
361        tokio::spawn(run_streamer(
362            client,
363            url,
364            extra_headers,
365            None,
366            memory_map,
367            state_tx,
368            request_rx,
369        ));
370
371        // Configure the initial state of the streamer.
372        let streamer_state = StreamerState::default();
373
374        let reader = Self {
375            len: memory_map_slice.len() as u64,
376            inner: Mutex::new(Inner {
377                data: memory_map_slice,
378                pos: 0,
379                requested_range,
380                streamer_state,
381                streamer_state_rx: WatchStream::new(state_rx),
382                request_tx,
383                poll_request_tx: None,
384            }),
385        };
386        Ok(reader)
387    }
388
389    /// Returns the ranges that this instance actually performed HTTP requests for.
390    pub async fn requested_ranges(&self) -> Vec<Range<u64>> {
391        let mut inner = self.inner.lock().await;
392        if let Some(Some(new_state)) = inner.streamer_state_rx.next().now_or_never() {
393            inner.streamer_state = new_state;
394        }
395        inner.streamer_state.requested_ranges.clone()
396    }
397
398    /// Prefetches a range of bytes from the remote. When specifying a large range this can
399    /// drastically reduce the number of requests required to the server.
400    pub async fn prefetch(&mut self, bytes: Range<u64>) {
401        let inner = self.inner.get_mut();
402
403        // Ensure the range is withing the file size and non-zero of length.
404        let range = bytes.start..(bytes.end.min(inner.data.len() as u64));
405        if range.start >= range.end {
406            return;
407        }
408
409        // Check if the range has been requested or not.
410        let inner = self.inner.get_mut();
411        if let Some((new_range, _)) = inner.requested_range.cover(range.clone()) {
412            let _ = inner.request_tx.send(range).await;
413            inner.requested_range = new_range;
414        }
415    }
416
417    /// Returns the length of the stream in bytes
418    #[allow(clippy::len_without_is_empty)]
419    pub fn len(&self) -> u64 {
420        self.len
421    }
422}
423
424/// A task that will download parts from the remote archive and "send" them to the frontend as they
425/// become available.
426#[tracing::instrument(name = "fetch_ranges", skip_all, fields(url))]
427async fn run_streamer(
428    client: reqwest_middleware::ClientWithMiddleware,
429    url: Url,
430    extra_headers: HeaderMap,
431    response: Option<(Response, u64, u64)>,
432    mut memory_map: MmapMut,
433    mut state_tx: Sender<StreamerState>,
434    mut request_rx: tokio::sync::mpsc::Receiver<Range<u64>>,
435) {
436    let mut state = StreamerState::default();
437
438    if let Some((response, start, end_exclusive)) = response {
439        // Add the initial range to the state
440        state.requested_ranges.push(start..end_exclusive);
441
442        // Stream the initial data in memory
443        if !stream_response(
444            response,
445            start,
446            end_exclusive,
447            &mut memory_map,
448            &mut state_tx,
449            &mut state,
450        )
451        .await
452        {
453            return;
454        }
455    }
456
457    // Listen for any new incoming requests
458    'outer: loop {
459        let range = match request_rx.recv().await {
460            Some(range) => range,
461            None => {
462                break 'outer;
463            }
464        };
465
466        // Determine the range that we need to cover
467        let uncovered_ranges = match state.resident_range.cover(range) {
468            None => continue,
469            Some((_, uncovered_ranges)) => uncovered_ranges,
470        };
471
472        // Download and stream each range.
473        for range in uncovered_ranges {
474            // Update the requested ranges
475            state
476                .requested_ranges
477                .push(*range.start()..*range.end() + 1);
478
479            // Execute the request
480            let range_string = format!("bytes={}-{}", range.start(), range.end());
481            let span = info_span!("fetch_range", range = range_string.as_str());
482            let response = match client
483                .get(url.clone())
484                .header(reqwest::header::RANGE, range_string)
485                .headers(extra_headers.clone())
486                .send()
487                .instrument(span)
488                .await
489                .and_then(error_for_status)
490                .map_err(std::io::Error::other)
491            {
492                Err(e) => {
493                    state.error = Some(e.into());
494                    let _ = state_tx.send(state);
495                    break 'outer;
496                }
497                Ok(response) => response,
498            };
499
500            if let Err(err) =
501                validate_content_range(&response, *range.start(), *range.end(), memory_map.len())
502            {
503                state.error = Some(err);
504                let _ = state_tx.send(state);
505                break 'outer;
506            }
507
508            // If the server returns a successful, but non-206 response (e.g., 200), then it
509            // doesn't support range requests (even if the `Accept-Ranges` header is set).
510            if response.status() != reqwest::StatusCode::PARTIAL_CONTENT {
511                state.error = Some(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported);
512                let _ = state_tx.send(state);
513                break 'outer;
514            }
515
516            if !stream_response(
517                response,
518                *range.start(),
519                *range.end() + 1,
520                &mut memory_map,
521                &mut state_tx,
522                &mut state,
523            )
524            .await
525            {
526                break 'outer;
527            }
528        }
529    }
530}
531
532/// Ensure that the response range headers match the request range headers
533fn validate_content_range(
534    response: &Response,
535    expected_start: u64,
536    expected_end_inclusive: u64,
537    expected_complete_length: usize,
538) -> Result<(), AsyncHttpRangeReaderError> {
539    let content_range_header = response
540        .headers()
541        .get(reqwest::header::CONTENT_RANGE)
542        .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)?
543        .to_str()
544        .map_err(|_err| AsyncHttpRangeReaderError::ContentRangeMissing)?;
545    let content_range = ContentRange::parse(content_range_header).ok_or_else(|| {
546        AsyncHttpRangeReaderError::ContentRangeParser(content_range_header.to_string())
547    })?;
548    let (actual_start, actual_end_inclusive, actual_complete_length) = match content_range {
549        ContentRange::Bytes(ContentRangeBytes {
550            first_byte,
551            last_byte,
552            complete_length,
553        }) => (first_byte, last_byte, complete_length),
554        _ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported),
555    };
556    if expected_start != actual_start
557        || expected_end_inclusive != actual_end_inclusive
558        || expected_complete_length as u64 != actual_complete_length
559    {
560        return Err(AsyncHttpRangeReaderError::RangeMismatch {
561            expected_start,
562            expected_end_inclusive,
563            expected_complete_length,
564            actual_start,
565            actual_end_inclusive,
566            actual_complete_length,
567        });
568    }
569
570    Ok(())
571}
572
573/// Streams the data from the specified response to the memory map updating progress in between.
574/// Returns `true` if everything went fine, `false` if anything went wrong. The error state, if any,
575/// is stored in `state_tx` so the "frontend" will consume it.
576///
577/// The response must return bytes for the range of precisely `start..end_exclusive`.
578async fn stream_response(
579    tail_request_response: Response,
580    start: u64,
581    end_exclusive: u64,
582    memory_map: &mut MmapMut,
583    state_tx: &mut Sender<StreamerState>,
584    state: &mut StreamerState,
585) -> bool {
586    // Enforce request channel contract
587    assert!(
588        (end_exclusive as usize) <= memory_map.len(),
589        "end is outside of memory map {} > {}",
590        end_exclusive,
591        memory_map.len()
592    );
593
594    let mut offset = start;
595    let mut byte_stream = tail_request_response.bytes_stream();
596    while let Some(bytes) = byte_stream.next().await {
597        let bytes = match bytes {
598            Err(e) => {
599                state.error = Some(e.into());
600                let _ = state_tx.send(state.clone());
601                return false;
602            }
603            Ok(bytes) => bytes,
604        };
605
606        // Determine the range of these bytes in the complete file
607        let byte_range = offset..offset + bytes.len() as u64;
608
609        // Update the offset
610        offset += bytes.len() as u64;
611
612        // Prevent the server from sending more bytes than advertised in a response
613        if offset > end_exclusive {
614            state.error = Some(AsyncHttpRangeReaderError::ResponseTooLong {
615                expected: end_exclusive - start,
616            });
617            let _ = state_tx.send(state.clone());
618            return false;
619        }
620
621        // Copy the data from the stream to memory
622        memory_map[byte_range.start as usize..byte_range.end as usize]
623            .copy_from_slice(bytes.as_ref());
624
625        // Update the range of bytes that have been downloaded
626        state.resident_range.update(byte_range);
627
628        // Notify anyone that's listening that we have downloaded some extra data
629        if state_tx.send(state.clone()).is_err() {
630            // If we failed to set the state it means there is no receiver. In that case we should
631            // just exit.
632            return false;
633        }
634    }
635
636    // Prevent the server from sending less bytes than advertised in a response
637    if offset != end_exclusive {
638        state.error = Some(AsyncHttpRangeReaderError::ResponseTooShort {
639            expected: end_exclusive - start,
640            actual: offset - start,
641        });
642        let _ = state_tx.send(state.clone());
643        return false;
644    }
645
646    true
647}
648
649impl AsyncSeek for AsyncHttpRangeReader {
650    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
651        let me = self.get_mut();
652        let inner = me.inner.get_mut();
653
654        inner.pos = match position {
655            SeekFrom::Start(pos) => pos,
656            SeekFrom::End(relative) => (inner.data.len() as i64).saturating_add(relative) as u64,
657            SeekFrom::Current(relative) => (inner.pos as i64).saturating_add(relative) as u64,
658        };
659
660        Ok(())
661    }
662
663    fn poll_complete(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
664        let inner = self.inner.get_mut();
665        Poll::Ready(Ok(inner.pos))
666    }
667}
668
669impl AsyncRead for AsyncHttpRangeReader {
670    fn poll_read(
671        self: Pin<&mut Self>,
672        cx: &mut Context<'_>,
673        buf: &mut ReadBuf<'_>,
674    ) -> Poll<io::Result<()>> {
675        let me = self.get_mut();
676        let inner = me.inner.get_mut();
677
678        // If a previous error occurred we return that.
679        if let Some(e) = inner.streamer_state.error.as_ref() {
680            return Poll::Ready(Err(io::Error::other(e.clone())));
681        }
682
683        // Determine the range to be fetched
684        let range = inner.pos..(inner.pos + buf.remaining() as u64).min(inner.data.len() as u64);
685        if range.start >= range.end {
686            return Poll::Ready(Ok(()));
687        }
688
689        // Ensure we requested the required bytes
690        while !inner.requested_range.is_covered(range.clone()) {
691            // If there is an active range request wait for it to complete
692            if let Some(mut poll) = inner.poll_request_tx.take() {
693                match poll.poll_reserve(cx) {
694                    Poll::Ready(_) => {
695                        let _ = poll.send_item(range.clone());
696                        inner.requested_range.update(range.clone());
697                        break;
698                    }
699                    Poll::Pending => {
700                        inner.poll_request_tx = Some(poll);
701                        return Poll::Pending;
702                    }
703                }
704            }
705
706            // Request the range
707            inner.poll_request_tx = Some(PollSender::new(inner.request_tx.clone()));
708        }
709
710        // If there is still a request poll open but there is no need for a request, abort it.
711        if let Some(mut poll) = inner.poll_request_tx.take() {
712            poll.abort_send();
713        }
714
715        loop {
716            // Is the range already available?
717            if inner
718                .streamer_state
719                .resident_range
720                .is_covered(range.clone())
721            {
722                let len = (range.end - range.start) as usize;
723                buf.initialize_unfilled_to(len)
724                    .copy_from_slice(&inner.data[range.start as usize..range.end as usize]);
725                buf.advance(len);
726                inner.pos += len as u64;
727                return Poll::Ready(Ok(()));
728            }
729
730            // Otherwise wait for new data to come in
731            match ready!(Pin::new(&mut inner.streamer_state_rx).poll_next(cx)) {
732                None => unreachable!(),
733                Some(state) => {
734                    inner.streamer_state = state;
735                    if let Some(e) = inner.streamer_state.error.as_ref() {
736                        return Poll::Ready(Err(io::Error::other(e.clone())));
737                    }
738                }
739            }
740        }
741    }
742}
743
744#[cfg(test)]
745mod static_directory_server;
746
747#[cfg(test)]
748mod test {
749    use super::*;
750    use crate::static_directory_server::StaticDirectoryServer;
751    use assert_matches::assert_matches;
752    use async_zip::tokio::read::seek::ZipFileReader;
753    use axum::body::Body;
754    use axum::extract::Request;
755    use axum::response::IntoResponse;
756    use futures::AsyncReadExt;
757    use reqwest::header;
758    use reqwest::Method;
759    use reqwest::{Client, StatusCode};
760    use rstest::*;
761    use std::path::Path;
762    use tokio::io::AsyncReadExt as _;
763
764    #[rstest]
765    #[case(CheckSupportMethod::Head)]
766    #[case(CheckSupportMethod::NegativeRangeRequest(8192))]
767    #[tokio::test]
768    async fn async_range_reader_zip(#[case] check_method: CheckSupportMethod) {
769        // Spawn a static file server
770        let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
771        let server = StaticDirectoryServer::new(&path)
772            .await
773            .expect("could not initialize server");
774
775        // check that file is there and has the right size
776        let filepath = path.join("andes-1.8.3-pyhd8ed1ab_0.conda");
777        assert!(
778            filepath.exists(),
779            "The conda package is not there yet. Did you run `git lfs pull`?"
780        );
781        let file_size = std::fs::metadata(&filepath).unwrap().len();
782        assert_eq!(
783            file_size, 2_463_995,
784            "The conda package is not there yet. Did you run `git lfs pull`?"
785        );
786
787        // Construct an AsyncRangeReader
788        let (mut range, _) = AsyncHttpRangeReader::new(
789            Client::new(),
790            server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
791            check_method,
792            HeaderMap::default(),
793        )
794        .await
795        .expect("Could not download range - did you run `git lfs pull`?");
796
797        // Make sure we have read the last couple of bytes
798        range.prefetch(range.len() - 8192..range.len()).await;
799
800        assert_eq!(range.len(), file_size);
801
802        let mut reader = ZipFileReader::with_tokio(tokio::io::BufReader::with_capacity(0, range))
803            .await
804            .unwrap();
805
806        assert_eq!(
807            reader
808                .file()
809                .entries()
810                .iter()
811                .map(|e| e.filename().as_str().unwrap_or(""))
812                .collect::<Vec<_>>(),
813            vec![
814                "metadata.json",
815                "info-andes-1.8.3-pyhd8ed1ab_0.tar.zst",
816                "pkg-andes-1.8.3-pyhd8ed1ab_0.tar.zst",
817            ]
818        );
819
820        // Get the number of performed requests so far
821        let request_ranges = reader
822            .inner_mut()
823            .get_mut()
824            .get_mut()
825            .requested_ranges()
826            .await;
827        assert_eq!(request_ranges.len(), 1);
828        assert_eq!(
829            request_ranges[0].end - request_ranges[0].start,
830            8192,
831            "first request should be the size of the initial chunk size"
832        );
833        assert_eq!(
834            request_ranges[0].end, file_size,
835            "first request should be at the end"
836        );
837
838        // Prefetch the data for the metadata.json file
839        let entry = reader.file().entries().first().unwrap();
840        let offset = entry.header_offset();
841        // Get the size of the entry plus the header + size of the filename. We should also actually
842        // include bytes for the extra fields but we don't have that information.
843        let size = entry.compressed_size() + 30 + entry.filename().as_bytes().len() as u64;
844
845        // The zip archive uses as BufReader which reads in chunks of 8192. To ensure we prefetch
846        // enough data we round the size up to the nearest multiple of the buffer size.
847        let buffer_size = 8192;
848        let size = size.div_ceil(buffer_size) * buffer_size;
849
850        // Fetch the bytes from the zip archive that contain the requested file.
851        reader
852            .inner_mut()
853            .get_mut()
854            .get_mut()
855            .prefetch(offset..offset + size as u64)
856            .await;
857
858        // Read the contents of the metadata.json file
859        let mut contents = String::new();
860        reader
861            .reader_with_entry(0)
862            .await
863            .unwrap()
864            .read_to_string(&mut contents)
865            .await
866            .unwrap();
867
868        // Get the number of performed requests
869        let request_ranges = reader
870            .inner_mut()
871            .get_mut()
872            .get_mut()
873            .requested_ranges()
874            .await;
875
876        assert_eq!(contents, r#"{"conda_pkg_format_version": 2}"#);
877        assert_eq!(request_ranges.len(), 2);
878        assert_eq!(
879            request_ranges[1],
880            0..size,
881            "expected only two range requests"
882        );
883    }
884
885    #[rstest]
886    #[case(CheckSupportMethod::Head)]
887    #[case(CheckSupportMethod::NegativeRangeRequest(8192))]
888    #[tokio::test]
889    async fn async_range_reader(#[case] check_method: CheckSupportMethod) {
890        // Spawn a static file server
891        let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
892        let server = StaticDirectoryServer::new(&path)
893            .await
894            .expect("could not initialize server");
895
896        // Construct an AsyncRangeReader
897        let (mut range, _) = AsyncHttpRangeReader::new(
898            Client::new(),
899            server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
900            check_method,
901            HeaderMap::default(),
902        )
903        .await
904        .expect("bla");
905
906        // Also open a simple file reader
907        let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda"))
908            .await
909            .unwrap();
910
911        // Read until the end and make sure that the contents matches
912        let mut range_read = vec![0; 64 * 1024];
913        let mut file_read = vec![0; 64 * 1024];
914        loop {
915            // Read with the async reader
916            let range_read_bytes = range.read(&mut range_read).await.unwrap();
917
918            // Read directly from the file
919            let file_read_bytes = file
920                .read_exact(&mut file_read[0..range_read_bytes])
921                .await
922                .unwrap();
923
924            assert_eq!(range_read_bytes, file_read_bytes);
925            assert_eq!(
926                range_read[0..range_read_bytes],
927                file_read[0..file_read_bytes]
928            );
929
930            if file_read_bytes == 0 && range_read_bytes == 0 {
931                break;
932            }
933        }
934    }
935
936    #[tokio::test]
937    async fn test_not_found() {
938        let server = StaticDirectoryServer::new(Path::new(env!("CARGO_MANIFEST_DIR")))
939            .await
940            .expect("could not initialize server");
941        let err = AsyncHttpRangeReader::new(
942            Client::new(),
943            server.url().join("not-found").unwrap(),
944            CheckSupportMethod::Head,
945            HeaderMap::default(),
946        )
947        .await
948        .expect_err("expected an error");
949
950        assert_matches!(
951            err, AsyncHttpRangeReaderError::HttpError(err) if err.status() == Some(StatusCode::NOT_FOUND)
952        );
953    }
954
955    /// Spawn a server where the HEAD response reports `head_size` bytes, and range requests always
956    /// claim to be `pretend_size` bytes, while actually serving `actual_size`.
957    async fn spawn_mismatch_server(
958        head_content_length: usize,
959        pretend_size: usize,
960        actual_size: usize,
961    ) -> Url {
962        let app =
963            axum::Router::new().fallback(async move |request: Request| match *request.method() {
964                Method::HEAD => {
965                    let headers = [
966                        (header::CONTENT_LENGTH, head_content_length.to_string()),
967                        (header::ACCEPT_RANGES, "bytes".to_string()),
968                    ];
969                    (StatusCode::OK, headers).into_response()
970                }
971                Method::GET => {
972                    let range_header = request
973                        .headers()
974                        .get(header::RANGE)
975                        .unwrap()
976                        .to_str()
977                        .unwrap()
978                        .to_string();
979
980                    let range_spec = range_header.strip_prefix("bytes=").unwrap();
981                    let (start_str, _end_str) = range_spec.split_once('-').unwrap();
982                    let start = start_str.parse::<usize>().unwrap();
983                    // The end is inclusive
984                    let end = start + pretend_size - 1;
985
986                    axum::response::Response::builder()
987                        .status(StatusCode::PARTIAL_CONTENT)
988                        // Note that the client ignores this value currently, it only checks the
989                        // actual size
990                        .header(
991                            header::CONTENT_RANGE,
992                            format!("bytes {start}-{end}/{head_content_length}"),
993                        )
994                        .body(Body::from(vec![1u8; actual_size]))
995                        .unwrap()
996                        .into_response()
997                }
998                _ => StatusCode::METHOD_NOT_ALLOWED.into_response(),
999            });
1000
1001        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1002        let local_addr = listener.local_addr().unwrap();
1003        tokio::spawn(async move {
1004            axum::serve(listener, app.into_make_service())
1005                .await
1006                .unwrap();
1007        });
1008
1009        Url::parse(&format!("http://localhost:{}/file", local_addr.port())).unwrap()
1010    }
1011
1012    /// HEAD says 512 bytes, but range responses return 1024 bytes — overflows
1013    /// the memory map.
1014    #[tokio::test]
1015    async fn test_content_length_response_beyond_content_length() {
1016        /// Extract the [`AsyncHttpRangeReaderError`] from an `io::Error` returned by `read`.
1017        fn into_range_error(err: std::io::Error) -> AsyncHttpRangeReaderError {
1018            err.into_inner()
1019                .unwrap()
1020                .downcast::<AsyncHttpRangeReaderError>()
1021                .map(|e| *e)
1022                .unwrap()
1023        }
1024
1025        let cases: Vec<(usize, usize, usize, Option<AsyncHttpRangeReaderError>)> = vec![
1026            // Baseline
1027            (512, 512, 512, None),
1028            // The requested and declared length is 512, while the actual content is 1024
1029            (
1030                512,
1031                512,
1032                1024,
1033                Some(AsyncHttpRangeReaderError::ResponseTooLong { expected: 512 }),
1034            ),
1035            // The declared total length is 512, but it says and sends a range of 1024
1036            (
1037                512,
1038                1024,
1039                1024,
1040                Some(AsyncHttpRangeReaderError::ContentRangeParser(
1041                    "bytes 0-1023/512".to_string(),
1042                )),
1043            ),
1044            // The declared total length is 512, but it says a range of 1024
1045            (
1046                512,
1047                1024,
1048                512,
1049                Some(AsyncHttpRangeReaderError::ContentRangeParser(
1050                    "bytes 0-1023/512".to_string(),
1051                )),
1052            ),
1053            // Baseline
1054            (1024, 512, 512, None),
1055            // We requested 512, but we're getting 1024
1056            (
1057                1024,
1058                512,
1059                1024,
1060                Some(AsyncHttpRangeReaderError::ResponseTooLong { expected: 512 }),
1061            ),
1062            // We requested 512, but we're getting 1024
1063            (
1064                1024,
1065                1024,
1066                1024,
1067                Some(AsyncHttpRangeReaderError::RangeMismatch {
1068                    expected_start: 0,
1069                    expected_end_inclusive: 511,
1070                    expected_complete_length: 1024,
1071                    actual_start: 0,
1072                    actual_end_inclusive: 1023,
1073                    actual_complete_length: 1024,
1074                }),
1075            ),
1076            // We requested 512, but the header says 1024
1077            (
1078                1024,
1079                1024,
1080                512,
1081                Some(AsyncHttpRangeReaderError::RangeMismatch {
1082                    expected_start: 0,
1083                    expected_end_inclusive: 511,
1084                    expected_complete_length: 1024,
1085                    actual_start: 0,
1086                    actual_end_inclusive: 1023,
1087                    actual_complete_length: 1024,
1088                }),
1089            ),
1090        ];
1091        for (head_content_length, range_header_length, range_actual_length, expected_error) in cases
1092        {
1093            let url = spawn_mismatch_server(
1094                head_content_length,
1095                range_header_length,
1096                range_actual_length,
1097            )
1098            .await;
1099
1100            let (mut reader, _) = AsyncHttpRangeReader::new(
1101                Client::new(),
1102                url,
1103                CheckSupportMethod::Head,
1104                HeaderMap::default(),
1105            )
1106            .await
1107            .unwrap();
1108
1109            assert_eq!(reader.len(), head_content_length as u64);
1110            reader.prefetch(0..512).await;
1111
1112            let mut buf = vec![0u8; 512];
1113            let result = reader.read(&mut buf).await;
1114            let label =
1115                format!("{head_content_length} {range_header_length} {range_actual_length}");
1116            match expected_error {
1117                None => {
1118                    assert_matches!(result, Ok(_), "{label}");
1119                }
1120                Some(expected) => {
1121                    // The nested error don't support `PartialEq`
1122                    assert_eq!(
1123                        into_range_error(result.unwrap_err()).to_string(),
1124                        expected.to_string(),
1125                        "{label}"
1126                    );
1127                }
1128            }
1129        }
1130    }
1131}