use crate::error::{Result, SklearsError};
use futures_core::future::BoxFuture;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NodeId(pub String);
impl NodeId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedMessage {
pub id: String,
pub sender: NodeId,
pub receiver: NodeId,
pub message_type: MessageType,
pub payload: Vec<u8>,
pub timestamp: SystemTime,
pub priority: MessagePriority,
pub retry_count: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MessageType {
DataTransfer,
ParameterSync,
GradientAggregation,
Coordination,
HealthCheck,
FaultRecovery,
LoadBalance,
Custom(String),
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum MessagePriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
pub trait MessagePassing: Send + Sync {
fn send_message(
&self,
target: NodeId,
message: DistributedMessage,
) -> BoxFuture<'_, Result<()>>;
fn receive_message(&self) -> BoxFuture<'_, Result<DistributedMessage>>;
fn broadcast_message(&self, message: DistributedMessage) -> BoxFuture<'_, Result<()>>;
fn send_and_receive(
&self,
target: NodeId,
message: DistributedMessage,
) -> BoxFuture<'_, Result<DistributedMessage>>;
fn has_pending_messages(&self) -> BoxFuture<'_, Result<bool>>;
fn pending_message_count(&self) -> BoxFuture<'_, Result<usize>>;
fn flush_outgoing(&self) -> BoxFuture<'_, Result<()>>;
}
pub trait ClusterNode: MessagePassing + Send + Sync {
fn node_id(&self) -> &NodeId;
fn cluster_nodes(&self) -> BoxFuture<'_, Result<Vec<NodeId>>>;
fn is_coordinator(&self) -> bool;
fn health_status(&self) -> BoxFuture<'_, Result<NodeHealth>>;
fn resources(&self) -> BoxFuture<'_, Result<NodeResources>>;
fn join_cluster(&mut self, coordinator: NodeId) -> BoxFuture<'_, Result<()>>;
fn leave_cluster(&mut self) -> BoxFuture<'_, Result<()>>;
fn handle_node_failure(&mut self, failed_node: NodeId) -> BoxFuture<'_, Result<()>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeHealth {
pub health_score: f64,
pub cpu_usage: f64,
pub memory_usage: f64,
pub network_latency: Duration,
pub last_heartbeat: SystemTime,
pub recent_errors: u32,
pub uptime: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeResources {
pub cpu_cores: u32,
pub total_memory: u64,
pub available_memory: u64,
pub gpu_devices: Vec<GpuDevice>,
pub network_bandwidth: u64,
pub storage_capacity: u64,
pub tags: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuDevice {
pub device_id: u32,
pub name: String,
pub total_memory: u64,
pub available_memory: u64,
pub compute_capability: String,
}
pub trait DistributedEstimator: Send + Sync {
type TrainingData;
type PredictionInput;
type PredictionOutput;
type Parameters: Serialize + for<'de> Deserialize<'de>;
fn fit_distributed<'a>(
&'a mut self,
cluster: &'a dyn DistributedCluster,
training_data: &Self::TrainingData,
) -> BoxFuture<'a, Result<()>>;
fn predict_distributed<'a>(
&'a self,
cluster: &dyn DistributedCluster,
input: &'a Self::PredictionInput,
) -> BoxFuture<'a, Result<Self::PredictionOutput>>;
fn get_parameters(&self) -> Result<Self::Parameters>;
fn set_parameters(&mut self, params: Self::Parameters) -> Result<()>;
fn sync_parameters(&mut self, cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<()>>;
fn training_progress(&self) -> DistributedTrainingProgress;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedTrainingProgress {
pub epoch: u32,
pub total_epochs: u32,
pub training_loss: f64,
pub validation_loss: Option<f64>,
pub samples_processed: u64,
pub start_time: SystemTime,
pub estimated_completion: Option<SystemTime>,
pub active_nodes: Vec<NodeId>,
pub node_statistics: HashMap<NodeId, NodeTrainingStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTrainingStats {
pub samples_processed: u64,
pub processing_rate: f64,
pub current_loss: f64,
pub memory_usage: u64,
pub cpu_utilization: f64,
}
pub trait DistributedCluster: Send + Sync {
fn active_nodes(&self) -> BoxFuture<'_, Result<Vec<NodeId>>>;
fn coordinator(&self) -> &NodeId;
fn configuration(&self) -> &ClusterConfiguration;
fn add_node(&mut self, node: NodeId) -> BoxFuture<'_, Result<()>>;
fn remove_node(&mut self, node: NodeId) -> BoxFuture<'_, Result<()>>;
fn rebalance_load(&mut self) -> BoxFuture<'_, Result<()>>;
fn cluster_health(&self) -> BoxFuture<'_, Result<ClusterHealth>>;
fn create_checkpoint(&self) -> BoxFuture<'_, Result<ClusterCheckpoint>>;
fn restore_checkpoint(&mut self, checkpoint: ClusterCheckpoint) -> BoxFuture<'_, Result<()>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterConfiguration {
pub max_nodes: u32,
pub heartbeat_interval: Duration,
pub failure_timeout: Duration,
pub max_retries: u32,
pub load_balancing: LoadBalancingStrategy,
pub fault_tolerance: FaultToleranceMode,
pub consistency_level: ConsistencyLevel,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum LoadBalancingStrategy {
RoundRobin,
ResourceBased,
LoadBased,
LocalityAware,
Custom(String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum FaultToleranceMode {
None,
BasicRetry,
CheckpointRecovery,
RedundantComputation,
Byzantine,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ConsistencyLevel {
None,
Eventual,
Strong,
Causal,
Sequential,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterHealth {
pub overall_health: f64,
pub healthy_nodes: u32,
pub failed_nodes: u32,
pub average_response_time: Duration,
pub total_throughput: f64,
pub resource_utilization: ClusterResourceUtilization,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterResourceUtilization {
pub cpu_utilization: f64,
pub memory_utilization: f64,
pub network_utilization: f64,
pub storage_utilization: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterCheckpoint {
pub checkpoint_id: String,
pub timestamp: SystemTime,
pub configuration: ClusterConfiguration,
pub node_states: HashMap<NodeId, NodeCheckpoint>,
pub cluster_state: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeCheckpoint {
pub node_id: NodeId,
pub state_data: Vec<u8>,
pub health: NodeHealth,
pub resources: NodeResources,
}
pub trait DistributedDataset: Send + Sync {
type Item;
type PartitionStrategy;
fn size(&self) -> u64;
fn partition_count(&self) -> u32;
fn partition<'a>(
&'a mut self,
cluster: &'a dyn DistributedCluster,
strategy: Self::PartitionStrategy,
) -> BoxFuture<'a, Result<Vec<DistributedPartition<Self::Item>>>>;
fn get_partition(
&self,
partition_id: u32,
) -> BoxFuture<'_, Result<DistributedPartition<Self::Item>>>;
fn repartition<'a>(
&'a mut self,
cluster: &'a dyn DistributedCluster,
new_strategy: Self::PartitionStrategy,
) -> BoxFuture<'a, Result<()>>;
fn collect(&self, cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<Vec<Self::Item>>>;
fn partition_assignment(&self) -> HashMap<NodeId, Vec<u32>>;
}
#[derive(Debug, Clone)]
pub struct DistributedPartition<T> {
pub partition_id: u32,
pub node_id: NodeId,
pub data: Vec<T>,
pub metadata: PartitionMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartitionMetadata {
pub item_count: u64,
pub size_bytes: u64,
pub schema: Option<String>,
pub created_at: SystemTime,
pub modified_at: SystemTime,
pub checksum: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum PartitioningStrategy {
EvenSplit,
HashBased(u32),
RangeBased,
Random,
Stratified,
Custom(String),
}
pub trait ParameterServer: Send + Sync {
type Parameters: Serialize + for<'de> Deserialize<'de>;
fn initialize(&mut self, initial_params: Self::Parameters) -> BoxFuture<'_, Result<()>>;
fn get_parameters(&self) -> BoxFuture<'_, Result<Self::Parameters>>;
fn update_parameters(&mut self, gradients: Vec<Self::Parameters>) -> BoxFuture<'_, Result<()>>;
fn push_parameters(&self, cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<()>>;
fn pull_parameters(&mut self, cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<()>>;
fn aggregate_gradients(
&mut self,
gradients: Vec<Self::Parameters>,
) -> BoxFuture<'_, Result<Self::Parameters>>;
fn apply_optimization(
&mut self,
aggregated_gradients: Self::Parameters,
) -> BoxFuture<'_, Result<()>>;
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum GradientAggregation {
Average,
WeightedAverage,
FederatedAveraging,
ByzantineRobust,
Compressed,
}
pub trait FaultTolerance: Send + Sync {
fn detect_failure(
&self,
cluster: &dyn DistributedCluster,
) -> BoxFuture<'_, Result<Vec<NodeId>>>;
fn recover_from_failure(
&mut self,
cluster: &mut dyn DistributedCluster,
failed_nodes: Vec<NodeId>,
) -> BoxFuture<'_, Result<()>>;
fn create_checkpoint(
&self,
cluster: &dyn DistributedCluster,
) -> BoxFuture<'_, Result<FaultToleranceCheckpoint>>;
fn restore_checkpoint(
&mut self,
cluster: &mut dyn DistributedCluster,
checkpoint: FaultToleranceCheckpoint,
) -> BoxFuture<'_, Result<()>>;
fn replicate_data(
&self,
cluster: &dyn DistributedCluster,
data: Vec<u8>,
) -> BoxFuture<'_, Result<()>>;
fn validate_integrity(
&self,
cluster: &dyn DistributedCluster,
) -> BoxFuture<'_, Result<IntegrityReport>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaultToleranceCheckpoint {
pub id: String,
pub timestamp: SystemTime,
pub training_state: Vec<u8>,
pub model_parameters: Vec<u8>,
pub node_assignments: HashMap<NodeId, Vec<u32>>,
pub replication_map: HashMap<String, Vec<NodeId>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntegrityReport {
pub integrity_score: f64,
pub data_consistency: bool,
pub parameter_sync: bool,
pub replication_health: f64,
pub inconsistencies: Vec<String>,
pub recommendations: Vec<String>,
}
pub struct DefaultDistributedCluster {
configuration: ClusterConfiguration,
coordinator: NodeId,
nodes: Arc<RwLock<HashMap<NodeId, Arc<dyn ClusterNode>>>>,
health_monitor: Arc<RwLock<ClusterHealth>>,
}
impl std::fmt::Debug for DefaultDistributedCluster {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DefaultDistributedCluster")
.field("configuration", &self.configuration)
.field("coordinator", &self.coordinator)
.field("nodes", &"<HashMap<NodeId, Arc<dyn ClusterNode>>>")
.field("health_monitor", &self.health_monitor)
.finish()
}
}
impl DefaultDistributedCluster {
pub fn new(coordinator: NodeId, configuration: ClusterConfiguration) -> Self {
Self {
configuration,
coordinator,
nodes: Arc::new(RwLock::new(HashMap::new())),
health_monitor: Arc::new(RwLock::new(ClusterHealth {
overall_health: 1.0,
healthy_nodes: 0,
failed_nodes: 0,
average_response_time: Duration::from_millis(10),
total_throughput: 0.0,
resource_utilization: ClusterResourceUtilization {
cpu_utilization: 0.0,
memory_utilization: 0.0,
network_utilization: 0.0,
storage_utilization: 0.0,
},
})),
}
}
}
impl DistributedCluster for DefaultDistributedCluster {
fn active_nodes(&self) -> BoxFuture<'_, Result<Vec<NodeId>>> {
Box::pin(async move {
let nodes = self.nodes.read().map_err(|_| {
SklearsError::InvalidOperation("Failed to acquire read lock on nodes".to_string())
})?;
Ok(nodes.keys().cloned().collect())
})
}
fn coordinator(&self) -> &NodeId {
&self.coordinator
}
fn configuration(&self) -> &ClusterConfiguration {
&self.configuration
}
fn add_node(&mut self, _node_id: NodeId) -> BoxFuture<'_, Result<()>> {
Box::pin(async move {
Ok(())
})
}
fn remove_node(&mut self, node_id: NodeId) -> BoxFuture<'_, Result<()>> {
Box::pin(async move {
let mut nodes = self.nodes.write().map_err(|_| {
SklearsError::InvalidOperation("Failed to acquire write lock on nodes".to_string())
})?;
nodes.remove(&node_id);
Ok(())
})
}
fn rebalance_load(&mut self) -> BoxFuture<'_, Result<()>> {
Box::pin(async move {
Ok(())
})
}
fn cluster_health(&self) -> BoxFuture<'_, Result<ClusterHealth>> {
Box::pin(async move {
let health = self.health_monitor.read().map_err(|_| {
SklearsError::InvalidOperation(
"Failed to acquire read lock on health monitor".to_string(),
)
})?;
Ok(health.clone())
})
}
fn create_checkpoint(&self) -> BoxFuture<'_, Result<ClusterCheckpoint>> {
Box::pin(async move {
let checkpoint = ClusterCheckpoint {
checkpoint_id: format!("checkpoint_{}", chrono::Utc::now().timestamp()),
timestamp: SystemTime::now(),
configuration: self.configuration.clone(),
node_states: HashMap::new(), cluster_state: Vec::new(), };
Ok(checkpoint)
})
}
fn restore_checkpoint(&mut self, _checkpoint: ClusterCheckpoint) -> BoxFuture<'_, Result<()>> {
Box::pin(async move {
Ok(())
})
}
}
impl Default for ClusterConfiguration {
fn default() -> Self {
Self {
max_nodes: 64,
heartbeat_interval: Duration::from_secs(30),
failure_timeout: Duration::from_secs(120),
max_retries: 3,
load_balancing: LoadBalancingStrategy::ResourceBased,
fault_tolerance: FaultToleranceMode::CheckpointRecovery,
consistency_level: ConsistencyLevel::Eventual,
}
}
}
#[derive(Debug)]
pub struct DistributedLinearRegression {
parameters: Option<Vec<f64>>,
config: DistributedTrainingConfig,
progress: DistributedTrainingProgress,
}
#[derive(Debug, Clone)]
pub struct DistributedTrainingConfig {
pub learning_rate: f64,
pub epochs: u32,
pub batch_size: u32,
pub aggregation: GradientAggregation,
pub checkpoint_frequency: u32,
}
impl Default for DistributedLinearRegression {
fn default() -> Self {
Self::new()
}
}
impl DistributedLinearRegression {
pub fn new() -> Self {
Self {
parameters: None,
config: DistributedTrainingConfig::default(),
progress: DistributedTrainingProgress {
epoch: 0,
total_epochs: 0,
training_loss: 0.0,
validation_loss: None,
samples_processed: 0,
start_time: SystemTime::now(),
estimated_completion: None,
active_nodes: Vec::new(),
node_statistics: HashMap::new(),
},
}
}
pub fn with_config(mut self, config: DistributedTrainingConfig) -> Self {
self.config = config;
self
}
}
impl Default for DistributedTrainingConfig {
fn default() -> Self {
Self {
learning_rate: 0.01,
epochs: 100,
batch_size: 32,
aggregation: GradientAggregation::Average,
checkpoint_frequency: 10,
}
}
}
impl DistributedEstimator for DistributedLinearRegression {
type TrainingData = (Vec<Vec<f64>>, Vec<f64>); type PredictionInput = Vec<Vec<f64>>;
type PredictionOutput = Vec<f64>;
type Parameters = Vec<f64>;
fn fit_distributed<'a>(
&'a mut self,
_cluster: &'a dyn DistributedCluster,
training_data: &Self::TrainingData,
) -> BoxFuture<'a, Result<()>> {
let training_data = training_data.clone();
Box::pin(async move {
let (x, _y) = &training_data;
if self.parameters.is_none() {
let feature_count = x.first().map(|row| row.len()).unwrap_or(0);
self.parameters = Some(vec![0.0; feature_count + 1]); }
self.progress.total_epochs = self.config.epochs;
self.progress.start_time = SystemTime::now();
self.progress.active_nodes = vec![];
for epoch in 0..self.config.epochs {
self.progress.epoch = epoch;
if let Some(ref mut params) = self.parameters {
for param in params.iter_mut() {
*param += self.config.learning_rate * 0.1; }
}
self.progress.samples_processed += x.len() as u64;
self.progress.training_loss = (epoch as f64 * 0.1).exp().recip();
if epoch % self.config.checkpoint_frequency == 0 {
}
}
Ok(())
})
}
fn predict_distributed<'a>(
&'a self,
_cluster: &dyn DistributedCluster,
input: &'a Self::PredictionInput,
) -> BoxFuture<'a, Result<Self::PredictionOutput>> {
Box::pin(async move {
let Some(ref params) = self.parameters else {
return Err(SklearsError::InvalidOperation(
"Model not trained. Call fit_distributed first.".to_string(),
));
};
let predictions = input
.iter()
.map(|features| {
let mut prediction = *params.last().unwrap_or(&0.0); for (feature, weight) in features.iter().zip(params.iter()) {
prediction += feature * weight;
}
prediction
})
.collect();
Ok(predictions)
})
}
fn get_parameters(&self) -> Result<Self::Parameters> {
self.parameters
.clone()
.ok_or_else(|| SklearsError::InvalidOperation("Model not trained".to_string()))
}
fn set_parameters(&mut self, params: Self::Parameters) -> Result<()> {
self.parameters = Some(params);
Ok(())
}
fn sync_parameters(&mut self, _cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<()>> {
Box::pin(async move {
Ok(())
})
}
fn training_progress(&self) -> DistributedTrainingProgress {
self.progress.clone()
}
}
#[derive(Debug)]
pub struct DistributedNumericalDataset {
data: Vec<Vec<f64>>,
partitions: Vec<DistributedPartition<Vec<f64>>>,
assignment: HashMap<NodeId, Vec<u32>>,
}
impl DistributedNumericalDataset {
pub fn new(data: Vec<Vec<f64>>) -> Self {
Self {
data,
partitions: Vec::new(),
assignment: HashMap::new(),
}
}
fn flush_buckets_into_partitions(&mut self, nodes: &[NodeId], buckets: Vec<Vec<Vec<f64>>>) {
for (node_id, partition_data) in nodes.iter().zip(buckets) {
if partition_data.is_empty() {
continue;
}
let partition_id = self.partitions.len() as u32;
self.partitions.push(build_numerical_partition(
partition_id,
node_id,
partition_data,
));
self.assignment
.entry(node_id.clone())
.or_default()
.push(partition_id);
}
}
}
#[inline]
fn fnv1a_mix(hash: &mut u64, bytes: &[u8]) {
const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
for &byte in bytes {
*hash ^= u64::from(byte);
*hash = hash.wrapping_mul(FNV_PRIME);
}
}
fn numerical_partition_checksum(rows: &[Vec<f64>]) -> String {
const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
let mut hash = FNV_OFFSET_BASIS;
fnv1a_mix(&mut hash, &(rows.len() as u64).to_le_bytes());
for row in rows {
fnv1a_mix(&mut hash, &(row.len() as u64).to_le_bytes());
for &value in row {
let normalized = if value == 0.0 { 0.0_f64 } else { value };
fnv1a_mix(&mut hash, &normalized.to_bits().to_le_bytes());
}
}
format!("{hash:016x}")
}
fn hash_numerical_row(row: &[f64], seed: u32) -> u64 {
const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
let mut hash = FNV_OFFSET_BASIS;
fnv1a_mix(&mut hash, &u64::from(seed).to_le_bytes());
for &value in row {
let normalized = if value == 0.0 { 0.0_f64 } else { value };
fnv1a_mix(&mut hash, &normalized.to_bits().to_le_bytes());
}
hash
}
fn build_numerical_partition(
partition_id: u32,
node_id: &NodeId,
partition_data: Vec<Vec<f64>>,
) -> DistributedPartition<Vec<f64>> {
let item_count = partition_data.len() as u64;
let size_bytes = item_count * std::mem::size_of::<f64>() as u64;
let checksum = numerical_partition_checksum(&partition_data);
let now = SystemTime::now();
DistributedPartition {
partition_id,
node_id: node_id.clone(),
metadata: PartitionMetadata {
item_count,
size_bytes,
schema: Some("numerical_array".to_string()),
created_at: now,
modified_at: now,
checksum,
},
data: partition_data,
}
}
impl DistributedDataset for DistributedNumericalDataset {
type Item = Vec<f64>;
type PartitionStrategy = PartitioningStrategy;
fn size(&self) -> u64 {
self.data.len() as u64
}
fn partition_count(&self) -> u32 {
self.partitions.len() as u32
}
fn partition<'a>(
&'a mut self,
cluster: &'a dyn DistributedCluster,
strategy: Self::PartitionStrategy,
) -> BoxFuture<'a, Result<Vec<DistributedPartition<Self::Item>>>> {
Box::pin(async move {
let nodes = cluster.active_nodes().await?;
let num_nodes = nodes.len();
if num_nodes == 0 {
return Err(SklearsError::InvalidOperation(
"No active nodes in cluster".to_string(),
));
}
self.partitions.clear();
self.assignment.clear();
match strategy {
PartitioningStrategy::EvenSplit => {
let chunk_size = self.data.len().div_ceil(num_nodes);
for (i, node_id) in nodes.iter().enumerate() {
let start = i * chunk_size;
let end = std::cmp::min(start + chunk_size, self.data.len());
if start < self.data.len() {
let partition_data = self.data[start..end].to_vec();
self.partitions.push(build_numerical_partition(
i as u32,
node_id,
partition_data,
));
self.assignment
.entry(node_id.clone())
.or_default()
.push(i as u32);
}
}
}
PartitioningStrategy::RangeBased => {
let mut order: Vec<usize> = (0..self.data.len()).collect();
order.sort_by(|&a, &b| {
let key_a = self.data[a].first().copied().unwrap_or(0.0);
let key_b = self.data[b].first().copied().unwrap_or(0.0);
key_a.total_cmp(&key_b).then_with(|| a.cmp(&b))
});
let chunk_size = order.len().div_ceil(num_nodes);
for (i, node_id) in nodes.iter().enumerate() {
let start = i * chunk_size;
let end = std::cmp::min(start + chunk_size, order.len());
if start < order.len() {
let partition_data: Vec<Vec<f64>> = order[start..end]
.iter()
.map(|&idx| self.data[idx].clone())
.collect();
self.partitions.push(build_numerical_partition(
i as u32,
node_id,
partition_data,
));
self.assignment
.entry(node_id.clone())
.or_default()
.push(i as u32);
}
}
}
PartitioningStrategy::HashBased(seed) => {
let mut buckets: Vec<Vec<Vec<f64>>> = vec![Vec::new(); num_nodes];
for row in &self.data {
let node_index =
(hash_numerical_row(row, seed) % num_nodes as u64) as usize;
buckets[node_index].push(row.clone());
}
self.flush_buckets_into_partitions(&nodes, buckets);
}
PartitioningStrategy::Random => {
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
let mut buckets: Vec<Vec<Vec<f64>>> = vec![Vec::new(); num_nodes];
for row in &self.data {
let node_index = rng.gen_range(0..num_nodes);
buckets[node_index].push(row.clone());
}
self.flush_buckets_into_partitions(&nodes, buckets);
}
PartitioningStrategy::Stratified => {
use std::collections::BTreeMap;
let mut strata: BTreeMap<u64, Vec<usize>> = BTreeMap::new();
for (idx, row) in self.data.iter().enumerate() {
let label = row.last().copied().unwrap_or(0.0);
let normalized = if label == 0.0 { 0.0_f64 } else { label };
strata.entry(normalized.to_bits()).or_default().push(idx);
}
let mut buckets: Vec<Vec<Vec<f64>>> = vec![Vec::new(); num_nodes];
let mut cursor = 0_usize;
for indices in strata.values() {
for &idx in indices {
buckets[cursor % num_nodes].push(self.data[idx].clone());
cursor += 1;
}
}
self.flush_buckets_into_partitions(&nodes, buckets);
}
PartitioningStrategy::Custom(name) => match name.as_str() {
"round_robin" | "roundrobin" | "round-robin" => {
let mut buckets: Vec<Vec<Vec<f64>>> = vec![Vec::new(); num_nodes];
for (idx, row) in self.data.iter().enumerate() {
buckets[idx % num_nodes].push(row.clone());
}
self.flush_buckets_into_partitions(&nodes, buckets);
}
other => {
return Err(SklearsError::InvalidOperation(format!(
"Custom partitioning strategy '{other}' is not registered"
)));
}
},
}
Ok(self.partitions.clone())
})
}
fn get_partition(
&self,
partition_id: u32,
) -> BoxFuture<'_, Result<DistributedPartition<Self::Item>>> {
Box::pin(async move {
self.partitions
.get(partition_id as usize)
.cloned()
.ok_or_else(|| {
SklearsError::InvalidOperation(format!("Partition {} not found", partition_id))
})
})
}
fn repartition<'a>(
&'a mut self,
cluster: &'a dyn DistributedCluster,
new_strategy: Self::PartitionStrategy,
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
let collected_data = self.collect(cluster).await?;
self.data = collected_data;
self.partition(cluster, new_strategy).await?;
Ok(())
})
}
fn collect(&self, _cluster: &dyn DistributedCluster) -> BoxFuture<'_, Result<Vec<Self::Item>>> {
Box::pin(async move {
let mut collected = Vec::new();
for partition in &self.partitions {
collected.extend(partition.data.clone());
}
Ok(collected)
})
}
fn partition_assignment(&self) -> HashMap<NodeId, Vec<u32>> {
self.assignment.clone()
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_id_creation() {
let node_id = NodeId::new("worker-01");
assert_eq!(node_id.as_str(), "worker-01");
assert_eq!(node_id.to_string(), "worker-01");
}
#[test]
fn test_message_priority_ordering() {
assert!(MessagePriority::Critical > MessagePriority::High);
assert!(MessagePriority::High > MessagePriority::Normal);
assert!(MessagePriority::Normal > MessagePriority::Low);
}
#[test]
fn test_cluster_configuration_default() {
let config = ClusterConfiguration::default();
assert_eq!(config.max_nodes, 64);
assert_eq!(config.load_balancing, LoadBalancingStrategy::ResourceBased);
assert_eq!(
config.fault_tolerance,
FaultToleranceMode::CheckpointRecovery
);
}
#[test]
fn test_distributed_linear_regression_creation() {
let model = DistributedLinearRegression::new();
assert!(model.parameters.is_none());
assert_eq!(model.progress.epoch, 0);
}
#[test]
fn test_distributed_dataset_size() {
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let dataset = DistributedNumericalDataset::new(data);
assert_eq!(dataset.size(), 3);
assert_eq!(dataset.partition_count(), 0); }
#[test]
fn test_message_type_serialization() {
let msg_type = MessageType::ParameterSync;
let serialized = serde_json::to_string(&msg_type).unwrap_or_default();
let deserialized: MessageType =
serde_json::from_str(&serialized).expect("valid JSON operation");
assert_eq!(msg_type, deserialized);
}
#[test]
fn test_partitioning_strategy_variants() {
let strategies = vec![
PartitioningStrategy::EvenSplit,
PartitioningStrategy::HashBased(4),
PartitioningStrategy::RangeBased,
PartitioningStrategy::Random,
PartitioningStrategy::Stratified,
PartitioningStrategy::Custom("custom_strategy".to_string()),
];
for strategy in strategies {
let serialized = serde_json::to_string(&strategy).unwrap_or_default();
let _deserialized: PartitioningStrategy =
serde_json::from_str(&serialized).expect("valid JSON operation");
}
}
#[test]
fn test_distributed_training_config() {
let config = DistributedTrainingConfig::default();
assert_eq!(config.learning_rate, 0.01);
assert_eq!(config.epochs, 100);
assert_eq!(config.batch_size, 32);
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_default_cluster_operations() {
let coordinator = NodeId::new("coordinator");
let config = ClusterConfiguration::default();
let cluster = DefaultDistributedCluster::new(coordinator.clone(), config);
assert_eq!(cluster.coordinator(), &coordinator);
let nodes = cluster.active_nodes().await.expect("expected valid value");
assert!(nodes.is_empty());
let health = cluster
.cluster_health()
.await
.expect("expected valid value");
assert_eq!(health.overall_health, 1.0);
}
#[test]
fn test_numerical_partition_checksum_is_deterministic() {
let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let first = numerical_partition_checksum(&data);
let second = numerical_partition_checksum(&data);
assert_eq!(
first, second,
"identical data must produce identical checksums"
);
}
#[test]
fn test_numerical_partition_checksum_changes_with_data() {
let data_a = vec![vec![1.0, 2.0, 3.0]];
let data_b = vec![vec![1.0, 2.0, 4.0]];
assert_ne!(
numerical_partition_checksum(&data_a),
numerical_partition_checksum(&data_b),
"different data must produce different checksums"
);
}
#[test]
fn test_numerical_partition_checksum_respects_row_structure() {
let grouped_a = vec![vec![1.0, 2.0], vec![3.0]];
let grouped_b = vec![vec![1.0], vec![2.0, 3.0]];
assert_ne!(
numerical_partition_checksum(&grouped_a),
numerical_partition_checksum(&grouped_b)
);
}
#[test]
fn test_numerical_partition_checksum_is_not_index_placeholder() {
let checksum = numerical_partition_checksum(&[vec![1.0, 2.0]]);
assert!(!checksum.starts_with("checksum_"));
assert_eq!(checksum.len(), 16);
}
#[test]
fn test_hash_numerical_row_determinism_and_seed_sensitivity() {
let row = vec![1.5, -2.25, 3.0];
assert_eq!(hash_numerical_row(&row, 7), hash_numerical_row(&row, 7));
assert_ne!(hash_numerical_row(&row, 7), hash_numerical_row(&row, 8));
}
#[cfg(feature = "async_support")]
struct TestCluster {
coordinator: NodeId,
configuration: ClusterConfiguration,
nodes: Vec<NodeId>,
}
#[cfg(feature = "async_support")]
impl TestCluster {
fn with_nodes(count: usize) -> Self {
let nodes = (0..count)
.map(|i| NodeId::new(format!("node-{i}")))
.collect();
Self {
coordinator: NodeId::new("coordinator"),
configuration: ClusterConfiguration::default(),
nodes,
}
}
}
#[cfg(feature = "async_support")]
impl DistributedCluster for TestCluster {
fn active_nodes(&self) -> BoxFuture<'_, Result<Vec<NodeId>>> {
let nodes = self.nodes.clone();
Box::pin(async move { Ok(nodes) })
}
fn coordinator(&self) -> &NodeId {
&self.coordinator
}
fn configuration(&self) -> &ClusterConfiguration {
&self.configuration
}
fn add_node(&mut self, node: NodeId) -> BoxFuture<'_, Result<()>> {
self.nodes.push(node);
Box::pin(async move { Ok(()) })
}
fn remove_node(&mut self, node: NodeId) -> BoxFuture<'_, Result<()>> {
self.nodes.retain(|existing| existing != &node);
Box::pin(async move { Ok(()) })
}
fn rebalance_load(&mut self) -> BoxFuture<'_, Result<()>> {
Box::pin(async move { Ok(()) })
}
fn cluster_health(&self) -> BoxFuture<'_, Result<ClusterHealth>> {
Box::pin(async move {
Err(SklearsError::InvalidOperation(
"cluster_health is unused in partition tests".to_string(),
))
})
}
fn create_checkpoint(&self) -> BoxFuture<'_, Result<ClusterCheckpoint>> {
Box::pin(async move {
Err(SklearsError::InvalidOperation(
"create_checkpoint is unused in partition tests".to_string(),
))
})
}
fn restore_checkpoint(
&mut self,
_checkpoint: ClusterCheckpoint,
) -> BoxFuture<'_, Result<()>> {
Box::pin(async move { Ok(()) })
}
}
#[cfg(feature = "async_support")]
fn parse_node_index(node_id: &NodeId) -> usize {
node_id
.as_str()
.strip_prefix("node-")
.and_then(|suffix| suffix.parse().ok())
.expect("test node ids follow the node-<index> convention")
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_partition_without_nodes_errors() {
let mut dataset = DistributedNumericalDataset::new(vec![vec![1.0]]);
let cluster = TestCluster::with_nodes(0);
let result = dataset
.partition(&cluster, PartitioningStrategy::EvenSplit)
.await;
assert!(result.is_err(), "partitioning with no nodes must error");
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_even_split_covers_all_items() {
let data: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64, (i * 2) as f64]).collect();
let mut dataset = DistributedNumericalDataset::new(data.clone());
let cluster = TestCluster::with_nodes(3);
let partitions = dataset
.partition(&cluster, PartitioningStrategy::EvenSplit)
.await
.expect("even split must succeed");
let mut collected: Vec<Vec<f64>> = partitions.iter().flat_map(|p| p.data.clone()).collect();
collected.sort_by(|a, b| a[0].total_cmp(&b[0]));
assert_eq!(collected, data, "union of partitions must cover all items");
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_round_robin_custom_strategy_assignment() {
let data: Vec<Vec<f64>> = (0..9).map(|i| vec![i as f64]).collect();
let mut dataset = DistributedNumericalDataset::new(data);
let cluster = TestCluster::with_nodes(3);
let partitions = dataset
.partition(
&cluster,
PartitioningStrategy::Custom("round_robin".to_string()),
)
.await
.expect("round-robin custom strategy must succeed");
assert_eq!(partitions.len(), 3);
for partition in &partitions {
let node_index = parse_node_index(&partition.node_id);
for row in &partition.data {
assert_eq!((row[0] as usize) % 3, node_index);
}
assert_eq!(partition.data.len(), 3);
}
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_hash_based_partitioning_is_deterministic() {
let data: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64, (i % 4) as f64]).collect();
let cluster = TestCluster::with_nodes(4);
let mut first_dataset = DistributedNumericalDataset::new(data.clone());
let mut second_dataset = DistributedNumericalDataset::new(data.clone());
let first = first_dataset
.partition(&cluster, PartitioningStrategy::HashBased(13))
.await
.expect("hash partitioning must succeed");
let second = second_dataset
.partition(&cluster, PartitioningStrategy::HashBased(13))
.await
.expect("hash partitioning must succeed");
let summarize = |partitions: &[DistributedPartition<Vec<f64>>]| {
partitions
.iter()
.map(|p| (p.partition_id, p.data.clone()))
.collect::<Vec<_>>()
};
assert_eq!(
summarize(&first),
summarize(&second),
"hash partitioning must be deterministic for identical data and seed"
);
for partition in &first {
let node_index = parse_node_index(&partition.node_id);
for row in &partition.data {
assert_eq!((hash_numerical_row(row, 13) % 4) as usize, node_index);
}
}
let mut all: Vec<Vec<f64>> = first.iter().flat_map(|p| p.data.clone()).collect();
all.sort_by(|a, b| a[0].total_cmp(&b[0]));
assert_eq!(all, data);
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_range_based_partitioning_orders_by_value() {
let data: Vec<Vec<f64>> = vec![
vec![9.0],
vec![1.0],
vec![7.0],
vec![3.0],
vec![5.0],
vec![2.0],
];
let mut dataset = DistributedNumericalDataset::new(data.clone());
let cluster = TestCluster::with_nodes(3);
let partitions = dataset
.partition(&cluster, PartitioningStrategy::RangeBased)
.await
.expect("range partitioning must succeed");
let ordered_keys: Vec<f64> = partitions
.iter()
.flat_map(|p| p.data.iter().map(|row| row[0]))
.collect();
let mut expected: Vec<f64> = data.iter().map(|row| row[0]).collect();
expected.sort_by(|a, b| a.total_cmp(b));
assert_eq!(
ordered_keys, expected,
"range partitions must be value-ordered"
);
assert_eq!(partitions.len(), 3);
for partition in &partitions {
assert_eq!(partition.data.len(), 2);
}
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_stratified_partitioning_balances_classes() {
let mut data: Vec<Vec<f64>> = Vec::new();
for class in 0..3 {
for k in 0..6 {
data.push(vec![k as f64, class as f64]);
}
}
let mut dataset = DistributedNumericalDataset::new(data.clone());
let cluster = TestCluster::with_nodes(3);
let partitions = dataset
.partition(&cluster, PartitioningStrategy::Stratified)
.await
.expect("stratified partitioning must succeed");
let total: usize = partitions.iter().map(|p| p.data.len()).sum();
assert_eq!(total, data.len(), "stratified must not drop items");
for partition in &partitions {
let mut class_counts: HashMap<i64, usize> = HashMap::new();
for row in &partition.data {
*class_counts.entry(row[1] as i64).or_insert(0) += 1;
}
assert_eq!(
class_counts.len(),
3,
"each partition must hold all classes"
);
for count in class_counts.values() {
assert_eq!(*count, 2, "each class must be evenly represented");
}
}
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_random_partitioning_covers_all_items() {
let data: Vec<Vec<f64>> = (0..30).map(|i| vec![i as f64]).collect();
let mut dataset = DistributedNumericalDataset::new(data);
let cluster = TestCluster::with_nodes(4);
let partitions = dataset
.partition(&cluster, PartitioningStrategy::Random)
.await
.expect("random partitioning must succeed");
let mut all: Vec<f64> = partitions
.iter()
.flat_map(|p| p.data.iter().map(|row| row[0]))
.collect();
all.sort_by(|a, b| a.total_cmp(b));
let expected: Vec<f64> = (0..30).map(|i| i as f64).collect();
assert_eq!(
all, expected,
"random partitioning must not lose or duplicate items"
);
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_unknown_custom_strategy_errors() {
let mut dataset = DistributedNumericalDataset::new(vec![vec![1.0], vec![2.0]]);
let cluster = TestCluster::with_nodes(2);
let result = dataset
.partition(
&cluster,
PartitioningStrategy::Custom("totally_unknown".to_string()),
)
.await;
assert!(
result.is_err(),
"unknown custom strategy must return an honest error"
);
}
#[cfg(feature = "async_support")]
#[tokio::test]
async fn test_partition_checksums_reflect_content() {
let cluster = TestCluster::with_nodes(1);
let mut dataset_a = DistributedNumericalDataset::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let mut dataset_b = DistributedNumericalDataset::new(vec![vec![1.0, 2.0], vec![3.0, 5.0]]);
let mut dataset_a_again =
DistributedNumericalDataset::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let partitions_a = dataset_a
.partition(&cluster, PartitioningStrategy::EvenSplit)
.await
.expect("partition a");
let partitions_b = dataset_b
.partition(&cluster, PartitioningStrategy::EvenSplit)
.await
.expect("partition b");
let partitions_a_again = dataset_a_again
.partition(&cluster, PartitioningStrategy::EvenSplit)
.await
.expect("partition a again");
assert_eq!(partitions_a.len(), 1);
assert_eq!(partitions_b.len(), 1);
assert_ne!(
partitions_a[0].metadata.checksum, partitions_b[0].metadata.checksum,
"different data must yield different checksums"
);
assert_eq!(
partitions_a[0].metadata.checksum, partitions_a_again[0].metadata.checksum,
"identical data must yield identical checksums"
);
assert!(!partitions_a[0].metadata.checksum.starts_with("checksum_"));
}
}