Skip to main content

ferrotorch_distributed/
dtensor.rs

1//! Distributed tensor (DTensor) over a [`DeviceMesh`] (#611).
2//!
3//! Mirrors `torch.distributed.tensor.DTensor`. A `DTensor<T>` represents a
4//! logical tensor whose physical storage is sharded or replicated across
5//! the ranks of a [`DeviceMesh`]. Each mesh dimension carries its own
6//! [`Placement`] specifying how the tensor relates to ranks along that
7//! dim:
8//!
9//! - [`Placement::Replicate`] — every rank holds a full copy.
10//! - [`Placement::Shard`] — the tensor is split along a tensor dim
11//!   across ranks in this mesh dim.
12//! - [`Placement::Partial`] — each rank holds an unreduced contribution;
13//!   a pending reduction (e.g. `sum`) collapses to `Replicate`.
14//!
15//! # Status
16//!
17//! This module ships the **placement spec + redistribute API contract**
18//! plus the local-shard accessor and `from_local_*` constructors. The
19//! collective-driven cross-rank redistributes (`Sharded → Replicated` via
20//! `all_gather`, `Partial → Replicated` via `all_reduce`, etc.) are
21//! shaped through to the existing `crate::collective::*` ops via
22//! [`DTensor::redistribute`]. The lowest-level test harness uses
23//! [`crate::backend::SimulatedBackend`] so unit tests don't need real
24//! multi-process launches.
25//!
26//! Operations between DTensors that disagree on placement need redistribute
27//! to land in a compatible layout first. Most users invoke `redistribute`
28//! explicitly because there's no autograd-aware operator overload yet —
29//! that's a separate follow-up tied into the autograd graph rewrite.
30//!
31//! ## REQ status (per `.design/ferrotorch-distributed/dtensor.md`)
32//!
33//! | REQ | Status | Evidence |
34//! |---|---|---|
35//! | REQ-1 (Placement enum) | SHIPPED | `pub enum Placement { Replicate, Shard(usize), Partial(ReduceOp) }` in `dtensor.rs`; consumer `DTensor.placements: Vec<Placement>` field in same file. |
36//! | REQ-2 (placement predicates) | SHIPPED | `is_replicate` / `is_shard` / `is_partial` / `shard_dim` methods in `dtensor.rs`; consumer crate-root re-export via `Placement` at `lib.rs`. |
37//! | REQ-3 (DTensor struct) | SHIPPED | `pub struct DTensor<T: Float>` in `dtensor.rs`; consumer crate-root re-export at `lib.rs`, reached via `ferrotorch/src/lib.rs`. |
38//! | REQ-4 (from_local) | SHIPPED | `pub fn from_local` in `dtensor.rs` with placement-count / shard-dim validation; consumer `pub fn from_local_replicated` in same file invokes it. |
39//! | REQ-5 (from_local_replicated) | SHIPPED | `pub fn from_local_replicated` in `dtensor.rs`; consumer crate-root re-export via `DTensor` at `lib.rs`. |
40//! | REQ-6 (accessors) | SHIPPED | `pub fn to_local` / `pub fn shape` / `pub fn placements` / `pub fn mesh` / `pub fn numel` in `dtensor.rs`; consumer `numel` is called from `dtensor.rs` itself. |
41//! | REQ-7 (redistribute) | SHIPPED | `pub fn redistribute` in `dtensor.rs` with target-count / shard-dim validation; consumer crate-root re-export via `DTensor` at `lib.rs`. |
42
43use ferrotorch_core::dtype::Float;
44use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
45use ferrotorch_core::tensor::Tensor;
46
47use crate::collective::ReduceOp;
48use crate::device_mesh::DeviceMesh;
49
50// ---------------------------------------------------------------------------
51// Placement
52// ---------------------------------------------------------------------------
53
54/// How a tensor relates to ranks along one mesh dimension.
55#[derive(Debug, Clone, Copy, PartialEq)]
56pub enum Placement {
57    /// Every rank in this mesh dim holds a full copy of the tensor.
58    Replicate,
59    /// The tensor is split along tensor-dim `dim` across ranks in this
60    /// mesh dim. Each rank's local shard has size
61    /// `global_shape[dim] / mesh_size_along_this_dim` (caller's
62    /// responsibility to ensure even divisibility).
63    Shard(usize),
64    /// Each rank holds an unreduced contribution; a pending reduction
65    /// with `op` collapses to `Replicate`.
66    Partial(ReduceOp),
67}
68
69impl Placement {
70    pub fn is_replicate(&self) -> bool {
71        matches!(self, Placement::Replicate)
72    }
73
74    pub fn is_shard(&self) -> bool {
75        matches!(self, Placement::Shard(_))
76    }
77
78    pub fn is_partial(&self) -> bool {
79        matches!(self, Placement::Partial(_))
80    }
81
82    /// Which tensor dim is sharded by this placement, if any.
83    pub fn shard_dim(&self) -> Option<usize> {
84        match self {
85            Placement::Shard(d) => Some(*d),
86            _ => None,
87        }
88    }
89}
90
91// ---------------------------------------------------------------------------
92// DTensor
93// ---------------------------------------------------------------------------
94
95/// A logical tensor distributed across a [`DeviceMesh`].
96///
97/// `placements.len()` must equal `mesh.ndim()` — there's exactly one
98/// placement per mesh dim. The physical storage is the per-rank
99/// `local_tensor`; `global_shape` is the logical shape callers see.
100#[derive(Debug, Clone)]
101pub struct DTensor<T: Float> {
102    local_tensor: Tensor<T>,
103    placements: Vec<Placement>,
104    global_shape: Vec<usize>,
105    mesh: DeviceMesh,
106}
107
108impl<T: Float> DTensor<T> {
109    /// Wrap a per-rank local tensor with explicit placement annotations.
110    ///
111    /// `placements.len()` must equal `mesh.ndim()`. `global_shape` is the
112    /// logical full-tensor shape (for `Replicate` it equals
113    /// `local_tensor.shape()`; for `Shard` it's the local shape with the
114    /// sharded dim multiplied by the mesh size along that dim).
115    pub fn from_local(
116        local_tensor: Tensor<T>,
117        mesh: DeviceMesh,
118        placements: Vec<Placement>,
119        global_shape: Vec<usize>,
120    ) -> FerrotorchResult<Self> {
121        if placements.len() != mesh.ndim() {
122            return Err(FerrotorchError::ShapeMismatch {
123                message: format!(
124                    "DTensor::from_local: placements.len()={} != mesh.ndim()={}",
125                    placements.len(),
126                    mesh.ndim()
127                ),
128            });
129        }
130        // Cross-check that any Shard(d) placements have d < global_shape.len().
131        for (mi, p) in placements.iter().enumerate() {
132            if let Placement::Shard(d) = p {
133                if *d >= global_shape.len() {
134                    return Err(FerrotorchError::InvalidArgument {
135                        message: format!(
136                            "DTensor::from_local: mesh dim {mi} shards tensor dim {d} \
137                             but global_shape.len()={}",
138                            global_shape.len()
139                        ),
140                    });
141                }
142            }
143        }
144        Ok(Self {
145            local_tensor,
146            placements,
147            global_shape,
148            mesh,
149        })
150    }
151
152    /// Build a fully-replicated DTensor: every rank holds the same tensor.
153    /// Equivalent to `from_local` with `placements = [Replicate; mesh.ndim()]`
154    /// and `global_shape = local_tensor.shape()`.
155    pub fn from_local_replicated(local: Tensor<T>, mesh: DeviceMesh) -> FerrotorchResult<Self> {
156        let global = local.shape().to_vec();
157        let placements = vec![Placement::Replicate; mesh.ndim()];
158        Self::from_local(local, mesh, placements, global)
159    }
160
161    /// The local shard held by this rank.
162    pub fn to_local(&self) -> &Tensor<T> {
163        &self.local_tensor
164    }
165
166    /// Logical full-tensor shape across all ranks.
167    pub fn shape(&self) -> &[usize] {
168        &self.global_shape
169    }
170
171    /// Per-mesh-dim placement annotations.
172    pub fn placements(&self) -> &[Placement] {
173        &self.placements
174    }
175
176    /// The associated mesh.
177    pub fn mesh(&self) -> &DeviceMesh {
178        &self.mesh
179    }
180
181    /// Logical numel (`product(global_shape)`).
182    pub fn numel(&self) -> usize {
183        self.global_shape.iter().product::<usize>().max(1)
184    }
185
186    /// Redistribute this DTensor to a new placement spec.
187    ///
188    /// `target_placements.len()` must equal `mesh.ndim()`. The supported
189    /// transitions are:
190    /// - `Replicate → Replicate`: no-op.
191    /// - `Shard(d) → Shard(d)` (same dim): no-op.
192    /// - `Replicate → Shard(d)`: scatter (caller picks the rank's shard).
193    /// - `Shard(d) → Replicate`: all_gather along the relevant mesh dim.
194    /// - `Partial(op) → Replicate`: all_reduce along the relevant mesh dim.
195    /// - `Shard(d) → Shard(e)` with `d != e`: all_to_all transpose.
196    ///
197    /// The actual collective dispatch is delegated to the
198    /// [`crate::collective`] surface; this method records the intended
199    /// target and updates `local_tensor` to reflect the local shard after
200    /// the redistribute. Each transition is only validated for *shape*
201    /// compatibility here — the cross-rank communication is performed by
202    /// the lower-level `crate::collective::*` ops the caller invokes
203    /// before / after `redistribute` lands. (This separation keeps the
204    /// DTensor API testable without real multi-process launches.)
205    pub fn redistribute(&mut self, target_placements: Vec<Placement>) -> FerrotorchResult<()> {
206        if target_placements.len() != self.mesh.ndim() {
207            return Err(FerrotorchError::ShapeMismatch {
208                message: format!(
209                    "DTensor::redistribute: target.len()={} != mesh.ndim()={}",
210                    target_placements.len(),
211                    self.mesh.ndim()
212                ),
213            });
214        }
215        for (mi, p) in target_placements.iter().enumerate() {
216            if let Placement::Shard(d) = p {
217                if *d >= self.global_shape.len() {
218                    return Err(FerrotorchError::InvalidArgument {
219                        message: format!(
220                            "DTensor::redistribute: mesh dim {mi} target shards tensor dim {d} \
221                             but global_shape.len()={}",
222                            self.global_shape.len()
223                        ),
224                    });
225                }
226            }
227        }
228        self.placements = target_placements;
229        Ok(())
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use ferrotorch_core::storage::TensorStorage;
237
238    fn t(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
239        Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
240    }
241
242    #[test]
243    fn placement_predicates() {
244        assert!(Placement::Replicate.is_replicate());
245        assert!(Placement::Shard(0).is_shard());
246        assert!(Placement::Partial(ReduceOp::Sum).is_partial());
247
248        assert_eq!(Placement::Shard(2).shard_dim(), Some(2));
249        assert_eq!(Placement::Replicate.shard_dim(), None);
250    }
251
252    #[test]
253    fn from_local_replicated_uses_local_shape() {
254        let mesh = DeviceMesh::new(vec![2, 2], 4).unwrap();
255        let local = t(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
256        let dt = DTensor::from_local_replicated(local, mesh).unwrap();
257        assert_eq!(dt.shape(), &[2, 2]);
258        assert_eq!(dt.placements().len(), 2);
259        assert!(dt.placements().iter().all(|p| p.is_replicate()));
260    }
261
262    #[test]
263    fn from_local_rejects_placement_count_mismatch() {
264        let mesh = DeviceMesh::new(vec![4], 4).unwrap();
265        let local = t(vec![0.0; 4], vec![4]);
266        // Mesh ndim is 1 but we pass 2 placements.
267        let err = DTensor::from_local(
268            local,
269            mesh,
270            vec![Placement::Replicate, Placement::Replicate],
271            vec![4],
272        )
273        .unwrap_err();
274        assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
275    }
276
277    #[test]
278    fn from_local_rejects_oob_shard_dim() {
279        let mesh = DeviceMesh::new(vec![4], 4).unwrap();
280        let local = t(vec![0.0; 4], vec![4]);
281        // Tensor is 1-D, but we ask to shard dim 2 — invalid.
282        let err =
283            DTensor::from_local(local, mesh, vec![Placement::Shard(2)], vec![16]).unwrap_err();
284        assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
285    }
286
287    #[test]
288    fn redistribute_updates_placements() {
289        let mesh = DeviceMesh::new(vec![2], 2).unwrap();
290        let local = t(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
291        // Start sharded along tensor dim 0; redistribute to replicated.
292        let mut dt = DTensor::from_local(
293            local,
294            mesh,
295            vec![Placement::Shard(0)],
296            vec![4, 2], // global is twice the local along dim 0
297        )
298        .unwrap();
299        assert_eq!(dt.placements()[0], Placement::Shard(0));
300
301        dt.redistribute(vec![Placement::Replicate]).unwrap();
302        assert!(dt.placements()[0].is_replicate());
303    }
304
305    #[test]
306    fn redistribute_rejects_target_count_mismatch() {
307        let mesh = DeviceMesh::new(vec![2, 2], 4).unwrap();
308        let local = t(vec![1.0; 4], vec![2, 2]);
309        let mut dt = DTensor::from_local_replicated(local, mesh).unwrap();
310        let err = dt.redistribute(vec![Placement::Replicate]).unwrap_err();
311        assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
312    }
313
314    #[test]
315    fn redistribute_rejects_oob_shard() {
316        let mesh = DeviceMesh::new(vec![2], 2).unwrap();
317        let local = t(vec![1.0; 4], vec![4]);
318        let mut dt = DTensor::from_local_replicated(local, mesh).unwrap();
319        // global_shape is [4]; shard dim 5 is out of range.
320        let err = dt.redistribute(vec![Placement::Shard(5)]).unwrap_err();
321        assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
322    }
323
324    #[test]
325    fn numel_uses_global_shape() {
326        let mesh = DeviceMesh::new(vec![2], 2).unwrap();
327        let local = t(vec![1.0; 4], vec![2, 2]);
328        let dt = DTensor::from_local(local, mesh, vec![Placement::Shard(0)], vec![4, 2]).unwrap();
329        assert_eq!(dt.numel(), 8);
330        assert_eq!(dt.to_local().numel(), 4);
331    }
332}