hf-xet 1.5.2

Client library and tooling for the Hugging Face Xet data storage system.
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use more_asserts::{assert_ge, assert_le};
use tokio::sync::Mutex;
use xet_data::progress_tracking::UniqueID;

use super::{ProgressUpdate, TrackingProgressUpdater};

/// Internal structure to track and validate progress data for one item.
#[derive(Debug)]
struct ItemProgressData {
    total_count: u64,
    last_completed: u64,
}

#[derive(Debug, Default)]
pub struct ProgressUpdaterVerificationWrapperImpl {
    items: HashMap<UniqueID, (Arc<str>, ItemProgressData)>,
    total_transfer_bytes: u64,
    total_transfer_bytes_completed: u64,
    total_bytes: u64,
    total_process_bytes_completed: u64,
}

/// A wrapper that forwards updates to an inner `TrackingProgressUpdater`
/// while also validating each update for correctness:
///
/// - `completed_count` must be non-decreasing and never exceed `total_count`.
/// - `completed_count` must match `last_completed + update_increment`.
/// - `total_count` must remain consistent (if it changes across updates for the same item, that's an error).
/// - Final verification (`assert_complete()`) ensures all items reached `completed_count == total_count`.
pub struct ProgressUpdaterVerificationWrapper {
    inner: Arc<dyn TrackingProgressUpdater>,
    tr: Mutex<ProgressUpdaterVerificationWrapperImpl>,
}

impl ProgressUpdaterVerificationWrapper {
    /// Creates a new verification wrapper around an existing `TrackingProgressUpdater`.
    /// All updates are validated and then forwarded to `inner`.
    pub fn new(inner: Arc<dyn TrackingProgressUpdater>) -> Arc<Self> {
        Arc::new(Self {
            inner,
            tr: Mutex::new(ProgressUpdaterVerificationWrapperImpl::default()),
        })
    }

    /// Once all uploads are done, call this to ensure that every item is fully complete.
    /// Panics if any item is still incomplete (i.e. `last_completed < total_count`).
    pub async fn assert_complete(&self) {
        let tr = self.tr.lock().await;

        for (tracking_id, (item_name, data)) in tr.items.iter() {
            assert_eq!(
                data.last_completed, data.total_count,
                "Item({}) '{}' is not fully complete: {}/{}",
                tracking_id, item_name, data.last_completed, data.total_count
            );
        }

        assert_eq!(tr.total_transfer_bytes_completed, tr.total_transfer_bytes);
    }
}

#[async_trait]
impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper {
    async fn register_updates(&self, update: ProgressUpdate) {
        let mut tr = self.tr.lock().await;

        for up in update.item_updates.iter() {
            let entry = tr.items.entry(up.tracking_id).or_insert((
                up.item_name.clone(),
                ItemProgressData {
                    total_count: 0,
                    last_completed: 0,
                },
            ));

            if entry.1.total_count == 0 {
                entry.1.total_count = up.total_bytes;
            } else {
                assert_ge!(
                    up.total_bytes,
                    entry.1.total_count,
                    "total_count for '{}' decreased; was {}, now {}",
                    up.item_name,
                    entry.1.total_count,
                    up.total_bytes
                );
                entry.1.total_count = up.total_bytes;
            }

            assert!(
                up.bytes_completed >= entry.1.last_completed,
                "Item '{}' completed_count went backwards: old={}, new={}",
                up.item_name,
                entry.1.last_completed,
                up.bytes_completed
            );

            assert!(
                up.bytes_completed <= up.total_bytes,
                "Item '{}' completed_count {} exceeds total {}",
                up.item_name,
                up.bytes_completed,
                up.total_bytes
            );

            let expected_new = entry.1.last_completed + up.bytes_completion_increment;
            assert_eq!(
                up.bytes_completed, expected_new,
                "Item '{}': mismatch: last_completed={} + update_increment={} != completed_count={}",
                up.item_name, entry.1.last_completed, up.bytes_completion_increment, up.bytes_completed
            );

            entry.1.last_completed = up.bytes_completed;
        }

        assert_le!(
            tr.total_transfer_bytes,
            update.total_transfer_bytes,
            "New total bytes {} a decrease from previous report of total bytes {}",
            update.total_transfer_bytes,
            tr.total_transfer_bytes
        );

        tr.total_transfer_bytes += update.total_transfer_bytes_increment;

        assert_eq!(
            tr.total_transfer_bytes, update.total_transfer_bytes,
            "New increment {} put tracked checked transfer bytes {} out of step from reported total bytes {}",
            update.total_transfer_bytes_increment, tr.total_transfer_bytes, update.total_transfer_bytes,
        );

        assert_le!(
            tr.total_transfer_bytes_completed,
            update.total_transfer_bytes_completed,
            "New total bytes completed {} a decrease from previous report of total bytes {}",
            update.total_transfer_bytes_completed,
            tr.total_transfer_bytes_completed
        );

        tr.total_transfer_bytes_completed += update.total_transfer_bytes_completion_increment;

        assert_eq!(
            tr.total_transfer_bytes_completed, update.total_transfer_bytes_completed,
            "Total bytes completed {} does not match tracked total bytes {}",
            update.total_transfer_bytes_completed, tr.total_transfer_bytes_completed
        );

        assert_le!(
            tr.total_bytes,
            update.total_bytes,
            "New total bytes {} a decrease from previous report of total bytes {}",
            update.total_bytes,
            tr.total_bytes
        );

        tr.total_bytes += update.total_bytes_increment;

        assert_eq!(
            tr.total_bytes, update.total_bytes,
            "New increment {} put checked total processing bytes {} out of step from reported total bytes {}",
            update.total_bytes_increment, tr.total_bytes, update.total_bytes,
        );

        assert_le!(
            tr.total_process_bytes_completed,
            update.total_bytes_completed,
            "New total bytes completed {} a decrease from previous report of total bytes {}",
            update.total_bytes_completed,
            tr.total_process_bytes_completed
        );

        tr.total_process_bytes_completed += update.total_bytes_completion_increment;

        assert_eq!(
            tr.total_process_bytes_completed, update.total_bytes_completed,
            "Total bytes completed {} does not match tracked total bytes {}",
            update.total_bytes_completed, tr.total_process_bytes_completed
        );

        self.inner.register_updates(update).await;
    }
    async fn flush(&self) {
        self.inner.flush().await;
    }
}

