use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use more_asserts::{assert_ge, assert_le};
use tokio::sync::Mutex;
use xet_data::progress_tracking::UniqueID;
use super::{ProgressUpdate, TrackingProgressUpdater};
#[derive(Debug)]
struct ItemProgressData {
total_count: u64,
last_completed: u64,
}
#[derive(Debug, Default)]
pub struct ProgressUpdaterVerificationWrapperImpl {
items: HashMap<UniqueID, (Arc<str>, ItemProgressData)>,
total_transfer_bytes: u64,
total_transfer_bytes_completed: u64,
total_bytes: u64,
total_process_bytes_completed: u64,
}
pub struct ProgressUpdaterVerificationWrapper {
inner: Arc<dyn TrackingProgressUpdater>,
tr: Mutex<ProgressUpdaterVerificationWrapperImpl>,
}
impl ProgressUpdaterVerificationWrapper {
pub fn new(inner: Arc<dyn TrackingProgressUpdater>) -> Arc<Self> {
Arc::new(Self {
inner,
tr: Mutex::new(ProgressUpdaterVerificationWrapperImpl::default()),
})
}
pub async fn assert_complete(&self) {
let tr = self.tr.lock().await;
for (tracking_id, (item_name, data)) in tr.items.iter() {
assert_eq!(
data.last_completed, data.total_count,
"Item({}) '{}' is not fully complete: {}/{}",
tracking_id, item_name, data.last_completed, data.total_count
);
}
assert_eq!(tr.total_transfer_bytes_completed, tr.total_transfer_bytes);
}
}
#[async_trait]
impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper {
async fn register_updates(&self, update: ProgressUpdate) {
let mut tr = self.tr.lock().await;
for up in update.item_updates.iter() {
let entry = tr.items.entry(up.tracking_id).or_insert((
up.item_name.clone(),
ItemProgressData {
total_count: 0,
last_completed: 0,
},
));
if entry.1.total_count == 0 {
entry.1.total_count = up.total_bytes;
} else {
assert_ge!(
up.total_bytes,
entry.1.total_count,
"total_count for '{}' decreased; was {}, now {}",
up.item_name,
entry.1.total_count,
up.total_bytes
);
entry.1.total_count = up.total_bytes;
}
assert!(
up.bytes_completed >= entry.1.last_completed,
"Item '{}' completed_count went backwards: old={}, new={}",
up.item_name,
entry.1.last_completed,
up.bytes_completed
);
assert!(
up.bytes_completed <= up.total_bytes,
"Item '{}' completed_count {} exceeds total {}",
up.item_name,
up.bytes_completed,
up.total_bytes
);
let expected_new = entry.1.last_completed + up.bytes_completion_increment;
assert_eq!(
up.bytes_completed, expected_new,
"Item '{}': mismatch: last_completed={} + update_increment={} != completed_count={}",
up.item_name, entry.1.last_completed, up.bytes_completion_increment, up.bytes_completed
);
entry.1.last_completed = up.bytes_completed;
}
assert_le!(
tr.total_transfer_bytes,
update.total_transfer_bytes,
"New total bytes {} a decrease from previous report of total bytes {}",
update.total_transfer_bytes,
tr.total_transfer_bytes
);
tr.total_transfer_bytes += update.total_transfer_bytes_increment;
assert_eq!(
tr.total_transfer_bytes, update.total_transfer_bytes,
"New increment {} put tracked checked transfer bytes {} out of step from reported total bytes {}",
update.total_transfer_bytes_increment, tr.total_transfer_bytes, update.total_transfer_bytes,
);
assert_le!(
tr.total_transfer_bytes_completed,
update.total_transfer_bytes_completed,
"New total bytes completed {} a decrease from previous report of total bytes {}",
update.total_transfer_bytes_completed,
tr.total_transfer_bytes_completed
);
tr.total_transfer_bytes_completed += update.total_transfer_bytes_completion_increment;
assert_eq!(
tr.total_transfer_bytes_completed, update.total_transfer_bytes_completed,
"Total bytes completed {} does not match tracked total bytes {}",
update.total_transfer_bytes_completed, tr.total_transfer_bytes_completed
);
assert_le!(
tr.total_bytes,
update.total_bytes,
"New total bytes {} a decrease from previous report of total bytes {}",
update.total_bytes,
tr.total_bytes
);
tr.total_bytes += update.total_bytes_increment;
assert_eq!(
tr.total_bytes, update.total_bytes,
"New increment {} put checked total processing bytes {} out of step from reported total bytes {}",
update.total_bytes_increment, tr.total_bytes, update.total_bytes,
);
assert_le!(
tr.total_process_bytes_completed,
update.total_bytes_completed,
"New total bytes completed {} a decrease from previous report of total bytes {}",
update.total_bytes_completed,
tr.total_process_bytes_completed
);
tr.total_process_bytes_completed += update.total_bytes_completion_increment;
assert_eq!(
tr.total_process_bytes_completed, update.total_bytes_completed,
"Total bytes completed {} does not match tracked total bytes {}",
update.total_bytes_completed, tr.total_process_bytes_completed
);
self.inner.register_updates(update).await;
}
async fn flush(&self) {
self.inner.flush().await;
}
}
#[cfg(test)]
mod tests {
use super::super::ItemProgressUpdate;
use super::*;
#[derive(Debug, Default)]
struct DummyLogger {
pub all_updates: Mutex<Vec<ItemProgressUpdate>>,
}
#[async_trait]
impl TrackingProgressUpdater for DummyLogger {
async fn register_updates(&self, updates: ProgressUpdate) {
let mut guard = self.all_updates.lock().await;
guard.extend_from_slice(&updates.item_updates);
}
}
#[tokio::test]
async fn test_verification_wrapper() {
let logger = Arc::new(DummyLogger::default());
let wrapper = ProgressUpdaterVerificationWrapper::new(logger.clone());
let file_a = (UniqueID::new(), "fileA");
let file_b = (UniqueID::new(), "fileB");
wrapper
.register_updates(ProgressUpdate {
item_updates: vec![
ItemProgressUpdate {
tracking_id: file_a.0,
item_name: file_a.1.into(),
total_bytes: 100,
bytes_completed: 50,
bytes_completion_increment: 50,
},
ItemProgressUpdate {
tracking_id: file_b.0,
item_name: file_b.1.into(),
total_bytes: 200,
bytes_completed: 100,
bytes_completion_increment: 100,
},
],
total_transfer_bytes: 100,
total_transfer_bytes_increment: 100,
total_transfer_bytes_completed: 50,
total_transfer_bytes_completion_increment: 50,
total_bytes: 200,
total_bytes_increment: 200,
total_bytes_completed: 100,
total_bytes_completion_increment: 100,
..Default::default()
})
.await;
wrapper
.register_updates(ProgressUpdate {
item_updates: vec![
ItemProgressUpdate {
tracking_id: file_a.0,
item_name: file_a.1.into(),
total_bytes: 100,
bytes_completed: 100,
bytes_completion_increment: 50,
},
ItemProgressUpdate {
tracking_id: file_b.0,
item_name: file_b.1.into(),
total_bytes: 200,
bytes_completed: 200,
bytes_completion_increment: 100,
},
],
total_transfer_bytes: 150,
total_transfer_bytes_increment: 50,
total_transfer_bytes_completed: 150,
total_transfer_bytes_completion_increment: 100,
total_bytes: 200,
total_bytes_increment: 0,
total_bytes_completed: 200,
total_bytes_completion_increment: 100,
..Default::default()
})
.await;
wrapper.assert_complete().await;
let final_updates = logger.all_updates.lock().await;
assert_eq!(final_updates.len(), 4);
}
}