1use crate::finetune::classification::SafetySample;
30use crate::finetune::classify_pipeline::{BatchResult, ClassifyConfig, ClassifyPipeline};
31use crate::transformer::TransformerConfig;
32
33pub struct DataParallelCoordinator {
41 pipelines: Vec<ClassifyPipeline>,
43 #[allow(dead_code)]
45 gpu_indices: Vec<u32>,
46}
47
48impl DataParallelCoordinator {
49 pub fn new(
62 model_config: &TransformerConfig,
63 classify_config: ClassifyConfig,
64 gpu_indices: &[u32],
65 ) -> Result<Self, String> {
66 if gpu_indices.is_empty() {
67 return Err("At least one GPU index required".to_string());
68 }
69
70 let mut pipelines = Vec::with_capacity(gpu_indices.len());
71
72 for &_idx in gpu_indices {
73 let pipeline = ClassifyPipeline::new(model_config, classify_config.clone());
75 pipelines.push(pipeline);
76 }
77
78 Ok(Self { pipelines, gpu_indices: gpu_indices.to_vec() })
79 }
80
81 #[must_use]
83 pub fn num_gpus(&self) -> usize {
84 self.pipelines.len()
85 }
86
87 pub fn primary_pipeline(&mut self) -> &mut ClassifyPipeline {
89 &mut self.pipelines[0]
90 }
91
92 pub fn primary_pipeline_ref(&self) -> &ClassifyPipeline {
94 &self.pipelines[0]
95 }
96
97 pub fn train_batch_parallel(&mut self, samples: &[SafetySample]) -> BatchResult {
108 let num_gpus = self.pipelines.len();
109
110 if num_gpus == 1 || samples.len() < num_gpus {
111 return self.pipelines[0].train_batch(samples);
113 }
114
115 let shard_size = samples.len() / num_gpus;
117 let shards: Vec<&[SafetySample]> = (0..num_gpus)
118 .map(|i| {
119 let start = i * shard_size;
120 let end = if i == num_gpus - 1 { samples.len() } else { start + shard_size };
121 &samples[start..end]
122 })
123 .collect();
124
125 let mut results = Vec::with_capacity(num_gpus);
133 for (gpu_idx, shard) in shards.iter().enumerate() {
134 let result = self.pipelines[gpu_idx].train_batch(shard);
135 results.push(result);
136 }
137
138 let total_samples: usize = results.iter().map(|r| r.total).sum();
140 let total_correct: usize = results.iter().map(|r| r.correct).sum();
141 let avg_loss: f32 =
142 results.iter().map(|r| r.avg_loss * r.total as f32).sum::<f32>() / total_samples as f32;
143 let avg_grad_norm: f32 = results.iter().map(|r| r.grad_norm).sum::<f32>() / num_gpus as f32;
144
145 if self.pipelines.len() > 1 {
150 self.sync_lora_weights_from_primary();
151 }
152
153 BatchResult {
154 avg_loss,
155 correct: total_correct,
156 total: total_samples,
157 grad_norm: avg_grad_norm,
158 }
159 }
160
161 fn sync_lora_weights_from_primary(&mut self) {
169 if self.pipelines.len() <= 1 {
170 return;
171 }
172
173 let (primary_slice, replicas) = self.pipelines.split_at_mut(1);
175 let primary = &primary_slice[0];
176
177 for replica in replicas.iter_mut() {
178 for (src_lora, dst_lora) in
180 primary.lora_layers.iter().zip(replica.lora_layers.iter_mut())
181 {
182 dst_lora.lora_a_mut().data_mut().assign(src_lora.lora_a().data());
183 dst_lora.lora_b_mut().data_mut().assign(src_lora.lora_b().data());
184 }
185
186 replica.classifier.weight.data_mut().assign(primary.classifier.weight.data());
188 replica.classifier.bias.data_mut().assign(primary.classifier.bias.data());
189 }
190 }
191}
192
193pub fn shard_samples<T>(samples: &[T], num_workers: usize) -> Vec<&[T]> {
202 if num_workers == 0 || samples.is_empty() {
203 return vec![samples];
204 }
205 let shard_size = samples.len() / num_workers;
206 (0..num_workers)
207 .map(|i| {
208 let start = i * shard_size;
209 let end = if i == num_workers - 1 { samples.len() } else { start + shard_size };
210 &samples[start..end]
211 })
212 .collect()
213}
214
215pub fn average_gradients(grads: &[Vec<f32>]) -> Vec<f32> {
222 if grads.is_empty() {
223 return Vec::new();
224 }
225 let len = grads[0].len();
226 let n = grads.len() as f32;
227 let mut avg = vec![0.0f32; len];
228 for grad in grads {
229 for (j, &v) in grad.iter().enumerate() {
230 avg[j] += v;
231 }
232 }
233 for v in &mut avg {
234 *v /= n;
235 }
236 avg
237}
238
239pub fn has_non_finite(values: &[f32]) -> bool {
243 values.iter().any(|v| !v.is_finite())
244}
245
246#[cfg(test)]
247mod tests {
248 #![allow(clippy::unwrap_used)]
249 use super::*;
250 use crate::transformer::ModelArchitecture;
251
252 fn test_config() -> (TransformerConfig, ClassifyConfig) {
253 let model_config = TransformerConfig {
254 hidden_size: 32,
255 num_hidden_layers: 2,
256 num_attention_heads: 4,
257 num_kv_heads: 4,
258 intermediate_size: 64,
259 vocab_size: 100,
260 max_position_embeddings: 64,
261 rms_norm_eps: 1e-6,
262 rope_theta: 10000.0,
263 use_bias: false,
264 head_dim_override: None,
265 architecture: ModelArchitecture::Decoder,
266 hf_architecture: None,
267 hf_model_type: None,
268 tie_word_embeddings: false,
269 };
270
271 let classify_config =
272 ClassifyConfig { num_classes: 2, lora_rank: 4, ..ClassifyConfig::default() };
273
274 (model_config, classify_config)
275 }
276
277 #[test]
278 fn test_coordinator_creation() {
279 let (model_config, classify_config) = test_config();
280 let coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0]);
281 assert!(coordinator.is_ok());
282 assert_eq!(
283 coordinator.as_ref().map(super::DataParallelCoordinator::num_gpus).unwrap_or(0),
284 1
285 );
286 }
287
288 #[test]
289 fn test_coordinator_empty_gpus_fails() {
290 let (model_config, classify_config) = test_config();
291 let result = DataParallelCoordinator::new(&model_config, classify_config, &[]);
292 assert!(result.is_err());
293 }
294
295 #[test]
296 fn test_multi_gpu_coordinator_accessors() {
297 let (model_config, classify_config) = test_config();
298 let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
299 .expect("creation should succeed");
300
301 assert_eq!(coordinator.num_gpus(), 1);
303
304 let primary = coordinator.primary_pipeline();
305 assert_eq!(primary.config.num_classes, 2);
306
307 let primary_ref = coordinator.primary_pipeline_ref();
308 assert_eq!(primary_ref.config.lora_rank, 4);
309 }
310
311 #[test]
312 fn test_single_gpu_fallback_path() {
313 let (model_config, classify_config) = test_config();
314 let coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
315 .expect("creation should succeed");
316
317 assert_eq!(coordinator.num_gpus(), 1);
318 }
319
320 #[test]
321 fn test_weight_sync_noop_single_gpu() {
322 let (model_config, classify_config) = test_config();
323 let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
324 .expect("creation should succeed");
325
326 coordinator.sync_lora_weights_from_primary();
327 }
328
329 #[test]
335 fn falsify_dp_001_weight_sync_makes_replicas_identical() {
336 let (model_config, classify_config) = test_config();
337 let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
338 .expect("creation should succeed");
339
340 let perturbed: Vec<f32> = coordinator.pipelines[1].lora_layers[0]
342 .lora_a()
343 .data()
344 .iter()
345 .map(|v| v + 1.0)
346 .collect();
347 let arr = ndarray::Array1::from(perturbed);
348 *coordinator.pipelines[1].lora_layers[0].lora_a_mut().data_mut() = arr;
349
350 let w0: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
352 let w1: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
353 assert_ne!(w0, w1, "Weights should differ before sync");
354
355 coordinator.sync_lora_weights_from_primary();
357
358 let w0_after: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
359 let w1_after: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
360 assert_eq!(w0_after, w1_after, "F-DP-001: Weights MUST be identical after sync");
361 }
362
363 #[test]
365 fn falsify_dp_001_weights_diverge_without_sync() {
366 let (model_config, classify_config) = test_config();
367 let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
368 .expect("creation should succeed");
369
370 let perturbed: Vec<f32> = coordinator.pipelines[1].lora_layers[0]
372 .lora_a()
373 .data()
374 .iter()
375 .map(|v| v + 0.5)
376 .collect();
377 let arr = ndarray::Array1::from(perturbed);
378 *coordinator.pipelines[1].lora_layers[0].lora_a_mut().data_mut() = arr;
379
380 let w0: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
382 let w1: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
383 assert_ne!(w0, w1, "Without sync, weights MUST diverge (proving sync is necessary)");
384 }
385
386 #[test]
388 fn falsify_dp_002_no_sample_lost_or_duplicated() {
389 let samples: Vec<u32> = (0..100).collect();
390
391 for num_workers in [1, 2, 3, 4, 7, 10] {
392 let shards = shard_samples(&samples, num_workers);
393 assert_eq!(
394 shards.len(),
395 num_workers,
396 "Wrong number of shards for {num_workers} workers"
397 );
398
399 let total: usize = shards.iter().map(|s| s.len()).sum();
401 assert_eq!(total, 100, "F-DP-002: samples lost with {num_workers} workers");
402
403 let mut seen = std::collections::HashSet::new();
405 for shard in &shards {
406 for &s in *shard {
407 assert!(
408 seen.insert(s),
409 "F-DP-002: duplicate sample {s} with {num_workers} workers"
410 );
411 }
412 }
413 assert_eq!(seen.len(), 100);
414 }
415 }
416
417 #[test]
419 fn falsify_dp_002_uneven_sharding_gets_remainder() {
420 let samples: Vec<u32> = (0..10).collect();
421 let shards = shard_samples(&samples, 3);
422 assert_eq!(shards[0].len(), 3);
424 assert_eq!(shards[1].len(), 3);
425 assert_eq!(shards[2].len(), 4); let total: usize = shards.iter().map(|s| s.len()).sum();
427 assert_eq!(total, 10);
428 }
429
430 #[test]
432 fn falsify_dp_003_nan_gradient_propagates() {
433 let grads = vec![vec![1.0, 2.0, 3.0], vec![f32::NAN, 2.0, 3.0]];
434 let avg = average_gradients(&grads);
435 assert!(avg[0].is_nan(), "F-DP-003: NaN MUST propagate through averaging (Jidoka)");
436 assert!((avg[1] - 2.0).abs() < 1e-6);
438 assert!((avg[2] - 3.0).abs() < 1e-6);
439 }
440
441 #[test]
443 fn falsify_dp_003_inf_gradient_propagates() {
444 let grads = vec![vec![1.0, 2.0], vec![f32::INFINITY, 2.0]];
445 let avg = average_gradients(&grads);
446 assert!(avg[0].is_infinite(), "F-DP-003: Inf MUST propagate through averaging");
447 }
448
449 #[test]
451 fn falsify_dp_003_non_finite_detection() {
452 assert!(!has_non_finite(&[1.0, 2.0, 3.0]));
453 assert!(has_non_finite(&[1.0, f32::NAN, 3.0]));
454 assert!(has_non_finite(&[1.0, f32::INFINITY, 3.0]));
455 assert!(has_non_finite(&[1.0, f32::NEG_INFINITY, 3.0]));
456 }
457
458 #[test]
460 fn test_average_gradients_correct() {
461 let grads = vec![vec![2.0, 4.0, 6.0], vec![4.0, 6.0, 8.0], vec![6.0, 8.0, 10.0]];
462 let avg = average_gradients(&grads);
463 assert!((avg[0] - 4.0).abs() < 1e-6);
464 assert!((avg[1] - 6.0).abs() < 1e-6);
465 assert!((avg[2] - 8.0).abs() < 1e-6);
466 }
467
468 #[test]
470 fn test_average_gradients_single_worker() {
471 let grads = vec![vec![1.0, 2.0, 3.0]];
472 let avg = average_gradients(&grads);
473 assert!((avg[0] - 1.0).abs() < 1e-6);
474 assert!((avg[1] - 2.0).abs() < 1e-6);
475 assert!((avg[2] - 3.0).abs() < 1e-6);
476 }
477
478 #[test]
480 fn test_average_gradients_empty() {
481 let grads: Vec<Vec<f32>> = vec![];
482 let avg = average_gradients(&grads);
483 assert!(avg.is_empty());
484 }
485
486 #[test]
488 fn falsify_dp_004_cpu_pipeline_produces_finite_hidden() {
489 let (model_config, classify_config) = test_config();
490 let pipeline = ClassifyPipeline::new(&model_config, classify_config);
491
492 let token_ids = vec![1u32, 2, 3, 4, 5];
494 let hidden = pipeline.model.forward_hidden(&token_ids);
495 let data = hidden.data();
496
497 assert!(
499 data.iter().all(|v| v.is_finite()),
500 "F-DP-004: CPU fallback must produce finite hidden states"
501 );
502 assert_eq!(data.len(), token_ids.len() * model_config.hidden_size);
504 }
505
506 #[test]
508 fn test_weight_sync_covers_classifier_head() {
509 let (model_config, classify_config) = test_config();
510 let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
511 .expect("creation should succeed");
512
513 let perturbed: Vec<f32> =
515 coordinator.pipelines[1].classifier.weight.data().iter().map(|v| v + 99.0).collect();
516 let arr = ndarray::Array1::from(perturbed);
517 *coordinator.pipelines[1].classifier.weight.data_mut() = arr;
518
519 coordinator.sync_lora_weights_from_primary();
521
522 let w0: Vec<f32> = coordinator.pipelines[0].classifier.weight.data().to_vec();
523 let w1: Vec<f32> = coordinator.pipelines[1].classifier.weight.data().to_vec();
524 assert_eq!(w0, w1, "Classifier head weights must sync across replicas");
525 }
526
527 #[test]
529 fn test_multi_gpu_creates_n_pipelines() {
530 let (model_config, classify_config) = test_config();
531 for n in [1, 2, 3, 4] {
532 let indices: Vec<u32> = (0..n).collect();
533 let coordinator =
534 DataParallelCoordinator::new(&model_config, classify_config.clone(), &indices)
535 .expect("creation should succeed");
536 assert_eq!(coordinator.num_gpus(), n as usize);
537 }
538 }
539
540 #[test]
543 fn falsify_dp_001_weight_sync_all_layers_and_classifier() {
544 let (model_config, classify_config) = test_config();
545 let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
546 .expect("creation should succeed");
547
548 for lora in &mut coordinator.pipelines[1].lora_layers {
550 let perturbed_a: Vec<f32> = lora.lora_a().data().iter().map(|v| v + 42.0).collect();
551 *lora.lora_a_mut().data_mut() = ndarray::Array1::from(perturbed_a);
552 let perturbed_b: Vec<f32> = lora.lora_b().data().iter().map(|v| v + 7.0).collect();
553 *lora.lora_b_mut().data_mut() = ndarray::Array1::from(perturbed_b);
554 }
555 let perturbed_w: Vec<f32> =
556 coordinator.pipelines[1].classifier.weight.data().iter().map(|v| v + 99.0).collect();
557 *coordinator.pipelines[1].classifier.weight.data_mut() = ndarray::Array1::from(perturbed_w);
558
559 coordinator.sync_lora_weights_from_primary();
561
562 for (i, (l0, l1)) in coordinator.pipelines[0]
564 .lora_layers
565 .iter()
566 .zip(coordinator.pipelines[1].lora_layers.iter())
567 .enumerate()
568 {
569 assert_eq!(
570 l0.lora_a().data().as_slice().unwrap(),
571 l1.lora_a().data().as_slice().unwrap(),
572 "F-DP-001: lora_a of layer {i} must match after sync"
573 );
574 assert_eq!(
575 l0.lora_b().data().as_slice().unwrap(),
576 l1.lora_b().data().as_slice().unwrap(),
577 "F-DP-001: lora_b of layer {i} must match after sync"
578 );
579 }
580
581 assert_eq!(
583 coordinator.pipelines[0].classifier.weight.data().as_slice().unwrap(),
584 coordinator.pipelines[1].classifier.weight.data().as_slice().unwrap(),
585 "F-DP-001: classifier weight must match after sync"
586 );
587 assert_eq!(
588 coordinator.pipelines[0].classifier.bias.data().as_slice().unwrap(),
589 coordinator.pipelines[1].classifier.bias.data().as_slice().unwrap(),
590 "F-DP-001: classifier bias must match after sync"
591 );
592 }
593
594 #[test]
597 fn falsify_dp_005_single_vs_multi_gpu_loss_convergence() {
598 let (model_config, classify_config) = test_config();
602
603 let samples: Vec<SafetySample> = (0..20)
605 .map(|i| SafetySample { input: format!("test_sample_{i}"), label: i % 2 })
606 .collect();
607
608 let mut single_pipe = ClassifyPipeline::new(&model_config, classify_config.clone());
610 let token_ids_batch: Vec<Vec<u32>> = samples
611 .iter()
612 .map(|s| {
613 let bytes: Vec<u32> = s.input.bytes().map(u32::from).collect();
614 bytes[..bytes.len().min(16)].to_vec()
615 })
616 .collect();
617
618 let mut single_loss = 0.0f32;
620 for (ids, sample) in token_ids_batch.iter().zip(&samples) {
621 let (loss, _pred) = single_pipe.forward_only(ids, sample.label);
622 single_loss += loss;
623 }
624 let single_avg_loss = single_loss / samples.len() as f32;
625
626 let mut multi = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
628 .expect("creation should succeed");
629
630 let id_label_pairs: Vec<(&Vec<u32>, usize)> =
632 token_ids_batch.iter().zip(samples.iter().map(|s| s.label)).collect();
633 let shards = shard_samples(&id_label_pairs, 2);
634 let mut multi_loss = 0.0f32;
635 let mut multi_count = 0usize;
636
637 for (shard_idx, shard) in shards.iter().enumerate() {
638 let pipe = &mut multi.pipelines[shard_idx];
639 for &(ids, label) in *shard {
640 let (loss, _pred) = pipe.forward_only(ids, label);
641 multi_loss += loss;
642 multi_count += 1;
643 }
644 }
645 let multi_avg_loss = multi_loss / multi_count as f32;
646
647 assert!(
651 (single_avg_loss - multi_avg_loss).abs() < 0.25 * single_avg_loss.abs() + 1e-6,
652 "F-DP-005: single GPU loss ({single_avg_loss:.6}) vs multi GPU loss ({multi_avg_loss:.6}) \
653 diverged beyond 25% tolerance"
654 );
655 }
656
657 #[test]
660 fn falsify_het_001_gradient_layout_identical_across_pipelines() {
661 let (model_config, classify_config) = test_config();
665
666 let pipe_a = ClassifyPipeline::new(&model_config, classify_config.clone());
667 let pipe_b = ClassifyPipeline::new(&model_config, classify_config);
668
669 let grads_a = pipe_a.collect_lora_gradients();
671 let grads_b = pipe_b.collect_lora_gradients();
672 assert_eq!(
673 grads_a.len(),
674 grads_b.len(),
675 "F-HET-001: gradient layout length mismatch between pipelines"
676 );
677
678 assert_eq!(
680 grads_a.len(),
681 pipe_a.num_trainable_parameters(),
682 "F-HET-001: gradient length != num_trainable_parameters for pipeline A"
683 );
684 assert_eq!(
685 grads_b.len(),
686 pipe_b.num_trainable_parameters(),
687 "F-HET-001: gradient length != num_trainable_parameters for pipeline B"
688 );
689
690 assert_eq!(
692 pipe_a.lora_layers.len(),
693 pipe_b.lora_layers.len(),
694 "F-HET-001: different LoRA layer counts"
695 );
696
697 for (i, (la, lb)) in pipe_a.lora_layers.iter().zip(pipe_b.lora_layers.iter()).enumerate() {
699 assert_eq!(
700 la.lora_a().data().len(),
701 lb.lora_a().data().len(),
702 "F-HET-001: lora_a dimension mismatch at layer {i}"
703 );
704 assert_eq!(
705 la.lora_b().data().len(),
706 lb.lora_b().data().len(),
707 "F-HET-001: lora_b dimension mismatch at layer {i}"
708 );
709 }
710 }
711
712 #[test]
715 fn falsify_het_002_memory_budget_within_vram() {
716 let (model_config, classify_config) = test_config();
719 let pipeline = ClassifyPipeline::new(&model_config, classify_config);
720
721 let hidden = model_config.hidden_size;
723 let layers = model_config.num_hidden_layers;
724 let vocab = model_config.vocab_size;
725
726 let model_params = vocab * hidden + layers * (4 * hidden * hidden) + layers * (2 * hidden * 4 * hidden); let model_bytes = model_params * 4;
731
732 let trainable = pipeline.num_trainable_parameters();
734 let adapter_bytes = trainable * 4;
735
736 let total_bytes = model_bytes + adapter_bytes;
738 let total_mb = total_bytes as f64 / (1024.0 * 1024.0);
739
740 assert!(
741 total_mb < 8192.0,
742 "F-HET-002: estimated memory {total_mb:.1} MB exceeds 8 GB VRAM budget"
743 );
744
745 let adapter_ratio = adapter_bytes as f64 / model_bytes as f64;
747 assert!(
748 adapter_ratio < 0.1,
749 "F-HET-002: adapter memory ratio {adapter_ratio:.4} exceeds 10% of model — \
750 LoRA should be much smaller than frozen model"
751 );
752 }
753
754 #[test]
757 fn falsify_dp_003_nan_and_inf_combined_in_gradient() {
758 assert!(has_non_finite(&[1.0, f32::NAN, f32::INFINITY, 4.0]));
760 assert!(has_non_finite(&[f32::NEG_INFINITY]));
761
762 let grads = vec![vec![f32::NAN, 1.0], vec![f32::INFINITY, 2.0]];
764 let avg = average_gradients(&grads);
765 assert!(avg[0].is_nan(), "NaN + Inf average should be NaN");
766 assert!(has_non_finite(&avg));
767 }
768
769 #[test]
772 fn falsify_dp_002_shard_empty_samples() {
773 let samples: Vec<i32> = vec![];
775 let shards = shard_samples(&samples, 3);
776 let total: usize = shards.iter().map(|s| s.len()).sum();
777 assert_eq!(total, 0, "F-DP-002: sharding empty data must produce 0 total samples");
778 }
779
780 #[test]
781 fn falsify_dp_002_shard_single_sample() {
782 let samples = vec![42];
784 let shards = shard_samples(&samples, 3);
785 let total: usize = shards.iter().map(|s| s.len()).sum();
786 assert_eq!(total, 1, "F-DP-002: must not lose or duplicate the single sample");
787 }
788}