#[cfg(test)]
mod tests {
    use super::super::ItemProgressUpdate;
    use super::*;

    #[derive(Debug, Default)]
    struct DummyLogger {
        pub all_updates: Mutex<Vec<ItemProgressUpdate>>,
    }

    #[async_trait]
    impl TrackingProgressUpdater for DummyLogger {
        async fn register_updates(&self, updates: ProgressUpdate) {
            let mut guard = self.all_updates.lock().await;
            guard.extend_from_slice(&updates.item_updates);
        }
    }

    #[tokio::test]
    async fn test_verification_wrapper() {
        let logger = Arc::new(DummyLogger::default());
        let wrapper = ProgressUpdaterVerificationWrapper::new(logger.clone());

        let file_a = (UniqueID::new(), "fileA");
        let file_b = (UniqueID::new(), "fileB");

        wrapper
            .register_updates(ProgressUpdate {
                item_updates: vec![
                    ItemProgressUpdate {
                        tracking_id: file_a.0,
                        item_name: file_a.1.into(),
                        total_bytes: 100,
                        bytes_completed: 50,
                        bytes_completion_increment: 50,
                    },
                    ItemProgressUpdate {
                        tracking_id: file_b.0,
                        item_name: file_b.1.into(),
                        total_bytes: 200,
                        bytes_completed: 100,
                        bytes_completion_increment: 100,
                    },
                ],
                total_transfer_bytes: 100,
                total_transfer_bytes_increment: 100,
                total_transfer_bytes_completed: 50,
                total_transfer_bytes_completion_increment: 50,
                total_bytes: 200,
                total_bytes_increment: 200,
                total_bytes_completed: 100,
                total_bytes_completion_increment: 100,
                ..Default::default()
            })
            .await;

        wrapper
            .register_updates(ProgressUpdate {
                item_updates: vec![
                    ItemProgressUpdate {
                        tracking_id: file_a.0,
                        item_name: file_a.1.into(),
                        total_bytes: 100,
                        bytes_completed: 100,
                        bytes_completion_increment: 50,
                    },
                    ItemProgressUpdate {
                        tracking_id: file_b.0,
                        item_name: file_b.1.into(),
                        total_bytes: 200,
                        bytes_completed: 200,
                        bytes_completion_increment: 100,
                    },
                ],
                total_transfer_bytes: 150,
                total_transfer_bytes_increment: 50,
                total_transfer_bytes_completed: 150,
                total_transfer_bytes_completion_increment: 100,
                total_bytes: 200,
                total_bytes_increment: 0,
                total_bytes_completed: 200,
                total_bytes_completion_increment: 100,
                ..Default::default()
            })
            .await;

        wrapper.assert_complete().await;

        let final_updates = logger.all_updates.lock().await;
        assert_eq!(final_updates.len(), 4);
    }
}