use ferrotorch_core::dtype::Float;
use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
use ferrotorch_core::tensor::Tensor;
use crate::collective::ReduceOp;
use crate::device_mesh::DeviceMesh;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Placement {
Replicate,
Shard(usize),
Partial(ReduceOp),
}
impl Placement {
pub fn is_replicate(&self) -> bool {
matches!(self, Placement::Replicate)
}
pub fn is_shard(&self) -> bool {
matches!(self, Placement::Shard(_))
}
pub fn is_partial(&self) -> bool {
matches!(self, Placement::Partial(_))
}
pub fn shard_dim(&self) -> Option<usize> {
match self {
Placement::Shard(d) => Some(*d),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct DTensor<T: Float> {
local_tensor: Tensor<T>,
placements: Vec<Placement>,
global_shape: Vec<usize>,
mesh: DeviceMesh,
}
impl<T: Float> DTensor<T> {
pub fn from_local(
local_tensor: Tensor<T>,
mesh: DeviceMesh,
placements: Vec<Placement>,
global_shape: Vec<usize>,
) -> FerrotorchResult<Self> {
if placements.len() != mesh.ndim() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"DTensor::from_local: placements.len()={} != mesh.ndim()={}",
placements.len(),
mesh.ndim()
),
});
}
for (mi, p) in placements.iter().enumerate() {
if let Placement::Shard(d) = p {
if *d >= global_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"DTensor::from_local: mesh dim {mi} shards tensor dim {d} \
but global_shape.len()={}",
global_shape.len()
),
});
}
}
}
Ok(Self {
local_tensor,
placements,
global_shape,
mesh,
})
}
pub fn from_local_replicated(local: Tensor<T>, mesh: DeviceMesh) -> FerrotorchResult<Self> {
let global = local.shape().to_vec();
let placements = vec![Placement::Replicate; mesh.ndim()];
Self::from_local(local, mesh, placements, global)
}
pub fn to_local(&self) -> &Tensor<T> {
&self.local_tensor
}
pub fn shape(&self) -> &[usize] {
&self.global_shape
}
pub fn placements(&self) -> &[Placement] {
&self.placements
}
pub fn mesh(&self) -> &DeviceMesh {
&self.mesh
}
pub fn numel(&self) -> usize {
self.global_shape.iter().product::<usize>().max(1)
}
pub fn redistribute(&mut self, target_placements: Vec<Placement>) -> FerrotorchResult<()> {
if target_placements.len() != self.mesh.ndim() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"DTensor::redistribute: target.len()={} != mesh.ndim()={}",
target_placements.len(),
self.mesh.ndim()
),
});
}
for (mi, p) in target_placements.iter().enumerate() {
if let Placement::Shard(d) = p {
if *d >= self.global_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"DTensor::redistribute: mesh dim {mi} target shards tensor dim {d} \
but global_shape.len()={}",
self.global_shape.len()
),
});
}
}
}
self.placements = target_placements;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferrotorch_core::storage::TensorStorage;
fn t(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
#[test]
fn placement_predicates() {
assert!(Placement::Replicate.is_replicate());
assert!(Placement::Shard(0).is_shard());
assert!(Placement::Partial(ReduceOp::Sum).is_partial());
assert_eq!(Placement::Shard(2).shard_dim(), Some(2));
assert_eq!(Placement::Replicate.shard_dim(), None);
}
#[test]
fn from_local_replicated_uses_local_shape() {
let mesh = DeviceMesh::new(vec![2, 2], 4).unwrap();
let local = t(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let dt = DTensor::from_local_replicated(local, mesh).unwrap();
assert_eq!(dt.shape(), &[2, 2]);
assert_eq!(dt.placements().len(), 2);
assert!(dt.placements().iter().all(|p| p.is_replicate()));
}
#[test]
fn from_local_rejects_placement_count_mismatch() {
let mesh = DeviceMesh::new(vec![4], 4).unwrap();
let local = t(vec![0.0; 4], vec![4]);
let err = DTensor::from_local(
local,
mesh,
vec![Placement::Replicate, Placement::Replicate],
vec![4],
)
.unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn from_local_rejects_oob_shard_dim() {
let mesh = DeviceMesh::new(vec![4], 4).unwrap();
let local = t(vec![0.0; 4], vec![4]);
let err =
DTensor::from_local(local, mesh, vec![Placement::Shard(2)], vec![16]).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn redistribute_updates_placements() {
let mesh = DeviceMesh::new(vec![2], 2).unwrap();
let local = t(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let mut dt = DTensor::from_local(
local,
mesh,
vec![Placement::Shard(0)],
vec![4, 2], )
.unwrap();
assert_eq!(dt.placements()[0], Placement::Shard(0));
dt.redistribute(vec![Placement::Replicate]).unwrap();
assert!(dt.placements()[0].is_replicate());
}
#[test]
fn redistribute_rejects_target_count_mismatch() {
let mesh = DeviceMesh::new(vec![2, 2], 4).unwrap();
let local = t(vec![1.0; 4], vec![2, 2]);
let mut dt = DTensor::from_local_replicated(local, mesh).unwrap();
let err = dt.redistribute(vec![Placement::Replicate]).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn redistribute_rejects_oob_shard() {
let mesh = DeviceMesh::new(vec![2], 2).unwrap();
let local = t(vec![1.0; 4], vec![4]);
let mut dt = DTensor::from_local_replicated(local, mesh).unwrap();
let err = dt.redistribute(vec![Placement::Shard(5)]).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn numel_uses_global_shape() {
let mesh = DeviceMesh::new(vec![2], 2).unwrap();
let local = t(vec![1.0; 4], vec![2, 2]);
let dt = DTensor::from_local(local, mesh, vec![Placement::Shard(0)], vec![4, 2]).unwrap();
assert_eq!(dt.numel(), 8);
assert_eq!(dt.to_local().numel(), 4);
}
}