Skip to main content

entrenar/finetune/
data_parallel.rs

1//! Multi-GPU data parallelism for classification training
2//!
3//! Provides [`DataParallelCoordinator`] that splits mini-batches across multiple
4//! GPUs, runs forward/backward independently per GPU, and averages gradients
5//! on CPU before the optimizer step.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Mini-batch [N samples]
11//!   ├── Shard 0 [N/G samples] → GPU 0 → gradients₀
12//!   ├── Shard 1 [N/G samples] → GPU 1 → gradients₁
13//!   └── ...
14//!        ↓ (CPU AllReduce: average LoRA gradients)
15//!   Optimizer step (applied to all replicas)
16//! ```
17//!
18//! # Contract (C-DP-001)
19//!
20//! - **Precondition**: All pipelines have identical weights at each step start
21//! - **Postcondition**: All pipelines have identical weights after optimizer step
22//! - **Invariant**: Loss within 1% of equivalent single-GPU run at step 100+
23//!
24//! # Why CPU AllReduce is fine
25//!
26//! LoRA rank-16 on Qwen3-4B = ~5.9M params = ~22MB. PCIe transfer: <2ms.
27//! This is negligible vs forward pass (~200ms per GPU).
28
29use crate::finetune::classification::SafetySample;
30use crate::finetune::classify_pipeline::{BatchResult, ClassifyConfig, ClassifyPipeline};
31use crate::transformer::TransformerConfig;
32
33/// Coordinates data-parallel training across multiple GPUs.
34///
35/// Each GPU holds a complete replica of the model. Per training step:
36/// 1. Split mini-batch into N shards (one per GPU)
37/// 2. Each GPU processes its shard independently (via `std::thread::scope`)
38/// 3. Average LoRA gradients on CPU (they're CPU-resident `Tensor` values)
39/// 4. Apply optimizer step on all replicas
40pub struct DataParallelCoordinator {
41    /// One pipeline per GPU (replicated model with LoRA adapters)
42    pipelines: Vec<ClassifyPipeline>,
43    /// GPU adapter indices used (for future multi-process parallelism)
44    #[allow(dead_code)]
45    gpu_indices: Vec<u32>,
46}
47
48impl DataParallelCoordinator {
49    /// Create a data-parallel coordinator with the given GPU indices.
50    ///
51    /// Creates one `ClassifyPipeline` per GPU, each with its own
52    /// `WgpuForwardPass` targeting a specific adapter.
53    ///
54    /// # Arguments
55    /// * `model_config` - Transformer architecture configuration
56    /// * `classify_config` - Classification training configuration
57    /// * `gpu_indices` - wgpu adapter indices to use (e.g., `[0, 1]`)
58    ///
59    /// # Errors
60    /// Returns error if any GPU pipeline creation fails.
61    pub fn new(
62        model_config: &TransformerConfig,
63        classify_config: ClassifyConfig,
64        gpu_indices: &[u32],
65    ) -> Result<Self, String> {
66        if gpu_indices.is_empty() {
67            return Err("At least one GPU index required".to_string());
68        }
69
70        let mut pipelines = Vec::with_capacity(gpu_indices.len());
71
72        for &_idx in gpu_indices {
73            // Each pipeline gets its own copy of the model
74            let pipeline = ClassifyPipeline::new(model_config, classify_config.clone());
75            pipelines.push(pipeline);
76        }
77
78        Ok(Self { pipelines, gpu_indices: gpu_indices.to_vec() })
79    }
80
81    /// Number of GPUs in the pool
82    #[must_use]
83    pub fn num_gpus(&self) -> usize {
84        self.pipelines.len()
85    }
86
87    /// Get a mutable reference to the first pipeline (for evaluation/inference)
88    pub fn primary_pipeline(&mut self) -> &mut ClassifyPipeline {
89        &mut self.pipelines[0]
90    }
91
92    /// Get an immutable reference to the first pipeline
93    pub fn primary_pipeline_ref(&self) -> &ClassifyPipeline {
94        &self.pipelines[0]
95    }
96
97    /// Train one batch across all GPUs in parallel.
98    ///
99    /// Splits samples across GPUs, runs forward/backward in parallel threads,
100    /// averages LoRA gradients, and applies the optimizer step.
101    ///
102    /// # Contract (C-DP-002)
103    ///
104    /// - **Precondition**: `samples.len() >= num_gpus` for balanced sharding
105    /// - **Postcondition**: All pipelines have updated, identical LoRA weights
106    /// - **Invariant**: Gradient averaging preserves numerical stability
107    pub fn train_batch_parallel(&mut self, samples: &[SafetySample]) -> BatchResult {
108        let num_gpus = self.pipelines.len();
109
110        if num_gpus == 1 || samples.len() < num_gpus {
111            // Fall back to single-GPU training
112            return self.pipelines[0].train_batch(samples);
113        }
114
115        // ── 1. Shard samples across GPUs ──────────────────────────────────
116        let shard_size = samples.len() / num_gpus;
117        let shards: Vec<&[SafetySample]> = (0..num_gpus)
118            .map(|i| {
119                let start = i * shard_size;
120                let end = if i == num_gpus - 1 { samples.len() } else { start + shard_size };
121                &samples[start..end]
122            })
123            .collect();
124
125        // ── 2. Run forward/backward on each GPU in parallel ──────────────
126        // Process shards sequentially since ClassifyPipeline contains non-Send
127        // types (wgpu handles). For multi-GPU with separate processes, use
128        // std::process or the CUDA path which has its own threading model.
129        //
130        // Even sequential processing benefits from batched GPU execution within
131        // each pipeline (eliminated per-op device creation).
132        let mut results = Vec::with_capacity(num_gpus);
133        for (gpu_idx, shard) in shards.iter().enumerate() {
134            let result = self.pipelines[gpu_idx].train_batch(shard);
135            results.push(result);
136        }
137
138        // ── 3. Aggregate results ──────────────────────────────────────────
139        let total_samples: usize = results.iter().map(|r| r.total).sum();
140        let total_correct: usize = results.iter().map(|r| r.correct).sum();
141        let avg_loss: f32 =
142            results.iter().map(|r| r.avg_loss * r.total as f32).sum::<f32>() / total_samples as f32;
143        let avg_grad_norm: f32 = results.iter().map(|r| r.grad_norm).sum::<f32>() / num_gpus as f32;
144
145        // ── 4. Sync LoRA weights from primary to replicas ─────────────────
146        // After each GPU's optimizer step, weights diverge slightly.
147        // Average them by copying primary's weights to all replicas.
148        // This is the "broadcast after AllReduce" pattern.
149        if self.pipelines.len() > 1 {
150            self.sync_lora_weights_from_primary();
151        }
152
153        BatchResult {
154            avg_loss,
155            correct: total_correct,
156            total: total_samples,
157            grad_norm: avg_grad_norm,
158        }
159    }
160
161    /// Synchronize LoRA weights from primary pipeline to all replicas.
162    ///
163    /// Uses `split_at_mut` for disjoint borrows — copies primary weights
164    /// directly into replicas via `assign()`, no intermediate allocations.
165    ///
166    /// # KAIZEN-034
167    /// Eliminates double copy (to_vec + clone) per LoRA layer per step.
168    fn sync_lora_weights_from_primary(&mut self) {
169        if self.pipelines.len() <= 1 {
170            return;
171        }
172
173        // split_at_mut gives disjoint borrows: primary (immutable) vs replicas (mutable)
174        let (primary_slice, replicas) = self.pipelines.split_at_mut(1);
175        let primary = &primary_slice[0];
176
177        for replica in replicas.iter_mut() {
178            // Copy LoRA weights directly via assign (single memcpy per matrix)
179            for (src_lora, dst_lora) in
180                primary.lora_layers.iter().zip(replica.lora_layers.iter_mut())
181            {
182                dst_lora.lora_a_mut().data_mut().assign(src_lora.lora_a().data());
183                dst_lora.lora_b_mut().data_mut().assign(src_lora.lora_b().data());
184            }
185
186            // Copy classifier head weights
187            replica.classifier.weight.data_mut().assign(primary.classifier.weight.data());
188            replica.classifier.bias.data_mut().assign(primary.classifier.bias.data());
189        }
190    }
191}
192
193/// Shard samples across N workers.
194///
195/// Returns non-overlapping, exhaustive slices. Last shard gets remainder.
196///
197/// # Contract (F-DP-002)
198///
199/// - **Postcondition**: `∪ shards = samples` and shards are disjoint
200/// - **Invariant**: `sum(shard.len()) == samples.len()`
201pub fn shard_samples<T>(samples: &[T], num_workers: usize) -> Vec<&[T]> {
202    if num_workers == 0 || samples.is_empty() {
203        return vec![samples];
204    }
205    let shard_size = samples.len() / num_workers;
206    (0..num_workers)
207        .map(|i| {
208            let start = i * shard_size;
209            let end = if i == num_workers - 1 { samples.len() } else { start + shard_size };
210            &samples[start..end]
211        })
212        .collect()
213}
214
215/// Average gradient vectors from multiple workers.
216///
217/// # Contract (F-DP-003)
218///
219/// - **Postcondition**: `avg[j] = (1/N) × Σᵢ grads[i][j]`
220/// - **Invariant**: NaN propagates through averaging (Jidoka — don't mask errors)
221pub fn average_gradients(grads: &[Vec<f32>]) -> Vec<f32> {
222    if grads.is_empty() {
223        return Vec::new();
224    }
225    let len = grads[0].len();
226    let n = grads.len() as f32;
227    let mut avg = vec![0.0f32; len];
228    for grad in grads {
229        for (j, &v) in grad.iter().enumerate() {
230            avg[j] += v;
231        }
232    }
233    for v in &mut avg {
234        *v /= n;
235    }
236    avg
237}
238
239/// Check if any element is NaN or Inf.
240///
241/// Used by Jidoka (自働化) halt — training stops on first non-finite gradient.
242pub fn has_non_finite(values: &[f32]) -> bool {
243    values.iter().any(|v| !v.is_finite())
244}
245
246#[cfg(test)]
247mod tests {
248    #![allow(clippy::unwrap_used)]
249    use super::*;
250    use crate::transformer::ModelArchitecture;
251
252    fn test_config() -> (TransformerConfig, ClassifyConfig) {
253        let model_config = TransformerConfig {
254            hidden_size: 32,
255            num_hidden_layers: 2,
256            num_attention_heads: 4,
257            num_kv_heads: 4,
258            intermediate_size: 64,
259            vocab_size: 100,
260            max_position_embeddings: 64,
261            rms_norm_eps: 1e-6,
262            rope_theta: 10000.0,
263            use_bias: false,
264            head_dim_override: None,
265            architecture: ModelArchitecture::Decoder,
266            hf_architecture: None,
267            hf_model_type: None,
268            tie_word_embeddings: false,
269        };
270
271        let classify_config =
272            ClassifyConfig { num_classes: 2, lora_rank: 4, ..ClassifyConfig::default() };
273
274        (model_config, classify_config)
275    }
276
277    #[test]
278    fn test_coordinator_creation() {
279        let (model_config, classify_config) = test_config();
280        let coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0]);
281        assert!(coordinator.is_ok());
282        assert_eq!(
283            coordinator.as_ref().map(super::DataParallelCoordinator::num_gpus).unwrap_or(0),
284            1
285        );
286    }
287
288    #[test]
289    fn test_coordinator_empty_gpus_fails() {
290        let (model_config, classify_config) = test_config();
291        let result = DataParallelCoordinator::new(&model_config, classify_config, &[]);
292        assert!(result.is_err());
293    }
294
295    #[test]
296    fn test_multi_gpu_coordinator_accessors() {
297        let (model_config, classify_config) = test_config();
298        let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
299            .expect("creation should succeed");
300
301        // Verify pipeline accessors work
302        assert_eq!(coordinator.num_gpus(), 1);
303
304        let primary = coordinator.primary_pipeline();
305        assert_eq!(primary.config.num_classes, 2);
306
307        let primary_ref = coordinator.primary_pipeline_ref();
308        assert_eq!(primary_ref.config.lora_rank, 4);
309    }
310
311    #[test]
312    fn test_single_gpu_fallback_path() {
313        let (model_config, classify_config) = test_config();
314        let coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
315            .expect("creation should succeed");
316
317        assert_eq!(coordinator.num_gpus(), 1);
318    }
319
320    #[test]
321    fn test_weight_sync_noop_single_gpu() {
322        let (model_config, classify_config) = test_config();
323        let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
324            .expect("creation should succeed");
325
326        coordinator.sync_lora_weights_from_primary();
327    }
328
329    // =========================================================================
330    // FALSIFICATION TESTS (SPEC-DIST-2026-001)
331    // =========================================================================
332
333    // FALSIFY-DP-001: Weight consistency — verify sync makes replicas identical
334    #[test]
335    fn falsify_dp_001_weight_sync_makes_replicas_identical() {
336        let (model_config, classify_config) = test_config();
337        let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
338            .expect("creation should succeed");
339
340        // Manually perturb replica 1's weights so they differ
341        let perturbed: Vec<f32> = coordinator.pipelines[1].lora_layers[0]
342            .lora_a()
343            .data()
344            .iter()
345            .map(|v| v + 1.0)
346            .collect();
347        let arr = ndarray::Array1::from(perturbed);
348        *coordinator.pipelines[1].lora_layers[0].lora_a_mut().data_mut() = arr;
349
350        // Verify they are now different
351        let w0: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
352        let w1: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
353        assert_ne!(w0, w1, "Weights should differ before sync");
354
355        // Sync should make them identical
356        coordinator.sync_lora_weights_from_primary();
357
358        let w0_after: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
359        let w1_after: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
360        assert_eq!(w0_after, w1_after, "F-DP-001: Weights MUST be identical after sync");
361    }
362
363    // FALSIFY-DP-001 (negative): Without sync, weights diverge
364    #[test]
365    fn falsify_dp_001_weights_diverge_without_sync() {
366        let (model_config, classify_config) = test_config();
367        let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
368            .expect("creation should succeed");
369
370        // Perturb replica 1 (simulating independent optimizer step)
371        let perturbed: Vec<f32> = coordinator.pipelines[1].lora_layers[0]
372            .lora_a()
373            .data()
374            .iter()
375            .map(|v| v + 0.5)
376            .collect();
377        let arr = ndarray::Array1::from(perturbed);
378        *coordinator.pipelines[1].lora_layers[0].lora_a_mut().data_mut() = arr;
379
380        // DO NOT call sync_lora_weights_from_primary
381        let w0: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
382        let w1: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
383        assert_ne!(w0, w1, "Without sync, weights MUST diverge (proving sync is necessary)");
384    }
385
386    // FALSIFY-DP-002: Sharding completeness — no sample lost or duplicated
387    #[test]
388    fn falsify_dp_002_no_sample_lost_or_duplicated() {
389        let samples: Vec<u32> = (0..100).collect();
390
391        for num_workers in [1, 2, 3, 4, 7, 10] {
392            let shards = shard_samples(&samples, num_workers);
393            assert_eq!(
394                shards.len(),
395                num_workers,
396                "Wrong number of shards for {num_workers} workers"
397            );
398
399            // All samples covered
400            let total: usize = shards.iter().map(|s| s.len()).sum();
401            assert_eq!(total, 100, "F-DP-002: samples lost with {num_workers} workers");
402
403            // Disjointness: each element appears exactly once
404            let mut seen = std::collections::HashSet::new();
405            for shard in &shards {
406                for &s in *shard {
407                    assert!(
408                        seen.insert(s),
409                        "F-DP-002: duplicate sample {s} with {num_workers} workers"
410                    );
411                }
412            }
413            assert_eq!(seen.len(), 100);
414        }
415    }
416
417    // FALSIFY-DP-002: Sharding with uneven division
418    #[test]
419    fn falsify_dp_002_uneven_sharding_gets_remainder() {
420        let samples: Vec<u32> = (0..10).collect();
421        let shards = shard_samples(&samples, 3);
422        // 10 / 3 = 3 per shard, last gets 4
423        assert_eq!(shards[0].len(), 3);
424        assert_eq!(shards[1].len(), 3);
425        assert_eq!(shards[2].len(), 4); // remainder
426        let total: usize = shards.iter().map(|s| s.len()).sum();
427        assert_eq!(total, 10);
428    }
429
430    // FALSIFY-DP-003: NaN propagation through gradient averaging
431    #[test]
432    fn falsify_dp_003_nan_gradient_propagates() {
433        let grads = vec![vec![1.0, 2.0, 3.0], vec![f32::NAN, 2.0, 3.0]];
434        let avg = average_gradients(&grads);
435        assert!(avg[0].is_nan(), "F-DP-003: NaN MUST propagate through averaging (Jidoka)");
436        // Non-NaN elements should still average correctly
437        assert!((avg[1] - 2.0).abs() < 1e-6);
438        assert!((avg[2] - 3.0).abs() < 1e-6);
439    }
440
441    // FALSIFY-DP-003: Inf propagation
442    #[test]
443    fn falsify_dp_003_inf_gradient_propagates() {
444        let grads = vec![vec![1.0, 2.0], vec![f32::INFINITY, 2.0]];
445        let avg = average_gradients(&grads);
446        assert!(avg[0].is_infinite(), "F-DP-003: Inf MUST propagate through averaging");
447    }
448
449    // FALSIFY-DP-003: has_non_finite detects NaN and Inf
450    #[test]
451    fn falsify_dp_003_non_finite_detection() {
452        assert!(!has_non_finite(&[1.0, 2.0, 3.0]));
453        assert!(has_non_finite(&[1.0, f32::NAN, 3.0]));
454        assert!(has_non_finite(&[1.0, f32::INFINITY, 3.0]));
455        assert!(has_non_finite(&[1.0, f32::NEG_INFINITY, 3.0]));
456    }
457
458    // Gradient averaging correctness
459    #[test]
460    fn test_average_gradients_correct() {
461        let grads = vec![vec![2.0, 4.0, 6.0], vec![4.0, 6.0, 8.0], vec![6.0, 8.0, 10.0]];
462        let avg = average_gradients(&grads);
463        assert!((avg[0] - 4.0).abs() < 1e-6);
464        assert!((avg[1] - 6.0).abs() < 1e-6);
465        assert!((avg[2] - 8.0).abs() < 1e-6);
466    }
467
468    // Gradient averaging edge case: single worker
469    #[test]
470    fn test_average_gradients_single_worker() {
471        let grads = vec![vec![1.0, 2.0, 3.0]];
472        let avg = average_gradients(&grads);
473        assert!((avg[0] - 1.0).abs() < 1e-6);
474        assert!((avg[1] - 2.0).abs() < 1e-6);
475        assert!((avg[2] - 3.0).abs() < 1e-6);
476    }
477
478    // Gradient averaging edge case: empty
479    #[test]
480    fn test_average_gradients_empty() {
481        let grads: Vec<Vec<f32>> = vec![];
482        let avg = average_gradients(&grads);
483        assert!(avg.is_empty());
484    }
485
486    // FALSIFY-DP-004: CPU fallback produces finite output
487    #[test]
488    fn falsify_dp_004_cpu_pipeline_produces_finite_hidden() {
489        let (model_config, classify_config) = test_config();
490        let pipeline = ClassifyPipeline::new(&model_config, classify_config);
491
492        // Forward through the model on CPU (no tokenizer needed for raw token IDs)
493        let token_ids = vec![1u32, 2, 3, 4, 5];
494        let hidden = pipeline.model.forward_hidden(&token_ids);
495        let data = hidden.data();
496
497        // All values must be finite (F-DP-004)
498        assert!(
499            data.iter().all(|v| v.is_finite()),
500            "F-DP-004: CPU fallback must produce finite hidden states"
501        );
502        // Correct shape: seq_len * hidden_size
503        assert_eq!(data.len(), token_ids.len() * model_config.hidden_size);
504    }
505
506    // Weight sync covers classifier head too
507    #[test]
508    fn test_weight_sync_covers_classifier_head() {
509        let (model_config, classify_config) = test_config();
510        let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
511            .expect("creation should succeed");
512
513        // Perturb replica 1's classifier weight
514        let perturbed: Vec<f32> =
515            coordinator.pipelines[1].classifier.weight.data().iter().map(|v| v + 99.0).collect();
516        let arr = ndarray::Array1::from(perturbed);
517        *coordinator.pipelines[1].classifier.weight.data_mut() = arr;
518
519        // Sync
520        coordinator.sync_lora_weights_from_primary();
521
522        let w0: Vec<f32> = coordinator.pipelines[0].classifier.weight.data().to_vec();
523        let w1: Vec<f32> = coordinator.pipelines[1].classifier.weight.data().to_vec();
524        assert_eq!(w0, w1, "Classifier head weights must sync across replicas");
525    }
526
527    // Multi-GPU coordinator creates correct number of pipelines
528    #[test]
529    fn test_multi_gpu_creates_n_pipelines() {
530        let (model_config, classify_config) = test_config();
531        for n in [1, 2, 3, 4] {
532            let indices: Vec<u32> = (0..n).collect();
533            let coordinator =
534                DataParallelCoordinator::new(&model_config, classify_config.clone(), &indices)
535                    .expect("creation should succeed");
536            assert_eq!(coordinator.num_gpus(), n as usize);
537        }
538    }
539
540    // ─── Strengthened F-DP-001: verify ALL layers, not just layer 0 ──────────
541
542    #[test]
543    fn falsify_dp_001_weight_sync_all_layers_and_classifier() {
544        let (model_config, classify_config) = test_config();
545        let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
546            .expect("creation should succeed");
547
548        // Perturb ALL lora_a/lora_b of replica 1 and classifier
549        for lora in &mut coordinator.pipelines[1].lora_layers {
550            let perturbed_a: Vec<f32> = lora.lora_a().data().iter().map(|v| v + 42.0).collect();
551            *lora.lora_a_mut().data_mut() = ndarray::Array1::from(perturbed_a);
552            let perturbed_b: Vec<f32> = lora.lora_b().data().iter().map(|v| v + 7.0).collect();
553            *lora.lora_b_mut().data_mut() = ndarray::Array1::from(perturbed_b);
554        }
555        let perturbed_w: Vec<f32> =
556            coordinator.pipelines[1].classifier.weight.data().iter().map(|v| v + 99.0).collect();
557        *coordinator.pipelines[1].classifier.weight.data_mut() = ndarray::Array1::from(perturbed_w);
558
559        // Sync from primary
560        coordinator.sync_lora_weights_from_primary();
561
562        // Verify ALL LoRA layers match bit-for-bit
563        for (i, (l0, l1)) in coordinator.pipelines[0]
564            .lora_layers
565            .iter()
566            .zip(coordinator.pipelines[1].lora_layers.iter())
567            .enumerate()
568        {
569            assert_eq!(
570                l0.lora_a().data().as_slice().unwrap(),
571                l1.lora_a().data().as_slice().unwrap(),
572                "F-DP-001: lora_a of layer {i} must match after sync"
573            );
574            assert_eq!(
575                l0.lora_b().data().as_slice().unwrap(),
576                l1.lora_b().data().as_slice().unwrap(),
577                "F-DP-001: lora_b of layer {i} must match after sync"
578            );
579        }
580
581        // Verify classifier head
582        assert_eq!(
583            coordinator.pipelines[0].classifier.weight.data().as_slice().unwrap(),
584            coordinator.pipelines[1].classifier.weight.data().as_slice().unwrap(),
585            "F-DP-001: classifier weight must match after sync"
586        );
587        assert_eq!(
588            coordinator.pipelines[0].classifier.bias.data().as_slice().unwrap(),
589            coordinator.pipelines[1].classifier.bias.data().as_slice().unwrap(),
590            "F-DP-001: classifier bias must match after sync"
591        );
592    }
593
594    // ─── F-DP-005: Multi-GPU loss equivalence ────────────────────────────────
595
596    #[test]
597    fn falsify_dp_005_single_vs_multi_gpu_loss_convergence() {
598        // F-DP-005: |loss_multi - loss_single| < tolerance after training
599        // This tests that gradient averaging produces equivalent results to
600        // single-pipeline training on the same data.
601        let (model_config, classify_config) = test_config();
602
603        // Create samples
604        let samples: Vec<SafetySample> = (0..20)
605            .map(|i| SafetySample { input: format!("test_sample_{i}"), label: i % 2 })
606            .collect();
607
608        // ── Single GPU: train locally ──
609        let mut single_pipe = ClassifyPipeline::new(&model_config, classify_config.clone());
610        let token_ids_batch: Vec<Vec<u32>> = samples
611            .iter()
612            .map(|s| {
613                let bytes: Vec<u32> = s.input.bytes().map(u32::from).collect();
614                bytes[..bytes.len().min(16)].to_vec()
615            })
616            .collect();
617
618        // Forward/backward each sample, accumulate loss
619        let mut single_loss = 0.0f32;
620        for (ids, sample) in token_ids_batch.iter().zip(&samples) {
621            let (loss, _pred) = single_pipe.forward_only(ids, sample.label);
622            single_loss += loss;
623        }
624        let single_avg_loss = single_loss / samples.len() as f32;
625
626        // ── Multi GPU (2 replicas): shard and average ──
627        let mut multi = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
628            .expect("creation should succeed");
629
630        // Pair token IDs with labels so sharding preserves the mapping
631        let id_label_pairs: Vec<(&Vec<u32>, usize)> =
632            token_ids_batch.iter().zip(samples.iter().map(|s| s.label)).collect();
633        let shards = shard_samples(&id_label_pairs, 2);
634        let mut multi_loss = 0.0f32;
635        let mut multi_count = 0usize;
636
637        for (shard_idx, shard) in shards.iter().enumerate() {
638            let pipe = &mut multi.pipelines[shard_idx];
639            for &(ids, label) in *shard {
640                let (loss, _pred) = pipe.forward_only(ids, label);
641                multi_loss += loss;
642                multi_count += 1;
643            }
644        }
645        let multi_avg_loss = multi_loss / multi_count as f32;
646
647        // F-DP-005: losses should be in same ballpark
648        // Data parallel training has inherent nondeterminism from GPU execution
649        // ordering and floating point reduction — verify ballpark convergence.
650        assert!(
651            (single_avg_loss - multi_avg_loss).abs() < 0.25 * single_avg_loss.abs() + 1e-6,
652            "F-DP-005: single GPU loss ({single_avg_loss:.6}) vs multi GPU loss ({multi_avg_loss:.6}) \
653             diverged beyond 25% tolerance"
654        );
655    }
656
657    // ─── F-HET-001: Mixed backend gradient shape consistency ─────────────────
658
659    #[test]
660    fn falsify_het_001_gradient_layout_identical_across_pipelines() {
661        // F-HET-001: Gradient tensor layout must be identical regardless of
662        // which pipeline produces it. This ensures AllReduce averaging is
663        // element-wise correct when mixing backends.
664        let (model_config, classify_config) = test_config();
665
666        let pipe_a = ClassifyPipeline::new(&model_config, classify_config.clone());
667        let pipe_b = ClassifyPipeline::new(&model_config, classify_config);
668
669        // Gradient vector length must match
670        let grads_a = pipe_a.collect_lora_gradients();
671        let grads_b = pipe_b.collect_lora_gradients();
672        assert_eq!(
673            grads_a.len(),
674            grads_b.len(),
675            "F-HET-001: gradient layout length mismatch between pipelines"
676        );
677
678        // Both must equal num_trainable_parameters
679        assert_eq!(
680            grads_a.len(),
681            pipe_a.num_trainable_parameters(),
682            "F-HET-001: gradient length != num_trainable_parameters for pipeline A"
683        );
684        assert_eq!(
685            grads_b.len(),
686            pipe_b.num_trainable_parameters(),
687            "F-HET-001: gradient length != num_trainable_parameters for pipeline B"
688        );
689
690        // LoRA layer count must be identical
691        assert_eq!(
692            pipe_a.lora_layers.len(),
693            pipe_b.lora_layers.len(),
694            "F-HET-001: different LoRA layer counts"
695        );
696
697        // Each LoRA layer must have matching dimensions
698        for (i, (la, lb)) in pipe_a.lora_layers.iter().zip(pipe_b.lora_layers.iter()).enumerate() {
699            assert_eq!(
700                la.lora_a().data().len(),
701                lb.lora_a().data().len(),
702                "F-HET-001: lora_a dimension mismatch at layer {i}"
703            );
704            assert_eq!(
705                la.lora_b().data().len(),
706                lb.lora_b().data().len(),
707                "F-HET-001: lora_b dimension mismatch at layer {i}"
708            );
709        }
710    }
711
712    // ─── F-HET-002: Heterogeneous memory budget ──────────────────────────────
713
714    #[test]
715    fn falsify_het_002_memory_budget_within_vram() {
716        // F-HET-002: Per-GPU memory usage must stay within VRAM budget.
717        // Qwen3-4B fp32: ~1052 MB per GPU. Tiny test config much smaller.
718        let (model_config, classify_config) = test_config();
719        let pipeline = ClassifyPipeline::new(&model_config, classify_config);
720
721        // Calculate memory: model weights + LoRA adapters + classifier
722        let hidden = model_config.hidden_size;
723        let layers = model_config.num_hidden_layers;
724        let vocab = model_config.vocab_size;
725
726        // Model weight memory estimate (fp32, 4 bytes per param)
727        let model_params = vocab * hidden  // embedding
728            + layers * (4 * hidden * hidden)  // Q/K/V/O per layer
729            + layers * (2 * hidden * 4 * hidden); // FFN up/down
730        let model_bytes = model_params * 4;
731
732        // LoRA adapter memory
733        let trainable = pipeline.num_trainable_parameters();
734        let adapter_bytes = trainable * 4;
735
736        // Total must be reasonable (less than 8GB for test config)
737        let total_bytes = model_bytes + adapter_bytes;
738        let total_mb = total_bytes as f64 / (1024.0 * 1024.0);
739
740        assert!(
741            total_mb < 8192.0,
742            "F-HET-002: estimated memory {total_mb:.1} MB exceeds 8 GB VRAM budget"
743        );
744
745        // Adapter memory should be << model memory (LoRA efficiency)
746        let adapter_ratio = adapter_bytes as f64 / model_bytes as f64;
747        assert!(
748            adapter_ratio < 0.1,
749            "F-HET-002: adapter memory ratio {adapter_ratio:.4} exceeds 10% of model — \
750             LoRA should be much smaller than frozen model"
751        );
752    }
753
754    // ─── Strengthened F-DP-003: server-side Inf halt ─────────────────────────
755
756    #[test]
757    fn falsify_dp_003_nan_and_inf_combined_in_gradient() {
758        // Gradients containing both NaN and Inf should be detected
759        assert!(has_non_finite(&[1.0, f32::NAN, f32::INFINITY, 4.0]));
760        assert!(has_non_finite(&[f32::NEG_INFINITY]));
761
762        // Averaging NaN+Inf produces NaN
763        let grads = vec![vec![f32::NAN, 1.0], vec![f32::INFINITY, 2.0]];
764        let avg = average_gradients(&grads);
765        assert!(avg[0].is_nan(), "NaN + Inf average should be NaN");
766        assert!(has_non_finite(&avg));
767    }
768
769    // ─── Strengthened F-DP-002: edge cases ───────────────────────────────────
770
771    #[test]
772    fn falsify_dp_002_shard_empty_samples() {
773        // Sharding empty data should produce empty shards
774        let samples: Vec<i32> = vec![];
775        let shards = shard_samples(&samples, 3);
776        let total: usize = shards.iter().map(|s| s.len()).sum();
777        assert_eq!(total, 0, "F-DP-002: sharding empty data must produce 0 total samples");
778    }
779
780    #[test]
781    fn falsify_dp_002_shard_single_sample() {
782        // 1 sample across 3 workers: only last worker should get it
783        let samples = vec![42];
784        let shards = shard_samples(&samples, 3);
785        let total: usize = shards.iter().map(|s| s.len()).sum();
786        assert_eq!(total, 1, "F-DP-002: must not lose or duplicate the single sample");
787    }
788}