1use anyhow::{anyhow, Result};
22use scirs2_core::ndarray_ext::Array1;
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::path::PathBuf;
26use std::sync::Arc;
27use std::time::{Duration, Instant};
28use tokio::sync::RwLock;
29use tracing::{debug, info, warn};
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct DistributedMLConfig {
34 pub num_workers: usize,
36 pub training_mode: TrainingMode,
38 pub aggregation_strategy: AggregationStrategy,
40 pub batch_size_per_worker: usize,
42 pub learning_rate: f64,
44 pub max_epochs: usize,
46 pub checkpoint_interval: usize,
48 pub checkpoint_dir: PathBuf,
50 pub enable_fault_tolerance: bool,
52 pub health_check_interval: Duration,
54 pub max_gradient_staleness: usize,
56}
57
58impl Default for DistributedMLConfig {
59 fn default() -> Self {
60 Self {
61 num_workers: 4,
62 training_mode: TrainingMode::DataParallel,
63 aggregation_strategy: AggregationStrategy::AllReduce,
64 batch_size_per_worker: 32,
65 learning_rate: 0.001,
66 max_epochs: 100,
67 checkpoint_interval: 10,
68 checkpoint_dir: PathBuf::from("/tmp/oxirs_ml_checkpoints"),
69 enable_fault_tolerance: true,
70 health_check_interval: Duration::from_secs(30),
71 max_gradient_staleness: 10,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
78pub enum TrainingMode {
79 DataParallel,
81 ModelParallel,
83 Hybrid,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89pub enum AggregationStrategy {
90 AllReduce,
92 ParameterServerSync,
94 ParameterServerAsync,
96 FederatedAveraging,
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
102pub enum WorkerStatus {
103 Idle,
104 Training,
105 Synchronizing,
106 Failed,
107 Stopped,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct WorkerInfo {
113 pub worker_id: String,
114 pub rank: usize,
115 pub status: WorkerStatus,
116 pub last_heartbeat: chrono::DateTime<chrono::Utc>,
117 pub gradients_processed: usize,
118 pub current_loss: f64,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct TrainingMetrics {
124 pub epoch: usize,
125 pub global_step: usize,
126 pub average_loss: f64,
127 pub learning_rate: f64,
128 pub throughput_samples_per_sec: f64,
129 pub worker_metrics: Vec<WorkerMetrics>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct WorkerMetrics {
135 pub worker_id: String,
136 pub local_loss: f64,
137 pub gradient_norm: f64,
138 pub samples_processed: usize,
139}
140
141pub struct DistributedMLTrainer {
143 config: DistributedMLConfig,
144 workers: Arc<RwLock<HashMap<String, WorkerInfo>>>,
145 model_parameters: Arc<RwLock<Vec<Array1<f64>>>>,
146 training_state: Arc<RwLock<TrainingState>>,
147}
148
149#[derive(Debug, Clone)]
151struct TrainingState {
152 current_epoch: usize,
153 global_step: usize,
154 best_loss: f64,
155 training_history: Vec<TrainingMetrics>,
156}
157
158impl DistributedMLTrainer {
159 pub fn new(config: DistributedMLConfig) -> Self {
161 Self {
162 config,
163 workers: Arc::new(RwLock::new(HashMap::new())),
164 model_parameters: Arc::new(RwLock::new(Vec::new())),
165 training_state: Arc::new(RwLock::new(TrainingState {
166 current_epoch: 0,
167 global_step: 0,
168 best_loss: f64::INFINITY,
169 training_history: Vec::new(),
170 })),
171 }
172 }
173
174 pub async fn initialize(&self, initial_parameters: Vec<Array1<f64>>) -> Result<()> {
176 info!(
177 "Initializing distributed ML training cluster with {} workers",
178 self.config.num_workers
179 );
180
181 {
183 let mut params = self.model_parameters.write().await;
184 *params = initial_parameters;
185 }
186
187 if !self.config.checkpoint_dir.exists() {
189 tokio::fs::create_dir_all(&self.config.checkpoint_dir).await?;
190 }
191
192 for rank in 0..self.config.num_workers {
194 let worker_id = format!("worker_{}", rank);
195 let worker = WorkerInfo {
196 worker_id: worker_id.clone(),
197 rank,
198 status: WorkerStatus::Idle,
199 last_heartbeat: chrono::Utc::now(),
200 gradients_processed: 0,
201 current_loss: 0.0,
202 };
203
204 let mut workers = self.workers.write().await;
205 workers.insert(worker_id, worker);
206 }
207
208 info!("Distributed training cluster initialized successfully");
209 Ok(())
210 }
211
212 pub async fn train(
214 &self,
215 training_data: Vec<Vec<f64>>,
216 labels: Vec<f64>,
217 ) -> Result<TrainingMetrics> {
218 info!(
219 "Starting distributed training for {} epochs",
220 self.config.max_epochs
221 );
222 let start_time = Instant::now();
223
224 for epoch in 0..self.config.max_epochs {
225 let epoch_start = Instant::now();
226
227 let data_partitions = self.partition_data(&training_data, &labels);
229
230 let worker_results = self.execute_parallel_training_step(data_partitions).await?;
232
233 let aggregated_gradients = self.aggregate_gradients(&worker_results).await?;
235
236 self.update_parameters(&aggregated_gradients).await?;
238
239 let average_loss =
241 worker_results.iter().map(|r| r.loss).sum::<f64>() / worker_results.len() as f64;
242
243 let worker_metrics: Vec<WorkerMetrics> = worker_results
244 .iter()
245 .map(|r| WorkerMetrics {
246 worker_id: r.worker_id.clone(),
247 local_loss: r.loss,
248 gradient_norm: r.gradient_norm,
249 samples_processed: r.samples_processed,
250 })
251 .collect();
252
253 let epoch_duration = epoch_start.elapsed();
254 let throughput = (training_data.len() as f64) / epoch_duration.as_secs_f64();
255
256 let metrics = TrainingMetrics {
257 epoch,
258 global_step: epoch * self.config.num_workers,
259 average_loss,
260 learning_rate: self.config.learning_rate,
261 throughput_samples_per_sec: throughput,
262 worker_metrics,
263 };
264
265 {
267 let mut state = self.training_state.write().await;
268 state.current_epoch = epoch;
269 state.global_step = metrics.global_step;
270 if average_loss < state.best_loss {
271 state.best_loss = average_loss;
272 }
273 state.training_history.push(metrics.clone());
274 }
275
276 info!(
277 "Epoch {}/{}: loss={:.6}, throughput={:.2} samples/sec",
278 epoch + 1,
279 self.config.max_epochs,
280 average_loss,
281 throughput
282 );
283
284 if (epoch + 1) % self.config.checkpoint_interval == 0 {
286 self.save_checkpoint(epoch).await?;
287 }
288
289 if self.config.enable_fault_tolerance {
291 self.check_worker_health().await?;
292 }
293 }
294
295 let total_duration = start_time.elapsed();
296 info!(
297 "Distributed training completed in {:.2}s",
298 total_duration.as_secs_f64()
299 );
300
301 let state = self.training_state.read().await;
303 Ok(state.training_history.last().cloned().unwrap())
304 }
305
306 fn partition_data(&self, data: &[Vec<f64>], labels: &[f64]) -> Vec<(Vec<Vec<f64>>, Vec<f64>)> {
308 let chunk_size = (data.len() + self.config.num_workers - 1) / self.config.num_workers;
309
310 (0..self.config.num_workers)
311 .map(|i| {
312 let start = i * chunk_size;
313 let end = ((i + 1) * chunk_size).min(data.len());
314
315 let data_chunk = data[start..end].to_vec();
316 let labels_chunk = labels[start..end].to_vec();
317
318 (data_chunk, labels_chunk)
319 })
320 .collect()
321 }
322
323 async fn execute_parallel_training_step(
325 &self,
326 data_partitions: Vec<(Vec<Vec<f64>>, Vec<f64>)>,
327 ) -> Result<Vec<WorkerTrainingResult>> {
328 let mut results = Vec::new();
330
331 for (rank, (data, labels)) in data_partitions.iter().enumerate() {
332 let worker_id = format!("worker_{}", rank);
333
334 let (gradients, loss) = self.compute_gradients_and_loss(data, labels)?;
336 let gradient_norm = gradients
337 .iter()
338 .map(|g| g.iter().map(|x| x * x).sum::<f64>())
339 .sum::<f64>()
340 .sqrt();
341
342 results.push(WorkerTrainingResult {
343 worker_id,
344 gradients,
345 loss,
346 gradient_norm,
347 samples_processed: data.len(),
348 });
349 }
350
351 Ok(results)
352 }
353
354 fn compute_gradients_and_loss(
356 &self,
357 data: &[Vec<f64>],
358 labels: &[f64],
359 ) -> Result<(Vec<Array1<f64>>, f64)> {
360 let num_params = 2;
364
365 let mut gradients = vec![Array1::zeros(10); num_params];
366 let mut total_loss = 0.0;
367
368 for (x, &y) in data.iter().zip(labels.iter()) {
369 let prediction = x.iter().sum::<f64>() / x.len() as f64;
371 let error = prediction - y;
372 total_loss += error * error;
373
374 for grad in &mut gradients {
376 for i in 0..grad.len() {
377 grad[i] += error * 2.0 / data.len() as f64;
378 }
379 }
380 }
381
382 let loss = total_loss / data.len() as f64;
383 Ok((gradients, loss))
384 }
385
386 async fn aggregate_gradients(
388 &self,
389 results: &[WorkerTrainingResult],
390 ) -> Result<Vec<Array1<f64>>> {
391 match self.config.aggregation_strategy {
392 AggregationStrategy::AllReduce => {
393 self.allreduce_aggregation(results).await
395 }
396 AggregationStrategy::ParameterServerSync => {
397 self.parameter_server_sync_aggregation(results).await
399 }
400 AggregationStrategy::ParameterServerAsync => {
401 self.parameter_server_async_aggregation(results).await
403 }
404 AggregationStrategy::FederatedAveraging => {
405 self.federated_averaging_aggregation(results).await
407 }
408 }
409 }
410
411 async fn allreduce_aggregation(
413 &self,
414 results: &[WorkerTrainingResult],
415 ) -> Result<Vec<Array1<f64>>> {
416 if results.is_empty() {
417 return Err(anyhow!("No worker results to aggregate"));
418 }
419
420 let num_params = results[0].gradients.len();
421 let mut aggregated = vec![Array1::zeros(10); num_params];
422
423 for result in results {
425 for (i, grad) in result.gradients.iter().enumerate() {
426 for j in 0..grad.len() {
427 aggregated[i][j] += grad[j];
428 }
429 }
430 }
431
432 let num_workers = results.len() as f64;
434 for grad in &mut aggregated {
435 for val in grad.iter_mut() {
436 *val /= num_workers;
437 }
438 }
439
440 debug!(
441 "AllReduce aggregation completed for {} workers",
442 results.len()
443 );
444 Ok(aggregated)
445 }
446
447 async fn parameter_server_sync_aggregation(
449 &self,
450 results: &[WorkerTrainingResult],
451 ) -> Result<Vec<Array1<f64>>> {
452 self.allreduce_aggregation(results).await
454 }
455
456 async fn parameter_server_async_aggregation(
458 &self,
459 results: &[WorkerTrainingResult],
460 ) -> Result<Vec<Array1<f64>>> {
461 self.allreduce_aggregation(results).await
464 }
465
466 async fn federated_averaging_aggregation(
468 &self,
469 results: &[WorkerTrainingResult],
470 ) -> Result<Vec<Array1<f64>>> {
471 if results.is_empty() {
473 return Err(anyhow!("No worker results to aggregate"));
474 }
475
476 let num_params = results[0].gradients.len();
477 let mut aggregated = vec![Array1::zeros(10); num_params];
478 let total_samples: usize = results.iter().map(|r| r.samples_processed).sum();
479
480 for result in results {
481 let weight = result.samples_processed as f64 / total_samples as f64;
482 for (i, grad) in result.gradients.iter().enumerate() {
483 for j in 0..grad.len() {
484 aggregated[i][j] += grad[j] * weight;
485 }
486 }
487 }
488
489 debug!(
490 "Federated averaging completed with {} total samples",
491 total_samples
492 );
493 Ok(aggregated)
494 }
495
496 async fn update_parameters(&self, gradients: &[Array1<f64>]) -> Result<()> {
498 let mut params = self.model_parameters.write().await;
499
500 for (param, grad) in params.iter_mut().zip(gradients.iter()) {
501 for i in 0..param.len().min(grad.len()) {
502 param[i] -= self.config.learning_rate * grad[i];
503 }
504 }
505
506 Ok(())
507 }
508
509 async fn save_checkpoint(&self, epoch: usize) -> Result<()> {
511 let checkpoint_path = self
512 .config
513 .checkpoint_dir
514 .join(format!("checkpoint_epoch_{}.json", epoch));
515
516 let params = self.model_parameters.read().await;
517 let state = self.training_state.read().await;
518
519 let checkpoint = CheckpointData {
520 epoch,
521 global_step: state.global_step,
522 best_loss: state.best_loss,
523 parameters: params.iter().map(|p| p.to_vec()).collect(),
524 };
525
526 let json = serde_json::to_string_pretty(&checkpoint)?;
527 tokio::fs::write(&checkpoint_path, json).await?;
528
529 info!("Checkpoint saved to {:?}", checkpoint_path);
530 Ok(())
531 }
532
533 async fn check_worker_health(&self) -> Result<()> {
535 let mut workers = self.workers.write().await;
536 let now = chrono::Utc::now();
537
538 for (worker_id, worker) in workers.iter_mut() {
539 let elapsed = (now - worker.last_heartbeat).num_seconds();
540 if elapsed > self.config.health_check_interval.as_secs() as i64 {
541 warn!("Worker {} missed heartbeat ({}s ago)", worker_id, elapsed);
542 worker.status = WorkerStatus::Failed;
543 }
544 }
545
546 Ok(())
547 }
548
549 pub async fn get_metrics(&self) -> Result<TrainingMetrics> {
551 let state = self.training_state.read().await;
552 state
553 .training_history
554 .last()
555 .cloned()
556 .ok_or_else(|| anyhow!("No training metrics available"))
557 }
558
559 pub async fn get_worker_status(&self) -> Vec<WorkerInfo> {
561 let workers = self.workers.read().await;
562 workers.values().cloned().collect()
563 }
564}
565
566#[derive(Debug, Clone)]
568struct WorkerTrainingResult {
569 worker_id: String,
570 gradients: Vec<Array1<f64>>,
571 loss: f64,
572 gradient_norm: f64,
573 samples_processed: usize,
574}
575
576#[derive(Debug, Clone, Serialize, Deserialize)]
578struct CheckpointData {
579 epoch: usize,
580 global_step: usize,
581 best_loss: f64,
582 parameters: Vec<Vec<f64>>,
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[tokio::test]
590 async fn test_distributed_trainer_creation() {
591 let config = DistributedMLConfig::default();
592 let trainer = DistributedMLTrainer::new(config);
593
594 let initial_params = vec![Array1::zeros(10); 3];
595 trainer.initialize(initial_params).await.unwrap();
596
597 let workers = trainer.get_worker_status().await;
598 assert_eq!(workers.len(), 4);
599 }
600
601 #[tokio::test]
602 async fn test_data_partitioning() {
603 let config = DistributedMLConfig {
604 num_workers: 2,
605 ..Default::default()
606 };
607 let trainer = DistributedMLTrainer::new(config);
608
609 let data = vec![vec![1.0, 2.0]; 10];
610 let labels = vec![1.0; 10];
611
612 let partitions = trainer.partition_data(&data, &labels);
613 assert_eq!(partitions.len(), 2);
614 assert_eq!(partitions[0].0.len(), 5);
615 assert_eq!(partitions[1].0.len(), 5);
616 }
617
618 #[tokio::test]
619 async fn test_gradient_aggregation() {
620 let config = DistributedMLConfig::default();
621 let trainer = DistributedMLTrainer::new(config);
622
623 let results = vec![
624 WorkerTrainingResult {
625 worker_id: "w1".to_string(),
626 gradients: vec![Array1::from_vec(vec![1.0, 2.0, 3.0])],
627 loss: 0.5,
628 gradient_norm: 1.0,
629 samples_processed: 10,
630 },
631 WorkerTrainingResult {
632 worker_id: "w2".to_string(),
633 gradients: vec![Array1::from_vec(vec![2.0, 3.0, 4.0])],
634 loss: 0.6,
635 gradient_norm: 1.5,
636 samples_processed: 10,
637 },
638 ];
639
640 let aggregated = trainer.allreduce_aggregation(&results).await.unwrap();
641 assert_eq!(aggregated.len(), 1);
642 assert!((aggregated[0][0] - 1.5).abs() < 1e-6);
643 assert!((aggregated[0][1] - 2.5).abs() < 1e-6);
644 assert!((aggregated[0][2] - 3.5).abs() < 1e-6);
645 }
646
647 #[tokio::test]
648 async fn test_federated_averaging() {
649 let config = DistributedMLConfig {
650 aggregation_strategy: AggregationStrategy::FederatedAveraging,
651 ..Default::default()
652 };
653 let trainer = DistributedMLTrainer::new(config);
654
655 let results = vec![
656 WorkerTrainingResult {
657 worker_id: "w1".to_string(),
658 gradients: vec![Array1::from_vec(vec![1.0, 2.0])],
659 loss: 0.5,
660 gradient_norm: 1.0,
661 samples_processed: 20,
662 },
663 WorkerTrainingResult {
664 worker_id: "w2".to_string(),
665 gradients: vec![Array1::from_vec(vec![3.0, 4.0])],
666 loss: 0.6,
667 gradient_norm: 1.5,
668 samples_processed: 10,
669 },
670 ];
671
672 let aggregated = trainer
673 .federated_averaging_aggregation(&results)
674 .await
675 .unwrap();
676 assert_eq!(aggregated.len(), 1);
679 assert!((aggregated[0][0] - (1.0 * 20.0 / 30.0 + 3.0 * 10.0 / 30.0)).abs() < 1e-6);
680 assert!((aggregated[0][1] - (2.0 * 20.0 / 30.0 + 4.0 * 10.0 / 30.0)).abs() < 1e-6);
681 }
682
683 #[tokio::test]
684 async fn test_training_flow() {
685 let config = DistributedMLConfig {
686 num_workers: 2,
687 max_epochs: 2,
688 checkpoint_interval: 1,
689 ..Default::default()
690 };
691 let trainer = DistributedMLTrainer::new(config);
692
693 let initial_params = vec![Array1::from_vec(vec![0.5; 10]); 2];
694 trainer.initialize(initial_params).await.unwrap();
695
696 let data = vec![vec![1.0, 2.0, 3.0]; 20];
697 let labels = vec![2.0; 20];
698
699 let metrics = trainer.train(data, labels).await.unwrap();
700 assert_eq!(metrics.epoch, 1);
701 assert!(metrics.average_loss >= 0.0);
702 assert_eq!(metrics.worker_metrics.len(), 2);
703 }
704}