use crate::data::dataset::{Batch, Dataset};
use crate::error::{Error, Result};
use numr::runtime::Runtime;
pub struct ShardedDataset<D> {
inner: D,
rank: usize,
world_size: usize,
shard_len: usize,
}
impl<D> ShardedDataset<D> {
pub fn inner(&self) -> &D {
&self.inner
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn world_size(&self) -> usize {
self.world_size
}
}
impl<D> ShardedDataset<D> {
fn from_parts(inner: D, rank: usize, world_size: usize, total: usize) -> Result<Self> {
if world_size == 0 {
return Err(Error::DataError {
reason: "world_size must be > 0".to_string(),
});
}
if rank >= world_size {
return Err(Error::DataError {
reason: format!("rank {rank} >= world_size {world_size}"),
});
}
let shard_len = total.saturating_sub(rank).div_ceil(world_size);
Ok(Self {
inner,
rank,
world_size,
shard_len,
})
}
}
pub fn new_sharded<R: Runtime, D: Dataset<R>>(
inner: D,
rank: usize,
world_size: usize,
) -> Result<ShardedDataset<D>> {
let total = inner.len();
ShardedDataset::from_parts(inner, rank, world_size, total)
}
impl<R: Runtime, D: Dataset<R>> Dataset<R> for ShardedDataset<D> {
fn len(&self) -> usize {
self.shard_len
}
fn get(&self, idx: usize, device: &R::Device) -> Result<Batch<R>> {
if idx >= self.shard_len {
return Err(Error::DataError {
reason: format!(
"shard index {idx} out of bounds for shard size {}",
self.shard_len
),
});
}
let global_idx = idx * self.world_size + self.rank;
self.inner.get(global_idx, device)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::dataset::Batch;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
use numr::tensor::Tensor;
struct DummyDataset {
size: usize,
}
impl Dataset<CpuRuntime> for DummyDataset {
fn len(&self) -> usize {
self.size
}
fn get(
&self,
idx: usize,
device: &<CpuRuntime as Runtime>::Device,
) -> Result<Batch<CpuRuntime>> {
let val = idx as f32;
Ok(Batch {
inputs: Tensor::from_slice(&[val], &[1], device),
targets: Tensor::from_slice(&[val + 0.5], &[1], device),
})
}
}
#[test]
fn test_shard_disjoint() {
let device = CpuDevice::new();
let ds = DummyDataset { size: 10 };
let s0 = new_sharded::<CpuRuntime, _>(ds, 0, 3).unwrap();
let ds = DummyDataset { size: 10 };
let s1 = new_sharded::<CpuRuntime, _>(ds, 1, 3).unwrap();
let ds = DummyDataset { size: 10 };
let s2 = new_sharded::<CpuRuntime, _>(ds, 2, 3).unwrap();
assert_eq!(s0.len(), 4);
assert_eq!(s1.len(), 3);
assert_eq!(s2.len(), 3);
let b = s0.get(0, &device).unwrap();
assert_eq!(b.inputs.to_vec::<f32>(), vec![0.0]);
let b = s0.get(1, &device).unwrap();
assert_eq!(b.inputs.to_vec::<f32>(), vec![3.0]);
let b = s0.get(2, &device).unwrap();
assert_eq!(b.inputs.to_vec::<f32>(), vec![6.0]);
let b = s0.get(3, &device).unwrap();
assert_eq!(b.inputs.to_vec::<f32>(), vec![9.0]);
}
#[test]
fn test_shard_invalid_rank() {
let ds = DummyDataset { size: 10 };
assert!(new_sharded::<CpuRuntime, _>(ds, 3, 3).is_err());
}
}