use std::sync::Arc;
use std::time::{Duration, Instant};
pub const SAMPLE_INTERVAL: Duration = Duration::from_millis(125);
use reqwest::Url;
use tokio::sync::mpsc;
pub use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Phase {
Evaluating,
ResolvingConflicts,
Downloading,
Assembling,
Flushing,
Verifying,
}
#[derive(Debug, Clone)]
pub enum ProgressEvent {
PhaseChanged(Phase),
FilenameResolved(String),
Progress { downloaded: u64, total: Option<u64> },
Speed { bytes_per_second: f64 },
PartAdded {
ulid: String,
offset: u64,
size: u64,
},
PartProgress {
ulid: String,
downloaded: u64,
total: u64,
},
PartFinished { ulid: String },
PartSpeed { ulid: String, bytes_per_second: f64 },
PartRetrying { ulid: String, attempt: u32 },
Message(String),
Completed {
path: std::path::PathBuf,
already_complete: bool,
},
Cancelled,
Failed { message: String },
}
pub trait ProgressReporter: Send + Sync + 'static {
fn on_event(&self, event: ProgressEvent);
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoopReporter;
impl ProgressReporter for NoopReporter {
fn on_event(&self, _event: ProgressEvent) {}
}
pub struct ChannelReporter {
tx: mpsc::UnboundedSender<ProgressEvent>,
}
impl ProgressReporter for ChannelReporter {
fn on_event(&self, event: ProgressEvent) {
let _ = self.tx.send(event);
}
}
pub fn channel_reporter() -> (Arc<ChannelReporter>, mpsc::UnboundedReceiver<ProgressEvent>) {
let (tx, rx) = mpsc::unbounded_channel();
(Arc::new(ChannelReporter { tx }), rx)
}
pub struct AsyncReporter {
tx: mpsc::UnboundedSender<ProgressEvent>,
_worker: tokio::task::JoinHandle<()>,
}
impl AsyncReporter {
pub fn spawn<R: ProgressReporter>(inner: R) -> Arc<Self> {
let (tx, mut rx) = mpsc::unbounded_channel::<ProgressEvent>();
let worker = tokio::spawn(async move {
while let Some(ev) = rx.recv().await {
inner.on_event(ev);
}
});
Arc::new(Self {
tx,
_worker: worker,
})
}
}
impl ProgressReporter for AsyncReporter {
fn on_event(&self, event: ProgressEvent) {
let _ = self.tx.send(event);
}
}
#[derive(Clone)]
pub struct DownloadContext {
pub reporter: Arc<dyn ProgressReporter>,
pub cancel: CancellationToken,
pub url: Option<Url>,
}
impl DownloadContext {
pub fn new() -> Self {
Self {
reporter: Arc::new(NoopReporter),
cancel: CancellationToken::new(),
url: None,
}
}
pub fn with_reporter(mut self, reporter: Arc<dyn ProgressReporter>) -> Self {
self.reporter = reporter;
self
}
pub fn with_cancel(mut self, cancel: CancellationToken) -> Self {
self.cancel = cancel;
self
}
pub fn with_url(mut self, url: Url) -> Self {
self.url = Some(url);
self
}
pub fn emit(&self, event: ProgressEvent) {
self.reporter.on_event(event);
}
pub fn is_cancelled(&self) -> bool {
self.cancel.is_cancelled()
}
}
impl Default for DownloadContext {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for DownloadContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DownloadContext")
.field("cancel", &self.cancel)
.field("url", &self.url)
.finish_non_exhaustive()
}
}
pub(crate) fn trim_speed_window(
window: &mut std::collections::VecDeque<(Instant, u64)>,
now: Instant,
window_len: Duration,
) {
while window.len() > 1 {
let Some(&(t, _)) = window.front() else {
break;
};
if now.saturating_duration_since(t) > window_len {
window.pop_front();
} else {
break;
}
}
}
pub(crate) fn speed_window_rate(
window: &std::collections::VecDeque<(Instant, u64)>,
) -> Option<f64> {
if window.len() < 2 {
return None;
}
let (t0, b0) = *window.front()?;
let (t1, b1) = *window.back()?;
let dt = t1.saturating_duration_since(t0).as_secs_f64();
if dt <= 0.0 {
return None;
}
Some(b1.saturating_sub(b0) as f64 / dt)
}
pub(crate) struct ProgressTracker {
started_at: Instant,
downloaded: std::sync::atomic::AtomicU64,
total: std::sync::atomic::AtomicU64, }
impl ProgressTracker {
pub fn new(total: Option<u64>) -> Self {
Self {
started_at: Instant::now(),
downloaded: std::sync::atomic::AtomicU64::new(0),
total: std::sync::atomic::AtomicU64::new(total.unwrap_or(0)),
}
}
pub fn advance(&self, delta: u64) -> u64 {
let prev = self
.downloaded
.fetch_add(delta, std::sync::atomic::Ordering::Relaxed);
prev + delta
}
pub fn downloaded(&self) -> u64 {
self.downloaded.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn total(&self) -> Option<u64> {
let t = self.total.load(std::sync::atomic::Ordering::Relaxed);
if t == 0 { None } else { Some(t) }
}
#[allow(dead_code)]
pub fn set_total(&self, total: Option<u64>) {
self.total
.store(total.unwrap_or(0), std::sync::atomic::Ordering::Relaxed);
}
pub fn elapsed(&self) -> std::time::Duration {
self.started_at.elapsed()
}
pub fn eta(&self) -> std::time::Duration {
let Some(total) = self.total() else {
return std::time::Duration::ZERO;
};
let downloaded = self.downloaded();
if downloaded == 0 || downloaded >= total {
return std::time::Duration::ZERO;
}
let elapsed = self.elapsed().as_secs_f64();
if elapsed <= 0.0 {
return std::time::Duration::ZERO;
}
let rate = downloaded as f64 / elapsed;
if rate <= 0.0 {
return std::time::Duration::ZERO;
}
let remaining = (total - downloaded) as f64;
std::time::Duration::try_from_secs_f64(remaining / rate)
.unwrap_or(std::time::Duration::ZERO)
}
}