use crate::finetune::classification::SafetySample;
use crate::finetune::classify_pipeline::{BatchResult, ClassifyConfig, ClassifyPipeline};
use crate::transformer::TransformerConfig;
pub struct DataParallelCoordinator {
pipelines: Vec<ClassifyPipeline>,
#[allow(dead_code)]
gpu_indices: Vec<u32>,
}
impl DataParallelCoordinator {
pub fn new(
model_config: &TransformerConfig,
classify_config: ClassifyConfig,
gpu_indices: &[u32],
) -> Result<Self, String> {
if gpu_indices.is_empty() {
return Err("At least one GPU index required".to_string());
}
let mut pipelines = Vec::with_capacity(gpu_indices.len());
for &_idx in gpu_indices {
let pipeline = ClassifyPipeline::new(model_config, classify_config.clone());
pipelines.push(pipeline);
}
Ok(Self { pipelines, gpu_indices: gpu_indices.to_vec() })
}
#[must_use]
pub fn num_gpus(&self) -> usize {
self.pipelines.len()
}
pub fn primary_pipeline(&mut self) -> &mut ClassifyPipeline {
&mut self.pipelines[0]
}
pub fn primary_pipeline_ref(&self) -> &ClassifyPipeline {
&self.pipelines[0]
}
pub fn train_batch_parallel(&mut self, samples: &[SafetySample]) -> BatchResult {
let num_gpus = self.pipelines.len();
if num_gpus == 1 || samples.len() < num_gpus {
return self.pipelines[0].train_batch(samples);
}
let shard_size = samples.len() / num_gpus;
let shards: Vec<&[SafetySample]> = (0..num_gpus)
.map(|i| {
let start = i * shard_size;
let end = if i == num_gpus - 1 { samples.len() } else { start + shard_size };
&samples[start..end]
})
.collect();
let mut results = Vec::with_capacity(num_gpus);
for (gpu_idx, shard) in shards.iter().enumerate() {
let result = self.pipelines[gpu_idx].train_batch(shard);
results.push(result);
}
let total_samples: usize = results.iter().map(|r| r.total).sum();
let total_correct: usize = results.iter().map(|r| r.correct).sum();
let avg_loss: f32 =
results.iter().map(|r| r.avg_loss * r.total as f32).sum::<f32>() / total_samples as f32;
let avg_grad_norm: f32 = results.iter().map(|r| r.grad_norm).sum::<f32>() / num_gpus as f32;
if self.pipelines.len() > 1 {
self.sync_lora_weights_from_primary();
}
BatchResult {
avg_loss,
correct: total_correct,
total: total_samples,
grad_norm: avg_grad_norm,
}
}
fn sync_lora_weights_from_primary(&mut self) {
if self.pipelines.len() <= 1 {
return;
}
let (primary_slice, replicas) = self.pipelines.split_at_mut(1);
let primary = &primary_slice[0];
for replica in replicas.iter_mut() {
for (src_lora, dst_lora) in
primary.lora_layers.iter().zip(replica.lora_layers.iter_mut())
{
dst_lora.lora_a_mut().data_mut().assign(src_lora.lora_a().data());
dst_lora.lora_b_mut().data_mut().assign(src_lora.lora_b().data());
}
replica.classifier.weight.data_mut().assign(primary.classifier.weight.data());
replica.classifier.bias.data_mut().assign(primary.classifier.bias.data());
}
}
}
pub fn shard_samples<T>(samples: &[T], num_workers: usize) -> Vec<&[T]> {
if num_workers == 0 || samples.is_empty() {
return vec![samples];
}
let shard_size = samples.len() / num_workers;
(0..num_workers)
.map(|i| {
let start = i * shard_size;
let end = if i == num_workers - 1 { samples.len() } else { start + shard_size };
&samples[start..end]
})
.collect()
}
pub fn average_gradients(grads: &[Vec<f32>]) -> Vec<f32> {
if grads.is_empty() {
return Vec::new();
}
let len = grads[0].len();
let n = grads.len() as f32;
let mut avg = vec![0.0f32; len];
for grad in grads {
for (j, &v) in grad.iter().enumerate() {
avg[j] += v;
}
}
for v in &mut avg {
*v /= n;
}
avg
}
pub fn has_non_finite(values: &[f32]) -> bool {
values.iter().any(|v| !v.is_finite())
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::transformer::ModelArchitecture;
fn test_config() -> (TransformerConfig, ClassifyConfig) {
let model_config = TransformerConfig {
hidden_size: 32,
num_hidden_layers: 2,
num_attention_heads: 4,
num_kv_heads: 4,
intermediate_size: 64,
vocab_size: 100,
max_position_embeddings: 64,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
use_bias: false,
head_dim_override: None,
architecture: ModelArchitecture::Decoder,
hf_architecture: None,
hf_model_type: None,
tie_word_embeddings: false,
};
let classify_config =
ClassifyConfig { num_classes: 2, lora_rank: 4, ..ClassifyConfig::default() };
(model_config, classify_config)
}
#[test]
fn test_coordinator_creation() {
let (model_config, classify_config) = test_config();
let coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0]);
assert!(coordinator.is_ok());
assert_eq!(
coordinator.as_ref().map(super::DataParallelCoordinator::num_gpus).unwrap_or(0),
1
);
}
#[test]
fn test_coordinator_empty_gpus_fails() {
let (model_config, classify_config) = test_config();
let result = DataParallelCoordinator::new(&model_config, classify_config, &[]);
assert!(result.is_err());
}
#[test]
fn test_multi_gpu_coordinator_accessors() {
let (model_config, classify_config) = test_config();
let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
.expect("creation should succeed");
assert_eq!(coordinator.num_gpus(), 1);
let primary = coordinator.primary_pipeline();
assert_eq!(primary.config.num_classes, 2);
let primary_ref = coordinator.primary_pipeline_ref();
assert_eq!(primary_ref.config.lora_rank, 4);
}
#[test]
fn test_single_gpu_fallback_path() {
let (model_config, classify_config) = test_config();
let coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
.expect("creation should succeed");
assert_eq!(coordinator.num_gpus(), 1);
}
#[test]
fn test_weight_sync_noop_single_gpu() {
let (model_config, classify_config) = test_config();
let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0])
.expect("creation should succeed");
coordinator.sync_lora_weights_from_primary();
}
#[test]
fn falsify_dp_001_weight_sync_makes_replicas_identical() {
let (model_config, classify_config) = test_config();
let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
.expect("creation should succeed");
let perturbed: Vec<f32> = coordinator.pipelines[1].lora_layers[0]
.lora_a()
.data()
.iter()
.map(|v| v + 1.0)
.collect();
let arr = ndarray::Array1::from(perturbed);
*coordinator.pipelines[1].lora_layers[0].lora_a_mut().data_mut() = arr;
let w0: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
let w1: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
assert_ne!(w0, w1, "Weights should differ before sync");
coordinator.sync_lora_weights_from_primary();
let w0_after: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
let w1_after: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
assert_eq!(w0_after, w1_after, "F-DP-001: Weights MUST be identical after sync");
}
#[test]
fn falsify_dp_001_weights_diverge_without_sync() {
let (model_config, classify_config) = test_config();
let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
.expect("creation should succeed");
let perturbed: Vec<f32> = coordinator.pipelines[1].lora_layers[0]
.lora_a()
.data()
.iter()
.map(|v| v + 0.5)
.collect();
let arr = ndarray::Array1::from(perturbed);
*coordinator.pipelines[1].lora_layers[0].lora_a_mut().data_mut() = arr;
let w0: Vec<f32> = coordinator.pipelines[0].lora_layers[0].lora_a().data().to_vec();
let w1: Vec<f32> = coordinator.pipelines[1].lora_layers[0].lora_a().data().to_vec();
assert_ne!(w0, w1, "Without sync, weights MUST diverge (proving sync is necessary)");
}
#[test]
fn falsify_dp_002_no_sample_lost_or_duplicated() {
let samples: Vec<u32> = (0..100).collect();
for num_workers in [1, 2, 3, 4, 7, 10] {
let shards = shard_samples(&samples, num_workers);
assert_eq!(
shards.len(),
num_workers,
"Wrong number of shards for {num_workers} workers"
);
let total: usize = shards.iter().map(|s| s.len()).sum();
assert_eq!(total, 100, "F-DP-002: samples lost with {num_workers} workers");
let mut seen = std::collections::HashSet::new();
for shard in &shards {
for &s in *shard {
assert!(
seen.insert(s),
"F-DP-002: duplicate sample {s} with {num_workers} workers"
);
}
}
assert_eq!(seen.len(), 100);
}
}
#[test]
fn falsify_dp_002_uneven_sharding_gets_remainder() {
let samples: Vec<u32> = (0..10).collect();
let shards = shard_samples(&samples, 3);
assert_eq!(shards[0].len(), 3);
assert_eq!(shards[1].len(), 3);
assert_eq!(shards[2].len(), 4); let total: usize = shards.iter().map(|s| s.len()).sum();
assert_eq!(total, 10);
}
#[test]
fn falsify_dp_003_nan_gradient_propagates() {
let grads = vec![vec![1.0, 2.0, 3.0], vec![f32::NAN, 2.0, 3.0]];
let avg = average_gradients(&grads);
assert!(avg[0].is_nan(), "F-DP-003: NaN MUST propagate through averaging (Jidoka)");
assert!((avg[1] - 2.0).abs() < 1e-6);
assert!((avg[2] - 3.0).abs() < 1e-6);
}
#[test]
fn falsify_dp_003_inf_gradient_propagates() {
let grads = vec![vec![1.0, 2.0], vec![f32::INFINITY, 2.0]];
let avg = average_gradients(&grads);
assert!(avg[0].is_infinite(), "F-DP-003: Inf MUST propagate through averaging");
}
#[test]
fn falsify_dp_003_non_finite_detection() {
assert!(!has_non_finite(&[1.0, 2.0, 3.0]));
assert!(has_non_finite(&[1.0, f32::NAN, 3.0]));
assert!(has_non_finite(&[1.0, f32::INFINITY, 3.0]));
assert!(has_non_finite(&[1.0, f32::NEG_INFINITY, 3.0]));
}
#[test]
fn test_average_gradients_correct() {
let grads = vec![vec![2.0, 4.0, 6.0], vec![4.0, 6.0, 8.0], vec![6.0, 8.0, 10.0]];
let avg = average_gradients(&grads);
assert!((avg[0] - 4.0).abs() < 1e-6);
assert!((avg[1] - 6.0).abs() < 1e-6);
assert!((avg[2] - 8.0).abs() < 1e-6);
}
#[test]
fn test_average_gradients_single_worker() {
let grads = vec![vec![1.0, 2.0, 3.0]];
let avg = average_gradients(&grads);
assert!((avg[0] - 1.0).abs() < 1e-6);
assert!((avg[1] - 2.0).abs() < 1e-6);
assert!((avg[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_average_gradients_empty() {
let grads: Vec<Vec<f32>> = vec![];
let avg = average_gradients(&grads);
assert!(avg.is_empty());
}
#[test]
fn falsify_dp_004_cpu_pipeline_produces_finite_hidden() {
let (model_config, classify_config) = test_config();
let pipeline = ClassifyPipeline::new(&model_config, classify_config);
let token_ids = vec![1u32, 2, 3, 4, 5];
let hidden = pipeline.model.forward_hidden(&token_ids);
let data = hidden.data();
assert!(
data.iter().all(|v| v.is_finite()),
"F-DP-004: CPU fallback must produce finite hidden states"
);
assert_eq!(data.len(), token_ids.len() * model_config.hidden_size);
}
#[test]
fn test_weight_sync_covers_classifier_head() {
let (model_config, classify_config) = test_config();
let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
.expect("creation should succeed");
let perturbed: Vec<f32> =
coordinator.pipelines[1].classifier.weight.data().iter().map(|v| v + 99.0).collect();
let arr = ndarray::Array1::from(perturbed);
*coordinator.pipelines[1].classifier.weight.data_mut() = arr;
coordinator.sync_lora_weights_from_primary();
let w0: Vec<f32> = coordinator.pipelines[0].classifier.weight.data().to_vec();
let w1: Vec<f32> = coordinator.pipelines[1].classifier.weight.data().to_vec();
assert_eq!(w0, w1, "Classifier head weights must sync across replicas");
}
#[test]
fn test_multi_gpu_creates_n_pipelines() {
let (model_config, classify_config) = test_config();
for n in [1, 2, 3, 4] {
let indices: Vec<u32> = (0..n).collect();
let coordinator =
DataParallelCoordinator::new(&model_config, classify_config.clone(), &indices)
.expect("creation should succeed");
assert_eq!(coordinator.num_gpus(), n as usize);
}
}
#[test]
fn falsify_dp_001_weight_sync_all_layers_and_classifier() {
let (model_config, classify_config) = test_config();
let mut coordinator = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
.expect("creation should succeed");
for lora in &mut coordinator.pipelines[1].lora_layers {
let perturbed_a: Vec<f32> = lora.lora_a().data().iter().map(|v| v + 42.0).collect();
*lora.lora_a_mut().data_mut() = ndarray::Array1::from(perturbed_a);
let perturbed_b: Vec<f32> = lora.lora_b().data().iter().map(|v| v + 7.0).collect();
*lora.lora_b_mut().data_mut() = ndarray::Array1::from(perturbed_b);
}
let perturbed_w: Vec<f32> =
coordinator.pipelines[1].classifier.weight.data().iter().map(|v| v + 99.0).collect();
*coordinator.pipelines[1].classifier.weight.data_mut() = ndarray::Array1::from(perturbed_w);
coordinator.sync_lora_weights_from_primary();
for (i, (l0, l1)) in coordinator.pipelines[0]
.lora_layers
.iter()
.zip(coordinator.pipelines[1].lora_layers.iter())
.enumerate()
{
assert_eq!(
l0.lora_a().data().as_slice().unwrap(),
l1.lora_a().data().as_slice().unwrap(),
"F-DP-001: lora_a of layer {i} must match after sync"
);
assert_eq!(
l0.lora_b().data().as_slice().unwrap(),
l1.lora_b().data().as_slice().unwrap(),
"F-DP-001: lora_b of layer {i} must match after sync"
);
}
assert_eq!(
coordinator.pipelines[0].classifier.weight.data().as_slice().unwrap(),
coordinator.pipelines[1].classifier.weight.data().as_slice().unwrap(),
"F-DP-001: classifier weight must match after sync"
);
assert_eq!(
coordinator.pipelines[0].classifier.bias.data().as_slice().unwrap(),
coordinator.pipelines[1].classifier.bias.data().as_slice().unwrap(),
"F-DP-001: classifier bias must match after sync"
);
}
#[test]
fn falsify_dp_005_single_vs_multi_gpu_loss_convergence() {
let (model_config, classify_config) = test_config();
let samples: Vec<SafetySample> = (0..20)
.map(|i| SafetySample { input: format!("test_sample_{i}"), label: i % 2 })
.collect();
let mut single_pipe = ClassifyPipeline::new(&model_config, classify_config.clone());
let token_ids_batch: Vec<Vec<u32>> = samples
.iter()
.map(|s| {
let bytes: Vec<u32> = s.input.bytes().map(u32::from).collect();
bytes[..bytes.len().min(16)].to_vec()
})
.collect();
let mut single_loss = 0.0f32;
for (ids, sample) in token_ids_batch.iter().zip(&samples) {
let (loss, _pred) = single_pipe.forward_only(ids, sample.label);
single_loss += loss;
}
let single_avg_loss = single_loss / samples.len() as f32;
let mut multi = DataParallelCoordinator::new(&model_config, classify_config, &[0, 1])
.expect("creation should succeed");
let id_label_pairs: Vec<(&Vec<u32>, usize)> =
token_ids_batch.iter().zip(samples.iter().map(|s| s.label)).collect();
let shards = shard_samples(&id_label_pairs, 2);
let mut multi_loss = 0.0f32;
let mut multi_count = 0usize;
for (shard_idx, shard) in shards.iter().enumerate() {
let pipe = &mut multi.pipelines[shard_idx];
for &(ids, label) in *shard {
let (loss, _pred) = pipe.forward_only(ids, label);
multi_loss += loss;
multi_count += 1;
}
}
let multi_avg_loss = multi_loss / multi_count as f32;
assert!(
(single_avg_loss - multi_avg_loss).abs() < 0.25 * single_avg_loss.abs() + 1e-6,
"F-DP-005: single GPU loss ({single_avg_loss:.6}) vs multi GPU loss ({multi_avg_loss:.6}) \
diverged beyond 25% tolerance"
);
}
#[test]
fn falsify_het_001_gradient_layout_identical_across_pipelines() {
let (model_config, classify_config) = test_config();
let pipe_a = ClassifyPipeline::new(&model_config, classify_config.clone());
let pipe_b = ClassifyPipeline::new(&model_config, classify_config);
let grads_a = pipe_a.collect_lora_gradients();
let grads_b = pipe_b.collect_lora_gradients();
assert_eq!(
grads_a.len(),
grads_b.len(),
"F-HET-001: gradient layout length mismatch between pipelines"
);
assert_eq!(
grads_a.len(),
pipe_a.num_trainable_parameters(),
"F-HET-001: gradient length != num_trainable_parameters for pipeline A"
);
assert_eq!(
grads_b.len(),
pipe_b.num_trainable_parameters(),
"F-HET-001: gradient length != num_trainable_parameters for pipeline B"
);
assert_eq!(
pipe_a.lora_layers.len(),
pipe_b.lora_layers.len(),
"F-HET-001: different LoRA layer counts"
);
for (i, (la, lb)) in pipe_a.lora_layers.iter().zip(pipe_b.lora_layers.iter()).enumerate() {
assert_eq!(
la.lora_a().data().len(),
lb.lora_a().data().len(),
"F-HET-001: lora_a dimension mismatch at layer {i}"
);
assert_eq!(
la.lora_b().data().len(),
lb.lora_b().data().len(),
"F-HET-001: lora_b dimension mismatch at layer {i}"
);
}
}
#[test]
fn falsify_het_002_memory_budget_within_vram() {
let (model_config, classify_config) = test_config();
let pipeline = ClassifyPipeline::new(&model_config, classify_config);
let hidden = model_config.hidden_size;
let layers = model_config.num_hidden_layers;
let vocab = model_config.vocab_size;
let model_params = vocab * hidden + layers * (4 * hidden * hidden) + layers * (2 * hidden * 4 * hidden); let model_bytes = model_params * 4;
let trainable = pipeline.num_trainable_parameters();
let adapter_bytes = trainable * 4;
let total_bytes = model_bytes + adapter_bytes;
let total_mb = total_bytes as f64 / (1024.0 * 1024.0);
assert!(
total_mb < 8192.0,
"F-HET-002: estimated memory {total_mb:.1} MB exceeds 8 GB VRAM budget"
);
let adapter_ratio = adapter_bytes as f64 / model_bytes as f64;
assert!(
adapter_ratio < 0.1,
"F-HET-002: adapter memory ratio {adapter_ratio:.4} exceeds 10% of model — \
LoRA should be much smaller than frozen model"
);
}
#[test]
fn falsify_dp_003_nan_and_inf_combined_in_gradient() {
assert!(has_non_finite(&[1.0, f32::NAN, f32::INFINITY, 4.0]));
assert!(has_non_finite(&[f32::NEG_INFINITY]));
let grads = vec![vec![f32::NAN, 1.0], vec![f32::INFINITY, 2.0]];
let avg = average_gradients(&grads);
assert!(avg[0].is_nan(), "NaN + Inf average should be NaN");
assert!(has_non_finite(&avg));
}
#[test]
fn falsify_dp_002_shard_empty_samples() {
let samples: Vec<i32> = vec![];
let shards = shard_samples(&samples, 3);
let total: usize = shards.iter().map(|s| s.len()).sum();
assert_eq!(total, 0, "F-DP-002: sharding empty data must produce 0 total samples");
}
#[test]
fn falsify_dp_002_shard_single_sample() {
let samples = vec![42];
let shards = shard_samples(&samples, 3);
let total: usize = shards.iter().map(|s| s.len()).sum();
assert_eq!(total, 1, "F-DP-002: must not lose or duplicate the single sample");
}
}