tenflowers-dataset 0.1.1

Data pipeline and dataset utilities for TenfloweRS
Documentation
//! Tests for distributed streaming module

#[cfg(test)]
mod tests {
    use crate::distributed_streaming::{
        CheckpointState, PartitionStrategy, StreamCoordinator, StreamingConfig,
        StreamingShardIterator, StreamingShardLoader, StreamingStats,
    };
    use crate::TensorDataset;
    use std::collections::HashSet;
    use std::sync::Arc;
    use tenflowers_core::Tensor;

    #[test]
    fn test_streaming_config_creation() {
        let config = StreamingConfig::new(4, 0).expect("config creation should succeed");
        assert_eq!(config.world_size, 4);
        assert_eq!(config.rank, 0);
    }

    #[test]
    fn test_streaming_config_validation() {
        assert!(StreamingConfig::new(0, 0).is_err());
        assert!(StreamingConfig::new(4, 4).is_err());
        assert!(StreamingConfig::new(4, 3).is_ok());
    }

    #[test]
    fn test_round_robin_partitioning() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
            .expect("tensor creation should succeed");
        let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
            .expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(4, 0)
            .expect("config creation should succeed")
            .with_partition_strategy(PartitionStrategy::RoundRobin);

        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        assert_eq!(loader.len(), 25);
    }

    #[test]
    fn test_contiguous_partitioning() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
            .expect("tensor creation should succeed");
        let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
            .expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(4, 1)
            .expect("config creation should succeed")
            .with_partition_strategy(PartitionStrategy::Contiguous);

        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        assert_eq!(loader.len(), 25);
    }

    #[test]
    fn test_hash_based_partitioning() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
            .expect("tensor creation should succeed");
        let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
            .expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(4, 0)
            .expect("config creation should succeed")
            .with_partition_strategy(PartitionStrategy::HashBased {
                num_partitions: 4,
                hash_seed: 42,
            });

        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        assert!(loader.len() > 0);
        assert!(loader.len() <= 100);
    }

    #[test]
    fn test_deterministic_shuffling() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 50], &[50, 1])
            .expect("tensor creation should succeed");
        let labels =
            Tensor::<f32>::from_vec(vec![1.0; 50], &[50]).expect("tensor creation should succeed");
        let dataset1 = TensorDataset::new(features.clone(), labels.clone());
        let dataset2 = TensorDataset::new(features, labels);

        let config1 = StreamingConfig::new(2, 0)
            .expect("config creation should succeed")
            .with_shuffle_seed(123);
        let config2 = StreamingConfig::new(2, 0)
            .expect("config creation should succeed")
            .with_shuffle_seed(123);

        let loader1 =
            StreamingShardLoader::new(dataset1, config1).expect("loader creation should succeed");
        let loader2 =
            StreamingShardLoader::new(dataset2, config2).expect("loader creation should succeed");

        assert_eq!(loader1.assigned_indices, loader2.assigned_indices);
    }

    #[test]
    fn test_streaming_next() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
            .expect("tensor creation should succeed");
        let labels =
            Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(2, 0).expect("config creation should succeed");
        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        let sample1 = loader.next().expect("next should succeed");
        assert!(sample1.is_some());

        let sample2 = loader.next().expect("next should succeed");
        assert!(sample2.is_some());
    }

    #[test]
    fn test_streaming_prefetch() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 20], &[20, 1])
            .expect("tensor creation should succeed");
        let labels =
            Tensor::<f32>::from_vec(vec![1.0; 20], &[20]).expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(2, 0)
            .expect("config creation should succeed")
            .with_prefetch_buffer_size(5);
        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        loader.prefetch(5).expect("prefetch should succeed");

        let sample = loader.next().expect("next should succeed");
        assert!(sample.is_some());

        let stats = loader.get_stats().expect("get_stats should succeed");
        assert!(stats.prefetch_hits > 0);
    }

    #[test]
    fn test_checkpoint_creation() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 20], &[20, 1])
            .expect("tensor creation should succeed");
        let labels =
            Tensor::<f32>::from_vec(vec![1.0; 20], &[20]).expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(2, 0)
            .expect("config creation should succeed")
            .with_checkpointing(5);
        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        for _ in 0..6 {
            let _ = loader.next();
        }

        let stats = loader.get_stats().expect("get_stats should succeed");
        assert!(stats.num_checkpoints > 0);
    }

    #[test]
    fn test_checkpoint_restore() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 20], &[20, 1])
            .expect("tensor creation should succeed");
        let labels =
            Tensor::<f32>::from_vec(vec![1.0; 20], &[20]).expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(2, 0).expect("config creation should succeed");
        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        for _ in 0..3 {
            let _ = loader.next();
        }

        let checkpoint = loader
            .get_checkpoint()
            .expect("get_checkpoint should succeed");

        for _ in 0..3 {
            let _ = loader.next();
        }

        loader
            .restore_from_checkpoint(checkpoint)
            .expect("restore should succeed");

        let restored_checkpoint = loader
            .get_checkpoint()
            .expect("get_checkpoint should succeed");
        assert_eq!(restored_checkpoint.position, 3);
    }

    #[test]
    fn test_stream_reset() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 20], &[20, 1])
            .expect("tensor creation should succeed");
        let labels =
            Tensor::<f32>::from_vec(vec![1.0; 20], &[20]).expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(2, 0).expect("config creation should succeed");
        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        for _ in 0..5 {
            let _ = loader.next();
        }

        loader.reset().expect("reset should succeed");

        let sample = loader.next().expect("next should succeed");
        assert!(sample.is_some());
    }

    #[test]
    fn test_stream_coordinator_creation() {
        let config = StreamingConfig::new(4, 0).expect("config creation should succeed");
        let coordinator = StreamCoordinator::new(config);
        assert!(coordinator.is_ok());
    }

    #[test]
    fn test_worker_registration() {
        let config = StreamingConfig::new(4, 0).expect("config creation should succeed");
        let coordinator =
            StreamCoordinator::new(config).expect("coordinator creation should succeed");

        let indices = vec![0, 1, 2, 3, 4];
        coordinator
            .register_worker(0, indices)
            .expect("worker registration should succeed");

        let health = coordinator
            .get_worker_health(0)
            .expect("get_worker_health should succeed");
        assert!(health.is_some());
        assert_eq!(health.expect("health should exist").rank, 0);
    }

    #[test]
    fn test_worker_health_update() {
        let config = StreamingConfig::new(4, 0).expect("config creation should succeed");
        let coordinator =
            StreamCoordinator::new(config).expect("coordinator creation should succeed");

        coordinator
            .register_worker(0, vec![])
            .expect("worker registration should succeed");

        coordinator
            .update_worker_health(0, 100, 50.0)
            .expect("health update should succeed");

        let health = coordinator
            .get_worker_health(0)
            .expect("get_worker_health should succeed")
            .expect("health should exist");

        assert_eq!(health.samples_processed, 100);
        assert!((health.average_throughput - 50.0).abs() < 1e-6);
    }

    #[test]
    fn test_coordinator_checkpoint_management() {
        let config = StreamingConfig::new(4, 0).expect("config creation should succeed");
        let coordinator =
            StreamCoordinator::new(config).expect("coordinator creation should succeed");

        let checkpoint = CheckpointState {
            epoch: 1,
            position: 100,
            shuffle_seed: Some(42),
            rank: 0,
            timestamp: 12345,
            processed_indices: HashSet::new(),
        };

        coordinator
            .register_checkpoint(0, checkpoint.clone())
            .expect("checkpoint registration should succeed");

        let retrieved = coordinator
            .get_checkpoint(0)
            .expect("get_checkpoint should succeed")
            .expect("checkpoint should exist");

        assert_eq!(retrieved.epoch, 1);
        assert_eq!(retrieved.position, 100);
    }

    #[test]
    fn test_iterator_adapter() {
        let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
            .expect("tensor creation should succeed");
        let labels =
            Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(2, 0).expect("config creation should succeed");
        let loader = Arc::new(
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed"),
        );

        let iter = StreamingShardIterator::new(loader);

        let mut count = 0;
        for result in iter {
            assert!(result.is_ok());
            count += 1;
        }

        assert!(count > 0);
    }

    #[test]
    fn test_empty_dataset_streaming() {
        let features =
            Tensor::<f32>::from_vec(vec![], &[0, 1]).expect("empty tensor creation should succeed");
        let labels =
            Tensor::<f32>::from_vec(vec![], &[0]).expect("empty tensor creation should succeed");
        let dataset = TensorDataset::new(features, labels);

        let config = StreamingConfig::new(2, 0).expect("config creation should succeed");
        let loader =
            StreamingShardLoader::new(dataset, config).expect("loader creation should succeed");

        assert_eq!(loader.len(), 0);
        assert!(loader.is_empty());

        let sample = loader.next().expect("next should succeed");
        assert!(sample.is_none());
    }

    #[test]
    fn test_partition_strategy_default() {
        let strategy = PartitionStrategy::default();
        assert!(matches!(strategy, PartitionStrategy::RoundRobin));
    }

    #[test]
    fn test_streaming_stats_default() {
        let stats = StreamingStats::default();
        assert_eq!(stats.samples_loaded, 0);
        assert_eq!(stats.prefetch_hits, 0);
        assert_eq!(stats.prefetch_misses, 0);
    }
}