stream_download/source/
mod.rs

1//! Provides the [`SourceStream`] trait which abstracts over the transport used to
2//! stream remote content.
3
4use std::convert::Infallible;
5use std::error::Error;
6use std::fmt::Debug;
7use std::future;
8use std::io::{self, SeekFrom};
9use std::time::{Duration, Instant};
10
11use bytes::{BufMut, Bytes, BytesMut};
12use futures_util::{Future, Stream, StreamExt, TryStream};
13use handle::{
14    DownloadStatus, Downloaded, NotifyRead, PositionReached, RequestedPosition, SourceHandle,
15};
16use tokio::sync::mpsc;
17use tokio::task::yield_now;
18use tokio::time::timeout;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, error, instrument, trace, warn};
21
22use crate::storage::StorageWriter;
23use crate::{ProgressFn, ReconnectFn, Settings, StreamPhase, StreamState};
24
25pub(crate) mod handle;
26
27/// Enum representing the final outcome of the stream.
28#[derive(Clone, Copy, PartialEq, Eq, Debug)]
29pub enum StreamOutcome {
30    /// The stream completed naturally.
31    Completed,
32    /// The stream was cancelled by the user.
33    CancelledByUser,
34}
35
36/// Represents a remote resource that can be streamed over the network. Streaming
37/// over http is implemented via the [`HttpStream`](crate::http::HttpStream)
38/// implementation if the `http` feature is enabled.
39///
40/// The implementation must also implement the
41/// [Stream](https://docs.rs/futures/latest/futures/stream/trait.Stream.html) trait.
42pub trait SourceStream:
43    TryStream<Ok = Bytes>
44    + Stream<Item = Result<Self::Ok, Self::Error>>
45    + Unpin
46    + Send
47    + Sync
48    + Sized
49    + 'static
50{
51    /// Parameters used to create the remote resource.
52    type Params: Send;
53
54    /// Error type thrown when creating the stream.
55    type StreamCreationError: DecodeError + Send;
56
57    /// Creates an instance of the stream.
58    fn create(
59        params: Self::Params,
60    ) -> impl Future<Output = Result<Self, Self::StreamCreationError>> + Send;
61
62    /// Returns the size of the remote resource in bytes. The result should be `None`
63    /// if the stream is infinite or doesn't have a known length.
64    fn content_length(&self) -> Option<u64>;
65
66    /// Seeks to a specific position in the stream. This method is only called if the
67    /// requested range has not been downloaded, so this method should jump to the
68    /// requested position in the stream as quickly as possible.
69    ///
70    /// The start value should be inclusive and the end value should be exclusive.
71    fn seek_range(
72        &mut self,
73        start: u64,
74        end: Option<u64>,
75    ) -> impl Future<Output = io::Result<()>> + Send;
76
77    /// Attempts to reconnect to the server when a failure occurs.
78    fn reconnect(&mut self, current_position: u64) -> impl Future<Output = io::Result<()>> + Send;
79
80    /// Returns whether seeking is supported in the stream.
81    /// If this method returns `false`, [`SourceStream::seek_range`] will never be invoked.
82    fn supports_seek(&self) -> bool;
83
84    /// Called when the stream finishes downloading
85    fn on_finish(
86        &mut self,
87        result: io::Result<()>,
88        #[expect(unused)] outcome: StreamOutcome,
89    ) -> impl Future<Output = io::Result<()>> + Send {
90        future::ready(result)
91    }
92}
93
94/// Trait for decoding extra error information asynchronously.
95pub trait DecodeError: Error + Send + Sized {
96    /// Decodes extra error information.
97    fn decode_error(self) -> impl Future<Output = String> + Send {
98        future::ready(self.to_string())
99    }
100}
101
102impl DecodeError for Infallible {
103    async fn decode_error(self) -> String {
104        // This will never get called since it's infallible
105        String::new()
106    }
107}
108
109#[derive(PartialEq, Eq)]
110enum DownloadAction {
111    Continue,
112    Complete,
113}
114
115pub(crate) struct Source<S: SourceStream, W: StorageWriter> {
116    writer: W,
117    downloaded: Downloaded,
118    download_status: DownloadStatus,
119    requested_position: RequestedPosition,
120    position_reached: PositionReached,
121    notify_read: NotifyRead,
122    content_length: Option<u64>,
123    seek_tx: mpsc::Sender<u64>,
124    seek_rx: mpsc::Receiver<u64>,
125    prefetch_bytes: u64,
126    batch_write_size: usize,
127    retry_timeout: Duration,
128    on_progress: Option<ProgressFn<S>>,
129    on_reconnect: Option<ReconnectFn<S>>,
130    prefetch_complete: bool,
131    prefetch_start_position: u64,
132    remaining_bytes: Option<Bytes>,
133    cancellation_token: CancellationToken,
134}
135
136impl<S, W> Source<S, W>
137where
138    S: SourceStream<Error: Debug>,
139    W: StorageWriter,
140{
141    pub(crate) fn new(
142        writer: W,
143        content_length: Option<u64>,
144        settings: Settings<S>,
145        cancellation_token: CancellationToken,
146    ) -> Self {
147        // buffer size of 1 is fine here because we wait for the position to update after we send
148        // each request
149        let (seek_tx, seek_rx) = mpsc::channel(1);
150        Self {
151            writer,
152            downloaded: Downloaded::default(),
153            download_status: DownloadStatus::default(),
154            requested_position: RequestedPosition::default(),
155            position_reached: PositionReached::default(),
156            notify_read: NotifyRead::default(),
157            seek_tx,
158            seek_rx,
159            content_length,
160            prefetch_complete: settings.prefetch_bytes == 0,
161            prefetch_bytes: settings.prefetch_bytes,
162            batch_write_size: settings.batch_write_size,
163            retry_timeout: settings.retry_timeout,
164            on_progress: settings.on_progress,
165            on_reconnect: settings.on_reconnect,
166            prefetch_start_position: 0,
167            remaining_bytes: None,
168            cancellation_token,
169        }
170    }
171
172    #[instrument(skip_all)]
173    pub(crate) async fn download(&mut self, mut stream: S) {
174        let res = self.download_inner(&mut stream).await;
175        let (res, stream_res) = match res {
176            Ok(StreamOutcome::Completed) => (Ok(()), StreamOutcome::Completed),
177            Ok(StreamOutcome::CancelledByUser) => (
178                Err(io::Error::new(
179                    io::ErrorKind::Interrupted,
180                    "stream cancelled by user",
181                )),
182                StreamOutcome::CancelledByUser,
183            ),
184            Err(e) => (Err(e), StreamOutcome::Completed),
185        };
186        let res = stream.on_finish(res, stream_res).await;
187        if let Err(e) = res {
188            if stream_res == StreamOutcome::Completed {
189                error!("download failed: {e:?}");
190            }
191            self.download_status.set_failed();
192        }
193        self.signal_download_complete();
194    }
195
196    async fn download_inner(&mut self, stream: &mut S) -> io::Result<StreamOutcome> {
197        debug!("starting file download");
198        let download_start = std::time::Instant::now();
199
200        loop {
201            // Some streams may get stuck if the connection has a hiccup while waiting for the next
202            // chunk. Forcing the client to abort and retry may help in these cases.
203            let next_chunk = timeout(self.retry_timeout, stream.next());
204            tokio::select! {
205                position = self.seek_rx.recv() => {
206                    // seek_tx can't be dropped here since we keep a reference in this struct
207                    self.handle_seek(stream, position.expect("seek_tx dropped")).await?;
208                },
209                bytes = next_chunk => {
210                    let Ok(bytes) = bytes else {
211                        self.handle_reconnect(stream).await?;
212                        continue;
213                    };
214                    if self
215                        .handle_bytes(stream, bytes, download_start)
216                        .await?
217                        == DownloadAction::Complete
218                    {
219                        debug!(
220                            download_duration = format!("{:?}", download_start.elapsed()),
221                            "stream finished downloading"
222                        );
223                        break;
224                    }
225                }
226                () = self.cancellation_token.cancelled() => {
227                    debug!("received cancellation request, stopping download task");
228                    return Ok(StreamOutcome::CancelledByUser);
229                }
230            };
231        }
232        self.report_download_complete(stream, download_start)?;
233        Ok(StreamOutcome::Completed)
234    }
235
236    async fn handle_seek(&mut self, stream: &mut S, position: u64) -> io::Result<()> {
237        if self.should_seek(stream, position)? {
238            debug!("seek position not yet downloaded");
239            let current_stream_position = self.writer.stream_position()?;
240            if self.prefetch_complete {
241                debug!("re-starting prefetch");
242                self.prefetch_start_position = position;
243                self.prefetch_complete = false;
244            } else {
245                debug!("seeking during prefetch, ending prefetch early");
246                self.downloaded
247                    .add(self.prefetch_start_position..current_stream_position);
248                self.prefetch_complete = true;
249            }
250            if let Some(content_length) = self.content_length {
251                // Get the minimum possible start position to ensure we capture the entire range
252                let min_start_position = current_stream_position.min(position);
253                debug!(
254                    start = min_start_position,
255                    end = content_length,
256                    "checking for seek range",
257                );
258                if let Some(gap) = self.downloaded.next_gap(min_start_position..content_length) {
259                    // Gap start may be too low if we're seeking forward, so check it against the
260                    // position
261                    let seek_start = gap.start.max(position);
262                    debug!(seek_start, seek_end = gap.end, "requesting seek range");
263                    self.seek(stream, seek_start, Some(gap.end)).await?;
264                }
265            } else {
266                self.seek(stream, position, None).await?;
267            }
268        }
269        Ok(())
270    }
271
272    async fn handle_reconnect(&mut self, stream: &mut S) -> io::Result<()> {
273        warn!("timed out reading next chunk, retrying");
274        let pos = self.writer.stream_position()?;
275        // We already know there's a network issue if we're attempting a reconnect.
276        // A retry policy on the client may cause an exponential backoff to be triggered here, so
277        // we'll cap the reconnect time to prevent additional delays between reconnect attempts.
278        let reconnect_pos = tokio::time::timeout(self.retry_timeout, stream.reconnect(pos)).await;
279        if reconnect_pos
280            .inspect_err(|e| warn!("error attempting to reconnect: {e:?}"))
281            .is_ok()
282            && let Some(on_reconnect) = &mut self.on_reconnect
283        {
284            on_reconnect(stream, &self.cancellation_token);
285        }
286
287        Ok(())
288    }
289
290    async fn handle_prefetch(
291        &mut self,
292        stream: &mut S,
293        bytes: Option<Bytes>,
294        start_position: u64,
295        download_start: Instant,
296    ) -> io::Result<DownloadAction> {
297        let Some(bytes) = bytes else {
298            self.prefetch_complete = true;
299            debug!("file shorter than prefetch length, download finished");
300            self.writer.flush()?;
301            let position = self.writer.stream_position()?;
302            self.downloaded.add(start_position..position);
303
304            return self.finish_or_find_next_gap(stream).await;
305        };
306        let written = self.write_batched(&bytes).await?;
307        self.writer.flush()?;
308
309        let stream_position = self.writer.stream_position()?;
310        let partial_write = written < bytes.len();
311
312        // End prefetch early if we weren't able to write the entire contents
313        if partial_write {
314            debug!(
315                written,
316                bytes_len = bytes.len(),
317                "failed to write all during prefetch"
318            );
319            self.remaining_bytes = Some(bytes.slice(written..));
320        }
321        if (stream_position >= start_position + self.prefetch_bytes) || partial_write {
322            self.downloaded.add(start_position..stream_position);
323            debug!("prefetch complete");
324            self.prefetch_complete = true;
325        }
326
327        self.report_prefetch_progress(stream, stream_position, download_start, written);
328        Ok(DownloadAction::Continue)
329    }
330
331    async fn finish_or_find_next_gap(&mut self, stream: &mut S) -> io::Result<DownloadAction> {
332        if stream.supports_seek()
333            && let Some(content_length) = self.content_length
334        {
335            let gap = self.downloaded.next_gap(0..content_length);
336            if let Some(gap) = gap {
337                debug!(
338                    missing = format!("{gap:?}"),
339                    "downloading missing stream chunk"
340                );
341                self.seek(stream, gap.start, Some(gap.end)).await?;
342                return Ok(DownloadAction::Continue);
343            }
344        }
345        self.writer.flush()?;
346        self.signal_download_complete();
347        Ok(DownloadAction::Complete)
348    }
349
350    async fn write_batched(&mut self, bytes: &[u8]) -> io::Result<usize> {
351        let mut written = 0;
352        loop {
353            let write_size = self.batch_write_size.min(bytes[written..].len());
354            let batch_written = self.writer.write(&bytes[written..written + write_size])?;
355            if batch_written == 0 {
356                return Ok(written);
357            }
358            written += batch_written;
359            // yield between writes to ensure we don't spend too long on writes
360            // without an await point
361            yield_now().await;
362        }
363    }
364
365    async fn handle_bytes(
366        &mut self,
367        stream: &mut S,
368        bytes: Option<Result<Bytes, S::Error>>,
369        download_start: Instant,
370    ) -> io::Result<DownloadAction> {
371        let bytes = match bytes.transpose() {
372            Ok(bytes) => bytes,
373            Err(e) => {
374                error!("Error fetching chunk from stream: {e:?}");
375                return Ok(DownloadAction::Continue);
376            }
377        };
378
379        if !self.prefetch_complete {
380            return self
381                .handle_prefetch(stream, bytes, self.prefetch_start_position, download_start)
382                .await;
383        }
384
385        let bytes = match (self.remaining_bytes.take(), bytes) {
386            (Some(remaining), Some(bytes)) => {
387                let mut combined = BytesMut::new();
388                combined.put(remaining);
389                combined.put(bytes);
390                combined.freeze()
391            }
392            (Some(remaining), None) => remaining,
393            (None, Some(bytes)) => bytes,
394            (None, None) => {
395                return self.finish_or_find_next_gap(stream).await;
396            }
397        };
398        let bytes_len = bytes.len();
399        let new_position = self.write(bytes).await?;
400        self.report_downloading_progress(stream, new_position, download_start, bytes_len)?;
401
402        Ok(DownloadAction::Continue)
403    }
404
405    async fn write(&mut self, bytes: Bytes) -> io::Result<u64> {
406        let mut written = 0;
407        let position = self.writer.stream_position()?;
408        let mut new_position = position;
409        // Keep writing until we process the whole buffer.
410        // If the reader is falling behind, this may take several attempts.
411        while written < bytes.len() {
412            self.notify_read.request();
413            let new_written = self.write_batched(&bytes[written..]).await?;
414            trace!(written, new_written, len = bytes.len(), "wrote data");
415
416            if new_written > 0 {
417                self.writer.flush()?;
418                written += new_written;
419            }
420            new_position = self.writer.stream_position()?;
421            if new_position > position {
422                self.downloaded.add(position..new_position);
423            }
424
425            if let Some(requested) = self.requested_position.get() {
426                debug!(
427                    requested_position = requested,
428                    current_position = new_position,
429                    "received requested position"
430                );
431
432                if new_position >= requested {
433                    debug!("notifying position reached");
434                    self.requested_position.clear();
435                    self.position_reached.notify_position_reached();
436                }
437            }
438            if new_written == 0 {
439                // We're not able to write any data, so we need to wait for space to be available
440                debug!("waiting for next read");
441                self.notify_read.wait_for_read().await;
442                debug!("read finished");
443            }
444
445            trace!(
446                previous_position = position,
447                new_position,
448                chunk_size = bytes.len(),
449                "received response chunk"
450            );
451        }
452        Ok(new_position)
453    }
454
455    fn should_seek(&mut self, stream: &S, position: u64) -> io::Result<bool> {
456        if !stream.supports_seek() {
457            warn!("Attempting to seek, but it's unsupported. Waiting for stream to catch up.");
458            return Ok(false);
459        }
460        Ok(if let Some(range) = self.downloaded.get(position) {
461            !range.contains(&self.writer.stream_position()?)
462        } else {
463            true
464        })
465    }
466
467    async fn seek(&mut self, stream: &mut S, start: u64, end: Option<u64>) -> io::Result<()> {
468        stream.seek_range(start, end).await?;
469        self.writer.seek(SeekFrom::Start(start))?;
470        Ok(())
471    }
472
473    fn signal_download_complete(&self) {
474        self.position_reached.notify_stream_done();
475    }
476
477    fn report_progress(&mut self, stream: &S, info: StreamState) {
478        if let Some(on_progress) = self.on_progress.as_mut() {
479            on_progress(stream, info, &self.cancellation_token);
480        }
481    }
482
483    fn report_prefetch_progress(
484        &mut self,
485        stream: &S,
486        stream_position: u64,
487        download_start: Instant,
488        chunk_size: usize,
489    ) {
490        self.report_progress(
491            stream,
492            StreamState {
493                current_position: stream_position,
494                current_chunk: (0..stream_position),
495                elapsed: download_start.elapsed(),
496                phase: StreamPhase::Prefetching {
497                    target: self.prefetch_bytes,
498                    chunk_size,
499                },
500            },
501        );
502    }
503
504    fn report_downloading_progress(
505        &mut self,
506        stream: &S,
507        new_position: u64,
508        download_start: Instant,
509        chunk_size: usize,
510    ) -> io::Result<()> {
511        let pos = self.writer.stream_position()?;
512        self.report_progress(
513            stream,
514            StreamState {
515                current_position: pos,
516                current_chunk: self
517                    .downloaded
518                    .get(new_position - 1)
519                    .expect("position already downloaded"),
520                elapsed: download_start.elapsed(),
521                phase: StreamPhase::Downloading { chunk_size },
522            },
523        );
524        Ok(())
525    }
526
527    fn report_download_complete(&mut self, stream: &S, download_start: Instant) -> io::Result<()> {
528        let pos = self.writer.stream_position()?;
529        self.report_progress(
530            stream,
531            StreamState {
532                current_position: pos,
533                elapsed: download_start.elapsed(),
534                // ensure no subtraction overflow
535                current_chunk: self.downloaded.get(pos.max(1) - 1).unwrap_or_default(),
536                phase: StreamPhase::Complete,
537            },
538        );
539        Ok(())
540    }
541
542    pub(crate) fn source_handle(&self) -> SourceHandle {
543        SourceHandle {
544            downloaded: self.downloaded.clone(),
545            download_status: self.download_status.clone(),
546            requested_position: self.requested_position.clone(),
547            notify_read: self.notify_read.clone(),
548            position_reached: self.position_reached.clone(),
549            seek_tx: self.seek_tx.clone(),
550            content_length: self.content_length,
551        }
552    }
553}