1use std::convert::Infallible;
5use std::error::Error;
6use std::fmt::Debug;
7use std::future;
8use std::io::{self, SeekFrom};
9use std::time::{Duration, Instant};
10
11use bytes::{BufMut, Bytes, BytesMut};
12use futures_util::{Future, Stream, StreamExt, TryStream};
13use handle::{
14 DownloadStatus, Downloaded, NotifyRead, PositionReached, RequestedPosition, SourceHandle,
15};
16use tokio::sync::mpsc;
17use tokio::task::yield_now;
18use tokio::time::timeout;
19use tokio_util::sync::CancellationToken;
20use tracing::{debug, error, instrument, trace, warn};
21
22use crate::storage::StorageWriter;
23use crate::{ProgressFn, ReconnectFn, Settings, StreamPhase, StreamState};
24
25pub(crate) mod handle;
26
27#[derive(Clone, Copy, PartialEq, Eq, Debug)]
29pub enum StreamOutcome {
30 Completed,
32 CancelledByUser,
34}
35
36pub trait SourceStream:
43 TryStream<Ok = Bytes>
44 + Stream<Item = Result<Self::Ok, Self::Error>>
45 + Unpin
46 + Send
47 + Sync
48 + Sized
49 + 'static
50{
51 type Params: Send;
53
54 type StreamCreationError: DecodeError + Send;
56
57 fn create(
59 params: Self::Params,
60 ) -> impl Future<Output = Result<Self, Self::StreamCreationError>> + Send;
61
62 fn content_length(&self) -> Option<u64>;
65
66 fn seek_range(
72 &mut self,
73 start: u64,
74 end: Option<u64>,
75 ) -> impl Future<Output = io::Result<()>> + Send;
76
77 fn reconnect(&mut self, current_position: u64) -> impl Future<Output = io::Result<()>> + Send;
79
80 fn supports_seek(&self) -> bool;
83
84 fn on_finish(
86 &mut self,
87 result: io::Result<()>,
88 #[expect(unused)] outcome: StreamOutcome,
89 ) -> impl Future<Output = io::Result<()>> + Send {
90 future::ready(result)
91 }
92}
93
94pub trait DecodeError: Error + Send + Sized {
96 fn decode_error(self) -> impl Future<Output = String> + Send {
98 future::ready(self.to_string())
99 }
100}
101
102impl DecodeError for Infallible {
103 async fn decode_error(self) -> String {
104 String::new()
106 }
107}
108
109#[derive(PartialEq, Eq)]
110enum DownloadAction {
111 Continue,
112 Complete,
113}
114
115pub(crate) struct Source<S: SourceStream, W: StorageWriter> {
116 writer: W,
117 downloaded: Downloaded,
118 download_status: DownloadStatus,
119 requested_position: RequestedPosition,
120 position_reached: PositionReached,
121 notify_read: NotifyRead,
122 content_length: Option<u64>,
123 seek_tx: mpsc::Sender<u64>,
124 seek_rx: mpsc::Receiver<u64>,
125 prefetch_bytes: u64,
126 batch_write_size: usize,
127 retry_timeout: Duration,
128 on_progress: Option<ProgressFn<S>>,
129 on_reconnect: Option<ReconnectFn<S>>,
130 prefetch_complete: bool,
131 prefetch_start_position: u64,
132 remaining_bytes: Option<Bytes>,
133 cancellation_token: CancellationToken,
134}
135
136impl<S, W> Source<S, W>
137where
138 S: SourceStream<Error: Debug>,
139 W: StorageWriter,
140{
141 pub(crate) fn new(
142 writer: W,
143 content_length: Option<u64>,
144 settings: Settings<S>,
145 cancellation_token: CancellationToken,
146 ) -> Self {
147 let (seek_tx, seek_rx) = mpsc::channel(1);
150 Self {
151 writer,
152 downloaded: Downloaded::default(),
153 download_status: DownloadStatus::default(),
154 requested_position: RequestedPosition::default(),
155 position_reached: PositionReached::default(),
156 notify_read: NotifyRead::default(),
157 seek_tx,
158 seek_rx,
159 content_length,
160 prefetch_complete: settings.prefetch_bytes == 0,
161 prefetch_bytes: settings.prefetch_bytes,
162 batch_write_size: settings.batch_write_size,
163 retry_timeout: settings.retry_timeout,
164 on_progress: settings.on_progress,
165 on_reconnect: settings.on_reconnect,
166 prefetch_start_position: 0,
167 remaining_bytes: None,
168 cancellation_token,
169 }
170 }
171
172 #[instrument(skip_all)]
173 pub(crate) async fn download(&mut self, mut stream: S) {
174 let res = self.download_inner(&mut stream).await;
175 let (res, stream_res) = match res {
176 Ok(StreamOutcome::Completed) => (Ok(()), StreamOutcome::Completed),
177 Ok(StreamOutcome::CancelledByUser) => (
178 Err(io::Error::new(
179 io::ErrorKind::Interrupted,
180 "stream cancelled by user",
181 )),
182 StreamOutcome::CancelledByUser,
183 ),
184 Err(e) => (Err(e), StreamOutcome::Completed),
185 };
186 let res = stream.on_finish(res, stream_res).await;
187 if let Err(e) = res {
188 if stream_res == StreamOutcome::Completed {
189 error!("download failed: {e:?}");
190 }
191 self.download_status.set_failed();
192 }
193 self.signal_download_complete();
194 }
195
196 async fn download_inner(&mut self, stream: &mut S) -> io::Result<StreamOutcome> {
197 debug!("starting file download");
198 let download_start = std::time::Instant::now();
199
200 loop {
201 let next_chunk = timeout(self.retry_timeout, stream.next());
204 tokio::select! {
205 position = self.seek_rx.recv() => {
206 self.handle_seek(stream, position.expect("seek_tx dropped")).await?;
208 },
209 bytes = next_chunk => {
210 let Ok(bytes) = bytes else {
211 self.handle_reconnect(stream).await?;
212 continue;
213 };
214 if self
215 .handle_bytes(stream, bytes, download_start)
216 .await?
217 == DownloadAction::Complete
218 {
219 debug!(
220 download_duration = format!("{:?}", download_start.elapsed()),
221 "stream finished downloading"
222 );
223 break;
224 }
225 }
226 () = self.cancellation_token.cancelled() => {
227 debug!("received cancellation request, stopping download task");
228 return Ok(StreamOutcome::CancelledByUser);
229 }
230 };
231 }
232 self.report_download_complete(stream, download_start)?;
233 Ok(StreamOutcome::Completed)
234 }
235
236 async fn handle_seek(&mut self, stream: &mut S, position: u64) -> io::Result<()> {
237 if self.should_seek(stream, position)? {
238 debug!("seek position not yet downloaded");
239 let current_stream_position = self.writer.stream_position()?;
240 if self.prefetch_complete {
241 debug!("re-starting prefetch");
242 self.prefetch_start_position = position;
243 self.prefetch_complete = false;
244 } else {
245 debug!("seeking during prefetch, ending prefetch early");
246 self.downloaded
247 .add(self.prefetch_start_position..current_stream_position);
248 self.prefetch_complete = true;
249 }
250 if let Some(content_length) = self.content_length {
251 let min_start_position = current_stream_position.min(position);
253 debug!(
254 start = min_start_position,
255 end = content_length,
256 "checking for seek range",
257 );
258 if let Some(gap) = self.downloaded.next_gap(min_start_position..content_length) {
259 let seek_start = gap.start.max(position);
262 debug!(seek_start, seek_end = gap.end, "requesting seek range");
263 self.seek(stream, seek_start, Some(gap.end)).await?;
264 }
265 } else {
266 self.seek(stream, position, None).await?;
267 }
268 }
269 Ok(())
270 }
271
272 async fn handle_reconnect(&mut self, stream: &mut S) -> io::Result<()> {
273 warn!("timed out reading next chunk, retrying");
274 let pos = self.writer.stream_position()?;
275 let reconnect_pos = tokio::time::timeout(self.retry_timeout, stream.reconnect(pos)).await;
279 if reconnect_pos
280 .inspect_err(|e| warn!("error attempting to reconnect: {e:?}"))
281 .is_ok()
282 && let Some(on_reconnect) = &mut self.on_reconnect
283 {
284 on_reconnect(stream, &self.cancellation_token);
285 }
286
287 Ok(())
288 }
289
290 async fn handle_prefetch(
291 &mut self,
292 stream: &mut S,
293 bytes: Option<Bytes>,
294 start_position: u64,
295 download_start: Instant,
296 ) -> io::Result<DownloadAction> {
297 let Some(bytes) = bytes else {
298 self.prefetch_complete = true;
299 debug!("file shorter than prefetch length, download finished");
300 self.writer.flush()?;
301 let position = self.writer.stream_position()?;
302 self.downloaded.add(start_position..position);
303
304 return self.finish_or_find_next_gap(stream).await;
305 };
306 let written = self.write_batched(&bytes).await?;
307 self.writer.flush()?;
308
309 let stream_position = self.writer.stream_position()?;
310 let partial_write = written < bytes.len();
311
312 if partial_write {
314 debug!(
315 written,
316 bytes_len = bytes.len(),
317 "failed to write all during prefetch"
318 );
319 self.remaining_bytes = Some(bytes.slice(written..));
320 }
321 if (stream_position >= start_position + self.prefetch_bytes) || partial_write {
322 self.downloaded.add(start_position..stream_position);
323 debug!("prefetch complete");
324 self.prefetch_complete = true;
325 }
326
327 self.report_prefetch_progress(stream, stream_position, download_start, written);
328 Ok(DownloadAction::Continue)
329 }
330
331 async fn finish_or_find_next_gap(&mut self, stream: &mut S) -> io::Result<DownloadAction> {
332 if stream.supports_seek()
333 && let Some(content_length) = self.content_length
334 {
335 let gap = self.downloaded.next_gap(0..content_length);
336 if let Some(gap) = gap {
337 debug!(
338 missing = format!("{gap:?}"),
339 "downloading missing stream chunk"
340 );
341 self.seek(stream, gap.start, Some(gap.end)).await?;
342 return Ok(DownloadAction::Continue);
343 }
344 }
345 self.writer.flush()?;
346 self.signal_download_complete();
347 Ok(DownloadAction::Complete)
348 }
349
350 async fn write_batched(&mut self, bytes: &[u8]) -> io::Result<usize> {
351 let mut written = 0;
352 loop {
353 let write_size = self.batch_write_size.min(bytes[written..].len());
354 let batch_written = self.writer.write(&bytes[written..written + write_size])?;
355 if batch_written == 0 {
356 return Ok(written);
357 }
358 written += batch_written;
359 yield_now().await;
362 }
363 }
364
365 async fn handle_bytes(
366 &mut self,
367 stream: &mut S,
368 bytes: Option<Result<Bytes, S::Error>>,
369 download_start: Instant,
370 ) -> io::Result<DownloadAction> {
371 let bytes = match bytes.transpose() {
372 Ok(bytes) => bytes,
373 Err(e) => {
374 error!("Error fetching chunk from stream: {e:?}");
375 return Ok(DownloadAction::Continue);
376 }
377 };
378
379 if !self.prefetch_complete {
380 return self
381 .handle_prefetch(stream, bytes, self.prefetch_start_position, download_start)
382 .await;
383 }
384
385 let bytes = match (self.remaining_bytes.take(), bytes) {
386 (Some(remaining), Some(bytes)) => {
387 let mut combined = BytesMut::new();
388 combined.put(remaining);
389 combined.put(bytes);
390 combined.freeze()
391 }
392 (Some(remaining), None) => remaining,
393 (None, Some(bytes)) => bytes,
394 (None, None) => {
395 return self.finish_or_find_next_gap(stream).await;
396 }
397 };
398 let bytes_len = bytes.len();
399 let new_position = self.write(bytes).await?;
400 self.report_downloading_progress(stream, new_position, download_start, bytes_len)?;
401
402 Ok(DownloadAction::Continue)
403 }
404
405 async fn write(&mut self, bytes: Bytes) -> io::Result<u64> {
406 let mut written = 0;
407 let position = self.writer.stream_position()?;
408 let mut new_position = position;
409 while written < bytes.len() {
412 self.notify_read.request();
413 let new_written = self.write_batched(&bytes[written..]).await?;
414 trace!(written, new_written, len = bytes.len(), "wrote data");
415
416 if new_written > 0 {
417 self.writer.flush()?;
418 written += new_written;
419 }
420 new_position = self.writer.stream_position()?;
421 if new_position > position {
422 self.downloaded.add(position..new_position);
423 }
424
425 if let Some(requested) = self.requested_position.get() {
426 debug!(
427 requested_position = requested,
428 current_position = new_position,
429 "received requested position"
430 );
431
432 if new_position >= requested {
433 debug!("notifying position reached");
434 self.requested_position.clear();
435 self.position_reached.notify_position_reached();
436 }
437 }
438 if new_written == 0 {
439 debug!("waiting for next read");
441 self.notify_read.wait_for_read().await;
442 debug!("read finished");
443 }
444
445 trace!(
446 previous_position = position,
447 new_position,
448 chunk_size = bytes.len(),
449 "received response chunk"
450 );
451 }
452 Ok(new_position)
453 }
454
455 fn should_seek(&mut self, stream: &S, position: u64) -> io::Result<bool> {
456 if !stream.supports_seek() {
457 warn!("Attempting to seek, but it's unsupported. Waiting for stream to catch up.");
458 return Ok(false);
459 }
460 Ok(if let Some(range) = self.downloaded.get(position) {
461 !range.contains(&self.writer.stream_position()?)
462 } else {
463 true
464 })
465 }
466
467 async fn seek(&mut self, stream: &mut S, start: u64, end: Option<u64>) -> io::Result<()> {
468 stream.seek_range(start, end).await?;
469 self.writer.seek(SeekFrom::Start(start))?;
470 Ok(())
471 }
472
473 fn signal_download_complete(&self) {
474 self.position_reached.notify_stream_done();
475 }
476
477 fn report_progress(&mut self, stream: &S, info: StreamState) {
478 if let Some(on_progress) = self.on_progress.as_mut() {
479 on_progress(stream, info, &self.cancellation_token);
480 }
481 }
482
483 fn report_prefetch_progress(
484 &mut self,
485 stream: &S,
486 stream_position: u64,
487 download_start: Instant,
488 chunk_size: usize,
489 ) {
490 self.report_progress(
491 stream,
492 StreamState {
493 current_position: stream_position,
494 current_chunk: (0..stream_position),
495 elapsed: download_start.elapsed(),
496 phase: StreamPhase::Prefetching {
497 target: self.prefetch_bytes,
498 chunk_size,
499 },
500 },
501 );
502 }
503
504 fn report_downloading_progress(
505 &mut self,
506 stream: &S,
507 new_position: u64,
508 download_start: Instant,
509 chunk_size: usize,
510 ) -> io::Result<()> {
511 let pos = self.writer.stream_position()?;
512 self.report_progress(
513 stream,
514 StreamState {
515 current_position: pos,
516 current_chunk: self
517 .downloaded
518 .get(new_position - 1)
519 .expect("position already downloaded"),
520 elapsed: download_start.elapsed(),
521 phase: StreamPhase::Downloading { chunk_size },
522 },
523 );
524 Ok(())
525 }
526
527 fn report_download_complete(&mut self, stream: &S, download_start: Instant) -> io::Result<()> {
528 let pos = self.writer.stream_position()?;
529 self.report_progress(
530 stream,
531 StreamState {
532 current_position: pos,
533 elapsed: download_start.elapsed(),
534 current_chunk: self.downloaded.get(pos.max(1) - 1).unwrap_or_default(),
536 phase: StreamPhase::Complete,
537 },
538 );
539 Ok(())
540 }
541
542 pub(crate) fn source_handle(&self) -> SourceHandle {
543 SourceHandle {
544 downloaded: self.downloaded.clone(),
545 download_status: self.download_status.clone(),
546 requested_position: self.requested_position.clone(),
547 notify_read: self.notify_read.clone(),
548 position_reached: self.position_reached.clone(),
549 seek_tx: self.seek_tx.clone(),
550 content_length: self.content_length,
551 }
552 }
553}