ferrotorch_distributed/fsdp.rs
1//! Fully Sharded Data Parallel (FSDP) wrapper.
2//!
3//! [`FSDP`] wraps a [`Module`] and shards its parameters across ranks.
4//! During forward, parameters are all-gathered to form the full tensors.
5//! During gradient synchronization, full-parameter gradients are
6//! reduce-scattered so each rank only stores its shard's gradient.
7//!
8//! This reduces per-rank memory from O(params) to O(params / world_size)
9//! at the cost of additional communication during forward and backward.
10//!
11//! ## REQ status (per `.design/ferrotorch-distributed/fsdp.md`)
12//!
13//! Full evidence rows (impl + non-test production consumer + upstream
14//! cites) live in the design doc; this synopsis is a one-line summary per
15//! REQ.
16//!
17//! | REQ | Status | Evidence |
18//! |---|---|---|
19//! | REQ-1 (`ShardingStrategy` enum) | SHIPPED | `pub enum ShardingStrategy` in `fsdp.rs` mirrors `ShardingStrategy(Enum)` in `torch/distributed/fsdp/api.py`; consumer `pub fn new_with_strategy` (same file) and `lib.rs` re-export |
20//! | REQ-2 (`FSDP<M, T>` struct) | SHIPPED | `pub struct FSDP` in `fsdp.rs` mirrors `class FullyShardedDataParallel` in `torch/distributed/fsdp/fully_sharded_data_parallel.py`; consumer `pub use fsdp::FSDP` in `lib.rs` |
21//! | REQ-3 (`new` + `new_with_strategy` constructors) | SHIPPED | `pub fn new` + `pub fn new_with_strategy` in `fsdp.rs`; consumer `new` invokes `new_with_strategy` (same file) plus `lib.rs` re-export |
22//! | REQ-4 (`FullShard` ZeRO-3 strategy) | SHIPPED | `ShardingStrategy::FullShard` arm of `new_with_strategy` in `fsdp.rs`; consumer `pub fn new` default path (same file) |
23//! | REQ-5 (`ShardGradOp` ZeRO-2 strategy + `broadcast_updated_params`) | SHIPPED | `ShardingStrategy::ShardGradOp` arms + `pub fn broadcast_updated_params` in `fsdp.rs`; consumer of `crate::collective::all_gather` for param re-sync |
24//! | REQ-6 (`NoShard` DDP-equivalent strategy) | SHIPPED | `ShardingStrategy::NoShard` arms across `new_with_strategy` / `forward` / `sync_gradients` in `fsdp.rs`; consumer of `crate::collective::allreduce` on the NoShard path |
25//! | REQ-7 (`HybridShard` intra-/inter-node strategy) | SHIPPED | `ShardingStrategy::HybridShard` arms in `fsdp.rs` build `Arc<SubBackend>` pair; consumer of `crate::backend::SubBackend::new` and `crate::collective::{reduce_scatter, allreduce}` |
26//! | REQ-8 (`forward` all-gather + run inner) | SHIPPED | `pub fn forward` in `fsdp.rs`; consumer of `crate::collective::all_gather` (FullShard / HybridShard); surfaced via `lib.rs` re-export |
27//! | REQ-9 (`prefetch_forward_params` + `has_pending_prefetch`) | SHIPPED | `pub fn prefetch_forward_params` + `pub fn has_pending_prefetch` in `fsdp.rs`; consumer of `crate::async_collective::async_all_gather`; `forward` joins via `handle.wait()?` (same file) |
28//! | REQ-10 (`sync_gradients` reduce-scatter / allreduce) | SHIPPED | `pub fn sync_gradients` in `fsdp.rs`; consumer of `crate::collective::{reduce_scatter, allreduce}`; surfaced via `lib.rs` re-export |
29//! | REQ-11 (`broadcast_updated_params` for ZeRO-2) | SHIPPED | `pub fn broadcast_updated_params` in `fsdp.rs`; consumer of `crate::collective::all_gather`; surfaced via `lib.rs` re-export |
30//! | REQ-12 (`update_shards` flat-buffer optimizer hook) | SHIPPED | `pub fn update_shards` in `fsdp.rs`; consumer via `lib.rs` re-export for downstream optimizers that produce flat parameter buffers |
31//! | REQ-13 (accessors) | SHIPPED | `pub fn strategy` / `module` / `module_mut` / `into_inner` / `backend` in `fsdp.rs`; consumer via `lib.rs` re-export — the user-facing accessor surface |
32
33use std::sync::Arc;
34
35use ferrotorch_core::storage::TensorStorage;
36use ferrotorch_core::{FerrotorchResult, Float, Tensor};
37use ferrotorch_nn::{Module, Parameter};
38
39use crate::async_collective::{PendingCollective, async_all_gather};
40use crate::backend::{Backend, SubBackend};
41use crate::collective::{ReduceOp, all_gather, allreduce, reduce_scatter};
42
43/// Sharding strategy for [`FSDP`]. Mirrors PyTorch's `ShardingStrategy`.
44///
45/// - [`FullShard`](ShardingStrategy::FullShard) — shard parameters,
46/// gradients, and optimizer state. Minimum memory, maximum
47/// communication. Equivalent to ZeRO-3.
48/// - [`ShardGradOp`](ShardingStrategy::ShardGradOp) — keep parameters
49/// replicated on every rank (no parameter all-gather in forward),
50/// but shard gradients and optimizer state. Equivalent to ZeRO-2.
51/// After `optimizer.step()`, call
52/// [`FSDP::broadcast_updated_params`] to re-sync the updated param
53/// shards back to every rank.
54/// - [`NoShard`](ShardingStrategy::NoShard) — equivalent to DDP: no
55/// sharding, allreduce the full gradient. Provided so the FSDP
56/// wrapper can be used as a drop-in replacement for DDP during
57/// debugging or for single-node experiments.
58/// - [`HybridShard`](ShardingStrategy::HybridShard) — shard within a
59/// node (intra-node FullShard) and replicate across nodes
60/// (inter-node DDP). Uses two [`SubBackend`]s derived from the
61/// global backend to run intra-node `all_gather` / `reduce_scatter`
62/// and inter-node `allreduce` independently. CL-327.
63///
64/// CL-372.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum ShardingStrategy {
67 /// Shard parameters + gradients + optimizer state (ZeRO-3 /
68 /// full FSDP). This is the default and matches the behavior
69 /// that existed before CL-372.
70 #[default]
71 FullShard,
72 /// Shard gradients + optimizer state only, keep params
73 /// replicated (ZeRO-2).
74 ShardGradOp,
75 /// No sharding (ZeRO-0 / DDP equivalent). gradients are allreduced.
76 NoShard,
77 /// Hybrid sharding: shard within a node (intra-node FullShard) +
78 /// replicate across nodes (inter-node DDP). The `intra_node_size`
79 /// is the number of ranks per node (i.e., local GPUs). The total
80 /// `world_size` must be a multiple of `intra_node_size`.
81 /// CL-327.
82 HybridShard { intra_node_size: usize },
83}
84
85/// Fully Sharded Data Parallel module wrapper.
86///
87/// Wraps an inner [`Module`] and shards each parameter across ranks so that
88/// each rank only stores `1 / world_size` of the full parameter tensor.
89///
90/// # Forward pass
91///
92/// Before calling the inner module's `forward()`, FSDP all-gathers each
93/// shard to reconstruct the full parameter tensor and installs it into the
94/// module. The full-parameter tensors are stored in [`full_params`] so
95/// that backward can accumulate gradients on them.
96///
97/// # Gradient synchronization
98///
99/// After `backward()`, call [`sync_gradients`] to:
100/// 1. Read gradients from the full-parameter tensors stored during forward.
101/// 2. Reduce-scatter the full gradients so each rank gets only its shard
102/// portion of the gradient.
103/// 3. Set each shard parameter's gradient from the reduce-scattered result.
104///
105/// # Example
106///
107/// ```ignore
108/// let mut fsdp = FSDP::new(model, backend)?;
109///
110/// loop {
111/// let output = fsdp.forward(&input)?;
112/// let loss = criterion.forward(&output, &target)?;
113/// ferrotorch_core::backward(&loss)?;
114/// fsdp.sync_gradients()?;
115/// optimizer.step()?;
116/// optimizer.zero_grad()?;
117/// }
118/// ```
119pub struct FSDP<M: Module<T>, T: Float> {
120 module: M,
121 backend: Arc<dyn Backend>,
122 /// Active sharding strategy. Drives the behavior of `new`,
123 /// `forward`, `sync_gradients`, and `broadcast_updated_params`.
124 strategy: ShardingStrategy,
125 /// Original full-parameter shapes before sharding.
126 original_shapes: Vec<Vec<usize>>,
127 /// Full-param tensors from the last forward pass, kept alive so
128 /// backward can accumulate gradients on them.
129 full_params: Vec<Tensor<T>>,
130 /// Pending async all-gather handles produced by
131 /// [`FSDP::prefetch_forward_params`]. One entry per parameter, in
132 /// the same order as [`Module::parameters_mut`]. `None` means no
133 /// prefetch is in flight and `forward()` will use the synchronous
134 /// all-gather path. CL-373.
135 pending_prefetch: Option<Vec<PendingCollective<T>>>,
136 /// Intra-node subgroup used by [`ShardingStrategy::HybridShard`].
137 /// `None` for all other strategies. CL-327.
138 intra_node_group: Option<Arc<SubBackend>>,
139 /// Inter-node subgroup used by [`ShardingStrategy::HybridShard`].
140 /// `None` for all other strategies. CL-327.
141 inter_node_group: Option<Arc<SubBackend>>,
142 _marker: std::marker::PhantomData<T>,
143}
144
145impl<M: Module<T>, T: Float> FSDP<M, T> {
146 /// Wrap a module for fully-sharded data-parallel training.
147 ///
148 /// Each parameter is split evenly across `world_size` ranks. This rank
149 /// keeps only its shard (the `rank`-th chunk). The original parameter
150 /// shapes are recorded for reconstruction during forward.
151 ///
152 /// # Panics
153 ///
154 /// Panics if any parameter's element count is not evenly divisible by
155 /// `world_size`.
156 pub fn new(module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self> {
157 Self::new_with_strategy(module, backend, ShardingStrategy::FullShard)
158 }
159
160 /// Wrap a module for data-parallel training with a specific
161 /// [`ShardingStrategy`].
162 ///
163 /// - `FullShard` — shard parameters, gradients, and optimizer state
164 /// (the classic FSDP / ZeRO-3 behavior; identical to [`new`]).
165 /// - `ShardGradOp` — keep parameters replicated on every rank and
166 /// only shard gradients + optimizer state (ZeRO-2). After
167 /// calling the optimizer step on the shard gradients, the caller
168 /// must call [`broadcast_updated_params`] to re-sync the updated
169 /// parameter shards back to every rank. CL-372.
170 /// - `NoShard` — no sharding (ZeRO-0 / DDP equivalent). Gradients
171 /// are allreduced across ranks in `sync_gradients` and all ranks
172 /// update the full parameters locally.
173 pub fn new_with_strategy(
174 mut module: M,
175 backend: Arc<dyn Backend>,
176 strategy: ShardingStrategy,
177 ) -> FerrotorchResult<Self> {
178 let rank = backend.rank();
179 let world_size = backend.world_size();
180 let mut original_shapes = Vec::new();
181
182 // For HybridShard, build the intra- and inter-node subgroups
183 // upfront so we can use the intra_group's world_size to shard
184 // parameters below.
185 let (intra_node_group, inter_node_group) = match strategy {
186 ShardingStrategy::HybridShard { intra_node_size } => {
187 if intra_node_size == 0 || world_size % intra_node_size != 0 {
188 return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
189 message: format!(
190 "HybridShard: world_size={world_size} must be a positive multiple of intra_node_size={intra_node_size}"
191 ),
192 });
193 }
194
195 // This node's row index in the (inter_size x intra_size)
196 // rank grid.
197 let node_idx = rank / intra_node_size;
198 let local_idx = rank % intra_node_size;
199
200 // Intra-node members: contiguous block of `intra_node_size`
201 // ranks starting at `node_idx * intra_node_size`.
202 let intra_members: Vec<usize> = (node_idx * intra_node_size
203 ..node_idx * intra_node_size + intra_node_size)
204 .collect();
205 let intra = Arc::new(SubBackend::new(Arc::clone(&backend), intra_members)?);
206
207 // Inter-node members: ranks with the same local_idx across
208 // every node (i.e., ranks stride-selected by
209 // `intra_node_size`).
210 let inter_members: Vec<usize> = (0..world_size / intra_node_size)
211 .map(|n| n * intra_node_size + local_idx)
212 .collect();
213 let inter = Arc::new(SubBackend::new(Arc::clone(&backend), inter_members)?);
214
215 (Some(intra), Some(inter))
216 }
217 _ => (None, None),
218 };
219
220 {
221 let params = module.parameters_mut();
222 for param in params {
223 let tensor = param.tensor();
224 let shape = tensor.shape().to_vec();
225 original_shapes.push(shape);
226
227 match strategy {
228 ShardingStrategy::FullShard => {
229 let numel = tensor.numel();
230 assert!(
231 numel % world_size == 0,
232 "FSDP: parameter with {numel} elements is not evenly divisible by world_size {world_size}"
233 );
234 let data = tensor.data_vec()?;
235 let chunk_size = numel / world_size;
236 let start = rank * chunk_size;
237 let end = start + chunk_size;
238 let shard_data = data[start..end].to_vec();
239 let shard_tensor = Tensor::from_storage(
240 TensorStorage::cpu(shard_data),
241 vec![chunk_size],
242 true,
243 )?;
244 *param = Parameter::new(shard_tensor);
245 }
246 ShardingStrategy::HybridShard { .. } => {
247 // Shard within the node only. Each rank keeps
248 // `1 / intra_node_size` of each parameter.
249 //
250 // INVARIANT: `intra_node_group` is `Some(_)` for every
251 // `HybridShard { .. }` strategy. The local binding was
252 // assigned `Some(intra)` ~50 lines above (in the
253 // `match strategy` that opens this fn) on exactly the
254 // `HybridShard { .. }` arm, and was never reassigned
255 // before reaching this match. Since this arm is only
256 // entered under the same `HybridShard { .. }` pattern,
257 // `as_ref()` here cannot observe `None` — `expect()`
258 // is provably unreachable. Category C per
259 // rust-fix-discipline.
260 let intra = intra_node_group.as_ref().expect(
261 "FSDP::new_with_strategy: intra_node_group is Some \
262 for HybridShard (set ~50 lines above on the same \
263 match arm; never reassigned)",
264 );
265 let intra_size = intra.world_size();
266 let intra_rank = intra.rank();
267 let numel = tensor.numel();
268 assert!(
269 numel % intra_size == 0,
270 "FSDP HybridShard: parameter with {numel} elements is not evenly divisible by intra_node_size {intra_size}"
271 );
272 let data = tensor.data_vec()?;
273 let chunk_size = numel / intra_size;
274 let start = intra_rank * chunk_size;
275 let end = start + chunk_size;
276 let shard_data = data[start..end].to_vec();
277 let shard_tensor = Tensor::from_storage(
278 TensorStorage::cpu(shard_data),
279 vec![chunk_size],
280 true,
281 )?;
282 *param = Parameter::new(shard_tensor);
283 }
284 ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
285 // Keep the full parameter on this rank; only
286 // gradients (and optimizer state, as an
287 // external concern) are sharded for ShardGradOp.
288 // NoShard is a plain DDP-style replication.
289 //
290 // For ShardGradOp, each rank still needs to know
291 // which slice of the flat parameter is "its"
292 // shard for the optimizer step. That's derived
293 // at grad-sync time from world_size + rank.
294 }
295 }
296 }
297 }
298
299 Ok(Self {
300 module,
301 backend,
302 strategy,
303 original_shapes,
304 full_params: Vec::new(),
305 pending_prefetch: None,
306 intra_node_group,
307 inter_node_group,
308 _marker: std::marker::PhantomData,
309 })
310 }
311
312 /// Return the active sharding strategy.
313 pub fn strategy(&self) -> ShardingStrategy {
314 self.strategy
315 }
316
317 /// Kick off asynchronous all-gathers for every parameter so the
318 /// next [`forward`](Self::forward) call consumes the pre-gathered
319 /// tensors instead of blocking on a fresh all-gather.
320 ///
321 /// This is FSDP's equivalent of PyTorch's backward prefetch:
322 /// communication for layer N+1 (or the next forward pass) overlaps
323 /// with compute for layer N. The caller should insert local
324 /// compute (e.g., the previous layer's backward, or input
325 /// preprocessing) between `prefetch_forward_params` and `forward`
326 /// to realize the overlap.
327 ///
328 /// Only valid for [`ShardingStrategy::FullShard`] — the other
329 /// strategies keep parameters replicated on every rank so there's
330 /// nothing to all-gather. Calling this on a non-`FullShard` FSDP
331 /// returns an `InvalidArgument` error.
332 ///
333 /// # Invariant
334 ///
335 /// Exactly one `prefetch_forward_params` → `forward` pair should be
336 /// in flight at any time on a given FSDP instance. Calling
337 /// `prefetch_forward_params` twice in a row (without an intervening
338 /// `forward`) returns an `InvalidArgument` error.
339 ///
340 /// CL-373.
341 pub fn prefetch_forward_params(&mut self) -> FerrotorchResult<()> {
342 if self.strategy != ShardingStrategy::FullShard {
343 return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
344 message: format!(
345 "FSDP::prefetch_forward_params: prefetch is only valid for FullShard, got {:?}",
346 self.strategy
347 ),
348 });
349 }
350 if self.pending_prefetch.is_some() {
351 return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
352 message: "FSDP::prefetch_forward_params called twice without intervening forward()"
353 .into(),
354 });
355 }
356
357 let mut handles = Vec::new();
358 {
359 let params = self.module.parameters();
360 for param in params {
361 let shard = param.tensor().clone();
362 let h = async_all_gather(shard, Arc::clone(&self.backend));
363 handles.push(h);
364 }
365 }
366 self.pending_prefetch = Some(handles);
367 Ok(())
368 }
369
370 /// True if a prefetch is currently pending. Primarily useful for
371 /// tests and diagnostics.
372 pub fn has_pending_prefetch(&self) -> bool {
373 self.pending_prefetch.is_some()
374 }
375
376 /// Immutable access to the inner module.
377 pub fn module(&self) -> &M {
378 &self.module
379 }
380
381 /// Mutable access to the inner module.
382 pub fn module_mut(&mut self) -> &mut M {
383 &mut self.module
384 }
385
386 /// Consume the wrapper and return the inner module.
387 pub fn into_inner(self) -> M {
388 self.module
389 }
390
391 /// The backend used for communication.
392 pub fn backend(&self) -> &Arc<dyn Backend> {
393 &self.backend
394 }
395
396 /// Reconstruct full parameters from shards across all ranks and run
397 /// the inner module's forward pass.
398 ///
399 /// The all-gathered full-parameter tensors are stored in `self.full_params`
400 /// so their gradients can be read after backward.
401 ///
402 /// For `ShardGradOp` and `NoShard` strategies, parameters are already
403 /// full on every rank, so no all-gather happens and `full_params` is
404 /// populated from the current parameter tensors directly.
405 ///
406 /// If [`prefetch_forward_params`](Self::prefetch_forward_params) was
407 /// called earlier, the pending async all-gather handles are consumed
408 /// here instead of running the synchronous all_gather — this is how
409 /// FSDP hides all-gather latency behind whatever local compute
410 /// happened between `prefetch_forward_params` and `forward`.
411 pub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
412 let world_size = self.backend.world_size();
413 self.full_params.clear();
414
415 // Grab any pending prefetch handles. They are consumed here so
416 // a subsequent forward() without a matching prefetch reverts to
417 // synchronous all_gather.
418 let mut pending = self.pending_prefetch.take();
419
420 match self.strategy {
421 ShardingStrategy::FullShard => {
422 if let Some(ref p) = pending {
423 if p.len() != self.module.parameters().len() {
424 return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
425 message: format!(
426 "FSDP prefetch: have {} pending handles but module has {} parameters",
427 p.len(),
428 self.module.parameters().len(),
429 ),
430 });
431 }
432 }
433 // Pop handles from the front by reversing so we can pop from back.
434 if let Some(ref mut p) = pending {
435 p.reverse();
436 }
437
438 let params = self.module.parameters_mut();
439 for (i, param) in params.into_iter().enumerate() {
440 let orig_shape = &self.original_shapes[i];
441
442 // Get the gathered full tensor either from a pending
443 // async handle or via a fresh synchronous all_gather.
444 let full = if let Some(ref mut handles) = pending {
445 // Consume the last handle (original order via reverse).
446 let handle = handles.pop().ok_or_else(|| {
447 ferrotorch_core::FerrotorchError::InvalidArgument {
448 message: "FSDP prefetch: exhausted pending handles".into(),
449 }
450 })?;
451 handle.wait()?
452 } else {
453 let shard = param.tensor().clone();
454 if world_size == 1 {
455 shard
456 } else {
457 all_gather(&shard, self.backend.as_ref())?
458 }
459 };
460
461 // Reshape to the original parameter shape and enable grad.
462 let full = Tensor::from_storage(
463 TensorStorage::cpu(full.data_vec()?),
464 orig_shape.clone(),
465 true,
466 )?;
467
468 self.full_params.push(full.clone());
469
470 // Install the full parameter into the module for this forward pass.
471 *param = Parameter::new(full);
472 }
473 }
474 ShardingStrategy::HybridShard { .. } => {
475 // HybridShard: all-gather the shards within the node only.
476 // Prefetch is not supported for this strategy yet — the
477 // intra-node subgroup needs its own async path. CL-327.
478 if pending.is_some() {
479 return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
480 message:
481 "FSDP prefetch_forward_params is not yet implemented for HybridShard"
482 .into(),
483 });
484 }
485
486 // INVARIANT: `self.intra_node_group` is `Some(_)` whenever
487 // `self.strategy` is `HybridShard { .. }`. Both fields are
488 // private and only assigned in `new_with_strategy`, which sets
489 // them as a coupled pair on the `HybridShard { .. }` arm. The
490 // strategy field is never reassigned after construction (no
491 // setter, no internal mutation). Since this arm is gated on
492 // the strategy, `as_ref()` here cannot observe `None`.
493 // Category C per rust-fix-discipline.
494 let intra = self.intra_node_group.as_ref().expect(
495 "FSDP::forward (HybridShard): intra_node_group is Some for \
496 HybridShard strategy (paired with strategy in \
497 new_with_strategy; never reassigned)",
498 );
499 let intra_ref: &dyn Backend = &**intra;
500 let intra_size = intra.world_size();
501
502 let params = self.module.parameters_mut();
503 for (i, param) in params.into_iter().enumerate() {
504 let orig_shape = &self.original_shapes[i];
505
506 let shard = param.tensor().clone();
507 let full = if intra_size == 1 {
508 shard
509 } else {
510 all_gather(&shard, intra_ref)?
511 };
512
513 let full = Tensor::from_storage(
514 TensorStorage::cpu(full.data_vec()?),
515 orig_shape.clone(),
516 true,
517 )?;
518
519 self.full_params.push(full.clone());
520 *param = Parameter::new(full);
521 }
522 }
523 ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
524 // Parameters are already full on every rank. Prefetch
525 // is meaningless for these strategies — surface any
526 // accidental misuse.
527 if pending.is_some() {
528 return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
529 message: "FSDP prefetch_forward_params is only valid for FullShard; \
530 use ShardingStrategy::FullShard or don't prefetch"
531 .into(),
532 });
533 }
534 // We still need to wrap them in requires_grad=true leaves
535 // so the autograd graph can flow through the forward
536 // pass, and stash them in full_params so sync_gradients
537 // can read the backward-accumulated gradients.
538 let params = self.module.parameters_mut();
539 for param in params.into_iter() {
540 let t = param.tensor().clone();
541 let data = t.data_vec()?;
542 let shape = t.shape().to_vec();
543 let full = Tensor::from_storage(TensorStorage::cpu(data), shape, true)?;
544 self.full_params.push(full.clone());
545 *param = Parameter::new(full);
546 }
547 }
548 }
549
550 let output = self.module.forward(input)?;
551
552 // After forward, restore shard parameters (FullShard /
553 // HybridShard) or leave the full params in place
554 // (ShardGradOp / NoShard).
555 match self.strategy {
556 ShardingStrategy::FullShard => self.restore_shards()?,
557 ShardingStrategy::HybridShard { .. } => self.restore_hybrid_shards()?,
558 ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
559 // Nothing to do: params are already full on every rank.
560 }
561 }
562
563 Ok(output)
564 }
565
566 /// Replace the current full parameter tensors with this rank's
567 /// intra-node shards. Used by [`ShardingStrategy::HybridShard`] after
568 /// forward completes. CL-327.
569 fn restore_hybrid_shards(&mut self) -> FerrotorchResult<()> {
570 // INVARIANT: `restore_hybrid_shards` is only reached from `forward()`
571 // under `ShardingStrategy::HybridShard { .. }`. `self.intra_node_group`
572 // is `Some(_)` whenever `self.strategy` is `HybridShard` (paired in
573 // `new_with_strategy`; both fields private; strategy never reassigned).
574 // Category C per rust-fix-discipline.
575 let intra = self
576 .intra_node_group
577 .as_ref()
578 .expect(
579 "FSDP::restore_hybrid_shards: intra_node_group is Some for \
580 HybridShard (only callsite is forward() under HybridShard arm)",
581 )
582 .clone();
583 let intra_size = intra.world_size();
584 let intra_rank = intra.rank();
585
586 let params = self.module.parameters_mut();
587 for (i, param) in params.into_iter().enumerate() {
588 let tensor = param.tensor();
589 let data = tensor.data_vec()?;
590 let numel = data.len();
591 let chunk_size = numel / intra_size;
592 let start = intra_rank * chunk_size;
593 let end = start + chunk_size;
594 let shard_data = data[start..end].to_vec();
595
596 let shard_tensor =
597 Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
598 *param = Parameter::new(shard_tensor);
599
600 // Preserve the original shape metadata.
601 let _ = &self.original_shapes[i];
602 }
603
604 Ok(())
605 }
606
607 /// Replace full parameters with their local shards to free memory.
608 fn restore_shards(&mut self) -> FerrotorchResult<()> {
609 let rank = self.backend.rank();
610 let world_size = self.backend.world_size();
611
612 let params = self.module.parameters_mut();
613 for (i, param) in params.into_iter().enumerate() {
614 let tensor = param.tensor();
615 let data = tensor.data_vec()?;
616 let numel = data.len();
617 let chunk_size = numel / world_size;
618 let start = rank * chunk_size;
619 let end = start + chunk_size;
620 let shard_data = data[start..end].to_vec();
621
622 let shard_tensor =
623 Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
624 *param = Parameter::new(shard_tensor);
625
626 // Preserve the original shape metadata.
627 let _ = &self.original_shapes[i];
628 }
629
630 Ok(())
631 }
632
633 /// Reduce-scatter gradients from the full-parameter tensors stored
634 /// during forward, then set each shard parameter's gradient.
635 ///
636 /// Call this after `backward()` and before `optimizer.step()`.
637 ///
638 /// # How it works
639 ///
640 /// 1. For each parameter, read the gradient from the full-param tensor
641 /// that was used during forward (stored in `self.full_params`).
642 /// 2. Reduce-scatter the full gradient across ranks (mean reduction) so
643 /// each rank gets only its shard portion.
644 /// 3. Set the shard parameter's `.grad()` to the reduce-scattered result.
645 ///
646 /// Using reduce-scatter (not allreduce) is correct for FSDP because each
647 /// rank only needs its own shard of the gradient to update its shard of
648 /// the parameter.
649 pub fn sync_gradients(&mut self) -> FerrotorchResult<()> {
650 let rank = self.backend.rank();
651 let world_size = self.backend.world_size();
652 let params = self.module.parameters_mut();
653
654 if self.full_params.len() != params.len() {
655 return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
656 message: format!(
657 "FSDP sync_gradients: expected {} full_params but have {}. \
658 Was forward() called before backward()?",
659 params.len(),
660 self.full_params.len(),
661 ),
662 });
663 }
664
665 for (i, param) in params.into_iter().enumerate() {
666 let full_param = &self.full_params[i];
667
668 // Read the gradient from the full-parameter tensor. If no
669 // gradient was computed (e.g., parameter was unused in forward),
670 // use zeros so all ranks exchange buffers of the same size.
671 let grad = full_param.grad()?;
672 let full_grad = match grad {
673 Some(g) => g,
674 None => {
675 let numel = full_param.numel();
676 Tensor::from_storage(
677 TensorStorage::cpu(vec![<T as num_traits::Zero>::zero(); numel]),
678 full_param.shape().to_vec(),
679 false,
680 )?
681 }
682 };
683
684 // Flatten for reduction ops.
685 let grad_data = full_grad.data_vec()?;
686 let flat_grad = Tensor::from_storage(
687 TensorStorage::cpu(grad_data),
688 vec![full_grad.numel()],
689 false,
690 )?;
691
692 match self.strategy {
693 ShardingStrategy::FullShard => {
694 // Reduce-scatter: each rank gets its shard of the
695 // averaged gradient. Parameter is already the
696 // shard here (installed by restore_shards after
697 // forward), so set_grad lines up.
698 let shard_grad = if world_size == 1 {
699 flat_grad
700 } else {
701 reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
702 };
703 param.tensor().set_grad(Some(shard_grad))?;
704 }
705 ShardingStrategy::ShardGradOp => {
706 // ZeRO-2: reduce-scatter the flat gradient into a
707 // per-rank slice, then set the parameter's .grad()
708 // to a tensor shaped like the full parameter but
709 // with only this rank's shard positions populated
710 // and the rest zeroed. The optimizer's update at
711 // the non-shard positions becomes a no-op, so
712 // each rank effectively updates only its shard
713 // slice. After optimizer.step, the caller must
714 // invoke broadcast_updated_params to re-sync.
715 let numel = flat_grad.numel();
716 assert!(
717 numel % world_size == 0,
718 "FSDP ShardGradOp: parameter with {numel} elements is not evenly \
719 divisible by world_size {world_size}"
720 );
721 let shard_grad_flat = if world_size == 1 {
722 flat_grad
723 } else {
724 reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
725 };
726 let chunk_size = numel / world_size;
727 let shard_data = shard_grad_flat.data_vec()?;
728 // Pad with zeros at non-shard positions so the
729 // gradient tensor matches the full parameter shape.
730 let mut padded = vec![<T as num_traits::Zero>::zero(); numel];
731 let start = rank * chunk_size;
732 padded[start..start + chunk_size].copy_from_slice(&shard_data);
733 let padded_grad = Tensor::from_storage(
734 TensorStorage::cpu(padded),
735 full_param.shape().to_vec(),
736 false,
737 )?;
738 param.tensor().set_grad(Some(padded_grad))?;
739 }
740 ShardingStrategy::HybridShard { .. } => {
741 // HybridShard: reduce-scatter within the node to
742 // get this rank's intra-node shard of the mean
743 // gradient, then allreduce the shard across nodes
744 // so every replica of this intra-rank has the
745 // same gradient. Parameter is already the
746 // intra-node shard (installed by
747 // restore_hybrid_shards after forward).
748 //
749 // INVARIANT: `intra_node_group` and `inter_node_group`
750 // are both `Some(_)` whenever `self.strategy` is
751 // `HybridShard { .. }`. The two fields are assigned as a
752 // coupled pair in `new_with_strategy` on the same arm
753 // that this match enters; all three are private and
754 // never reassigned. Both `expect()`s are provably
755 // unreachable. Category C per rust-fix-discipline.
756 let intra = self.intra_node_group.as_ref().expect(
757 "FSDP::sync_gradients (HybridShard): intra_node_group \
758 is Some for HybridShard (paired with strategy in \
759 new_with_strategy; never reassigned)",
760 );
761 let inter = self.inter_node_group.as_ref().expect(
762 "FSDP::sync_gradients (HybridShard): inter_node_group \
763 is Some for HybridShard (paired with strategy in \
764 new_with_strategy; never reassigned)",
765 );
766 let intra_ref: &dyn Backend = &**intra;
767 let inter_ref: &dyn Backend = &**inter;
768 let intra_size = intra.world_size();
769 let inter_size = inter.world_size();
770
771 // Step 1: intra-node reduce_scatter (mean).
772 let intra_shard = if intra_size == 1 {
773 flat_grad
774 } else {
775 reduce_scatter(&flat_grad, intra_ref, ReduceOp::Mean)?
776 };
777
778 // Step 2: inter-node allreduce (mean across
779 // replicas).
780 let replicated = if inter_size == 1 {
781 intra_shard
782 } else {
783 allreduce(&intra_shard, inter_ref, ReduceOp::Mean)?
784 };
785
786 param.tensor().set_grad(Some(replicated))?;
787 }
788 ShardingStrategy::NoShard => {
789 // Plain DDP: allreduce the full gradient so every
790 // rank has the same averaged gradient, then set
791 // it on the full parameter.
792 let reduced = if world_size == 1 {
793 flat_grad
794 } else {
795 allreduce(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
796 };
797 let reduced_full = Tensor::from_storage(
798 TensorStorage::cpu(reduced.data_vec()?),
799 full_param.shape().to_vec(),
800 false,
801 )?;
802 param.tensor().set_grad(Some(reduced_full))?;
803 }
804 }
805 }
806
807 // Clear full_params to free memory now that gradients have been read.
808 self.full_params.clear();
809
810 Ok(())
811 }
812
813 /// For `ShardGradOp`: after `optimizer.step()`, each rank has
814 /// applied the update to its own shard of the full parameter
815 /// (because `sync_gradients` zeroed the non-shard positions of the
816 /// gradient). This method re-syncs the parameter tensors so every
817 /// rank has the fully updated parameter, by summing contributions
818 /// via an allreduce: each rank contributes its updated shard, zero
819 /// elsewhere; the sum across ranks is the full updated parameter.
820 ///
821 /// More precisely, this method reconstructs the full parameter as
822 /// an allgather of per-rank shards. It is a no-op for `FullShard`
823 /// and `NoShard` strategies (they already have consistent
824 /// parameters after step).
825 ///
826 /// Call this AFTER `optimizer.step()` and BEFORE the next
827 /// `forward()`. CL-372.
828 pub fn broadcast_updated_params(&mut self) -> FerrotorchResult<()> {
829 if self.strategy != ShardingStrategy::ShardGradOp {
830 // Nothing to do for FullShard (already shard-local) or
831 // NoShard (all ranks already have the same params).
832 return Ok(());
833 }
834
835 let rank = self.backend.rank();
836 let world_size = self.backend.world_size();
837 if world_size == 1 {
838 return Ok(());
839 }
840
841 let params = self.module.parameters_mut();
842 for param in params {
843 // Extract this rank's shard from the updated full parameter.
844 let full = param.tensor();
845 let full_data = full.data_vec()?;
846 let numel = full_data.len();
847 assert!(
848 numel % world_size == 0,
849 "FSDP broadcast_updated_params: parameter with {numel} elements is not evenly \
850 divisible by world_size {world_size}"
851 );
852 let chunk_size = numel / world_size;
853 let start = rank * chunk_size;
854 let end = start + chunk_size;
855 let shard = full_data[start..end].to_vec();
856 let shard_tensor =
857 Tensor::from_storage(TensorStorage::cpu(shard), vec![chunk_size], false)?;
858
859 // All-gather across ranks to get the full updated parameter.
860 let gathered = all_gather(&shard_tensor, self.backend.as_ref())?;
861 let full_shape = full.shape().to_vec();
862 let new_full =
863 Tensor::from_storage(TensorStorage::cpu(gathered.data_vec()?), full_shape, true)?;
864 *param = Parameter::new(new_full);
865 }
866 Ok(())
867 }
868
869 /// Update shard parameters from a flat data slice.
870 ///
871 /// This is used by optimizers that produce a flat parameter buffer.
872 /// The slice must have exactly the number of elements expected for
873 /// this rank's shards.
874 pub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()> {
875 let params = self.module.parameters_mut();
876 let total_shard_numel: usize = params.iter().map(|p| p.tensor().numel()).sum();
877
878 assert!(
879 flat_data.len() == total_shard_numel,
880 "FSDP update_shards: expected {} elements but got {}",
881 total_shard_numel,
882 flat_data.len(),
883 );
884
885 let mut offset = 0;
886 for param in params {
887 let numel = param.tensor().numel();
888 let shard_data = flat_data[offset..offset + numel].to_vec();
889 let shard_tensor = Tensor::from_storage(
890 TensorStorage::cpu(shard_data),
891 param.tensor().shape().to_vec(),
892 true,
893 )?;
894 *param = Parameter::new(shard_tensor);
895 offset += numel;
896 }
897
898 Ok(())
899 }
900}
901
902// FSDP does NOT implement Module<T> because forward() requires &mut self
903// (to store full_params). Callers must use fsdp.forward() directly.
904
905#[cfg(test)]
906mod tests {
907 use super::*;
908 use crate::backend::SimulatedBackend;
909 use ferrotorch_core::storage::TensorStorage;
910 use ferrotorch_core::{FerrotorchResult, Tensor};
911 use ferrotorch_nn::Parameter;
912 use std::thread;
913
914 /// Minimal module with one parameter for testing FSDP.
915 struct TestModule<T: Float> {
916 weight: Parameter<T>,
917 training: bool,
918 }
919
920 impl<T: Float> TestModule<T> {
921 fn new(data: &[T]) -> FerrotorchResult<Self> {
922 Ok(Self {
923 weight: Parameter::from_slice(data, &[data.len()])?,
924 training: true,
925 })
926 }
927 }
928
929 impl<T: Float> Module<T> for TestModule<T> {
930 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
931 // Simple forward: multiply input by weight sum (produces a scalar
932 // that depends on all weight elements).
933 let w_data = self.weight.tensor().data_vec()?;
934 let w_sum: T = w_data
935 .iter()
936 .copied()
937 .fold(<T as num_traits::Zero>::zero(), |a, b| a + b);
938 let i_data = input.data_vec()?;
939 let out: Vec<T> = i_data.iter().map(|&x| x * w_sum).collect();
940 Tensor::from_storage(TensorStorage::cpu(out), input.shape().to_vec(), false)
941 }
942
943 fn parameters(&self) -> Vec<&Parameter<T>> {
944 vec![&self.weight]
945 }
946
947 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
948 vec![&mut self.weight]
949 }
950
951 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
952 vec![("weight".into(), &self.weight)]
953 }
954
955 fn train(&mut self) {
956 self.training = true;
957 }
958
959 fn eval(&mut self) {
960 self.training = false;
961 }
962
963 fn is_training(&self) -> bool {
964 self.training
965 }
966 }
967
968 #[test]
969 fn test_fsdp_sharding() {
970 // 2 ranks, parameter [10, 20, 30, 40].
971 // Rank 0 gets [10, 20], Rank 1 gets [30, 40].
972 let group = SimulatedBackend::create_group(2).unwrap();
973 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
974
975 let handles: Vec<_> = arcs
976 .iter()
977 .cloned()
978 .map(|b| {
979 thread::spawn(move || {
980 let rank = b.rank();
981 let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
982 let fsdp = FSDP::new(model, b).unwrap();
983
984 let shard = fsdp.module().weight.tensor().data_vec().unwrap();
985 (rank, shard)
986 })
987 })
988 .collect();
989
990 for h in handles {
991 let (rank, shard) = h.join().unwrap();
992 if rank == 0 {
993 assert_eq!(shard, &[10.0, 20.0]);
994 } else {
995 assert_eq!(shard, &[30.0, 40.0]);
996 }
997 }
998 }
999
1000 #[test]
1001 fn test_fsdp_shard_requires_grad() {
1002 // Shard parameters must have requires_grad=true.
1003 let group = SimulatedBackend::create_group(2).unwrap();
1004 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1005
1006 let handles: Vec<_> = arcs
1007 .iter()
1008 .cloned()
1009 .map(|b| {
1010 thread::spawn(move || {
1011 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1012 let fsdp = FSDP::new(model, b).unwrap();
1013 fsdp.module().weight.tensor().requires_grad()
1014 })
1015 })
1016 .collect();
1017
1018 for h in handles {
1019 assert!(h.join().unwrap(), "shard must have requires_grad=true");
1020 }
1021 }
1022
1023 #[test]
1024 fn test_fsdp_forward_restores_shards() {
1025 // After forward(), parameters should be back to shard size.
1026 let group = SimulatedBackend::create_group(2).unwrap();
1027 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1028
1029 let handles: Vec<_> = arcs
1030 .iter()
1031 .cloned()
1032 .map(|b| {
1033 thread::spawn(move || {
1034 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1035 let mut fsdp = FSDP::new(model, b).unwrap();
1036
1037 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1038 let _output = fsdp.forward(&input).unwrap();
1039
1040 // After forward, shard should be size 2 (4 / 2 ranks).
1041 let shard = fsdp.module().weight.tensor();
1042 assert_eq!(shard.numel(), 2);
1043 assert!(shard.requires_grad());
1044 })
1045 })
1046 .collect();
1047
1048 for h in handles {
1049 h.join().unwrap();
1050 }
1051 }
1052
1053 #[test]
1054 fn test_fsdp_forward_produces_correct_output() {
1055 // 2 ranks, param [1, 2, 3, 4], weight_sum = 10.
1056 // Input [2.0] -> output should be [20.0] on all ranks.
1057 let group = SimulatedBackend::create_group(2).unwrap();
1058 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1059
1060 let handles: Vec<_> = arcs
1061 .iter()
1062 .cloned()
1063 .map(|b| {
1064 thread::spawn(move || {
1065 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1066 let mut fsdp = FSDP::new(model, b).unwrap();
1067
1068 let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
1069 let output = fsdp.forward(&input).unwrap();
1070 let data = output.data_vec().unwrap();
1071 assert!(
1072 (data[0] - 20.0).abs() < 1e-6,
1073 "expected 20.0, got {}",
1074 data[0]
1075 );
1076 })
1077 })
1078 .collect();
1079
1080 for h in handles {
1081 h.join().unwrap();
1082 }
1083 }
1084
1085 #[test]
1086 fn test_fsdp_update_shards() {
1087 let group = SimulatedBackend::create_group(1).unwrap();
1088 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1089 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1090 let mut fsdp = FSDP::new(model, b).unwrap();
1091
1092 fsdp.update_shards(&[10.0, 20.0, 30.0, 40.0]).unwrap();
1093 let data = fsdp.module().weight.tensor().data_vec().unwrap();
1094 assert_eq!(data, &[10.0, 20.0, 30.0, 40.0]);
1095 }
1096
1097 #[test]
1098 #[should_panic(expected = "expected 4 elements but got 2")]
1099 fn test_fsdp_update_shards_size_validation() {
1100 let group = SimulatedBackend::create_group(1).unwrap();
1101 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1102 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1103 let mut fsdp = FSDP::new(model, b).unwrap();
1104
1105 // Wrong size: should panic.
1106 fsdp.update_shards(&[10.0, 20.0]).unwrap();
1107 }
1108
1109 #[test]
1110 fn test_fsdp_sync_gradients_single_rank() {
1111 // Single rank: sync_gradients should pass through the gradient
1112 // from the full param to the shard param.
1113 let group = SimulatedBackend::create_group(1).unwrap();
1114 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1115 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1116 let mut fsdp = FSDP::new(model, b).unwrap();
1117
1118 // Run forward to populate full_params.
1119 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1120 let _output = fsdp.forward(&input).unwrap();
1121
1122 // Manually set gradient on full_params (simulating backward).
1123 let grad = Tensor::from_storage(
1124 TensorStorage::cpu(vec![0.1f32, 0.2, 0.3, 0.4]),
1125 vec![4],
1126 false,
1127 )
1128 .unwrap();
1129 fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1130
1131 fsdp.sync_gradients().unwrap();
1132
1133 // Shard param should now have the full gradient (single rank = no scatter).
1134 let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
1135 let data = shard_grad.data_vec().unwrap();
1136 assert_eq!(data, &[0.1, 0.2, 0.3, 0.4]);
1137 }
1138
1139 #[test]
1140 fn test_fsdp_shard_grad_op_keeps_full_params() {
1141 // ShardGradOp (ZeRO-2): params stay replicated on every rank.
1142 // Verify each rank has the full parameter after `new_with_strategy`.
1143 let group = SimulatedBackend::create_group(2).unwrap();
1144 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1145
1146 let handles: Vec<_> = arcs
1147 .iter()
1148 .cloned()
1149 .map(|b| {
1150 thread::spawn(move || {
1151 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1152 let fsdp =
1153 FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
1154 assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
1155 fsdp.module().weight.tensor().data_vec().unwrap()
1156 })
1157 })
1158 .collect();
1159
1160 for h in handles {
1161 let data = h.join().unwrap();
1162 assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
1163 }
1164 }
1165
1166 #[test]
1167 // reason: sharded-grad zero-padding writes the exact bit pattern 0.0 to
1168 // off-shard slots (no arithmetic). The on-shard slots are tested with
1169 // an explicit epsilon (the abs/<1e-6 lines just above each assert_eq!),
1170 // because those go through reduce-scatter mean. The 0.0 == 0.0 sentinel
1171 // and ints {0,1,2,3,4} are bit-exact, so equality is the right check.
1172 #[allow(clippy::float_cmp)]
1173 fn test_fsdp_shard_grad_op_sync_gradients_multi_rank() {
1174 // ZeRO-2: two ranks, param [1,2,3,4], both ranks produce full grad
1175 // [1,2,3,4]. reduce_scatter(mean) gives rank 0 the slice [1,2] and
1176 // rank 1 the slice [3,4]. Each rank's .grad() is then padded so the
1177 // other-rank positions are zero, giving
1178 // rank 0 grad: [1,2,0,0]
1179 // rank 1 grad: [0,0,3,4]
1180 let group = SimulatedBackend::create_group(2).unwrap();
1181 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1182
1183 let handles: Vec<_> = arcs
1184 .iter()
1185 .cloned()
1186 .map(|b| {
1187 thread::spawn(move || {
1188 let rank = b.rank();
1189 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1190 let mut fsdp =
1191 FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
1192
1193 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1194 let _output = fsdp.forward(&input).unwrap();
1195
1196 let grad = Tensor::from_storage(
1197 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
1198 vec![4],
1199 false,
1200 )
1201 .unwrap();
1202 fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1203
1204 fsdp.sync_gradients().unwrap();
1205
1206 // Param remains full (not restored to shard) under
1207 // ShardGradOp, and .grad() is a tensor of the full
1208 // shape with only this rank's shard populated.
1209 let w = fsdp.module().weight.tensor();
1210 assert_eq!(w.numel(), 4, "ShardGradOp keeps params full");
1211 let g = w.grad().unwrap().unwrap();
1212 let gd = g.data_vec().unwrap();
1213 assert_eq!(gd.len(), 4, "grad should be full-shape");
1214 (rank, gd)
1215 })
1216 })
1217 .collect();
1218
1219 for h in handles {
1220 let (rank, gd) = h.join().unwrap();
1221 if rank == 0 {
1222 assert!((gd[0] - 1.0).abs() < 1e-6, "rank 0 [0]: {}", gd[0]);
1223 assert!((gd[1] - 2.0).abs() < 1e-6, "rank 0 [1]: {}", gd[1]);
1224 assert_eq!(gd[2], 0.0, "rank 0 [2] should be zero");
1225 assert_eq!(gd[3], 0.0, "rank 0 [3] should be zero");
1226 } else {
1227 assert_eq!(gd[0], 0.0, "rank 1 [0] should be zero");
1228 assert_eq!(gd[1], 0.0, "rank 1 [1] should be zero");
1229 assert!((gd[2] - 3.0).abs() < 1e-6, "rank 1 [2]: {}", gd[2]);
1230 assert!((gd[3] - 4.0).abs() < 1e-6, "rank 1 [3]: {}", gd[3]);
1231 }
1232 }
1233 }
1234
1235 #[test]
1236 fn test_fsdp_shard_grad_op_broadcast_updated_params() {
1237 // Full ZeRO-2 loop: each rank simulates an optimizer.step() that
1238 // applies a per-shard update (adds rank*10 to its slice), then
1239 // calls broadcast_updated_params. After that, every rank should
1240 // see the fully updated parameter.
1241 //
1242 // Rank 0 slice is [0,1]; rank 1 slice is [2,3].
1243 // Starting param: [1,2,3,4].
1244 // After per-rank update:
1245 // rank 0 local param: [1+10, 2+10, 3, 4] = [11, 12, 3, 4]
1246 // rank 1 local param: [1, 2, 3+20, 4+20] = [1, 2, 23, 24]
1247 // After broadcast_updated_params (allgather of rank-local shards):
1248 // both ranks: [11, 12, 23, 24]
1249 let group = SimulatedBackend::create_group(2).unwrap();
1250 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1251
1252 let handles: Vec<_> = arcs
1253 .iter()
1254 .cloned()
1255 .map(|b| {
1256 thread::spawn(move || {
1257 let rank = b.rank();
1258 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1259 let mut fsdp =
1260 FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
1261
1262 // Simulate per-rank optimizer step: each rank overwrites
1263 // its shard slice, leaving the other slice untouched.
1264 let mut local = fsdp.module().weight.tensor().data_vec().unwrap();
1265 if rank == 0 {
1266 local[0] += 10.0;
1267 local[1] += 10.0;
1268 } else {
1269 local[2] += 20.0;
1270 local[3] += 20.0;
1271 }
1272 let new_param =
1273 Tensor::from_storage(TensorStorage::cpu(local), vec![4], true).unwrap();
1274 *fsdp.module.parameters_mut()[0] = Parameter::new(new_param);
1275
1276 // Re-sync.
1277 fsdp.broadcast_updated_params().unwrap();
1278
1279 fsdp.module().weight.tensor().data_vec().unwrap()
1280 })
1281 })
1282 .collect();
1283
1284 for h in handles {
1285 let data = h.join().unwrap();
1286 assert_eq!(data, &[11.0, 12.0, 23.0, 24.0]);
1287 }
1288 }
1289
1290 #[test]
1291 fn test_fsdp_no_shard_is_ddp_equivalent() {
1292 // NoShard (ZeRO-0 / DDP): each rank has the full parameter and
1293 // allreduce-averages gradients. Param [1,2,3,4]; both ranks set
1294 // identical grads [1,2,3,4]; after sync both ranks should see the
1295 // same averaged grad [1,2,3,4] (identity since both contributions
1296 // are equal).
1297 let group = SimulatedBackend::create_group(2).unwrap();
1298 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1299
1300 let handles: Vec<_> = arcs
1301 .iter()
1302 .cloned()
1303 .map(|b| {
1304 thread::spawn(move || {
1305 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1306 let mut fsdp =
1307 FSDP::new_with_strategy(model, b, ShardingStrategy::NoShard).unwrap();
1308 assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
1309
1310 // Params stay full.
1311 assert_eq!(fsdp.module().weight.tensor().numel(), 4);
1312
1313 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1314 let _output = fsdp.forward(&input).unwrap();
1315
1316 let grad = Tensor::from_storage(
1317 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
1318 vec![4],
1319 false,
1320 )
1321 .unwrap();
1322 fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1323
1324 fsdp.sync_gradients().unwrap();
1325
1326 let w = fsdp.module().weight.tensor();
1327 assert_eq!(w.numel(), 4, "NoShard keeps params full");
1328 w.grad().unwrap().unwrap().data_vec().unwrap()
1329 })
1330 })
1331 .collect();
1332
1333 for h in handles {
1334 let gd = h.join().unwrap();
1335 assert_eq!(gd.len(), 4);
1336 // Mean of two identical [1,2,3,4] vectors is [1,2,3,4].
1337 for (i, expected) in [1.0f32, 2.0, 3.0, 4.0].iter().enumerate() {
1338 assert!(
1339 (gd[i] - expected).abs() < 1e-6,
1340 "NoShard allreduce: got {} at {}, expected {}",
1341 gd[i],
1342 i,
1343 expected
1344 );
1345 }
1346 }
1347 }
1348
1349 #[test]
1350 fn test_fsdp_prefetched_forward_matches_sync_forward() {
1351 // Prefetch then forward should produce exactly the same output
1352 // as a plain synchronous forward. 2 ranks, weight [1,2,3,4],
1353 // input [2.0] -> expected output [20.0].
1354 let group = SimulatedBackend::create_group(2).unwrap();
1355 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1356
1357 let handles: Vec<_> = arcs
1358 .iter()
1359 .cloned()
1360 .map(|b| {
1361 thread::spawn(move || {
1362 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1363 let mut fsdp = FSDP::new(model, b).unwrap();
1364
1365 // Kick off prefetch BEFORE doing other work.
1366 assert!(!fsdp.has_pending_prefetch());
1367 fsdp.prefetch_forward_params().unwrap();
1368 assert!(fsdp.has_pending_prefetch());
1369
1370 // Simulated "local compute" between prefetch and forward.
1371 let _scratch: f32 = (0..100).map(|i| i as f32).sum();
1372
1373 let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
1374 let output = fsdp.forward(&input).unwrap();
1375 // After forward, the prefetch handles have been consumed.
1376 assert!(!fsdp.has_pending_prefetch());
1377
1378 output.data_vec().unwrap()[0]
1379 })
1380 })
1381 .collect();
1382
1383 for h in handles {
1384 let v = h.join().unwrap();
1385 assert!((v - 20.0).abs() < 1e-6, "expected 20.0, got {v}");
1386 }
1387 }
1388
1389 #[test]
1390 fn test_fsdp_forward_without_prefetch_still_works() {
1391 // Smoke test: forward() without a prior prefetch should use the
1392 // synchronous path and produce the same result.
1393 let group = SimulatedBackend::create_group(2).unwrap();
1394 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1395
1396 let handles: Vec<_> = arcs
1397 .iter()
1398 .cloned()
1399 .map(|b| {
1400 thread::spawn(move || {
1401 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1402 let mut fsdp = FSDP::new(model, b).unwrap();
1403 assert!(!fsdp.has_pending_prefetch());
1404 let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
1405 let output = fsdp.forward(&input).unwrap();
1406 output.data_vec().unwrap()[0]
1407 })
1408 })
1409 .collect();
1410
1411 for h in handles {
1412 assert!((h.join().unwrap() - 20.0).abs() < 1e-6);
1413 }
1414 }
1415
1416 #[test]
1417 fn test_fsdp_prefetch_rejects_double_call() {
1418 // Calling prefetch twice without an intervening forward should
1419 // return an error (stale pending handles would be leaked).
1420 let group = SimulatedBackend::create_group(1).unwrap();
1421 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1422 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1423 let mut fsdp = FSDP::new(model, b).unwrap();
1424 fsdp.prefetch_forward_params().unwrap();
1425 let r = fsdp.prefetch_forward_params();
1426 assert!(r.is_err());
1427 let err = format!("{}", r.unwrap_err());
1428 assert!(err.contains("called twice"), "err = {err}");
1429
1430 // Consume the pending handles so the test doesn't leak a live
1431 // background thread that waits on channel traffic. Running
1432 // forward() drains and joins them cleanly.
1433 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1434 let _ = fsdp.forward(&input).unwrap();
1435 }
1436
1437 #[test]
1438 fn test_fsdp_prefetch_rejects_non_fullshard() {
1439 // Prefetch only makes sense for FullShard (the other strategies
1440 // don't do a parameter all_gather). Calling it on ShardGradOp
1441 // must return a clear error.
1442 let group = SimulatedBackend::create_group(1).unwrap();
1443 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1444 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1445 let mut fsdp = FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
1446 let r = fsdp.prefetch_forward_params();
1447 assert!(r.is_err());
1448 let err = format!("{}", r.unwrap_err());
1449 assert!(err.contains("FullShard"), "err = {err}");
1450 }
1451
1452 #[test]
1453 fn test_fsdp_no_shard_broadcast_is_noop() {
1454 // broadcast_updated_params should be a no-op for NoShard and
1455 // FullShard strategies.
1456 let group = SimulatedBackend::create_group(1).unwrap();
1457 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1458 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1459 let mut fsdp = FSDP::new_with_strategy(model, b, ShardingStrategy::NoShard).unwrap();
1460 fsdp.broadcast_updated_params().unwrap();
1461 assert_eq!(
1462 fsdp.module().weight.tensor().data_vec().unwrap(),
1463 &[1.0, 2.0, 3.0, 4.0]
1464 );
1465 }
1466
1467 #[test]
1468 fn test_fsdp_sync_gradients_multi_rank() {
1469 // 2 ranks, param size 4 -> shard size 2.
1470 // Both ranks set identical gradients on full_params: [1, 2, 3, 4].
1471 // reduce_scatter(mean) on [1,2,3,4] -> rank 0 gets [1,2], rank 1 gets [3,4].
1472 let group = SimulatedBackend::create_group(2).unwrap();
1473 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1474
1475 let handles: Vec<_> = arcs
1476 .iter()
1477 .cloned()
1478 .map(|b| {
1479 thread::spawn(move || {
1480 let rank = b.rank();
1481 let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1482 let mut fsdp = FSDP::new(model, b).unwrap();
1483
1484 // Run forward.
1485 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1486 let _output = fsdp.forward(&input).unwrap();
1487
1488 // Set gradient on full_params.
1489 let grad = Tensor::from_storage(
1490 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
1491 vec![4],
1492 false,
1493 )
1494 .unwrap();
1495 fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1496
1497 fsdp.sync_gradients().unwrap();
1498
1499 let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
1500 let data = shard_grad.data_vec().unwrap();
1501 (rank, data)
1502 })
1503 })
1504 .collect();
1505
1506 for h in handles {
1507 let (rank, data) = h.join().unwrap();
1508 if rank == 0 {
1509 // Mean of [1,2] from both ranks = [1,2].
1510 assert_eq!(data.len(), 2);
1511 assert!(
1512 (data[0] - 1.0).abs() < 1e-6,
1513 "rank 0: expected 1.0, got {}",
1514 data[0]
1515 );
1516 assert!(
1517 (data[1] - 2.0).abs() < 1e-6,
1518 "rank 0: expected 2.0, got {}",
1519 data[1]
1520 );
1521 } else {
1522 // Mean of [3,4] from both ranks = [3,4].
1523 assert_eq!(data.len(), 2);
1524 assert!(
1525 (data[0] - 3.0).abs() < 1e-6,
1526 "rank 1: expected 3.0, got {}",
1527 data[0]
1528 );
1529 assert!(
1530 (data[1] - 4.0).abs() < 1e-6,
1531 "rank 1: expected 4.0, got {}",
1532 data[1]
1533 );
1534 }
1535 }
1536 }
1537
1538 // -----------------------------------------------------------------------
1539 // HybridShard tests. CL-327
1540 // -----------------------------------------------------------------------
1541
1542 #[test]
1543 fn test_fsdp_hybrid_shard_rejects_uneven_world_size() {
1544 // world_size=4, intra_node_size=3 is not a divisor → error.
1545 let group = SimulatedBackend::create_group(4).unwrap();
1546 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1547
1548 let b = arcs[0].clone();
1549 let model = TestModule::<f32>::new(&[1.0f32; 8]).unwrap();
1550 let result = FSDP::new_with_strategy(
1551 model,
1552 b,
1553 ShardingStrategy::HybridShard { intra_node_size: 3 },
1554 );
1555 assert!(result.is_err(), "expected uneven intra_node_size to fail");
1556 }
1557
1558 #[test]
1559 fn test_fsdp_hybrid_shard_intra_node_sharding() {
1560 // world_size=4, intra_node_size=2:
1561 // node 0 = {rank 0, rank 1}
1562 // node 1 = {rank 2, rank 3}
1563 // Each rank shards a param of size 4 into chunks of size 4/2=2.
1564 // rank 0 (intra 0, node 0) -> [weight[0], weight[1]]
1565 // rank 1 (intra 1, node 0) -> [weight[2], weight[3]]
1566 // rank 2 (intra 0, node 1) -> [weight[0], weight[1]]
1567 // rank 3 (intra 1, node 1) -> [weight[2], weight[3]]
1568 let group = SimulatedBackend::create_group(4).unwrap();
1569 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1570
1571 let handles: Vec<_> = arcs
1572 .iter()
1573 .cloned()
1574 .map(|b| {
1575 thread::spawn(move || {
1576 let rank = b.rank();
1577 let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
1578 let fsdp = FSDP::new_with_strategy(
1579 model,
1580 b,
1581 ShardingStrategy::HybridShard { intra_node_size: 2 },
1582 )
1583 .unwrap();
1584 let shard = fsdp.module().weight.tensor().data_vec().unwrap();
1585 (rank, shard)
1586 })
1587 })
1588 .collect();
1589
1590 for h in handles {
1591 let (rank, shard) = h.join().unwrap();
1592 assert_eq!(shard.len(), 2, "each rank gets 4/2 = 2 elements");
1593 let expected: &[f32] = if rank % 2 == 0 {
1594 &[10.0, 20.0]
1595 } else {
1596 &[30.0, 40.0]
1597 };
1598 assert_eq!(shard, expected, "rank {} shard mismatch", rank);
1599 }
1600 }
1601
1602 #[test]
1603 fn test_fsdp_hybrid_shard_sync_gradients() {
1604 // world_size=4, intra_node_size=2:
1605 // node 0 = {rank 0, rank 1}, node 1 = {rank 2, rank 3}.
1606 // Each rank produces the same full grad [1,2,3,4]. After:
1607 // intra-node reduce_scatter(mean) → each intra-rank gets its shard
1608 // of the node-local mean gradient: rank 0 gets [1,2], rank 1
1609 // gets [3,4], rank 2 gets [1,2], rank 3 gets [3,4].
1610 // inter-node allreduce(mean) across replicas of the same intra
1611 // position: replicas produce identical shards, so the mean is
1612 // unchanged.
1613 // Final expected shard gradients:
1614 // ranks 0, 2 → [1, 2]
1615 // ranks 1, 3 → [3, 4]
1616 let group = SimulatedBackend::create_group(4).unwrap();
1617 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1618
1619 let handles: Vec<_> = arcs
1620 .iter()
1621 .cloned()
1622 .map(|b| {
1623 thread::spawn(move || {
1624 let rank = b.rank();
1625 let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
1626 let mut fsdp = FSDP::new_with_strategy(
1627 model,
1628 b,
1629 ShardingStrategy::HybridShard { intra_node_size: 2 },
1630 )
1631 .unwrap();
1632
1633 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1634 let _output = fsdp.forward(&input).unwrap();
1635
1636 // After forward, the module param is back to a shard
1637 // (intra-node) but full_params[0] holds the gathered
1638 // full tensor. Attach a synthetic full-shape gradient.
1639 let grad = Tensor::from_storage(
1640 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
1641 vec![4],
1642 false,
1643 )
1644 .unwrap();
1645 fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1646
1647 fsdp.sync_gradients().unwrap();
1648
1649 let w = fsdp.module().weight.tensor();
1650 assert_eq!(w.numel(), 2, "HybridShard keeps params as intra-shards");
1651 let g = w.grad().unwrap().unwrap();
1652 let gd = g.data_vec().unwrap();
1653 (rank, gd)
1654 })
1655 })
1656 .collect();
1657
1658 for h in handles {
1659 let (rank, gd) = h.join().unwrap();
1660 assert_eq!(gd.len(), 2, "shard grad should have 2 elements");
1661 let expected: &[f32] = if rank % 2 == 0 {
1662 &[1.0, 2.0]
1663 } else {
1664 &[3.0, 4.0]
1665 };
1666 for (i, e) in expected.iter().enumerate() {
1667 assert!(
1668 (gd[i] - e).abs() < 1e-6,
1669 "rank {} [{}]: expected {}, got {}",
1670 rank,
1671 i,
1672 e,
1673 gd[i]
1674 );
1675 }
1676 }
1677 }
1678
1679 #[test]
1680 fn test_fsdp_hybrid_shard_inter_node_averaging() {
1681 // Construct a scenario where the two nodes disagree on grads, so
1682 // the inter-node allreduce has meaningful work:
1683 // node 0 (ranks 0,1) full grad = [2,4,6,8]
1684 // node 1 (ranks 2,3) full grad = [10,20,30,40]
1685 // Node means after intra reduce_scatter:
1686 // rank 0 [2,4], rank 1 [6,8], rank 2 [10,20], rank 3 [30,40].
1687 // After inter-node allreduce(mean):
1688 // rank 0 (intra 0) = mean(rank 0's shard, rank 2's shard) = [6,12]
1689 // rank 1 (intra 1) = mean(rank 1's shard, rank 3's shard) = [18,24]
1690 // rank 2 (intra 0) = mean(rank 2's shard, rank 0's shard) = [6,12]
1691 // rank 3 (intra 1) = mean(rank 3's shard, rank 1's shard) = [18,24]
1692 let group = SimulatedBackend::create_group(4).unwrap();
1693 let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1694
1695 let handles: Vec<_> = arcs
1696 .iter()
1697 .cloned()
1698 .map(|b| {
1699 thread::spawn(move || {
1700 let rank = b.rank();
1701 let model = TestModule::<f32>::new(&[0.0f32; 4]).unwrap();
1702 let mut fsdp = FSDP::new_with_strategy(
1703 model,
1704 b,
1705 ShardingStrategy::HybridShard { intra_node_size: 2 },
1706 )
1707 .unwrap();
1708
1709 let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1710 let _ = fsdp.forward(&input).unwrap();
1711
1712 // Node-specific gradient: node 0 ranks 0/1 get
1713 // [2,4,6,8], node 1 ranks 2/3 get [10,20,30,40].
1714 let grad_vec: Vec<f32> = if rank < 2 {
1715 vec![2.0, 4.0, 6.0, 8.0]
1716 } else {
1717 vec![10.0, 20.0, 30.0, 40.0]
1718 };
1719 let grad =
1720 Tensor::from_storage(TensorStorage::cpu(grad_vec), vec![4], false).unwrap();
1721 fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1722
1723 fsdp.sync_gradients().unwrap();
1724
1725 let gd = fsdp
1726 .module()
1727 .weight
1728 .tensor()
1729 .grad()
1730 .unwrap()
1731 .unwrap()
1732 .data_vec()
1733 .unwrap();
1734 (rank, gd)
1735 })
1736 })
1737 .collect();
1738
1739 for h in handles {
1740 let (rank, gd) = h.join().unwrap();
1741 assert_eq!(gd.len(), 2);
1742 // intra-position 0 -> [6, 12]; intra-position 1 -> [18, 24]
1743 let expected: &[f32] = if rank % 2 == 0 {
1744 &[6.0, 12.0]
1745 } else {
1746 &[18.0, 24.0]
1747 };
1748 for (i, e) in expected.iter().enumerate() {
1749 assert!(
1750 (gd[i] - e).abs() < 1e-4,
1751 "hybrid inter-node mean rank {} [{}]: expected {}, got {}",
1752 rank,
1753 i,
1754 e,
1755 gd[i]
1756 );
1757 }
1758 }
1759 }
1760}