Skip to main content

ferrotorch_distributed/
fsdp.rs

1//! Fully Sharded Data Parallel (FSDP) wrapper.
2//!
3//! [`FSDP`] wraps a [`Module`] and shards its parameters across ranks.
4//! During forward, parameters are all-gathered to form the full tensors.
5//! During gradient synchronization, full-parameter gradients are
6//! reduce-scattered so each rank only stores its shard's gradient.
7//!
8//! This reduces per-rank memory from O(params) to O(params / world_size)
9//! at the cost of additional communication during forward and backward.
10//!
11//! ## REQ status (per `.design/ferrotorch-distributed/fsdp.md`)
12//!
13//! Full evidence rows (impl + non-test production consumer + upstream
14//! cites) live in the design doc; this synopsis is a one-line summary per
15//! REQ.
16//!
17//! | REQ | Status | Evidence |
18//! |---|---|---|
19//! | REQ-1 (`ShardingStrategy` enum) | SHIPPED | `pub enum ShardingStrategy` in `fsdp.rs` mirrors `ShardingStrategy(Enum)` in `torch/distributed/fsdp/api.py`; consumer `pub fn new_with_strategy` (same file) and `lib.rs` re-export |
20//! | REQ-2 (`FSDP<M, T>` struct) | SHIPPED | `pub struct FSDP` in `fsdp.rs` mirrors `class FullyShardedDataParallel` in `torch/distributed/fsdp/fully_sharded_data_parallel.py`; consumer `pub use fsdp::FSDP` in `lib.rs` |
21//! | REQ-3 (`new` + `new_with_strategy` constructors) | SHIPPED | `pub fn new` + `pub fn new_with_strategy` in `fsdp.rs`; consumer `new` invokes `new_with_strategy` (same file) plus `lib.rs` re-export |
22//! | REQ-4 (`FullShard` ZeRO-3 strategy) | SHIPPED | `ShardingStrategy::FullShard` arm of `new_with_strategy` in `fsdp.rs`; consumer `pub fn new` default path (same file) |
23//! | REQ-5 (`ShardGradOp` ZeRO-2 strategy + `broadcast_updated_params`) | SHIPPED | `ShardingStrategy::ShardGradOp` arms + `pub fn broadcast_updated_params` in `fsdp.rs`; consumer of `crate::collective::all_gather` for param re-sync |
24//! | REQ-6 (`NoShard` DDP-equivalent strategy) | SHIPPED | `ShardingStrategy::NoShard` arms across `new_with_strategy` / `forward` / `sync_gradients` in `fsdp.rs`; consumer of `crate::collective::allreduce` on the NoShard path |
25//! | REQ-7 (`HybridShard` intra-/inter-node strategy) | SHIPPED | `ShardingStrategy::HybridShard` arms in `fsdp.rs` build `Arc<SubBackend>` pair; consumer of `crate::backend::SubBackend::new` and `crate::collective::{reduce_scatter, allreduce}` |
26//! | REQ-8 (`forward` all-gather + run inner) | SHIPPED | `pub fn forward` in `fsdp.rs`; consumer of `crate::collective::all_gather` (FullShard / HybridShard); surfaced via `lib.rs` re-export |
27//! | REQ-9 (`prefetch_forward_params` + `has_pending_prefetch`) | SHIPPED | `pub fn prefetch_forward_params` + `pub fn has_pending_prefetch` in `fsdp.rs`; consumer of `crate::async_collective::async_all_gather`; `forward` joins via `handle.wait()?` (same file) |
28//! | REQ-10 (`sync_gradients` reduce-scatter / allreduce) | SHIPPED | `pub fn sync_gradients` in `fsdp.rs`; consumer of `crate::collective::{reduce_scatter, allreduce}`; surfaced via `lib.rs` re-export |
29//! | REQ-11 (`broadcast_updated_params` for ZeRO-2) | SHIPPED | `pub fn broadcast_updated_params` in `fsdp.rs`; consumer of `crate::collective::all_gather`; surfaced via `lib.rs` re-export |
30//! | REQ-12 (`update_shards` flat-buffer optimizer hook) | SHIPPED | `pub fn update_shards` in `fsdp.rs`; consumer via `lib.rs` re-export for downstream optimizers that produce flat parameter buffers |
31//! | REQ-13 (accessors) | SHIPPED | `pub fn strategy` / `module` / `module_mut` / `into_inner` / `backend` in `fsdp.rs`; consumer via `lib.rs` re-export — the user-facing accessor surface |
32
33use std::sync::Arc;
34
35use ferrotorch_core::storage::TensorStorage;
36use ferrotorch_core::{FerrotorchResult, Float, Tensor};
37use ferrotorch_nn::{Module, Parameter};
38
39use crate::async_collective::{PendingCollective, async_all_gather};
40use crate::backend::{Backend, SubBackend};
41use crate::collective::{ReduceOp, all_gather, allreduce, reduce_scatter};
42
43/// Sharding strategy for [`FSDP`]. Mirrors PyTorch's `ShardingStrategy`.
44///
45/// - [`FullShard`](ShardingStrategy::FullShard) — shard parameters,
46///   gradients, and optimizer state. Minimum memory, maximum
47///   communication. Equivalent to ZeRO-3.
48/// - [`ShardGradOp`](ShardingStrategy::ShardGradOp) — keep parameters
49///   replicated on every rank (no parameter all-gather in forward),
50///   but shard gradients and optimizer state. Equivalent to ZeRO-2.
51///   After `optimizer.step()`, call
52///   [`FSDP::broadcast_updated_params`] to re-sync the updated param
53///   shards back to every rank.
54/// - [`NoShard`](ShardingStrategy::NoShard) — equivalent to DDP: no
55///   sharding, allreduce the full gradient. Provided so the FSDP
56///   wrapper can be used as a drop-in replacement for DDP during
57///   debugging or for single-node experiments.
58/// - [`HybridShard`](ShardingStrategy::HybridShard) — shard within a
59///   node (intra-node FullShard) and replicate across nodes
60///   (inter-node DDP). Uses two [`SubBackend`]s derived from the
61///   global backend to run intra-node `all_gather` / `reduce_scatter`
62///   and inter-node `allreduce` independently. CL-327.
63///
64/// CL-372.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum ShardingStrategy {
67    /// Shard parameters + gradients + optimizer state (ZeRO-3 /
68    /// full FSDP). This is the default and matches the behavior
69    /// that existed before CL-372.
70    #[default]
71    FullShard,
72    /// Shard gradients + optimizer state only, keep params
73    /// replicated (ZeRO-2).
74    ShardGradOp,
75    /// No sharding (ZeRO-0 / DDP equivalent). gradients are allreduced.
76    NoShard,
77    /// Hybrid sharding: shard within a node (intra-node FullShard) +
78    /// replicate across nodes (inter-node DDP). The `intra_node_size`
79    /// is the number of ranks per node (i.e., local GPUs). The total
80    /// `world_size` must be a multiple of `intra_node_size`.
81    /// CL-327.
82    HybridShard { intra_node_size: usize },
83}
84
85/// Fully Sharded Data Parallel module wrapper.
86///
87/// Wraps an inner [`Module`] and shards each parameter across ranks so that
88/// each rank only stores `1 / world_size` of the full parameter tensor.
89///
90/// # Forward pass
91///
92/// Before calling the inner module's `forward()`, FSDP all-gathers each
93/// shard to reconstruct the full parameter tensor and installs it into the
94/// module. The full-parameter tensors are stored in [`full_params`] so
95/// that backward can accumulate gradients on them.
96///
97/// # Gradient synchronization
98///
99/// After `backward()`, call [`sync_gradients`] to:
100/// 1. Read gradients from the full-parameter tensors stored during forward.
101/// 2. Reduce-scatter the full gradients so each rank gets only its shard
102///    portion of the gradient.
103/// 3. Set each shard parameter's gradient from the reduce-scattered result.
104///
105/// # Example
106///
107/// ```ignore
108/// let mut fsdp = FSDP::new(model, backend)?;
109///
110/// loop {
111///     let output = fsdp.forward(&input)?;
112///     let loss = criterion.forward(&output, &target)?;
113///     ferrotorch_core::backward(&loss)?;
114///     fsdp.sync_gradients()?;
115///     optimizer.step()?;
116///     optimizer.zero_grad()?;
117/// }
118/// ```
119pub struct FSDP<M: Module<T>, T: Float> {
120    module: M,
121    backend: Arc<dyn Backend>,
122    /// Active sharding strategy. Drives the behavior of `new`,
123    /// `forward`, `sync_gradients`, and `broadcast_updated_params`.
124    strategy: ShardingStrategy,
125    /// Original full-parameter shapes before sharding.
126    original_shapes: Vec<Vec<usize>>,
127    /// Full-param tensors from the last forward pass, kept alive so
128    /// backward can accumulate gradients on them.
129    full_params: Vec<Tensor<T>>,
130    /// Pending async all-gather handles produced by
131    /// [`FSDP::prefetch_forward_params`]. One entry per parameter, in
132    /// the same order as [`Module::parameters_mut`]. `None` means no
133    /// prefetch is in flight and `forward()` will use the synchronous
134    /// all-gather path. CL-373.
135    pending_prefetch: Option<Vec<PendingCollective<T>>>,
136    /// Intra-node subgroup used by [`ShardingStrategy::HybridShard`].
137    /// `None` for all other strategies. CL-327.
138    intra_node_group: Option<Arc<SubBackend>>,
139    /// Inter-node subgroup used by [`ShardingStrategy::HybridShard`].
140    /// `None` for all other strategies. CL-327.
141    inter_node_group: Option<Arc<SubBackend>>,
142    _marker: std::marker::PhantomData<T>,
143}
144
145impl<M: Module<T>, T: Float> FSDP<M, T> {
146    /// Wrap a module for fully-sharded data-parallel training.
147    ///
148    /// Each parameter is split evenly across `world_size` ranks. This rank
149    /// keeps only its shard (the `rank`-th chunk). The original parameter
150    /// shapes are recorded for reconstruction during forward.
151    ///
152    /// # Panics
153    ///
154    /// Panics if any parameter's element count is not evenly divisible by
155    /// `world_size`.
156    pub fn new(module: M, backend: Arc<dyn Backend>) -> FerrotorchResult<Self> {
157        Self::new_with_strategy(module, backend, ShardingStrategy::FullShard)
158    }
159
160    /// Wrap a module for data-parallel training with a specific
161    /// [`ShardingStrategy`].
162    ///
163    /// - `FullShard` — shard parameters, gradients, and optimizer state
164    ///   (the classic FSDP / ZeRO-3 behavior; identical to [`new`]).
165    /// - `ShardGradOp` — keep parameters replicated on every rank and
166    ///   only shard gradients + optimizer state (ZeRO-2). After
167    ///   calling the optimizer step on the shard gradients, the caller
168    ///   must call [`broadcast_updated_params`] to re-sync the updated
169    ///   parameter shards back to every rank. CL-372.
170    /// - `NoShard` — no sharding (ZeRO-0 / DDP equivalent). Gradients
171    ///   are allreduced across ranks in `sync_gradients` and all ranks
172    ///   update the full parameters locally.
173    pub fn new_with_strategy(
174        mut module: M,
175        backend: Arc<dyn Backend>,
176        strategy: ShardingStrategy,
177    ) -> FerrotorchResult<Self> {
178        let rank = backend.rank();
179        let world_size = backend.world_size();
180        let mut original_shapes = Vec::new();
181
182        // For HybridShard, build the intra- and inter-node subgroups
183        // upfront so we can use the intra_group's world_size to shard
184        // parameters below.
185        let (intra_node_group, inter_node_group) = match strategy {
186            ShardingStrategy::HybridShard { intra_node_size } => {
187                if intra_node_size == 0 || world_size % intra_node_size != 0 {
188                    return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
189                        message: format!(
190                            "HybridShard: world_size={world_size} must be a positive multiple of intra_node_size={intra_node_size}"
191                        ),
192                    });
193                }
194
195                // This node's row index in the (inter_size x intra_size)
196                // rank grid.
197                let node_idx = rank / intra_node_size;
198                let local_idx = rank % intra_node_size;
199
200                // Intra-node members: contiguous block of `intra_node_size`
201                // ranks starting at `node_idx * intra_node_size`.
202                let intra_members: Vec<usize> = (node_idx * intra_node_size
203                    ..node_idx * intra_node_size + intra_node_size)
204                    .collect();
205                let intra = Arc::new(SubBackend::new(Arc::clone(&backend), intra_members)?);
206
207                // Inter-node members: ranks with the same local_idx across
208                // every node (i.e., ranks stride-selected by
209                // `intra_node_size`).
210                let inter_members: Vec<usize> = (0..world_size / intra_node_size)
211                    .map(|n| n * intra_node_size + local_idx)
212                    .collect();
213                let inter = Arc::new(SubBackend::new(Arc::clone(&backend), inter_members)?);
214
215                (Some(intra), Some(inter))
216            }
217            _ => (None, None),
218        };
219
220        {
221            let params = module.parameters_mut();
222            for param in params {
223                let tensor = param.tensor();
224                let shape = tensor.shape().to_vec();
225                original_shapes.push(shape);
226
227                match strategy {
228                    ShardingStrategy::FullShard => {
229                        let numel = tensor.numel();
230                        assert!(
231                            numel % world_size == 0,
232                            "FSDP: parameter with {numel} elements is not evenly divisible by world_size {world_size}"
233                        );
234                        let data = tensor.data_vec()?;
235                        let chunk_size = numel / world_size;
236                        let start = rank * chunk_size;
237                        let end = start + chunk_size;
238                        let shard_data = data[start..end].to_vec();
239                        let shard_tensor = Tensor::from_storage(
240                            TensorStorage::cpu(shard_data),
241                            vec![chunk_size],
242                            true,
243                        )?;
244                        *param = Parameter::new(shard_tensor);
245                    }
246                    ShardingStrategy::HybridShard { .. } => {
247                        // Shard within the node only. Each rank keeps
248                        // `1 / intra_node_size` of each parameter.
249                        //
250                        // INVARIANT: `intra_node_group` is `Some(_)` for every
251                        // `HybridShard { .. }` strategy. The local binding was
252                        // assigned `Some(intra)` ~50 lines above (in the
253                        // `match strategy` that opens this fn) on exactly the
254                        // `HybridShard { .. }` arm, and was never reassigned
255                        // before reaching this match. Since this arm is only
256                        // entered under the same `HybridShard { .. }` pattern,
257                        // `as_ref()` here cannot observe `None` — `expect()`
258                        // is provably unreachable. Category C per
259                        // rust-fix-discipline.
260                        let intra = intra_node_group.as_ref().expect(
261                            "FSDP::new_with_strategy: intra_node_group is Some \
262                             for HybridShard (set ~50 lines above on the same \
263                             match arm; never reassigned)",
264                        );
265                        let intra_size = intra.world_size();
266                        let intra_rank = intra.rank();
267                        let numel = tensor.numel();
268                        assert!(
269                            numel % intra_size == 0,
270                            "FSDP HybridShard: parameter with {numel} elements is not evenly divisible by intra_node_size {intra_size}"
271                        );
272                        let data = tensor.data_vec()?;
273                        let chunk_size = numel / intra_size;
274                        let start = intra_rank * chunk_size;
275                        let end = start + chunk_size;
276                        let shard_data = data[start..end].to_vec();
277                        let shard_tensor = Tensor::from_storage(
278                            TensorStorage::cpu(shard_data),
279                            vec![chunk_size],
280                            true,
281                        )?;
282                        *param = Parameter::new(shard_tensor);
283                    }
284                    ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
285                        // Keep the full parameter on this rank; only
286                        // gradients (and optimizer state, as an
287                        // external concern) are sharded for ShardGradOp.
288                        // NoShard is a plain DDP-style replication.
289                        //
290                        // For ShardGradOp, each rank still needs to know
291                        // which slice of the flat parameter is "its"
292                        // shard for the optimizer step. That's derived
293                        // at grad-sync time from world_size + rank.
294                    }
295                }
296            }
297        }
298
299        Ok(Self {
300            module,
301            backend,
302            strategy,
303            original_shapes,
304            full_params: Vec::new(),
305            pending_prefetch: None,
306            intra_node_group,
307            inter_node_group,
308            _marker: std::marker::PhantomData,
309        })
310    }
311
312    /// Return the active sharding strategy.
313    pub fn strategy(&self) -> ShardingStrategy {
314        self.strategy
315    }
316
317    /// Kick off asynchronous all-gathers for every parameter so the
318    /// next [`forward`](Self::forward) call consumes the pre-gathered
319    /// tensors instead of blocking on a fresh all-gather.
320    ///
321    /// This is FSDP's equivalent of PyTorch's backward prefetch:
322    /// communication for layer N+1 (or the next forward pass) overlaps
323    /// with compute for layer N. The caller should insert local
324    /// compute (e.g., the previous layer's backward, or input
325    /// preprocessing) between `prefetch_forward_params` and `forward`
326    /// to realize the overlap.
327    ///
328    /// Only valid for [`ShardingStrategy::FullShard`] — the other
329    /// strategies keep parameters replicated on every rank so there's
330    /// nothing to all-gather. Calling this on a non-`FullShard` FSDP
331    /// returns an `InvalidArgument` error.
332    ///
333    /// # Invariant
334    ///
335    /// Exactly one `prefetch_forward_params` → `forward` pair should be
336    /// in flight at any time on a given FSDP instance. Calling
337    /// `prefetch_forward_params` twice in a row (without an intervening
338    /// `forward`) returns an `InvalidArgument` error.
339    ///
340    /// CL-373.
341    pub fn prefetch_forward_params(&mut self) -> FerrotorchResult<()> {
342        if self.strategy != ShardingStrategy::FullShard {
343            return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
344                message: format!(
345                    "FSDP::prefetch_forward_params: prefetch is only valid for FullShard, got {:?}",
346                    self.strategy
347                ),
348            });
349        }
350        if self.pending_prefetch.is_some() {
351            return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
352                message: "FSDP::prefetch_forward_params called twice without intervening forward()"
353                    .into(),
354            });
355        }
356
357        let mut handles = Vec::new();
358        {
359            let params = self.module.parameters();
360            for param in params {
361                let shard = param.tensor().clone();
362                let h = async_all_gather(shard, Arc::clone(&self.backend));
363                handles.push(h);
364            }
365        }
366        self.pending_prefetch = Some(handles);
367        Ok(())
368    }
369
370    /// True if a prefetch is currently pending. Primarily useful for
371    /// tests and diagnostics.
372    pub fn has_pending_prefetch(&self) -> bool {
373        self.pending_prefetch.is_some()
374    }
375
376    /// Immutable access to the inner module.
377    pub fn module(&self) -> &M {
378        &self.module
379    }
380
381    /// Mutable access to the inner module.
382    pub fn module_mut(&mut self) -> &mut M {
383        &mut self.module
384    }
385
386    /// Consume the wrapper and return the inner module.
387    pub fn into_inner(self) -> M {
388        self.module
389    }
390
391    /// The backend used for communication.
392    pub fn backend(&self) -> &Arc<dyn Backend> {
393        &self.backend
394    }
395
396    /// Reconstruct full parameters from shards across all ranks and run
397    /// the inner module's forward pass.
398    ///
399    /// The all-gathered full-parameter tensors are stored in `self.full_params`
400    /// so their gradients can be read after backward.
401    ///
402    /// For `ShardGradOp` and `NoShard` strategies, parameters are already
403    /// full on every rank, so no all-gather happens and `full_params` is
404    /// populated from the current parameter tensors directly.
405    ///
406    /// If [`prefetch_forward_params`](Self::prefetch_forward_params) was
407    /// called earlier, the pending async all-gather handles are consumed
408    /// here instead of running the synchronous all_gather — this is how
409    /// FSDP hides all-gather latency behind whatever local compute
410    /// happened between `prefetch_forward_params` and `forward`.
411    pub fn forward(&mut self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
412        let world_size = self.backend.world_size();
413        self.full_params.clear();
414
415        // Grab any pending prefetch handles. They are consumed here so
416        // a subsequent forward() without a matching prefetch reverts to
417        // synchronous all_gather.
418        let mut pending = self.pending_prefetch.take();
419
420        match self.strategy {
421            ShardingStrategy::FullShard => {
422                if let Some(ref p) = pending {
423                    if p.len() != self.module.parameters().len() {
424                        return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
425                            message: format!(
426                                "FSDP prefetch: have {} pending handles but module has {} parameters",
427                                p.len(),
428                                self.module.parameters().len(),
429                            ),
430                        });
431                    }
432                }
433                // Pop handles from the front by reversing so we can pop from back.
434                if let Some(ref mut p) = pending {
435                    p.reverse();
436                }
437
438                let params = self.module.parameters_mut();
439                for (i, param) in params.into_iter().enumerate() {
440                    let orig_shape = &self.original_shapes[i];
441
442                    // Get the gathered full tensor either from a pending
443                    // async handle or via a fresh synchronous all_gather.
444                    let full = if let Some(ref mut handles) = pending {
445                        // Consume the last handle (original order via reverse).
446                        let handle = handles.pop().ok_or_else(|| {
447                            ferrotorch_core::FerrotorchError::InvalidArgument {
448                                message: "FSDP prefetch: exhausted pending handles".into(),
449                            }
450                        })?;
451                        handle.wait()?
452                    } else {
453                        let shard = param.tensor().clone();
454                        if world_size == 1 {
455                            shard
456                        } else {
457                            all_gather(&shard, self.backend.as_ref())?
458                        }
459                    };
460
461                    // Reshape to the original parameter shape and enable grad.
462                    let full = Tensor::from_storage(
463                        TensorStorage::cpu(full.data_vec()?),
464                        orig_shape.clone(),
465                        true,
466                    )?;
467
468                    self.full_params.push(full.clone());
469
470                    // Install the full parameter into the module for this forward pass.
471                    *param = Parameter::new(full);
472                }
473            }
474            ShardingStrategy::HybridShard { .. } => {
475                // HybridShard: all-gather the shards within the node only.
476                // Prefetch is not supported for this strategy yet — the
477                // intra-node subgroup needs its own async path. CL-327.
478                if pending.is_some() {
479                    return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
480                        message:
481                            "FSDP prefetch_forward_params is not yet implemented for HybridShard"
482                                .into(),
483                    });
484                }
485
486                // INVARIANT: `self.intra_node_group` is `Some(_)` whenever
487                // `self.strategy` is `HybridShard { .. }`. Both fields are
488                // private and only assigned in `new_with_strategy`, which sets
489                // them as a coupled pair on the `HybridShard { .. }` arm. The
490                // strategy field is never reassigned after construction (no
491                // setter, no internal mutation). Since this arm is gated on
492                // the strategy, `as_ref()` here cannot observe `None`.
493                // Category C per rust-fix-discipline.
494                let intra = self.intra_node_group.as_ref().expect(
495                    "FSDP::forward (HybridShard): intra_node_group is Some for \
496                     HybridShard strategy (paired with strategy in \
497                     new_with_strategy; never reassigned)",
498                );
499                let intra_ref: &dyn Backend = &**intra;
500                let intra_size = intra.world_size();
501
502                let params = self.module.parameters_mut();
503                for (i, param) in params.into_iter().enumerate() {
504                    let orig_shape = &self.original_shapes[i];
505
506                    let shard = param.tensor().clone();
507                    let full = if intra_size == 1 {
508                        shard
509                    } else {
510                        all_gather(&shard, intra_ref)?
511                    };
512
513                    let full = Tensor::from_storage(
514                        TensorStorage::cpu(full.data_vec()?),
515                        orig_shape.clone(),
516                        true,
517                    )?;
518
519                    self.full_params.push(full.clone());
520                    *param = Parameter::new(full);
521                }
522            }
523            ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
524                // Parameters are already full on every rank. Prefetch
525                // is meaningless for these strategies — surface any
526                // accidental misuse.
527                if pending.is_some() {
528                    return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
529                        message: "FSDP prefetch_forward_params is only valid for FullShard; \
530                                  use ShardingStrategy::FullShard or don't prefetch"
531                            .into(),
532                    });
533                }
534                // We still need to wrap them in requires_grad=true leaves
535                // so the autograd graph can flow through the forward
536                // pass, and stash them in full_params so sync_gradients
537                // can read the backward-accumulated gradients.
538                let params = self.module.parameters_mut();
539                for param in params.into_iter() {
540                    let t = param.tensor().clone();
541                    let data = t.data_vec()?;
542                    let shape = t.shape().to_vec();
543                    let full = Tensor::from_storage(TensorStorage::cpu(data), shape, true)?;
544                    self.full_params.push(full.clone());
545                    *param = Parameter::new(full);
546                }
547            }
548        }
549
550        let output = self.module.forward(input)?;
551
552        // After forward, restore shard parameters (FullShard /
553        // HybridShard) or leave the full params in place
554        // (ShardGradOp / NoShard).
555        match self.strategy {
556            ShardingStrategy::FullShard => self.restore_shards()?,
557            ShardingStrategy::HybridShard { .. } => self.restore_hybrid_shards()?,
558            ShardingStrategy::ShardGradOp | ShardingStrategy::NoShard => {
559                // Nothing to do: params are already full on every rank.
560            }
561        }
562
563        Ok(output)
564    }
565
566    /// Replace the current full parameter tensors with this rank's
567    /// intra-node shards. Used by [`ShardingStrategy::HybridShard`] after
568    /// forward completes. CL-327.
569    fn restore_hybrid_shards(&mut self) -> FerrotorchResult<()> {
570        // INVARIANT: `restore_hybrid_shards` is only reached from `forward()`
571        // under `ShardingStrategy::HybridShard { .. }`. `self.intra_node_group`
572        // is `Some(_)` whenever `self.strategy` is `HybridShard` (paired in
573        // `new_with_strategy`; both fields private; strategy never reassigned).
574        // Category C per rust-fix-discipline.
575        let intra = self
576            .intra_node_group
577            .as_ref()
578            .expect(
579                "FSDP::restore_hybrid_shards: intra_node_group is Some for \
580                 HybridShard (only callsite is forward() under HybridShard arm)",
581            )
582            .clone();
583        let intra_size = intra.world_size();
584        let intra_rank = intra.rank();
585
586        let params = self.module.parameters_mut();
587        for (i, param) in params.into_iter().enumerate() {
588            let tensor = param.tensor();
589            let data = tensor.data_vec()?;
590            let numel = data.len();
591            let chunk_size = numel / intra_size;
592            let start = intra_rank * chunk_size;
593            let end = start + chunk_size;
594            let shard_data = data[start..end].to_vec();
595
596            let shard_tensor =
597                Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
598            *param = Parameter::new(shard_tensor);
599
600            // Preserve the original shape metadata.
601            let _ = &self.original_shapes[i];
602        }
603
604        Ok(())
605    }
606
607    /// Replace full parameters with their local shards to free memory.
608    fn restore_shards(&mut self) -> FerrotorchResult<()> {
609        let rank = self.backend.rank();
610        let world_size = self.backend.world_size();
611
612        let params = self.module.parameters_mut();
613        for (i, param) in params.into_iter().enumerate() {
614            let tensor = param.tensor();
615            let data = tensor.data_vec()?;
616            let numel = data.len();
617            let chunk_size = numel / world_size;
618            let start = rank * chunk_size;
619            let end = start + chunk_size;
620            let shard_data = data[start..end].to_vec();
621
622            let shard_tensor =
623                Tensor::from_storage(TensorStorage::cpu(shard_data), vec![chunk_size], true)?;
624            *param = Parameter::new(shard_tensor);
625
626            // Preserve the original shape metadata.
627            let _ = &self.original_shapes[i];
628        }
629
630        Ok(())
631    }
632
633    /// Reduce-scatter gradients from the full-parameter tensors stored
634    /// during forward, then set each shard parameter's gradient.
635    ///
636    /// Call this after `backward()` and before `optimizer.step()`.
637    ///
638    /// # How it works
639    ///
640    /// 1. For each parameter, read the gradient from the full-param tensor
641    ///    that was used during forward (stored in `self.full_params`).
642    /// 2. Reduce-scatter the full gradient across ranks (mean reduction) so
643    ///    each rank gets only its shard portion.
644    /// 3. Set the shard parameter's `.grad()` to the reduce-scattered result.
645    ///
646    /// Using reduce-scatter (not allreduce) is correct for FSDP because each
647    /// rank only needs its own shard of the gradient to update its shard of
648    /// the parameter.
649    pub fn sync_gradients(&mut self) -> FerrotorchResult<()> {
650        let rank = self.backend.rank();
651        let world_size = self.backend.world_size();
652        let params = self.module.parameters_mut();
653
654        if self.full_params.len() != params.len() {
655            return Err(ferrotorch_core::FerrotorchError::InvalidArgument {
656                message: format!(
657                    "FSDP sync_gradients: expected {} full_params but have {}. \
658                     Was forward() called before backward()?",
659                    params.len(),
660                    self.full_params.len(),
661                ),
662            });
663        }
664
665        for (i, param) in params.into_iter().enumerate() {
666            let full_param = &self.full_params[i];
667
668            // Read the gradient from the full-parameter tensor. If no
669            // gradient was computed (e.g., parameter was unused in forward),
670            // use zeros so all ranks exchange buffers of the same size.
671            let grad = full_param.grad()?;
672            let full_grad = match grad {
673                Some(g) => g,
674                None => {
675                    let numel = full_param.numel();
676                    Tensor::from_storage(
677                        TensorStorage::cpu(vec![<T as num_traits::Zero>::zero(); numel]),
678                        full_param.shape().to_vec(),
679                        false,
680                    )?
681                }
682            };
683
684            // Flatten for reduction ops.
685            let grad_data = full_grad.data_vec()?;
686            let flat_grad = Tensor::from_storage(
687                TensorStorage::cpu(grad_data),
688                vec![full_grad.numel()],
689                false,
690            )?;
691
692            match self.strategy {
693                ShardingStrategy::FullShard => {
694                    // Reduce-scatter: each rank gets its shard of the
695                    // averaged gradient. Parameter is already the
696                    // shard here (installed by restore_shards after
697                    // forward), so set_grad lines up.
698                    let shard_grad = if world_size == 1 {
699                        flat_grad
700                    } else {
701                        reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
702                    };
703                    param.tensor().set_grad(Some(shard_grad))?;
704                }
705                ShardingStrategy::ShardGradOp => {
706                    // ZeRO-2: reduce-scatter the flat gradient into a
707                    // per-rank slice, then set the parameter's .grad()
708                    // to a tensor shaped like the full parameter but
709                    // with only this rank's shard positions populated
710                    // and the rest zeroed. The optimizer's update at
711                    // the non-shard positions becomes a no-op, so
712                    // each rank effectively updates only its shard
713                    // slice. After optimizer.step, the caller must
714                    // invoke broadcast_updated_params to re-sync.
715                    let numel = flat_grad.numel();
716                    assert!(
717                        numel % world_size == 0,
718                        "FSDP ShardGradOp: parameter with {numel} elements is not evenly \
719                         divisible by world_size {world_size}"
720                    );
721                    let shard_grad_flat = if world_size == 1 {
722                        flat_grad
723                    } else {
724                        reduce_scatter(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
725                    };
726                    let chunk_size = numel / world_size;
727                    let shard_data = shard_grad_flat.data_vec()?;
728                    // Pad with zeros at non-shard positions so the
729                    // gradient tensor matches the full parameter shape.
730                    let mut padded = vec![<T as num_traits::Zero>::zero(); numel];
731                    let start = rank * chunk_size;
732                    padded[start..start + chunk_size].copy_from_slice(&shard_data);
733                    let padded_grad = Tensor::from_storage(
734                        TensorStorage::cpu(padded),
735                        full_param.shape().to_vec(),
736                        false,
737                    )?;
738                    param.tensor().set_grad(Some(padded_grad))?;
739                }
740                ShardingStrategy::HybridShard { .. } => {
741                    // HybridShard: reduce-scatter within the node to
742                    // get this rank's intra-node shard of the mean
743                    // gradient, then allreduce the shard across nodes
744                    // so every replica of this intra-rank has the
745                    // same gradient. Parameter is already the
746                    // intra-node shard (installed by
747                    // restore_hybrid_shards after forward).
748                    //
749                    // INVARIANT: `intra_node_group` and `inter_node_group`
750                    // are both `Some(_)` whenever `self.strategy` is
751                    // `HybridShard { .. }`. The two fields are assigned as a
752                    // coupled pair in `new_with_strategy` on the same arm
753                    // that this match enters; all three are private and
754                    // never reassigned. Both `expect()`s are provably
755                    // unreachable. Category C per rust-fix-discipline.
756                    let intra = self.intra_node_group.as_ref().expect(
757                        "FSDP::sync_gradients (HybridShard): intra_node_group \
758                         is Some for HybridShard (paired with strategy in \
759                         new_with_strategy; never reassigned)",
760                    );
761                    let inter = self.inter_node_group.as_ref().expect(
762                        "FSDP::sync_gradients (HybridShard): inter_node_group \
763                         is Some for HybridShard (paired with strategy in \
764                         new_with_strategy; never reassigned)",
765                    );
766                    let intra_ref: &dyn Backend = &**intra;
767                    let inter_ref: &dyn Backend = &**inter;
768                    let intra_size = intra.world_size();
769                    let inter_size = inter.world_size();
770
771                    // Step 1: intra-node reduce_scatter (mean).
772                    let intra_shard = if intra_size == 1 {
773                        flat_grad
774                    } else {
775                        reduce_scatter(&flat_grad, intra_ref, ReduceOp::Mean)?
776                    };
777
778                    // Step 2: inter-node allreduce (mean across
779                    // replicas).
780                    let replicated = if inter_size == 1 {
781                        intra_shard
782                    } else {
783                        allreduce(&intra_shard, inter_ref, ReduceOp::Mean)?
784                    };
785
786                    param.tensor().set_grad(Some(replicated))?;
787                }
788                ShardingStrategy::NoShard => {
789                    // Plain DDP: allreduce the full gradient so every
790                    // rank has the same averaged gradient, then set
791                    // it on the full parameter.
792                    let reduced = if world_size == 1 {
793                        flat_grad
794                    } else {
795                        allreduce(&flat_grad, self.backend.as_ref(), ReduceOp::Mean)?
796                    };
797                    let reduced_full = Tensor::from_storage(
798                        TensorStorage::cpu(reduced.data_vec()?),
799                        full_param.shape().to_vec(),
800                        false,
801                    )?;
802                    param.tensor().set_grad(Some(reduced_full))?;
803                }
804            }
805        }
806
807        // Clear full_params to free memory now that gradients have been read.
808        self.full_params.clear();
809
810        Ok(())
811    }
812
813    /// For `ShardGradOp`: after `optimizer.step()`, each rank has
814    /// applied the update to its own shard of the full parameter
815    /// (because `sync_gradients` zeroed the non-shard positions of the
816    /// gradient). This method re-syncs the parameter tensors so every
817    /// rank has the fully updated parameter, by summing contributions
818    /// via an allreduce: each rank contributes its updated shard, zero
819    /// elsewhere; the sum across ranks is the full updated parameter.
820    ///
821    /// More precisely, this method reconstructs the full parameter as
822    /// an allgather of per-rank shards. It is a no-op for `FullShard`
823    /// and `NoShard` strategies (they already have consistent
824    /// parameters after step).
825    ///
826    /// Call this AFTER `optimizer.step()` and BEFORE the next
827    /// `forward()`. CL-372.
828    pub fn broadcast_updated_params(&mut self) -> FerrotorchResult<()> {
829        if self.strategy != ShardingStrategy::ShardGradOp {
830            // Nothing to do for FullShard (already shard-local) or
831            // NoShard (all ranks already have the same params).
832            return Ok(());
833        }
834
835        let rank = self.backend.rank();
836        let world_size = self.backend.world_size();
837        if world_size == 1 {
838            return Ok(());
839        }
840
841        let params = self.module.parameters_mut();
842        for param in params {
843            // Extract this rank's shard from the updated full parameter.
844            let full = param.tensor();
845            let full_data = full.data_vec()?;
846            let numel = full_data.len();
847            assert!(
848                numel % world_size == 0,
849                "FSDP broadcast_updated_params: parameter with {numel} elements is not evenly \
850                 divisible by world_size {world_size}"
851            );
852            let chunk_size = numel / world_size;
853            let start = rank * chunk_size;
854            let end = start + chunk_size;
855            let shard = full_data[start..end].to_vec();
856            let shard_tensor =
857                Tensor::from_storage(TensorStorage::cpu(shard), vec![chunk_size], false)?;
858
859            // All-gather across ranks to get the full updated parameter.
860            let gathered = all_gather(&shard_tensor, self.backend.as_ref())?;
861            let full_shape = full.shape().to_vec();
862            let new_full =
863                Tensor::from_storage(TensorStorage::cpu(gathered.data_vec()?), full_shape, true)?;
864            *param = Parameter::new(new_full);
865        }
866        Ok(())
867    }
868
869    /// Update shard parameters from a flat data slice.
870    ///
871    /// This is used by optimizers that produce a flat parameter buffer.
872    /// The slice must have exactly the number of elements expected for
873    /// this rank's shards.
874    pub fn update_shards(&mut self, flat_data: &[T]) -> FerrotorchResult<()> {
875        let params = self.module.parameters_mut();
876        let total_shard_numel: usize = params.iter().map(|p| p.tensor().numel()).sum();
877
878        assert!(
879            flat_data.len() == total_shard_numel,
880            "FSDP update_shards: expected {} elements but got {}",
881            total_shard_numel,
882            flat_data.len(),
883        );
884
885        let mut offset = 0;
886        for param in params {
887            let numel = param.tensor().numel();
888            let shard_data = flat_data[offset..offset + numel].to_vec();
889            let shard_tensor = Tensor::from_storage(
890                TensorStorage::cpu(shard_data),
891                param.tensor().shape().to_vec(),
892                true,
893            )?;
894            *param = Parameter::new(shard_tensor);
895            offset += numel;
896        }
897
898        Ok(())
899    }
900}
901
902// FSDP does NOT implement Module<T> because forward() requires &mut self
903// (to store full_params). Callers must use fsdp.forward() directly.
904
905#[cfg(test)]
906mod tests {
907    use super::*;
908    use crate::backend::SimulatedBackend;
909    use ferrotorch_core::storage::TensorStorage;
910    use ferrotorch_core::{FerrotorchResult, Tensor};
911    use ferrotorch_nn::Parameter;
912    use std::thread;
913
914    /// Minimal module with one parameter for testing FSDP.
915    struct TestModule<T: Float> {
916        weight: Parameter<T>,
917        training: bool,
918    }
919
920    impl<T: Float> TestModule<T> {
921        fn new(data: &[T]) -> FerrotorchResult<Self> {
922            Ok(Self {
923                weight: Parameter::from_slice(data, &[data.len()])?,
924                training: true,
925            })
926        }
927    }
928
929    impl<T: Float> Module<T> for TestModule<T> {
930        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
931            // Simple forward: multiply input by weight sum (produces a scalar
932            // that depends on all weight elements).
933            let w_data = self.weight.tensor().data_vec()?;
934            let w_sum: T = w_data
935                .iter()
936                .copied()
937                .fold(<T as num_traits::Zero>::zero(), |a, b| a + b);
938            let i_data = input.data_vec()?;
939            let out: Vec<T> = i_data.iter().map(|&x| x * w_sum).collect();
940            Tensor::from_storage(TensorStorage::cpu(out), input.shape().to_vec(), false)
941        }
942
943        fn parameters(&self) -> Vec<&Parameter<T>> {
944            vec![&self.weight]
945        }
946
947        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
948            vec![&mut self.weight]
949        }
950
951        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
952            vec![("weight".into(), &self.weight)]
953        }
954
955        fn train(&mut self) {
956            self.training = true;
957        }
958
959        fn eval(&mut self) {
960            self.training = false;
961        }
962
963        fn is_training(&self) -> bool {
964            self.training
965        }
966    }
967
968    #[test]
969    fn test_fsdp_sharding() {
970        // 2 ranks, parameter [10, 20, 30, 40].
971        // Rank 0 gets [10, 20], Rank 1 gets [30, 40].
972        let group = SimulatedBackend::create_group(2).unwrap();
973        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
974
975        let handles: Vec<_> = arcs
976            .iter()
977            .cloned()
978            .map(|b| {
979                thread::spawn(move || {
980                    let rank = b.rank();
981                    let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
982                    let fsdp = FSDP::new(model, b).unwrap();
983
984                    let shard = fsdp.module().weight.tensor().data_vec().unwrap();
985                    (rank, shard)
986                })
987            })
988            .collect();
989
990        for h in handles {
991            let (rank, shard) = h.join().unwrap();
992            if rank == 0 {
993                assert_eq!(shard, &[10.0, 20.0]);
994            } else {
995                assert_eq!(shard, &[30.0, 40.0]);
996            }
997        }
998    }
999
1000    #[test]
1001    fn test_fsdp_shard_requires_grad() {
1002        // Shard parameters must have requires_grad=true.
1003        let group = SimulatedBackend::create_group(2).unwrap();
1004        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1005
1006        let handles: Vec<_> = arcs
1007            .iter()
1008            .cloned()
1009            .map(|b| {
1010                thread::spawn(move || {
1011                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1012                    let fsdp = FSDP::new(model, b).unwrap();
1013                    fsdp.module().weight.tensor().requires_grad()
1014                })
1015            })
1016            .collect();
1017
1018        for h in handles {
1019            assert!(h.join().unwrap(), "shard must have requires_grad=true");
1020        }
1021    }
1022
1023    #[test]
1024    fn test_fsdp_forward_restores_shards() {
1025        // After forward(), parameters should be back to shard size.
1026        let group = SimulatedBackend::create_group(2).unwrap();
1027        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1028
1029        let handles: Vec<_> = arcs
1030            .iter()
1031            .cloned()
1032            .map(|b| {
1033                thread::spawn(move || {
1034                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1035                    let mut fsdp = FSDP::new(model, b).unwrap();
1036
1037                    let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1038                    let _output = fsdp.forward(&input).unwrap();
1039
1040                    // After forward, shard should be size 2 (4 / 2 ranks).
1041                    let shard = fsdp.module().weight.tensor();
1042                    assert_eq!(shard.numel(), 2);
1043                    assert!(shard.requires_grad());
1044                })
1045            })
1046            .collect();
1047
1048        for h in handles {
1049            h.join().unwrap();
1050        }
1051    }
1052
1053    #[test]
1054    fn test_fsdp_forward_produces_correct_output() {
1055        // 2 ranks, param [1, 2, 3, 4], weight_sum = 10.
1056        // Input [2.0] -> output should be [20.0] on all ranks.
1057        let group = SimulatedBackend::create_group(2).unwrap();
1058        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1059
1060        let handles: Vec<_> = arcs
1061            .iter()
1062            .cloned()
1063            .map(|b| {
1064                thread::spawn(move || {
1065                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1066                    let mut fsdp = FSDP::new(model, b).unwrap();
1067
1068                    let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
1069                    let output = fsdp.forward(&input).unwrap();
1070                    let data = output.data_vec().unwrap();
1071                    assert!(
1072                        (data[0] - 20.0).abs() < 1e-6,
1073                        "expected 20.0, got {}",
1074                        data[0]
1075                    );
1076                })
1077            })
1078            .collect();
1079
1080        for h in handles {
1081            h.join().unwrap();
1082        }
1083    }
1084
1085    #[test]
1086    fn test_fsdp_update_shards() {
1087        let group = SimulatedBackend::create_group(1).unwrap();
1088        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1089        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1090        let mut fsdp = FSDP::new(model, b).unwrap();
1091
1092        fsdp.update_shards(&[10.0, 20.0, 30.0, 40.0]).unwrap();
1093        let data = fsdp.module().weight.tensor().data_vec().unwrap();
1094        assert_eq!(data, &[10.0, 20.0, 30.0, 40.0]);
1095    }
1096
1097    #[test]
1098    #[should_panic(expected = "expected 4 elements but got 2")]
1099    fn test_fsdp_update_shards_size_validation() {
1100        let group = SimulatedBackend::create_group(1).unwrap();
1101        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1102        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1103        let mut fsdp = FSDP::new(model, b).unwrap();
1104
1105        // Wrong size: should panic.
1106        fsdp.update_shards(&[10.0, 20.0]).unwrap();
1107    }
1108
1109    #[test]
1110    fn test_fsdp_sync_gradients_single_rank() {
1111        // Single rank: sync_gradients should pass through the gradient
1112        // from the full param to the shard param.
1113        let group = SimulatedBackend::create_group(1).unwrap();
1114        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1115        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1116        let mut fsdp = FSDP::new(model, b).unwrap();
1117
1118        // Run forward to populate full_params.
1119        let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1120        let _output = fsdp.forward(&input).unwrap();
1121
1122        // Manually set gradient on full_params (simulating backward).
1123        let grad = Tensor::from_storage(
1124            TensorStorage::cpu(vec![0.1f32, 0.2, 0.3, 0.4]),
1125            vec![4],
1126            false,
1127        )
1128        .unwrap();
1129        fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1130
1131        fsdp.sync_gradients().unwrap();
1132
1133        // Shard param should now have the full gradient (single rank = no scatter).
1134        let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
1135        let data = shard_grad.data_vec().unwrap();
1136        assert_eq!(data, &[0.1, 0.2, 0.3, 0.4]);
1137    }
1138
1139    #[test]
1140    fn test_fsdp_shard_grad_op_keeps_full_params() {
1141        // ShardGradOp (ZeRO-2): params stay replicated on every rank.
1142        // Verify each rank has the full parameter after `new_with_strategy`.
1143        let group = SimulatedBackend::create_group(2).unwrap();
1144        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1145
1146        let handles: Vec<_> = arcs
1147            .iter()
1148            .cloned()
1149            .map(|b| {
1150                thread::spawn(move || {
1151                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1152                    let fsdp =
1153                        FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
1154                    assert_eq!(fsdp.strategy(), ShardingStrategy::ShardGradOp);
1155                    fsdp.module().weight.tensor().data_vec().unwrap()
1156                })
1157            })
1158            .collect();
1159
1160        for h in handles {
1161            let data = h.join().unwrap();
1162            assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
1163        }
1164    }
1165
1166    #[test]
1167    // reason: sharded-grad zero-padding writes the exact bit pattern 0.0 to
1168    // off-shard slots (no arithmetic). The on-shard slots are tested with
1169    // an explicit epsilon (the abs/<1e-6 lines just above each assert_eq!),
1170    // because those go through reduce-scatter mean. The 0.0 == 0.0 sentinel
1171    // and ints {0,1,2,3,4} are bit-exact, so equality is the right check.
1172    #[allow(clippy::float_cmp)]
1173    fn test_fsdp_shard_grad_op_sync_gradients_multi_rank() {
1174        // ZeRO-2: two ranks, param [1,2,3,4], both ranks produce full grad
1175        // [1,2,3,4]. reduce_scatter(mean) gives rank 0 the slice [1,2] and
1176        // rank 1 the slice [3,4]. Each rank's .grad() is then padded so the
1177        // other-rank positions are zero, giving
1178        //   rank 0 grad: [1,2,0,0]
1179        //   rank 1 grad: [0,0,3,4]
1180        let group = SimulatedBackend::create_group(2).unwrap();
1181        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1182
1183        let handles: Vec<_> = arcs
1184            .iter()
1185            .cloned()
1186            .map(|b| {
1187                thread::spawn(move || {
1188                    let rank = b.rank();
1189                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1190                    let mut fsdp =
1191                        FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
1192
1193                    let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1194                    let _output = fsdp.forward(&input).unwrap();
1195
1196                    let grad = Tensor::from_storage(
1197                        TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
1198                        vec![4],
1199                        false,
1200                    )
1201                    .unwrap();
1202                    fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1203
1204                    fsdp.sync_gradients().unwrap();
1205
1206                    // Param remains full (not restored to shard) under
1207                    // ShardGradOp, and .grad() is a tensor of the full
1208                    // shape with only this rank's shard populated.
1209                    let w = fsdp.module().weight.tensor();
1210                    assert_eq!(w.numel(), 4, "ShardGradOp keeps params full");
1211                    let g = w.grad().unwrap().unwrap();
1212                    let gd = g.data_vec().unwrap();
1213                    assert_eq!(gd.len(), 4, "grad should be full-shape");
1214                    (rank, gd)
1215                })
1216            })
1217            .collect();
1218
1219        for h in handles {
1220            let (rank, gd) = h.join().unwrap();
1221            if rank == 0 {
1222                assert!((gd[0] - 1.0).abs() < 1e-6, "rank 0 [0]: {}", gd[0]);
1223                assert!((gd[1] - 2.0).abs() < 1e-6, "rank 0 [1]: {}", gd[1]);
1224                assert_eq!(gd[2], 0.0, "rank 0 [2] should be zero");
1225                assert_eq!(gd[3], 0.0, "rank 0 [3] should be zero");
1226            } else {
1227                assert_eq!(gd[0], 0.0, "rank 1 [0] should be zero");
1228                assert_eq!(gd[1], 0.0, "rank 1 [1] should be zero");
1229                assert!((gd[2] - 3.0).abs() < 1e-6, "rank 1 [2]: {}", gd[2]);
1230                assert!((gd[3] - 4.0).abs() < 1e-6, "rank 1 [3]: {}", gd[3]);
1231            }
1232        }
1233    }
1234
1235    #[test]
1236    fn test_fsdp_shard_grad_op_broadcast_updated_params() {
1237        // Full ZeRO-2 loop: each rank simulates an optimizer.step() that
1238        // applies a per-shard update (adds rank*10 to its slice), then
1239        // calls broadcast_updated_params. After that, every rank should
1240        // see the fully updated parameter.
1241        //
1242        // Rank 0 slice is [0,1]; rank 1 slice is [2,3].
1243        // Starting param: [1,2,3,4].
1244        // After per-rank update:
1245        //   rank 0 local param: [1+10, 2+10, 3, 4]     = [11, 12, 3, 4]
1246        //   rank 1 local param: [1, 2, 3+20, 4+20]     = [1, 2, 23, 24]
1247        // After broadcast_updated_params (allgather of rank-local shards):
1248        //   both ranks: [11, 12, 23, 24]
1249        let group = SimulatedBackend::create_group(2).unwrap();
1250        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1251
1252        let handles: Vec<_> = arcs
1253            .iter()
1254            .cloned()
1255            .map(|b| {
1256                thread::spawn(move || {
1257                    let rank = b.rank();
1258                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1259                    let mut fsdp =
1260                        FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
1261
1262                    // Simulate per-rank optimizer step: each rank overwrites
1263                    // its shard slice, leaving the other slice untouched.
1264                    let mut local = fsdp.module().weight.tensor().data_vec().unwrap();
1265                    if rank == 0 {
1266                        local[0] += 10.0;
1267                        local[1] += 10.0;
1268                    } else {
1269                        local[2] += 20.0;
1270                        local[3] += 20.0;
1271                    }
1272                    let new_param =
1273                        Tensor::from_storage(TensorStorage::cpu(local), vec![4], true).unwrap();
1274                    *fsdp.module.parameters_mut()[0] = Parameter::new(new_param);
1275
1276                    // Re-sync.
1277                    fsdp.broadcast_updated_params().unwrap();
1278
1279                    fsdp.module().weight.tensor().data_vec().unwrap()
1280                })
1281            })
1282            .collect();
1283
1284        for h in handles {
1285            let data = h.join().unwrap();
1286            assert_eq!(data, &[11.0, 12.0, 23.0, 24.0]);
1287        }
1288    }
1289
1290    #[test]
1291    fn test_fsdp_no_shard_is_ddp_equivalent() {
1292        // NoShard (ZeRO-0 / DDP): each rank has the full parameter and
1293        // allreduce-averages gradients. Param [1,2,3,4]; both ranks set
1294        // identical grads [1,2,3,4]; after sync both ranks should see the
1295        // same averaged grad [1,2,3,4] (identity since both contributions
1296        // are equal).
1297        let group = SimulatedBackend::create_group(2).unwrap();
1298        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1299
1300        let handles: Vec<_> = arcs
1301            .iter()
1302            .cloned()
1303            .map(|b| {
1304                thread::spawn(move || {
1305                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1306                    let mut fsdp =
1307                        FSDP::new_with_strategy(model, b, ShardingStrategy::NoShard).unwrap();
1308                    assert_eq!(fsdp.strategy(), ShardingStrategy::NoShard);
1309
1310                    // Params stay full.
1311                    assert_eq!(fsdp.module().weight.tensor().numel(), 4);
1312
1313                    let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1314                    let _output = fsdp.forward(&input).unwrap();
1315
1316                    let grad = Tensor::from_storage(
1317                        TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
1318                        vec![4],
1319                        false,
1320                    )
1321                    .unwrap();
1322                    fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1323
1324                    fsdp.sync_gradients().unwrap();
1325
1326                    let w = fsdp.module().weight.tensor();
1327                    assert_eq!(w.numel(), 4, "NoShard keeps params full");
1328                    w.grad().unwrap().unwrap().data_vec().unwrap()
1329                })
1330            })
1331            .collect();
1332
1333        for h in handles {
1334            let gd = h.join().unwrap();
1335            assert_eq!(gd.len(), 4);
1336            // Mean of two identical [1,2,3,4] vectors is [1,2,3,4].
1337            for (i, expected) in [1.0f32, 2.0, 3.0, 4.0].iter().enumerate() {
1338                assert!(
1339                    (gd[i] - expected).abs() < 1e-6,
1340                    "NoShard allreduce: got {} at {}, expected {}",
1341                    gd[i],
1342                    i,
1343                    expected
1344                );
1345            }
1346        }
1347    }
1348
1349    #[test]
1350    fn test_fsdp_prefetched_forward_matches_sync_forward() {
1351        // Prefetch then forward should produce exactly the same output
1352        // as a plain synchronous forward. 2 ranks, weight [1,2,3,4],
1353        // input [2.0] -> expected output [20.0].
1354        let group = SimulatedBackend::create_group(2).unwrap();
1355        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1356
1357        let handles: Vec<_> = arcs
1358            .iter()
1359            .cloned()
1360            .map(|b| {
1361                thread::spawn(move || {
1362                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1363                    let mut fsdp = FSDP::new(model, b).unwrap();
1364
1365                    // Kick off prefetch BEFORE doing other work.
1366                    assert!(!fsdp.has_pending_prefetch());
1367                    fsdp.prefetch_forward_params().unwrap();
1368                    assert!(fsdp.has_pending_prefetch());
1369
1370                    // Simulated "local compute" between prefetch and forward.
1371                    let _scratch: f32 = (0..100).map(|i| i as f32).sum();
1372
1373                    let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
1374                    let output = fsdp.forward(&input).unwrap();
1375                    // After forward, the prefetch handles have been consumed.
1376                    assert!(!fsdp.has_pending_prefetch());
1377
1378                    output.data_vec().unwrap()[0]
1379                })
1380            })
1381            .collect();
1382
1383        for h in handles {
1384            let v = h.join().unwrap();
1385            assert!((v - 20.0).abs() < 1e-6, "expected 20.0, got {v}");
1386        }
1387    }
1388
1389    #[test]
1390    fn test_fsdp_forward_without_prefetch_still_works() {
1391        // Smoke test: forward() without a prior prefetch should use the
1392        // synchronous path and produce the same result.
1393        let group = SimulatedBackend::create_group(2).unwrap();
1394        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1395
1396        let handles: Vec<_> = arcs
1397            .iter()
1398            .cloned()
1399            .map(|b| {
1400                thread::spawn(move || {
1401                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1402                    let mut fsdp = FSDP::new(model, b).unwrap();
1403                    assert!(!fsdp.has_pending_prefetch());
1404                    let input = ferrotorch_core::from_slice(&[2.0f32], &[1]).unwrap();
1405                    let output = fsdp.forward(&input).unwrap();
1406                    output.data_vec().unwrap()[0]
1407                })
1408            })
1409            .collect();
1410
1411        for h in handles {
1412            assert!((h.join().unwrap() - 20.0).abs() < 1e-6);
1413        }
1414    }
1415
1416    #[test]
1417    fn test_fsdp_prefetch_rejects_double_call() {
1418        // Calling prefetch twice without an intervening forward should
1419        // return an error (stale pending handles would be leaked).
1420        let group = SimulatedBackend::create_group(1).unwrap();
1421        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1422        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1423        let mut fsdp = FSDP::new(model, b).unwrap();
1424        fsdp.prefetch_forward_params().unwrap();
1425        let r = fsdp.prefetch_forward_params();
1426        assert!(r.is_err());
1427        let err = format!("{}", r.unwrap_err());
1428        assert!(err.contains("called twice"), "err = {err}");
1429
1430        // Consume the pending handles so the test doesn't leak a live
1431        // background thread that waits on channel traffic. Running
1432        // forward() drains and joins them cleanly.
1433        let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1434        let _ = fsdp.forward(&input).unwrap();
1435    }
1436
1437    #[test]
1438    fn test_fsdp_prefetch_rejects_non_fullshard() {
1439        // Prefetch only makes sense for FullShard (the other strategies
1440        // don't do a parameter all_gather). Calling it on ShardGradOp
1441        // must return a clear error.
1442        let group = SimulatedBackend::create_group(1).unwrap();
1443        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1444        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1445        let mut fsdp = FSDP::new_with_strategy(model, b, ShardingStrategy::ShardGradOp).unwrap();
1446        let r = fsdp.prefetch_forward_params();
1447        assert!(r.is_err());
1448        let err = format!("{}", r.unwrap_err());
1449        assert!(err.contains("FullShard"), "err = {err}");
1450    }
1451
1452    #[test]
1453    fn test_fsdp_no_shard_broadcast_is_noop() {
1454        // broadcast_updated_params should be a no-op for NoShard and
1455        // FullShard strategies.
1456        let group = SimulatedBackend::create_group(1).unwrap();
1457        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
1458        let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1459        let mut fsdp = FSDP::new_with_strategy(model, b, ShardingStrategy::NoShard).unwrap();
1460        fsdp.broadcast_updated_params().unwrap();
1461        assert_eq!(
1462            fsdp.module().weight.tensor().data_vec().unwrap(),
1463            &[1.0, 2.0, 3.0, 4.0]
1464        );
1465    }
1466
1467    #[test]
1468    fn test_fsdp_sync_gradients_multi_rank() {
1469        // 2 ranks, param size 4 -> shard size 2.
1470        // Both ranks set identical gradients on full_params: [1, 2, 3, 4].
1471        // reduce_scatter(mean) on [1,2,3,4] -> rank 0 gets [1,2], rank 1 gets [3,4].
1472        let group = SimulatedBackend::create_group(2).unwrap();
1473        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1474
1475        let handles: Vec<_> = arcs
1476            .iter()
1477            .cloned()
1478            .map(|b| {
1479                thread::spawn(move || {
1480                    let rank = b.rank();
1481                    let model = TestModule::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).unwrap();
1482                    let mut fsdp = FSDP::new(model, b).unwrap();
1483
1484                    // Run forward.
1485                    let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1486                    let _output = fsdp.forward(&input).unwrap();
1487
1488                    // Set gradient on full_params.
1489                    let grad = Tensor::from_storage(
1490                        TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
1491                        vec![4],
1492                        false,
1493                    )
1494                    .unwrap();
1495                    fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1496
1497                    fsdp.sync_gradients().unwrap();
1498
1499                    let shard_grad = fsdp.module().weight.tensor().grad().unwrap().unwrap();
1500                    let data = shard_grad.data_vec().unwrap();
1501                    (rank, data)
1502                })
1503            })
1504            .collect();
1505
1506        for h in handles {
1507            let (rank, data) = h.join().unwrap();
1508            if rank == 0 {
1509                // Mean of [1,2] from both ranks = [1,2].
1510                assert_eq!(data.len(), 2);
1511                assert!(
1512                    (data[0] - 1.0).abs() < 1e-6,
1513                    "rank 0: expected 1.0, got {}",
1514                    data[0]
1515                );
1516                assert!(
1517                    (data[1] - 2.0).abs() < 1e-6,
1518                    "rank 0: expected 2.0, got {}",
1519                    data[1]
1520                );
1521            } else {
1522                // Mean of [3,4] from both ranks = [3,4].
1523                assert_eq!(data.len(), 2);
1524                assert!(
1525                    (data[0] - 3.0).abs() < 1e-6,
1526                    "rank 1: expected 3.0, got {}",
1527                    data[0]
1528                );
1529                assert!(
1530                    (data[1] - 4.0).abs() < 1e-6,
1531                    "rank 1: expected 4.0, got {}",
1532                    data[1]
1533                );
1534            }
1535        }
1536    }
1537
1538    // -----------------------------------------------------------------------
1539    // HybridShard tests. CL-327
1540    // -----------------------------------------------------------------------
1541
1542    #[test]
1543    fn test_fsdp_hybrid_shard_rejects_uneven_world_size() {
1544        // world_size=4, intra_node_size=3 is not a divisor → error.
1545        let group = SimulatedBackend::create_group(4).unwrap();
1546        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1547
1548        let b = arcs[0].clone();
1549        let model = TestModule::<f32>::new(&[1.0f32; 8]).unwrap();
1550        let result = FSDP::new_with_strategy(
1551            model,
1552            b,
1553            ShardingStrategy::HybridShard { intra_node_size: 3 },
1554        );
1555        assert!(result.is_err(), "expected uneven intra_node_size to fail");
1556    }
1557
1558    #[test]
1559    fn test_fsdp_hybrid_shard_intra_node_sharding() {
1560        // world_size=4, intra_node_size=2:
1561        //   node 0 = {rank 0, rank 1}
1562        //   node 1 = {rank 2, rank 3}
1563        // Each rank shards a param of size 4 into chunks of size 4/2=2.
1564        //   rank 0 (intra 0, node 0) -> [weight[0], weight[1]]
1565        //   rank 1 (intra 1, node 0) -> [weight[2], weight[3]]
1566        //   rank 2 (intra 0, node 1) -> [weight[0], weight[1]]
1567        //   rank 3 (intra 1, node 1) -> [weight[2], weight[3]]
1568        let group = SimulatedBackend::create_group(4).unwrap();
1569        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1570
1571        let handles: Vec<_> = arcs
1572            .iter()
1573            .cloned()
1574            .map(|b| {
1575                thread::spawn(move || {
1576                    let rank = b.rank();
1577                    let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
1578                    let fsdp = FSDP::new_with_strategy(
1579                        model,
1580                        b,
1581                        ShardingStrategy::HybridShard { intra_node_size: 2 },
1582                    )
1583                    .unwrap();
1584                    let shard = fsdp.module().weight.tensor().data_vec().unwrap();
1585                    (rank, shard)
1586                })
1587            })
1588            .collect();
1589
1590        for h in handles {
1591            let (rank, shard) = h.join().unwrap();
1592            assert_eq!(shard.len(), 2, "each rank gets 4/2 = 2 elements");
1593            let expected: &[f32] = if rank % 2 == 0 {
1594                &[10.0, 20.0]
1595            } else {
1596                &[30.0, 40.0]
1597            };
1598            assert_eq!(shard, expected, "rank {} shard mismatch", rank);
1599        }
1600    }
1601
1602    #[test]
1603    fn test_fsdp_hybrid_shard_sync_gradients() {
1604        // world_size=4, intra_node_size=2:
1605        //   node 0 = {rank 0, rank 1}, node 1 = {rank 2, rank 3}.
1606        // Each rank produces the same full grad [1,2,3,4]. After:
1607        //   intra-node reduce_scatter(mean) → each intra-rank gets its shard
1608        //     of the node-local mean gradient: rank 0 gets [1,2], rank 1
1609        //     gets [3,4], rank 2 gets [1,2], rank 3 gets [3,4].
1610        //   inter-node allreduce(mean) across replicas of the same intra
1611        //     position: replicas produce identical shards, so the mean is
1612        //     unchanged.
1613        // Final expected shard gradients:
1614        //   ranks 0, 2 → [1, 2]
1615        //   ranks 1, 3 → [3, 4]
1616        let group = SimulatedBackend::create_group(4).unwrap();
1617        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1618
1619        let handles: Vec<_> = arcs
1620            .iter()
1621            .cloned()
1622            .map(|b| {
1623                thread::spawn(move || {
1624                    let rank = b.rank();
1625                    let model = TestModule::<f32>::new(&[10.0, 20.0, 30.0, 40.0]).unwrap();
1626                    let mut fsdp = FSDP::new_with_strategy(
1627                        model,
1628                        b,
1629                        ShardingStrategy::HybridShard { intra_node_size: 2 },
1630                    )
1631                    .unwrap();
1632
1633                    let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1634                    let _output = fsdp.forward(&input).unwrap();
1635
1636                    // After forward, the module param is back to a shard
1637                    // (intra-node) but full_params[0] holds the gathered
1638                    // full tensor. Attach a synthetic full-shape gradient.
1639                    let grad = Tensor::from_storage(
1640                        TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
1641                        vec![4],
1642                        false,
1643                    )
1644                    .unwrap();
1645                    fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1646
1647                    fsdp.sync_gradients().unwrap();
1648
1649                    let w = fsdp.module().weight.tensor();
1650                    assert_eq!(w.numel(), 2, "HybridShard keeps params as intra-shards");
1651                    let g = w.grad().unwrap().unwrap();
1652                    let gd = g.data_vec().unwrap();
1653                    (rank, gd)
1654                })
1655            })
1656            .collect();
1657
1658        for h in handles {
1659            let (rank, gd) = h.join().unwrap();
1660            assert_eq!(gd.len(), 2, "shard grad should have 2 elements");
1661            let expected: &[f32] = if rank % 2 == 0 {
1662                &[1.0, 2.0]
1663            } else {
1664                &[3.0, 4.0]
1665            };
1666            for (i, e) in expected.iter().enumerate() {
1667                assert!(
1668                    (gd[i] - e).abs() < 1e-6,
1669                    "rank {} [{}]: expected {}, got {}",
1670                    rank,
1671                    i,
1672                    e,
1673                    gd[i]
1674                );
1675            }
1676        }
1677    }
1678
1679    #[test]
1680    fn test_fsdp_hybrid_shard_inter_node_averaging() {
1681        // Construct a scenario where the two nodes disagree on grads, so
1682        // the inter-node allreduce has meaningful work:
1683        //   node 0 (ranks 0,1) full grad = [2,4,6,8]
1684        //   node 1 (ranks 2,3) full grad = [10,20,30,40]
1685        // Node means after intra reduce_scatter:
1686        //   rank 0 [2,4], rank 1 [6,8], rank 2 [10,20], rank 3 [30,40].
1687        // After inter-node allreduce(mean):
1688        //   rank 0 (intra 0) = mean(rank 0's shard, rank 2's shard) = [6,12]
1689        //   rank 1 (intra 1) = mean(rank 1's shard, rank 3's shard) = [18,24]
1690        //   rank 2 (intra 0) = mean(rank 2's shard, rank 0's shard) = [6,12]
1691        //   rank 3 (intra 1) = mean(rank 3's shard, rank 1's shard) = [18,24]
1692        let group = SimulatedBackend::create_group(4).unwrap();
1693        let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
1694
1695        let handles: Vec<_> = arcs
1696            .iter()
1697            .cloned()
1698            .map(|b| {
1699                thread::spawn(move || {
1700                    let rank = b.rank();
1701                    let model = TestModule::<f32>::new(&[0.0f32; 4]).unwrap();
1702                    let mut fsdp = FSDP::new_with_strategy(
1703                        model,
1704                        b,
1705                        ShardingStrategy::HybridShard { intra_node_size: 2 },
1706                    )
1707                    .unwrap();
1708
1709                    let input = ferrotorch_core::from_slice(&[1.0f32], &[1]).unwrap();
1710                    let _ = fsdp.forward(&input).unwrap();
1711
1712                    // Node-specific gradient: node 0 ranks 0/1 get
1713                    // [2,4,6,8], node 1 ranks 2/3 get [10,20,30,40].
1714                    let grad_vec: Vec<f32> = if rank < 2 {
1715                        vec![2.0, 4.0, 6.0, 8.0]
1716                    } else {
1717                        vec![10.0, 20.0, 30.0, 40.0]
1718                    };
1719                    let grad =
1720                        Tensor::from_storage(TensorStorage::cpu(grad_vec), vec![4], false).unwrap();
1721                    fsdp.full_params[0].set_grad(Some(grad)).unwrap();
1722
1723                    fsdp.sync_gradients().unwrap();
1724
1725                    let gd = fsdp
1726                        .module()
1727                        .weight
1728                        .tensor()
1729                        .grad()
1730                        .unwrap()
1731                        .unwrap()
1732                        .data_vec()
1733                        .unwrap();
1734                    (rank, gd)
1735                })
1736            })
1737            .collect();
1738
1739        for h in handles {
1740            let (rank, gd) = h.join().unwrap();
1741            assert_eq!(gd.len(), 2);
1742            // intra-position 0 -> [6, 12]; intra-position 1 -> [18, 24]
1743            let expected: &[f32] = if rank % 2 == 0 {
1744                &[6.0, 12.0]
1745            } else {
1746                &[18.0, 24.0]
1747            };
1748            for (i, e) in expected.iter().enumerate() {
1749                assert!(
1750                    (gd[i] - e).abs() < 1e-4,
1751                    "hybrid inter-node mean rank {} [{}]: expected {}, got {}",
1752                    rank,
1753                    i,
1754                    e,
1755                    gd[i]
1756                );
1757            }
1758        }
1759    }
1760}