use educe::Educe;
use human_repr::HumanCount as _;
use humantime::format_duration;
use pin_project_lite::pin_project;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tokio::io::ReadBuf;
const UPDATE_FREQUENCY: Duration = Duration::from_millis(5000);
pin_project! {
#[derive(Debug, Clone)]
pub struct WithProgress<Inner> {
#[pin]
inner: Inner,
progress: Progress,
}
}
impl<R: tokio::io::AsyncRead> tokio::io::AsyncRead for WithProgress<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let prev_len = buf.filled().len() as u64;
let this = self.project();
if let Poll::Ready(e) = this.inner.poll_read(cx, buf) {
this.progress.inc(buf.filled().len() as u64 - prev_len);
Poll::Ready(e)
} else {
Poll::Pending
}
}
}
impl<S> WithProgress<S> {
pub fn wrap_sync_read_with_callback(
message: &str,
read: S,
total_items: u64,
callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
) -> WithProgress<S> {
WithProgress {
inner: read,
progress: Progress::new(message)
.with_callback(callback)
.with_total(total_items),
}
}
pub fn bytes(mut self) -> Self {
self.progress.item_type = ItemType::Bytes;
self
}
}
#[derive(Clone, Educe)]
#[educe(Debug)]
pub struct Progress {
completed_items: u64,
total_items: Option<u64>,
last_logged_items: u64,
start: Instant,
last_logged: Instant,
message: String,
item_type: ItemType,
#[educe(Debug(ignore))]
callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
}
#[derive(Debug, Clone, Copy)]
enum ItemType {
Bytes,
Items,
}
impl Progress {
fn new(message: &str) -> Self {
let now = Instant::now();
Self {
completed_items: 0,
last_logged_items: 0,
total_items: None,
start: now,
last_logged: now,
message: message.into(),
item_type: ItemType::Items,
callback: None,
}
}
fn with_callback(mut self, callback: Option<Arc<dyn Fn(String) + Sync + Send>>) -> Self {
self.callback = callback;
self
}
fn with_total(mut self, total: u64) -> Self {
self.total_items = Some(total);
self
}
fn inc(&mut self, value: u64) {
self.completed_items += value;
self.emit_log_if_required();
}
#[cfg(test)]
fn set(&mut self, value: u64) {
self.completed_items = value;
self.emit_log_if_required();
}
fn msg(&self, now: Instant) -> String {
let message = &self.message;
let elapsed_secs = (now - self.start).as_secs_f64();
let elapsed_duration = format_duration(Duration::from_secs(elapsed_secs as u64));
let seconds_since_last_msg = (now - self.last_logged).as_secs_f64().max(0.1);
let at = match self.item_type {
ItemType::Bytes => self.completed_items.human_count_bytes().to_string(),
ItemType::Items => self.completed_items.to_string(),
};
let total = if let Some(total) = self.total_items {
let mut output = String::new();
if total > 0 {
output += " / ";
output += &match self.item_type {
ItemType::Bytes => total.human_count_bytes().to_string(),
ItemType::Items => total.to_string(),
};
output += &format!(", {}%", self.completed_items * 100 / total);
}
output
} else {
String::new()
};
let diff = (self.completed_items - self.last_logged_items) as f64 / seconds_since_last_msg;
let speed = match self.item_type {
ItemType::Bytes => format!("{}/s", diff.human_count_bytes()),
ItemType::Items => format!("{diff:.0} items/s"),
};
format!("{message} {at}{total}, {speed}, elapsed time: {elapsed_duration}")
}
fn emit_log_if_required(&mut self) {
let now = Instant::now();
if (now - self.last_logged) > UPDATE_FREQUENCY {
let msg = self.msg(now);
if let Some(cb) = self.callback.as_ref() {
cb(msg.clone());
}
tracing::info!(
target: "forest::progress",
"{}",
msg
);
self.last_logged = now;
self.last_logged_items = self.completed_items;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_progress_msg_bytes() {
let mut progress = Progress::new("test");
let now = progress.start;
progress.item_type = ItemType::Bytes;
progress.total_items = Some(1024 * 1024 * 1024);
progress.set(1024 * 1024 * 1024);
progress.last_logged_items = 1024 * 1024 * 1024 / 2;
assert_eq!(
progress.msg(now + Duration::from_secs(1)),
"test 1 GiB / 1 GiB, 100%, 512 MiB/s, elapsed time: 1s"
);
progress.set(1024 * 1024 * 1024 / 2);
progress.last_logged_items = 1024 * 1024 * 128;
assert_eq!(
progress.msg(now + Duration::from_secs(125)),
"test 512 MiB / 1 GiB, 50%, 3.1 MiB/s, elapsed time: 2m 5s"
);
progress.set(1024 * 1024 * 1024 / 10);
progress.last_logged_items = 1024 * 1024;
assert_eq!(
progress.msg(now + Duration::from_secs(10)),
"test 102.4 MiB / 1 GiB, 9%, 10.1 MiB/s, elapsed time: 10s"
);
}
#[test]
fn test_progress_msg_items() {
let mut progress = Progress::new("test");
let now = progress.start;
progress.item_type = ItemType::Items;
progress.total_items = Some(1024);
progress.set(1024);
progress.last_logged_items = 1024 / 2;
assert_eq!(
progress.msg(now + Duration::from_secs(1)),
"test 1024 / 1024, 100%, 512 items/s, elapsed time: 1s"
);
progress.set(1024 / 2);
progress.last_logged_items = 1024 / 3;
assert_eq!(
progress.msg(now + Duration::from_secs(125)),
"test 512 / 1024, 50%, 1 items/s, elapsed time: 2m 5s"
);
progress.set(1024 / 10);
progress.last_logged_items = 0;
assert_eq!(
progress.msg(now + Duration::from_secs(10)),
"test 102 / 1024, 9%, 10 items/s, elapsed time: 10s"
);
}
}