use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ComponentId(String);
impl ComponentId {
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_inner(self) -> String {
self.0
}
}
impl<S: Into<String>> From<S> for ComponentId {
fn from(s: S) -> Self {
ComponentId(s.into())
}
}
impl std::fmt::Display for ComponentId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ProgressUnit {
Bytes,
Items,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum ProgressEvent {
Started {
id: ComponentId,
total: Option<u64>,
unit: ProgressUnit,
},
Progress {
id: ComponentId,
fetched: u64,
total: Option<u64>,
},
Skipped {
id: ComponentId,
},
Done {
id: ComponentId,
transferred: u64,
},
Message(String),
}
pub trait ProgressReporter: Send + Sync {
fn report(&self, event: ProgressEvent);
}
#[derive(Debug, Default)]
pub struct NullReporter;
impl ProgressReporter for NullReporter {
#[inline]
fn report(&self, _event: ProgressEvent) {}
}
pub type SharedReporter = Arc<dyn ProgressReporter>;
pub struct ProgressRead<R> {
inner: R,
tx: tokio::sync::watch::Sender<u64>,
}
impl<R: std::fmt::Debug> std::fmt::Debug for ProgressRead<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProgressRead")
.field("inner", &self.inner)
.field("bytes_read", &*self.tx.borrow())
.finish_non_exhaustive()
}
}
impl<R> ProgressRead<R> {
pub fn new(
inner: R,
reporter: SharedReporter,
id: ComponentId,
total: Option<u64>,
) -> (Self, impl Future<Output = ()>) {
let (tx, mut rx) = tokio::sync::watch::channel(0u64);
let driver = async move {
while rx.changed().await.is_ok() {
let fetched = *rx.borrow_and_update();
reporter.report(ProgressEvent::Progress {
id: id.clone(),
fetched,
total,
});
}
};
(Self { inner, tx }, driver)
}
}
impl<R: AsyncRead + Unpin> AsyncRead for ProgressRead<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let before = buf.filled().len();
let result = Pin::new(&mut self.inner).poll_read(cx, buf);
if let Poll::Ready(Ok(())) = &result {
let n = (buf.filled().len() - before) as u64;
if n > 0 {
self.tx.send_modify(|v| *v += n);
}
}
result
}
}
use std::future::Future;
#[cfg(any(test, feature = "test"))]
pub mod test_support {
use std::sync::Mutex;
use super::{ProgressEvent, ProgressReporter};
pub struct RecordingReporter {
events: Mutex<Vec<ProgressEvent>>,
}
impl std::fmt::Debug for RecordingReporter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecordingReporter")
.field("events", &self.events.lock().unwrap().len())
.finish()
}
}
impl Default for RecordingReporter {
fn default() -> Self {
Self {
events: Mutex::new(Vec::new()),
}
}
}
impl RecordingReporter {
pub fn new() -> Self {
Self::default()
}
pub fn events(&self) -> Vec<ProgressEvent> {
self.events.lock().unwrap().clone()
}
}
impl ProgressReporter for RecordingReporter {
fn report(&self, event: ProgressEvent) {
self.events.lock().unwrap().push(event);
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use super::test_support::RecordingReporter;
use super::*;
#[test]
fn test_null_reporter_does_not_panic() {
let reporter = NullReporter;
reporter.report(ProgressEvent::Started {
id: "layer1".into(),
total: Some(1024),
unit: ProgressUnit::Bytes,
});
reporter.report(ProgressEvent::Progress {
id: "layer1".into(),
fetched: 512,
total: Some(1024),
});
reporter.report(ProgressEvent::Skipped {
id: "layer2".into(),
});
reporter.report(ProgressEvent::Done {
id: "layer1".into(),
transferred: 1024,
});
reporter.report(ProgressEvent::Message("done".to_string()));
}
#[test]
fn test_component_id_conversions() {
let cases = [
"sha256:abc123",
"objects:my-stream",
"",
"docker://quay.io/foo:latest",
];
for input in cases {
let from_str: ComponentId = input.into();
let from_string: ComponentId = input.to_string().into();
assert_eq!(
from_str.as_str(),
input,
"ComponentId::from(&str) should store value"
);
assert_eq!(
from_string.as_str(),
input,
"ComponentId::from(String) should store value"
);
assert_eq!(from_str.to_string(), input, "Display should round-trip");
assert_eq!(from_str, from_string, "both constructors should be equal");
}
}
#[test]
fn test_component_id_hash_map_key() {
let mut map: HashMap<ComponentId, u32> = HashMap::new();
let id: ComponentId = "layer1".into();
map.insert(id.clone(), 42);
assert_eq!(
map.get(&ComponentId::from("layer1")),
Some(&42),
"lookup by equal ComponentId should succeed"
);
assert_eq!(
map.get(&ComponentId::from("layer2")),
None,
"lookup by different ComponentId should return None"
);
let removed = map.remove(&id);
assert_eq!(removed, Some(42));
assert!(map.is_empty());
}
#[test]
fn test_progress_event_debug_all_variants() {
let events = [
ProgressEvent::Started {
id: "x".into(),
total: Some(100),
unit: ProgressUnit::Bytes,
},
ProgressEvent::Started {
id: "y".into(),
total: None,
unit: ProgressUnit::Items,
},
ProgressEvent::Progress {
id: "x".into(),
fetched: 50,
total: Some(100),
},
ProgressEvent::Skipped { id: "z".into() },
ProgressEvent::Done {
id: "x".into(),
transferred: 100,
},
ProgressEvent::Message("status update".into()),
];
for event in &events {
let debug = format!("{event:?}");
assert!(!debug.is_empty(), "Debug output must not be empty");
}
}
#[test]
fn test_progress_event_clone() {
let event = ProgressEvent::Started {
id: "layer".into(),
total: Some(1000),
unit: ProgressUnit::Bytes,
};
let cloned = event.clone();
assert_eq!(
format!("{event:?}"),
format!("{cloned:?}"),
"Clone should produce an identical value"
);
}
#[test]
fn test_recording_reporter_captures_events_in_order() {
let reporter = RecordingReporter::new();
reporter.report(ProgressEvent::Message("hello".into()));
reporter.report(ProgressEvent::Started {
id: "c1".into(),
total: Some(100),
unit: ProgressUnit::Bytes,
});
reporter.report(ProgressEvent::Done {
id: "c1".into(),
transferred: 100,
});
let events = reporter.events();
assert_eq!(events.len(), 3, "all three events should be recorded");
assert!(
matches!(&events[0], ProgressEvent::Message(m) if m == "hello"),
"first event should be Message"
);
assert!(
matches!(&events[1], ProgressEvent::Started { id, .. } if id.as_str() == "c1"),
"second event should be Started for c1"
);
assert!(
matches!(&events[2], ProgressEvent::Done { id, .. } if id.as_str() == "c1"),
"third event should be Done for c1"
);
}
#[test]
fn test_shared_reporter_is_send_sync() {
let inner = Arc::new(RecordingReporter::new());
let handles: Vec<_> = (0..4u32)
.map(|i| {
let r = Arc::clone(&inner);
std::thread::spawn(move || {
r.report(ProgressEvent::Message(format!("thread {i}")));
})
})
.collect();
for handle in handles {
handle.join().expect("thread should not panic");
}
assert_eq!(
inner.events().len(),
4,
"all four threads should have recorded their event"
);
}
#[test]
fn test_progress_unit_variants() {
let bytes = ProgressUnit::Bytes;
let items = ProgressUnit::Items;
assert_ne!(bytes, items);
assert!(!format!("{bytes:?}").is_empty());
assert!(!format!("{items:?}").is_empty());
}
async fn run_progress_read(
data: Vec<u8>,
id: ComponentId,
total: Option<u64>,
) -> Vec<ProgressEvent> {
use tokio::io::AsyncReadExt;
let reporter = Arc::new(test_support::RecordingReporter::new());
let cursor = tokio::io::BufReader::new(std::io::Cursor::new(data));
let (mut reader, driver) =
ProgressRead::new(cursor, Arc::clone(&reporter) as SharedReporter, id, total);
let driver_handle = tokio::spawn(driver);
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await.unwrap();
drop(reader);
driver_handle.await.unwrap();
reporter.events()
}
#[tokio::test]
async fn test_progress_read_emits_events() {
let id: ComponentId = "test-layer".into();
let total: u64 = 1024;
let data = vec![0u8; total as usize];
let events = run_progress_read(data, id.clone(), Some(total)).await;
let progress_events: Vec<_> = events
.iter()
.filter(|e| matches!(e, ProgressEvent::Progress { .. }))
.collect();
assert!(
!progress_events.is_empty(),
"expected at least one Progress event"
);
for event in &progress_events {
if let ProgressEvent::Progress {
id: eid,
total: etot,
..
} = event
{
assert_eq!(eid, &id);
assert_eq!(*etot, Some(total));
}
}
if let Some(ProgressEvent::Progress { fetched, .. }) = progress_events.last() {
assert_eq!(
*fetched, total,
"last Progress event should have fetched == total"
);
}
}
#[tokio::test]
async fn test_progress_read_empty_source_no_events() {
let events = run_progress_read(vec![], "empty".into(), Some(0)).await;
assert!(
events.is_empty(),
"no events should be emitted for an empty source"
);
}
#[tokio::test]
async fn test_progress_read_single_byte_one_event() {
let events = run_progress_read(vec![42u8], "single".into(), Some(1)).await;
let progress_count = events
.iter()
.filter(|e| matches!(e, ProgressEvent::Progress { .. }))
.count();
assert_eq!(
progress_count, 1,
"single byte should produce exactly one Progress event"
);
}
}