use crate::{
distributed_sharding::{ShardConfig, ShardStrategy},
error_taxonomy::helpers as error_helpers,
Dataset,
};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::marker::PhantomData;
use std::sync::{Arc, Mutex, RwLock};
use tenflowers_core::{Result, Tensor, TensorError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamingConfig {
pub world_size: usize,
pub rank: usize,
pub partition_strategy: PartitionStrategy,
pub prefetch_buffer_size: usize,
pub shuffle_seed: Option<u64>,
pub checkpoint_interval: Option<usize>,
pub fault_tolerant: bool,
pub replication_factor: usize,
pub dynamic_balancing: bool,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
world_size: 1,
rank: 0,
partition_strategy: PartitionStrategy::HashBased {
num_partitions: 1,
hash_seed: 0,
},
prefetch_buffer_size: 128,
shuffle_seed: None,
checkpoint_interval: Some(1000),
fault_tolerant: false,
replication_factor: 1,
dynamic_balancing: false,
}
}
}
impl StreamingConfig {
pub fn new(world_size: usize, rank: usize) -> Result<Self> {
if world_size == 0 {
return Err(error_helpers::invalid_configuration(
"StreamingConfig::new",
"world_size",
"world_size must be > 0",
));
}
if rank >= world_size {
return Err(error_helpers::invalid_configuration(
"StreamingConfig::new",
"rank",
format!("rank {} must be < world_size {}", rank, world_size),
));
}
Ok(Self {
world_size,
rank,
..Default::default()
})
}
pub fn with_partition_strategy(mut self, strategy: PartitionStrategy) -> Self {
self.partition_strategy = strategy;
self
}
pub fn with_prefetch_buffer_size(mut self, size: usize) -> Self {
self.prefetch_buffer_size = size;
self
}
pub fn with_shuffle_seed(mut self, seed: u64) -> Self {
self.shuffle_seed = Some(seed);
self
}
pub fn with_checkpointing(mut self, interval: usize) -> Self {
self.checkpoint_interval = Some(interval);
self
}
pub fn with_fault_tolerance(mut self, replication_factor: usize) -> Self {
self.fault_tolerant = true;
self.replication_factor = replication_factor;
self
}
pub fn with_dynamic_balancing(mut self, enabled: bool) -> Self {
self.dynamic_balancing = enabled;
self
}
pub fn validate(&self) -> Result<()> {
if self.world_size == 0 {
return Err(error_helpers::invalid_configuration(
"StreamingConfig::validate",
"world_size",
"world_size must be > 0",
));
}
if self.rank >= self.world_size {
return Err(error_helpers::invalid_configuration(
"StreamingConfig::validate",
"rank",
format!(
"rank {} must be < world_size {}",
self.rank, self.world_size
),
));
}
if self.replication_factor > self.world_size {
return Err(error_helpers::invalid_configuration(
"StreamingConfig::validate",
"replication_factor",
format!(
"replication_factor {} cannot exceed world_size {}",
self.replication_factor, self.world_size
),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PartitionStrategy {
RoundRobin,
Contiguous,
HashBased {
num_partitions: usize,
hash_seed: u64,
},
RangeBased { ranges: Vec<(usize, usize)> },
Stratified { num_classes: usize },
Adaptive {
base_strategy: Box<PartitionStrategy>,
rebalance_threshold: f64,
},
Custom { partition_id: String },
}
impl Default for PartitionStrategy {
fn default() -> Self {
Self::RoundRobin
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointState {
pub epoch: usize,
pub position: usize,
pub shuffle_seed: Option<u64>,
pub rank: usize,
pub timestamp: u64,
pub processed_indices: HashSet<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamingStats {
pub samples_loaded: u64,
pub local_samples: u64,
pub remote_samples: u64,
pub prefetch_hits: u64,
pub prefetch_misses: u64,
pub avg_load_time_us: u64,
pub num_checkpoints: u64,
pub worker_utilization: f64,
}
impl Default for StreamingStats {
fn default() -> Self {
Self {
samples_loaded: 0,
local_samples: 0,
remote_samples: 0,
prefetch_hits: 0,
prefetch_misses: 0,
avg_load_time_us: 0,
num_checkpoints: 0,
worker_utilization: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerHealth {
pub rank: usize,
pub status: WorkerStatus,
pub last_heartbeat: u64,
pub samples_processed: u64,
pub average_throughput: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WorkerStatus {
Active,
Idle,
Slow,
Failed,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerMetrics {
pub rank: usize,
pub throughput_samples_per_sec: f64,
pub queue_depth: usize,
pub cpu_utilization: f64,
pub memory_usage_mb: f64,
}