#![allow(
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
clippy::uninlined_format_args,
clippy::explicit_iter_loop,
clippy::redundant_else
)]
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::{FerrotorchError, Tensor};
use ferrotorch_distributed::backend::{Backend, SimulatedBackend};
use ferrotorch_distributed::{
DEFAULT_COLLECTIVE_TIMEOUT, DTensor, DeviceMesh, DistCheckpointError, DistributedError,
Placement, ReduceOp, SubBackend, all_gather, all_gather_with_timeout, all_to_all,
all_to_all_single_uneven, all_to_all_with_timeout, allreduce, allreduce_with_timeout,
async_all_gather, async_reduce_scatter, barrier, broadcast, flat_shard_metadata,
is_gloo_available, is_mpi_available, is_ucc_available, load_distributed, recv, recv_into,
recv_into_with_timeout, recv_with_timeout, reduce_scatter, reduce_scatter_tensor,
reduce_scatter_with_timeout, send, sendrecv,
};
use serde::Deserialize;
macro_rules! cascade_skip {
($reason:literal) => {{
eprintln!(" [cascade_skip] {} — {}", module_path!(), $reason);
return;
}};
}
#[derive(Debug, Deserialize)]
struct FixtureFile {
metadata: FixtureMetadata,
fixtures: Vec<Fixture>,
}
#[derive(Debug, Deserialize)]
struct FixtureMetadata {
torch_version: String,
#[allow(dead_code, reason = "metadata kept for diagnostics")]
python_platform: String,
#[allow(dead_code, reason = "metadata kept for diagnostics")]
generated_at: String,
#[allow(dead_code, reason = "metadata kept for diagnostics")]
conformance_note: String,
}
#[derive(Debug, Deserialize)]
struct Fixture {
op: String,
#[serde(default)]
expected: Option<serde_json::Value>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
note: Option<String>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
platform_note: Option<String>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
expected_error: Option<String>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
cascade_skip_reason: Option<String>,
#[serde(default)]
expected_secs: Option<u64>,
#[serde(default)]
world_size: Option<usize>,
#[serde(default)]
expected_len: Option<usize>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
expected_rank_0: Option<usize>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
expected_world_size_0: Option<usize>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
expected_ok: Option<bool>,
#[serde(default)]
input: Option<Vec<f32>>,
#[serde(default)]
shape: Option<Vec<usize>>,
#[serde(default)]
input_shape: Option<Vec<usize>>,
#[serde(default)]
input_rank0: Option<Vec<f32>>,
#[serde(default)]
input_rank1: Option<Vec<f32>>,
#[serde(default)]
expected_all_ranks: Option<Vec<f32>>,
#[serde(default)]
expected_rank0: Option<Vec<f32>>,
#[serde(default)]
expected_rank1: Option<Vec<f32>>,
#[serde(default)]
#[allow(
dead_code,
reason = "metadata kept for diagnostics — op_type used in fixture lookup"
)]
op_type: Option<String>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
root: Option<usize>,
#[serde(default)]
members: Option<Vec<usize>>,
#[serde(default)]
expected_members: Option<Vec<usize>>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
mesh_shape: Option<Vec<usize>>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
mesh_world_size: Option<usize>,
#[serde(default)]
expected_ndim: Option<usize>,
#[serde(default)]
expected_size: Option<usize>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
num_shards: Option<usize>,
#[serde(default)]
#[allow(dead_code, reason = "metadata kept for diagnostics")]
total_elements: Option<usize>,
#[serde(default)]
rank0_sends: Option<Vec<f32>>,
#[serde(default)]
rank1_sends: Option<Vec<f32>>,
#[serde(default)]
expected_rank0_receives: Option<Vec<f32>>,
#[serde(default)]
expected_rank1_receives: Option<Vec<f32>>,
}
fn load_fixtures() -> FixtureFile {
let p = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("conformance")
.join("fixtures.json");
let bytes = std::fs::read(&p).unwrap_or_else(|e| {
panic!(
"read {} failed: {e}. Regenerate via \
scripts/regenerate_distributed_fixtures.py",
p.display()
)
});
serde_json::from_slice(&bytes).unwrap_or_else(|e| panic!("parse {}: {e}", p.display()))
}
fn fixtures_for<'a>(file: &'a FixtureFile, op: &str) -> Vec<&'a Fixture> {
file.fixtures.iter().filter(|f| f.op == op).collect()
}
fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32, ctx: &str) {
assert_eq!(
actual.len(),
expected.len(),
"{ctx}: length mismatch — actual={}, expected={}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
let diff = (a - e).abs();
assert!(
diff <= tol,
"{ctx}: index {i}: |{a} - {e}| = {diff} > tol {tol}"
);
}
}
#[test]
fn fixture_file_covers_every_expected_op() {
let file = load_fixtures();
let mut by_op: HashMap<&str, usize> = HashMap::new();
for f in &file.fixtures {
*by_op.entry(f.op.as_str()).or_insert(0) += 1;
}
let required = [
"is_gloo_available",
"is_mpi_available",
"is_ucc_available",
"ReduceOp_variants",
"DEFAULT_COLLECTIVE_TIMEOUT_secs",
"SimulatedBackend_create_group_world_size_1",
"SimulatedBackend_create_group_world_size_2",
"allreduce_world_size_1_sum",
"allreduce_world_size_2_sum",
"broadcast_world_size_1",
"all_gather_world_size_1",
"reduce_scatter_world_size_1_sum",
"barrier_world_size_1",
"send_recv_round_trip",
"DeviceMesh_new_valid",
"Placement_variants",
"DTensor_from_local_valid",
"DistributedError_display",
"PendingCollective_op_name",
];
for r in required {
let n = by_op.get(r).copied().unwrap_or(0);
assert!(
n > 0,
"fixture file missing op {r:?} (have: {:?})",
by_op.keys().collect::<Vec<_>>()
);
}
eprintln!(
" fixture_file_covers_every_expected_op: {} ops, {} fixtures, torch={}",
by_op.len(),
file.fixtures.len(),
file.metadata.torch_version,
);
}
#[test]
fn is_gloo_available_matches_fixture() {
let file = load_fixtures();
let cases = fixtures_for(&file, "is_gloo_available");
assert!(!cases.is_empty(), "fixture is_gloo_available not found");
let fixture_expected = cases[0]
.expected
.as_ref()
.and_then(|v| v.as_bool())
.expect("is_gloo_available fixture must have bool `expected`");
let ft_result = is_gloo_available();
if fixture_expected && !ft_result {
cascade_skip!(
"torch.distributed.is_gloo_available()=True but is_gloo_available()=False — \
divergence; tracking issue #882"
);
}
if !fixture_expected && ft_result {
cascade_skip!(
"torch.distributed.is_gloo_available()=False but is_gloo_available()=True \
(build has `--features=gloo-backend`, fixture predates native backend); \
post-#1132 expected divergence"
);
}
assert_eq!(
ft_result, fixture_expected,
"is_gloo_available() parity with torch.distributed.is_gloo_available()"
);
}
#[test]
fn is_mpi_available_matches_fixture() {
let file = load_fixtures();
let cases = fixtures_for(&file, "is_mpi_available");
assert!(!cases.is_empty(), "fixture is_mpi_available not found");
let fixture_expected = cases[0]
.expected
.as_ref()
.and_then(|v| v.as_bool())
.expect("is_mpi_available fixture must have bool `expected`");
let ft_result = is_mpi_available();
if fixture_expected && !ft_result {
cascade_skip!(
"torch.distributed.is_mpi_available()=True but is_mpi_available()=False — \
divergence; tracking issue #889"
);
}
if !fixture_expected && ft_result {
cascade_skip!(
"torch.distributed.is_mpi_available()=False but is_mpi_available()=True \
(build has `--features=mpi-native`, fixture predates native backend); \
post-#1133 expected divergence"
);
}
assert_eq!(
ft_result, fixture_expected,
"is_mpi_available() parity with torch.distributed.is_mpi_available()"
);
}
#[test]
fn is_ucc_available_matches_fixture() {
let file = load_fixtures();
let cases = fixtures_for(&file, "is_ucc_available");
assert!(!cases.is_empty(), "fixture is_ucc_available not found");
let fixture_expected = cases[0]
.expected
.as_ref()
.and_then(|v| v.as_bool())
.expect("is_ucc_available fixture must have bool `expected`");
let ft_result = is_ucc_available();
if fixture_expected && !ft_result {
cascade_skip!(
"torch.distributed.is_ucc_available()=True but is_ucc_available()=False — \
divergence; tracking issue #890"
);
}
if !fixture_expected && ft_result {
cascade_skip!(
"torch.distributed.is_ucc_available()=False but is_ucc_available()=True \
(build has `--features=ucc-native`, fixture predates native router); \
post-#1134 expected divergence"
);
}
assert_eq!(
ft_result, fixture_expected,
"is_ucc_available() parity with torch.distributed.is_ucc_available()"
);
}
#[test]
fn reduce_op_variants_exist() {
let sum = ReduceOp::Sum;
let mean = ReduceOp::Mean;
match sum {
ReduceOp::Sum => {}
ReduceOp::Mean => unreachable!(),
}
match mean {
ReduceOp::Sum => unreachable!(),
ReduceOp::Mean => {}
}
let _copy = sum;
assert_eq!(sum, ReduceOp::Sum);
assert_eq!(mean, ReduceOp::Mean);
assert_ne!(sum, mean);
}
#[test]
fn default_collective_timeout_matches_fixture() {
let file = load_fixtures();
let cases = fixtures_for(&file, "DEFAULT_COLLECTIVE_TIMEOUT_secs");
assert!(
!cases.is_empty(),
"fixture DEFAULT_COLLECTIVE_TIMEOUT_secs not found"
);
let expected_secs = cases[0]
.expected_secs
.expect("fixture must have expected_secs");
assert_eq!(
DEFAULT_COLLECTIVE_TIMEOUT.as_secs(),
expected_secs,
"DEFAULT_COLLECTIVE_TIMEOUT must be {} seconds",
expected_secs
);
}
#[test]
fn simulated_backend_create_group_world_size_1() {
let file = load_fixtures();
let cases = fixtures_for(&file, "SimulatedBackend_create_group_world_size_1");
assert!(
!cases.is_empty(),
"fixture create_group_world_size_1 not found"
);
let f = cases[0];
let world_size = f.world_size.expect("fixture must have world_size");
let expected_len = f.expected_len.expect("fixture must have expected_len");
let group = SimulatedBackend::create_group(world_size).unwrap();
assert_eq!(
group.len(),
expected_len,
"create_group({world_size}) must return {expected_len} backends"
);
assert_eq!(group[0].rank(), 0, "rank 0 backend must have rank=0");
assert_eq!(group[0].world_size(), 1, "world_size must be 1");
}
#[test]
fn simulated_backend_create_group_world_size_2() {
let file = load_fixtures();
let cases = fixtures_for(&file, "SimulatedBackend_create_group_world_size_2");
assert!(
!cases.is_empty(),
"fixture create_group_world_size_2 not found"
);
let f = cases[0];
let world_size = f.world_size.expect("fixture must have world_size");
let group = SimulatedBackend::create_group(world_size).unwrap();
assert_eq!(
group.len(),
world_size,
"create_group({world_size}) must return {world_size} backends"
);
for (i, b) in group.iter().enumerate() {
assert_eq!(b.rank(), i, "backend[{i}] must have rank={i}");
assert_eq!(
b.world_size(),
world_size,
"backend[{i}] must have world_size={world_size}"
);
}
}
#[test]
fn simulated_backend_create_group_world_size_0_error() {
let result = SimulatedBackend::create_group(0);
assert!(result.is_err(), "create_group(0) must return Err, got Ok");
}
#[test]
fn allreduce_world_size_1_sum_is_identity() {
let file = load_fixtures();
let cases = fixtures_for(&file, "allreduce_world_size_1_sum");
assert!(
!cases.is_empty(),
"fixture allreduce_world_size_1_sum not found"
);
let f = cases[0];
let input = f.input.clone().expect("fixture must have input");
let shape = f.shape.clone().expect("fixture must have shape");
let expected = f
.expected
.as_ref()
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.map(|x| x.as_f64().unwrap() as f32)
.collect::<Vec<_>>()
})
.expect("fixture must have expected");
let group = SimulatedBackend::create_group(1).unwrap();
let t = make_tensor(input, shape);
let result = allreduce(&t, &group[0], ReduceOp::Sum).unwrap();
let data = result.data_vec().unwrap();
assert_close(&data, &expected, 1e-6, "allreduce(world_size=1, Sum)");
}
#[test]
fn allreduce_world_size_1_mean_is_identity() {
let file = load_fixtures();
let cases = fixtures_for(&file, "allreduce_world_size_1_mean");
assert!(
!cases.is_empty(),
"fixture allreduce_world_size_1_mean not found"
);
let f = cases[0];
let input = f.input.clone().expect("fixture must have input");
let shape = f.shape.clone().expect("fixture must have shape");
let expected = f
.expected
.as_ref()
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.map(|x| x.as_f64().unwrap() as f32)
.collect::<Vec<_>>()
})
.expect("fixture must have expected");
let group = SimulatedBackend::create_group(1).unwrap();
let t = make_tensor(input, shape);
let result = allreduce(&t, &group[0], ReduceOp::Mean).unwrap();
let data = result.data_vec().unwrap();
assert_close(&data, &expected, 1e-6, "allreduce(world_size=1, Mean)");
}
#[test]
fn allreduce_world_size_2_sum_matches_reference() {
let file = load_fixtures();
let cases = fixtures_for(&file, "allreduce_world_size_2_sum");
assert!(
!cases.is_empty(),
"fixture allreduce_world_size_2_sum not found"
);
let f = cases[0];
let input0 = f
.input_rank0
.clone()
.expect("fixture must have input_rank0");
let input1 = f
.input_rank1
.clone()
.expect("fixture must have input_rank1");
let shape = f.shape.clone().expect("fixture must have shape");
let expected = f
.expected_all_ranks
.clone()
.expect("fixture must have expected_all_ranks");
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(input0, shape.clone());
let t1 = make_tensor(input1, shape);
let (result0, result1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
allreduce(&t0, b0, ReduceOp::Sum)
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
allreduce(&t1, b1, ReduceOp::Sum)
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(
&result0,
&expected,
1e-6,
"allreduce(world_size=2, Sum) rank0",
);
assert_close(
&result1,
&expected,
1e-6,
"allreduce(world_size=2, Sum) rank1",
);
}
#[test]
fn allreduce_world_size_2_mean_matches_reference() {
let file = load_fixtures();
let cases = fixtures_for(&file, "allreduce_world_size_2_mean");
assert!(
!cases.is_empty(),
"fixture allreduce_world_size_2_mean not found"
);
let f = cases[0];
let input0 = f
.input_rank0
.clone()
.expect("fixture must have input_rank0");
let input1 = f
.input_rank1
.clone()
.expect("fixture must have input_rank1");
let shape = f.shape.clone().expect("fixture must have shape");
let expected = f
.expected_all_ranks
.clone()
.expect("fixture must have expected_all_ranks");
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(input0, shape.clone());
let t1 = make_tensor(input1, shape);
let (result0, result1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
allreduce(&t0, b0, ReduceOp::Mean)
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
allreduce(&t1, b1, ReduceOp::Mean)
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(
&result0,
&expected,
1e-6,
"allreduce(world_size=2, Mean) rank0",
);
assert_close(
&result1,
&expected,
1e-6,
"allreduce(world_size=2, Mean) rank1",
);
}
#[test]
fn broadcast_world_size_1_is_identity() {
let file = load_fixtures();
let cases = fixtures_for(&file, "broadcast_world_size_1");
assert!(
!cases.is_empty(),
"fixture broadcast_world_size_1 not found"
);
let f = cases[0];
let input = f.input.clone().expect("fixture must have input");
let shape = f.shape.clone().expect("fixture must have shape");
let expected = f
.expected
.as_ref()
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.map(|x| x.as_f64().unwrap() as f32)
.collect::<Vec<_>>()
})
.expect("fixture must have expected");
let group = SimulatedBackend::create_group(1).unwrap();
let t = make_tensor(input, shape);
let result = broadcast(&t, &group[0], 0).unwrap();
let data = result.data_vec().unwrap();
assert_close(&data, &expected, 1e-6, "broadcast(world_size=1)");
}
#[test]
fn broadcast_world_size_2_from_root_0_matches_reference() {
let file = load_fixtures();
let cases = fixtures_for(&file, "broadcast_world_size_2_from_root_0");
assert!(
!cases.is_empty(),
"fixture broadcast_world_size_2_from_root_0 not found"
);
let f = cases[0];
let input0 = f
.input_rank0
.clone()
.expect("fixture must have input_rank0");
let input1 = f
.input_rank1
.clone()
.expect("fixture must have input_rank1");
let shape = f.shape.clone().expect("fixture must have shape");
let expected0 = f
.expected_rank0
.clone()
.expect("fixture must have expected_rank0");
let expected1 = f
.expected_rank1
.clone()
.expect("fixture must have expected_rank1");
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(input0, shape.clone());
let t1 = make_tensor(input1, shape);
let (result0, result1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| broadcast(&t0, b0, 0).unwrap().data_vec().unwrap());
let h1 = s.spawn(|| broadcast(&t1, b1, 0).unwrap().data_vec().unwrap());
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(&result0, &expected0, 1e-6, "broadcast rank0");
assert_close(
&result1,
&expected1,
1e-6,
"broadcast rank1 receives root's tensor",
);
}
#[test]
fn all_gather_world_size_1_is_identity() {
let file = load_fixtures();
let cases = fixtures_for(&file, "all_gather_world_size_1");
assert!(
!cases.is_empty(),
"fixture all_gather_world_size_1 not found"
);
let f = cases[0];
let input = f.input.clone().expect("fixture must have input");
let shape = f.shape.clone().expect("fixture must have shape");
let expected = f
.expected
.as_ref()
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.map(|x| x.as_f64().unwrap() as f32)
.collect::<Vec<_>>()
})
.expect("fixture must have expected");
let group = SimulatedBackend::create_group(1).unwrap();
let t = make_tensor(input, shape);
let result = all_gather(&t, &group[0]).unwrap();
let data = result.data_vec().unwrap();
assert_close(&data, &expected, 1e-6, "all_gather(world_size=1)");
}
#[test]
fn all_gather_world_size_2_matches_reference() {
let file = load_fixtures();
let cases = fixtures_for(&file, "all_gather_world_size_2");
assert!(
!cases.is_empty(),
"fixture all_gather_world_size_2 not found"
);
let f = cases[0];
let input0 = f
.input_rank0
.clone()
.expect("fixture must have input_rank0");
let input1 = f
.input_rank1
.clone()
.expect("fixture must have input_rank1");
let input_shape = f
.input_shape
.clone()
.expect("fixture must have input_shape");
let expected = f
.expected_all_ranks
.clone()
.expect("fixture must have expected_all_ranks");
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(input0, input_shape.clone());
let t1 = make_tensor(input1, input_shape);
let (result0, result1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| all_gather(&t0, b0).unwrap().data_vec().unwrap());
let h1 = s.spawn(|| all_gather(&t1, b1).unwrap().data_vec().unwrap());
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(&result0, &expected, 1e-6, "all_gather(world_size=2) rank0");
assert_close(&result1, &expected, 1e-6, "all_gather(world_size=2) rank1");
}
#[test]
fn reduce_scatter_world_size_1_sum_is_identity() {
let file = load_fixtures();
let cases = fixtures_for(&file, "reduce_scatter_world_size_1_sum");
assert!(
!cases.is_empty(),
"fixture reduce_scatter_world_size_1_sum not found"
);
let f = cases[0];
let input = f.input.clone().expect("fixture must have input");
let shape = f.shape.clone().expect("fixture must have shape");
let expected = f
.expected
.as_ref()
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.map(|x| x.as_f64().unwrap() as f32)
.collect::<Vec<_>>()
})
.expect("fixture must have expected");
let group = SimulatedBackend::create_group(1).unwrap();
let t = make_tensor(input, shape);
let result = reduce_scatter(&t, &group[0], ReduceOp::Sum).unwrap();
let data = result.data_vec().unwrap();
assert_close(&data, &expected, 1e-6, "reduce_scatter(world_size=1, Sum)");
}
#[test]
fn reduce_scatter_world_size_2_sum_matches_reference() {
let file = load_fixtures();
let cases = fixtures_for(&file, "reduce_scatter_world_size_2_sum");
assert!(
!cases.is_empty(),
"fixture reduce_scatter_world_size_2_sum not found"
);
let f = cases[0];
let input0 = f
.input_rank0
.clone()
.expect("fixture must have input_rank0");
let input1 = f
.input_rank1
.clone()
.expect("fixture must have input_rank1");
let input_shape = f
.input_shape
.clone()
.expect("fixture must have input_shape");
let expected0 = f
.expected_rank0
.clone()
.expect("fixture must have expected_rank0");
let expected1 = f
.expected_rank1
.clone()
.expect("fixture must have expected_rank1");
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(input0, input_shape.clone());
let t1 = make_tensor(input1, input_shape);
let (result0, result1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
reduce_scatter(&t0, b0, ReduceOp::Sum)
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
reduce_scatter(&t1, b1, ReduceOp::Sum)
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(
&result0,
&expected0,
1e-6,
"reduce_scatter rank0 gets first chunk",
);
assert_close(
&result1,
&expected1,
1e-6,
"reduce_scatter rank1 gets second chunk",
);
}
#[test]
fn barrier_world_size_1_is_ok() {
let group = SimulatedBackend::create_group(1).unwrap();
let result = barrier(&group[0]);
assert!(
result.is_ok(),
"barrier(world_size=1) must return Ok, got: {result:?}"
);
}
#[test]
fn barrier_world_size_2_synchronises() {
let group = SimulatedBackend::create_group(2).unwrap();
let (r0, r1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| barrier(b0));
let h1 = s.spawn(|| barrier(b1));
(h0.join().unwrap(), h1.join().unwrap())
});
assert!(
r0.is_ok(),
"barrier(world_size=2) rank0 must return Ok, got: {r0:?}"
);
assert!(
r1.is_ok(),
"barrier(world_size=2) rank1 must return Ok, got: {r1:?}"
);
}
#[test]
fn send_recv_round_trip_matches_reference() {
let file = load_fixtures();
let cases = fixtures_for(&file, "send_recv_round_trip");
assert!(!cases.is_empty(), "fixture send_recv_round_trip not found");
let f = cases[0];
let input = f.input.clone().expect("fixture must have input");
let shape = f.shape.clone().expect("fixture must have shape");
let expected = input.clone();
let group = SimulatedBackend::create_group(2).unwrap();
let t = make_tensor(input, shape.clone());
let received = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h_send = s.spawn(|| send(&t, 1, b0));
let h_recv = s.spawn(|| recv::<f32>(&shape, 0, b1));
h_send.join().unwrap().unwrap();
h_recv.join().unwrap().unwrap()
});
let data = received.data_vec().unwrap();
assert_close(&data, &expected, 1e-6, "send/recv round-trip");
}
#[test]
fn send_to_self_returns_error() {
let group = SimulatedBackend::create_group(2).unwrap();
let t = make_tensor(vec![1.0_f32, 2.0], vec![2]);
let result = send(&t, 0, &group[0]);
assert!(result.is_err(), "send to self rank must return Err, got Ok");
let err_str = format!("{:?}", result.unwrap_err());
assert!(
err_str.contains("InvalidArgument") || err_str.contains("self rank"),
"send to self must produce an InvalidArgument-like error, got: {err_str}"
);
}
#[test]
fn send_dst_out_of_range_returns_error() {
let group = SimulatedBackend::create_group(2).unwrap();
let t = make_tensor(vec![1.0_f32], vec![1]);
let result = send(&t, 5, &group[0]);
assert!(
result.is_err(),
"send to dst_rank >= world_size must return Err, got Ok"
);
}
#[test]
fn sendrecv_round_trip_matches_reference() {
let file = load_fixtures();
let cases = fixtures_for(&file, "sendrecv_round_trip");
assert!(!cases.is_empty(), "fixture sendrecv_round_trip not found");
let f = cases[0];
let data0 = f
.rank0_sends
.clone()
.expect("fixture must have rank0_sends");
let data1 = f
.rank1_sends
.clone()
.expect("fixture must have rank1_sends");
let shape = f.shape.clone().expect("fixture must have shape");
let expected_r0 = f
.expected_rank0_receives
.clone()
.expect("fixture must have expected_rank0_receives");
let expected_r1 = f
.expected_rank1_receives
.clone()
.expect("fixture must have expected_rank1_receives");
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(data0, shape.clone());
let t1 = make_tensor(data1, shape.clone());
let (received0, received1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| sendrecv(&t0, &shape, 1, b0).unwrap().data_vec().unwrap());
let h1 = s.spawn(|| sendrecv(&t1, &shape, 0, b1).unwrap().data_vec().unwrap());
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(
&received0,
&expected_r0,
1e-6,
"sendrecv: rank0 receives rank1's data",
);
assert_close(
&received1,
&expected_r1,
1e-6,
"sendrecv: rank1 receives rank0's data",
);
}
#[test]
fn sub_backend_members_matches_fixture() {
let file = load_fixtures();
let cases = fixtures_for(&file, "SubBackend_members");
assert!(!cases.is_empty(), "fixture SubBackend_members not found");
let f = cases[0];
let world_size = f.world_size.expect("fixture must have world_size");
let members = f.members.clone().expect("fixture must have members");
let expected_members = f
.expected_members
.clone()
.expect("fixture must have expected_members");
let group = SimulatedBackend::create_group(world_size).unwrap();
let first_member = members[0];
let parent: Arc<dyn Backend> = Arc::new(
group
.into_iter()
.nth(first_member)
.expect("group must have enough ranks"),
);
let sub = SubBackend::new(parent, members).unwrap();
assert_eq!(
sub.members(),
expected_members.as_slice(),
"SubBackend::members() must match the construction members list"
);
}
#[test]
fn sub_backend_rank_mapping_matches_fixture() {
let file = load_fixtures();
let cases = fixtures_for(&file, "SubBackend_rank_mapping");
assert!(
!cases.is_empty(),
"fixture SubBackend_rank_mapping not found"
);
let f = cases[0];
let world_size = f.world_size.expect("fixture must have world_size");
let members = f.members.clone().expect("fixture must have members");
let group = SimulatedBackend::create_group(world_size).unwrap();
let first_member = members[0];
let parent: Arc<dyn Backend> = Arc::new(
group
.into_iter()
.nth(first_member)
.expect("group must have enough ranks"),
);
let sub = SubBackend::new(parent, members.clone()).unwrap();
for (local, &global) in members.iter().enumerate() {
assert_eq!(
sub.to_global(local),
global,
"to_global({local}) must return {global}"
);
}
for (local, &global) in members.iter().enumerate() {
assert_eq!(
sub.to_local(global),
Some(local),
"to_local({global}) must return Some({local})"
);
}
assert_eq!(
sub.to_local(0),
None,
"to_local(0) must return None (rank 0 is not in members [1,2,3])"
);
}
#[test]
fn device_mesh_new_valid_matches_fixture() {
let file = load_fixtures();
let cases = fixtures_for(&file, "DeviceMesh_new_valid");
assert!(!cases.is_empty(), "fixture DeviceMesh_new_valid not found");
let f = cases[0];
let shape = f.shape.clone().expect("fixture must have shape");
let world_size = f.world_size.expect("fixture must have world_size");
let expected_ndim = f.expected_ndim.expect("fixture must have expected_ndim");
let expected_size = f.expected_size.expect("fixture must have expected_size");
let mesh = DeviceMesh::new(shape, world_size).unwrap();
assert_eq!(mesh.ndim(), expected_ndim, "DeviceMesh ndim");
assert_eq!(mesh.size(), expected_size, "DeviceMesh size");
}
#[test]
fn device_mesh_new_shape_mismatch_returns_error() {
let result = DeviceMesh::new(vec![2, 3], 4);
assert!(
result.is_err(),
"DeviceMesh::new([2,3], 4) must return Err for shape mismatch"
);
}
#[test]
fn device_mesh_new_empty_shape_returns_error() {
let result = DeviceMesh::new(vec![], 1);
assert!(
result.is_err(),
"DeviceMesh::new([], 1) must return Err for empty shape"
);
}
#[test]
fn placement_variants_match_fixture() {
let replicate = Placement::Replicate;
let shard = Placement::Shard(0);
let partial = Placement::Partial(ReduceOp::Sum);
assert!(
replicate.is_replicate(),
"Replicate::is_replicate must be true"
);
assert!(!replicate.is_shard(), "Replicate::is_shard must be false");
assert!(
!replicate.is_partial(),
"Replicate::is_partial must be false"
);
assert!(
replicate.shard_dim().is_none(),
"Replicate::shard_dim must be None"
);
assert!(!shard.is_replicate(), "Shard::is_replicate must be false");
assert!(shard.is_shard(), "Shard::is_shard must be true");
assert!(!shard.is_partial(), "Shard::is_partial must be false");
assert_eq!(
shard.shard_dim(),
Some(0),
"Shard(0)::shard_dim must be Some(0)"
);
assert!(
!partial.is_replicate(),
"Partial::is_replicate must be false"
);
assert!(!partial.is_shard(), "Partial::is_shard must be false");
assert!(partial.is_partial(), "Partial::is_partial must be true");
assert!(
partial.shard_dim().is_none(),
"Partial::shard_dim must be None"
);
}
#[test]
fn dtensor_from_local_valid_matches_fixture() {
let mesh = DeviceMesh::new(vec![2], 2).unwrap();
let local = make_tensor(vec![1.0_f32, 2.0], vec![2]);
let result = DTensor::from_local(local, mesh, vec![Placement::Shard(0)], vec![4]);
assert!(
result.is_ok(),
"DTensor::from_local with Shard(0) on 2-rank mesh must succeed, got: {result:?}"
);
}
#[test]
fn dtensor_from_local_placement_mismatch_returns_error() {
let mesh = DeviceMesh::new(vec![2, 2], 4).unwrap();
let local = make_tensor(vec![1.0_f32, 2.0], vec![2]);
let result = DTensor::from_local(local, mesh, vec![Placement::Replicate], vec![2]);
assert!(
result.is_err(),
"DTensor::from_local with mismatched placements must return Err"
);
}
#[test]
fn distributed_error_display_matches_fixture() {
let e = DistributedError::InvalidWorldSize { world_size: 0 };
let s = e.to_string();
assert!(
s.contains("world size") || s.contains("world_size"),
"InvalidWorldSize Display must mention 'world size', got: {s:?}"
);
let e = DistributedError::InvalidRank {
rank: 5,
world_size: 3,
};
let s = e.to_string();
assert!(
s.contains("rank") || s.contains("5"),
"InvalidRank Display must mention rank, got: {s:?}"
);
let e = DistributedError::Timeout { seconds: 30 };
let s = e.to_string();
assert!(
s.contains("timed out") || s.contains("timeout") || s.contains("30"),
"Timeout Display must mention timeout, got: {s:?}"
);
let e = DistributedError::BackendUnavailable { backend: "gloo" };
let s = e.to_string();
assert!(
s.contains("gloo"),
"BackendUnavailable Display must mention backend name, got: {s:?}"
);
let ft_err: FerrotorchError = DistributedError::BackendUnavailable { backend: "test" }.into();
let ft_str = format!("{ft_err:?}");
assert!(
!ft_str.is_empty(),
"DistributedError must convert to FerrotorchError"
);
}
#[test]
fn async_all_gather_matches_sync_reference() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<dyn Backend>> = group
.into_iter()
.map(|b| Arc::new(b) as Arc<dyn Backend>)
.collect();
let t0 = make_tensor(vec![0.0_f32, 10.0], vec![2]);
let t1 = make_tensor(vec![1.0_f32, 11.0], vec![2]);
let expected = vec![0.0_f32, 10.0, 1.0, 11.0];
let arc0 = Arc::clone(&arcs[0]);
let arc1 = Arc::clone(&arcs[1]);
let (result0, result1) = std::thread::scope(|s| {
let h0 = s.spawn(|| {
async_all_gather(t0, arc0)
.wait()
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
async_all_gather(t1, arc1)
.wait()
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(&result0, &expected, 1e-6, "async_all_gather rank0");
assert_close(&result1, &expected, 1e-6, "async_all_gather rank1");
}
#[test]
fn pending_collective_op_name_matches_fixture() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<dyn Backend>> = group
.into_iter()
.map(|b| Arc::new(b) as Arc<dyn Backend>)
.collect();
let t0 = make_tensor(vec![1.0_f32], vec![1]);
let arc0 = Arc::clone(&arcs[0]);
let handle = async_all_gather(t0, arc0);
assert_eq!(
handle.op_name(),
"async_all_gather",
"PendingCollective::op_name must return 'async_all_gather'"
);
let _ = std::thread::scope(|s| {
let _b1 = arcs[1].as_ref();
let t1 = make_tensor(vec![2.0_f32], vec![1]);
let arc1 = Arc::clone(&arcs[1]);
let h1 = s.spawn(|| async_all_gather(t1, arc1).wait());
let result = handle.wait();
h1.join().unwrap().ok();
result.ok()
});
}
#[test]
fn async_reduce_scatter_matches_sync_reference() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<dyn Backend>> = group
.into_iter()
.map(|b| Arc::new(b) as Arc<dyn Backend>)
.collect();
let t0 = make_tensor(vec![1.0_f32, 2.0, 3.0, 4.0], vec![4]);
let t1 = make_tensor(vec![5.0_f32, 6.0, 7.0, 8.0], vec![4]);
let expected_r0 = vec![6.0_f32, 8.0]; let expected_r1 = vec![10.0_f32, 12.0];
let arc0 = Arc::clone(&arcs[0]);
let arc1 = Arc::clone(&arcs[1]);
let (result0, result1) = std::thread::scope(|s| {
let h0 = s.spawn(|| {
async_reduce_scatter(t0, arc0, ReduceOp::Sum)
.wait()
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
async_reduce_scatter(t1, arc1, ReduceOp::Sum)
.wait()
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(
&result0,
&expected_r0,
1e-6,
"async_reduce_scatter rank0 gets first chunk",
);
assert_close(
&result1,
&expected_r1,
1e-6,
"async_reduce_scatter rank1 gets second chunk",
);
}
#[test]
fn tensor_shard_spec_fields_accessible() {
let mut state_dict: HashMap<String, Tensor<f32>> = HashMap::new();
state_dict.insert("w".to_string(), make_tensor(vec![0.0_f32; 100], vec![100]));
let metadata = flat_shard_metadata(&state_dict, 4);
let spec = metadata.tensor_specs.get("w").expect("spec must exist");
assert_eq!(spec.full_shape, vec![400], "TensorShardSpec::full_shape");
assert_eq!(spec.shard_dim, 0, "TensorShardSpec::shard_dim");
assert_eq!(
spec.shard_sizes.len(),
4,
"TensorShardSpec::shard_sizes len"
);
assert!(
spec.shard_sizes.iter().all(|&s| s == 100),
"all shard_sizes must be 100"
);
}
#[test]
fn shard_metadata_fields_accessible() {
let empty_state_dict: HashMap<String, Tensor<f32>> = HashMap::new();
let metadata = flat_shard_metadata(&empty_state_dict, 4);
assert_eq!(metadata.num_ranks, 4, "ShardMetadata::num_ranks");
assert!(
metadata.tensor_specs.is_empty(),
"ShardMetadata::tensor_specs empty"
);
}
#[test]
fn flat_shard_metadata_single_shard_matches_fixture() {
let t = make_tensor(vec![1.0_f32; 100], vec![100]);
let mut state_dict: HashMap<String, Tensor<f32>> = HashMap::new();
state_dict.insert("weight".to_string(), t);
let metadata = flat_shard_metadata(&state_dict, 1);
assert_eq!(metadata.num_ranks, 1, "num_ranks must be 1");
let spec = metadata
.tensor_specs
.get("weight")
.expect("weight spec must exist");
assert_eq!(spec.full_shape, vec![100], "full_shape must be [100]");
assert_eq!(spec.shard_dim, 0, "shard_dim must be 0");
assert_eq!(
spec.shard_sizes,
vec![100],
"single shard: shard_sizes=[100]"
);
}
#[test]
fn flat_shard_metadata_four_shards_matches_fixture() {
let t = make_tensor(vec![1.0_f32; 25], vec![25]);
let mut state_dict: HashMap<String, Tensor<f32>> = HashMap::new();
state_dict.insert("weight".to_string(), t);
let metadata = flat_shard_metadata(&state_dict, 4);
assert_eq!(metadata.num_ranks, 4, "num_ranks must be 4");
let spec = metadata
.tensor_specs
.get("weight")
.expect("weight spec must exist");
assert_eq!(
spec.full_shape,
vec![100],
"full_shape must be [25 * 4 = 100]"
);
assert_eq!(spec.shard_sizes.len(), 4, "must have 4 shard_sizes entries");
assert!(
spec.shard_sizes.iter().all(|&s| s == 25),
"each shard must have 25 elements, got: {:?}",
spec.shard_sizes
);
}
#[test]
fn rpc_error_display_structural() {
use ferrotorch_distributed::RpcError;
let e = RpcError::FunctionNotFound {
name: "my_fn".to_string(),
};
let s = e.to_string();
assert!(
s.contains("my_fn") || s.contains("not found"),
"RpcError::FunctionNotFound Display must mention function name, got: {s:?}"
);
let e = RpcError::Timeout;
let s = e.to_string();
assert!(
!s.is_empty(),
"RpcError::Timeout Display must be non-empty, got: {s:?}"
);
let ft: FerrotorchError = RpcError::Timeout.into();
assert!(
!format!("{ft:?}").is_empty(),
"RpcError must convert to FerrotorchError"
);
}
#[test]
fn live_nccl_allreduce() {
cascade_skip!(
"NCCL allreduce requires GPU + NCCL feature; not available on this box — tracking issue #882"
);
}
#[test]
fn live_tcp_backend_rendezvous() {
cascade_skip!(
"TcpBackend requires a live multi-process TCP rendezvous; single-process test not feasible — tracking issue #882"
);
}
#[test]
fn live_ddp_gradient_sync() {
cascade_skip!(
"DDP gradient synchronisation requires multi-rank forward/backward; deferred — tracking issue #883"
);
}
#[test]
fn live_fsdp_parameter_sharding() {
cascade_skip!("FSDP parameter sharding requires >= 2 ranks; deferred — tracking issue #884");
}
#[test]
fn live_pipeline_microbatch() {
cascade_skip!(
"Pipeline parallelism requires multi-stage multi-rank execution; deferred — tracking issue #885"
);
}
#[test]
fn live_rpc_remote_invocation() {
cascade_skip!(
"RpcAgent remote invocation requires live TCP connections; deferred — tracking issue #886"
);
}
#[test]
fn live_sync_batch_norm() {
cascade_skip!(
"SyncBatchNorm2d requires cross-rank stat synchronisation; deferred — tracking issue #887"
);
}
#[test]
fn live_distributed_checkpoint_round_trip() {
cascade_skip!(
"Distributed checkpoint round-trip requires multi-rank shard writes; deferred — tracking issue #888"
);
}
const PASS1_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1);
#[test]
fn allreduce_with_timeout_matches_default_world_size_2() {
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(vec![1.0_f32, 2.0, 3.0, 4.0], vec![4]);
let t1 = make_tensor(vec![10.0_f32, 20.0, 30.0, 40.0], vec![4]);
let expected = vec![11.0_f32, 22.0, 33.0, 44.0];
let (r0, r1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
allreduce_with_timeout(&t0, b0, ReduceOp::Sum, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
allreduce_with_timeout(&t1, b1, ReduceOp::Sum, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(&r0, &expected, 1e-6, "allreduce_with_timeout rank0");
assert_close(&r1, &expected, 1e-6, "allreduce_with_timeout rank1");
let solo = SimulatedBackend::create_group(1).unwrap();
let t = make_tensor(vec![7.0_f32, 8.0, 9.0], vec![3]);
let r = allreduce_with_timeout(&t, &solo[0], ReduceOp::Sum, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap();
assert_close(&r, &[7.0, 8.0, 9.0], 1e-6, "allreduce_with_timeout solo");
}
#[test]
fn all_gather_with_timeout_matches_default_world_size_2() {
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(vec![0.0_f32, 1.0], vec![2]);
let t1 = make_tensor(vec![10.0_f32, 11.0], vec![2]);
let expected = vec![0.0_f32, 1.0, 10.0, 11.0];
let (r0, r1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
all_gather_with_timeout(&t0, b0, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
all_gather_with_timeout(&t1, b1, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(&r0, &expected, 1e-6, "all_gather_with_timeout rank0");
assert_close(&r1, &expected, 1e-6, "all_gather_with_timeout rank1");
let group2 = SimulatedBackend::create_group(2).unwrap();
let t0c = make_tensor(vec![0.0_f32, 1.0], vec![2]);
let t1c = make_tensor(vec![10.0_f32, 11.0], vec![2]);
let (def_r0, def_r1) = std::thread::scope(|s| {
let b0 = &group2[0];
let b1 = &group2[1];
let h0 = s.spawn(|| all_gather(&t0c, b0).unwrap().data_vec().unwrap());
let h1 = s.spawn(|| all_gather(&t1c, b1).unwrap().data_vec().unwrap());
(h0.join().unwrap(), h1.join().unwrap())
});
assert_eq!(
r0, def_r0,
"all_gather_with_timeout must match all_gather (rank0)"
);
assert_eq!(
r1, def_r1,
"all_gather_with_timeout must match all_gather (rank1)"
);
}
#[test]
fn reduce_scatter_with_timeout_matches_default_world_size_2() {
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(vec![1.0_f32, 2.0, 3.0, 4.0], vec![4]);
let t1 = make_tensor(vec![5.0_f32, 6.0, 7.0, 8.0], vec![4]);
let exp0 = vec![6.0_f32, 8.0];
let exp1 = vec![10.0_f32, 12.0];
let (r0, r1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
reduce_scatter_with_timeout(&t0, b0, ReduceOp::Sum, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
reduce_scatter_with_timeout(&t1, b1, ReduceOp::Sum, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(&r0, &exp0, 1e-6, "reduce_scatter_with_timeout rank0");
assert_close(&r1, &exp1, 1e-6, "reduce_scatter_with_timeout rank1");
}
#[test]
fn reduce_scatter_tensor_matches_reduce_scatter_world_size_2() {
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(vec![1.0_f32, 2.0, 3.0, 4.0], vec![4]);
let t1 = make_tensor(vec![5.0_f32, 6.0, 7.0, 8.0], vec![4]);
let exp0 = vec![6.0_f32, 8.0];
let exp1 = vec![10.0_f32, 12.0];
let (r0, r1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
let out = reduce_scatter_tensor(&t0, b0, ReduceOp::Sum).unwrap();
(out.shape().to_vec(), out.data_vec().unwrap())
});
let h1 = s.spawn(|| {
let out = reduce_scatter_tensor(&t1, b1, ReduceOp::Sum).unwrap();
(out.shape().to_vec(), out.data_vec().unwrap())
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_eq!(
r0.0,
vec![2_usize],
"reduce_scatter_tensor rank0 shape (chunk = numel/world_size)"
);
assert_eq!(r1.0, vec![2_usize], "reduce_scatter_tensor rank1 shape");
assert_close(&r0.1, &exp0, 1e-6, "reduce_scatter_tensor rank0");
assert_close(&r1.1, &exp1, 1e-6, "reduce_scatter_tensor rank1");
}
#[test]
fn all_to_all_world_size_2_matches_reference() {
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(vec![10.0_f32, 20.0, 30.0, 40.0], vec![4]);
let t1 = make_tensor(vec![50.0_f32, 60.0, 70.0, 80.0], vec![4]);
let (r0, r1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| all_to_all(&t0, b0).unwrap().data_vec().unwrap());
let h1 = s.spawn(|| all_to_all(&t1, b1).unwrap().data_vec().unwrap());
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(
&r0,
&[10.0, 20.0, 50.0, 60.0],
1e-6,
"all_to_all rank0 (self chunk + chunk from rank 1)",
);
assert_close(
&r1,
&[30.0, 40.0, 70.0, 80.0],
1e-6,
"all_to_all rank1 (chunk from rank 0 + self chunk)",
);
}
#[test]
fn all_to_all_with_timeout_matches_default_world_size_2() {
let group = SimulatedBackend::create_group(2).unwrap();
let t0 = make_tensor(vec![10.0_f32, 20.0, 30.0, 40.0], vec![4]);
let t1 = make_tensor(vec![50.0_f32, 60.0, 70.0, 80.0], vec![4]);
let (r0, r1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
all_to_all_with_timeout(&t0, b0, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
all_to_all_with_timeout(&t1, b1, PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(
&r0,
&[10.0, 20.0, 50.0, 60.0],
1e-6,
"all_to_all_with_timeout rank0",
);
assert_close(
&r1,
&[30.0, 40.0, 70.0, 80.0],
1e-6,
"all_to_all_with_timeout rank1",
);
let solo = SimulatedBackend::create_group(1).unwrap();
let t = make_tensor(vec![1.0_f32, 2.0, 3.0], vec![3]);
let r = all_to_all_with_timeout(&t, &solo[0], PASS1_TIMEOUT)
.unwrap()
.data_vec()
.unwrap();
assert_close(&r, &[1.0, 2.0, 3.0], 1e-6, "all_to_all_with_timeout solo");
}
#[test]
fn all_to_all_single_uneven_world_size_2_matches_reference() {
let group = SimulatedBackend::create_group(2).unwrap();
let (r0, r1) = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h0 = s.spawn(|| {
let t = make_tensor(vec![10.0_f32, 20.0, 21.0, 22.0], vec![4]);
all_to_all_single_uneven(&t, &[1, 3], &[1, 2], b0)
.unwrap()
.data_vec()
.unwrap()
});
let h1 = s.spawn(|| {
let t = make_tensor(vec![30.0_f32, 31.0, 40.0, 41.0], vec![4]);
all_to_all_single_uneven(&t, &[2, 2], &[3, 2], b1)
.unwrap()
.data_vec()
.unwrap()
});
(h0.join().unwrap(), h1.join().unwrap())
});
assert_close(
&r0,
&[10.0, 30.0, 31.0],
1e-6,
"all_to_all_single_uneven rank0 (self + from-rank-1)",
);
assert_close(
&r1,
&[20.0, 21.0, 22.0, 40.0, 41.0],
1e-6,
"all_to_all_single_uneven rank1 (from-rank-0 + self)",
);
}
#[test]
fn recv_with_timeout_matches_default_round_trip() {
let group = SimulatedBackend::create_group(2).unwrap();
let payload = vec![100.0_f32, 200.0, 300.0, 400.0];
let shape = vec![2_usize, 2];
let t = make_tensor(payload.clone(), shape.clone());
let received = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h_send = s.spawn(|| send(&t, 1, b0));
let h_recv = s.spawn(|| recv_with_timeout::<f32>(&shape, 0, b1, PASS1_TIMEOUT));
h_send.join().unwrap().unwrap();
h_recv.join().unwrap().unwrap()
});
assert_eq!(
received.shape(),
shape.as_slice(),
"recv_with_timeout preserves shape"
);
assert_close(
&received.data_vec().unwrap(),
&payload,
1e-6,
"recv_with_timeout round-trip",
);
}
#[test]
fn recv_into_overwrites_destination() {
let group = SimulatedBackend::create_group(2).unwrap();
let payload = vec![11.0_f32, 22.0, 33.0];
let shape = vec![3_usize];
let t = make_tensor(payload.clone(), shape.clone());
let result = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h_send = s.spawn(|| send(&t, 1, b0));
let h_recv = s.spawn(|| {
let mut dst = make_tensor(vec![-1.0_f32, -1.0, -1.0], vec![3_usize]);
recv_into(&mut dst, 0, b1).unwrap();
dst.data_vec().unwrap()
});
h_send.join().unwrap().unwrap();
h_recv.join().unwrap()
});
assert_close(&result, &payload, 1e-6, "recv_into overwrites destination");
assert!(
result.iter().all(|&v| v >= 0.0),
"recv_into must overwrite the sentinel -1.0 buffer; got {result:?}"
);
}
#[test]
fn recv_into_with_timeout_overwrites_destination() {
let group = SimulatedBackend::create_group(2).unwrap();
let payload = vec![7.0_f32, 8.0, 9.0, 10.0];
let shape = vec![4_usize];
let t = make_tensor(payload.clone(), shape.clone());
let result = std::thread::scope(|s| {
let b0 = &group[0];
let b1 = &group[1];
let h_send = s.spawn(|| send(&t, 1, b0));
let h_recv = s.spawn(|| {
let mut dst = make_tensor(vec![-1.0_f32, -1.0, -1.0, -1.0], vec![4_usize]);
recv_into_with_timeout(&mut dst, 0, b1, PASS1_TIMEOUT).unwrap();
dst.data_vec().unwrap()
});
h_send.join().unwrap().unwrap();
h_recv.join().unwrap()
});
assert_close(
&result,
&payload,
1e-6,
"recv_into_with_timeout overwrites destination",
);
}
#[test]
fn dist_checkpoint_error_variants_surface_via_load() {
use std::collections::HashMap;
use std::path::PathBuf;
let base = std::env::temp_dir().join("ferrotorch_pass1_dist_ckpt_err");
let _ = std::fs::remove_dir_all(&base);
let dir_no_meta: PathBuf = base.join("no_metadata");
std::fs::create_dir_all(&dir_no_meta).expect("create test dir");
let err1 = load_distributed::<f32>(&dir_no_meta, 0, 1)
.expect_err("load_distributed must fail without metadata.json");
assert!(
matches!(err1, DistCheckpointError::Io { .. }),
"missing metadata.json must produce DistCheckpointError::Io, got {err1:?}"
);
let msg1 = err1.to_string();
assert!(
msg1.contains("I/O") || msg1.contains("metadata") || msg1.contains("reading"),
"DistCheckpointError::Io display must describe the I/O failure, got: {msg1:?}"
);
let dir_bad: PathBuf = base.join("bad_metadata");
std::fs::create_dir_all(&dir_bad).expect("create test dir");
std::fs::write(dir_bad.join("metadata.json"), b"{not valid json}")
.expect("write malformed metadata");
let err2 = load_distributed::<f32>(&dir_bad, 0, 1)
.expect_err("load_distributed must fail on malformed metadata");
assert!(
matches!(err2, DistCheckpointError::Serialization { .. }),
"malformed metadata.json must produce DistCheckpointError::Serialization, got {err2:?}"
);
let _: FerrotorchError = err2.into();
let _ = HashMap::<String, &str>::new();
let invalid = DistCheckpointError::InvalidArgument {
message: "x".into(),
};
let missing = DistCheckpointError::MissingShard {
path: "/none".into(),
};
let tensor_err = DistCheckpointError::Tensor {
message: "x".into(),
};
let meta = DistCheckpointError::Metadata {
message: "x".into(),
};
let async_err = DistCheckpointError::AsyncFailed {
message: "x".into(),
};
for e in [&invalid, &missing, &tensor_err, &meta, &async_err] {
assert!(
!e.to_string().is_empty(),
"DistCheckpointError variant Display must be non-empty"
);
}
let _ = std::fs::remove_dir_all(&base);
}