use std::collections::HashMap;
use crate::autograd::GradientStore;
use crate::backend::gpu::context::MultiGpuContext;
use crate::nn::{Module, Parameter};
use crate::tensor::{self, Tensor};
pub struct DataParallel<M: Module> {
pub module: M,
pub device_ids: Vec<usize>,
replicas: Vec<HashMap<String, Tensor>>,
}
impl<M: Module> DataParallel<M> {
pub fn new(module: M, device_ids: Vec<usize>) -> Self {
assert!(!device_ids.is_empty(), "DataParallel: need at least one device");
let mgpu = MultiGpuContext::get().expect("MultiGpuContext required for DataParallel");
for &id in &device_ids {
assert!(id < mgpu.num_devices(), "device {} out of range", id);
}
let master_dict = module.state_dict("");
let mut replicas = Vec::with_capacity(device_ids.len());
replicas.push(master_dict.clone());
for &dev in &device_ids[1..] {
let mut replica_dict = HashMap::new();
for (name, tensor) in &master_dict {
replica_dict.insert(name.clone(), tensor.to_device(dev));
}
replicas.push(replica_dict);
}
Self { module, device_ids, replicas }
}
pub fn forward<F>(&self, input: &Tensor, fwd: F) -> Tensor
where
F: Fn(&HashMap<String, Tensor>, &Tensor) -> Tensor + Send + Sync,
{
let n = self.device_ids.len();
let batch = input.shape()[0];
assert!(batch % n == 0, "DataParallel: batch {} not divisible by {} devices", batch, n);
let chunk_size = batch / n;
let chunks: Vec<Tensor> = (0..n)
.map(|i| {
input
.slice_range(0, i * chunk_size, (i + 1) * chunk_size)
.to_device(self.device_ids[i])
})
.collect();
let outputs: Vec<Tensor> = std::thread::scope(|s| {
let handles: Vec<_> = chunks
.iter()
.enumerate()
.map(|(i, chunk)| {
let replica = &self.replicas[i];
let fwd_ref = &fwd;
s.spawn(move || fwd_ref(replica, chunk))
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("DataParallel: forward thread panicked"))
.collect()
});
let gathered: Vec<Tensor> = outputs
.into_iter()
.map(|t| t.to_device(self.device_ids[0]))
.collect();
tensor::cat(&gathered, 0)
}
pub fn broadcast_weights(&mut self) {
let master_dict = self.module.state_dict("");
for (i, &dev) in self.device_ids[1..].iter().enumerate() {
for (name, tensor) in &master_dict {
self.replicas[i + 1].insert(name.clone(), tensor.to_device(dev));
}
}
self.replicas[0] = master_dict;
}
}
pub struct AllReduceSync {
device_ids: Vec<usize>,
}
impl AllReduceSync {
pub fn new(device_ids: Vec<usize>) -> Self {
Self { device_ids }
}
pub fn all_reduce(
&self,
grad_stores: &mut [GradientStore],
params: &[Parameter],
) -> GradientStore {
let mgpu = MultiGpuContext::get().expect("MultiGpuContext required for AllReduce");
let n = self.device_ids.len() as f32;
let mut averaged = GradientStore::new();
for param in params {
let gid = param.grad_id();
let numel = param.tensor.numel();
let byte_size = (numel * 4) as u64;
let mut staging_buffers = Vec::with_capacity(self.device_ids.len());
let mut map_receivers = Vec::with_capacity(self.device_ids.len());
for (store_idx, &dev_id) in self.device_ids.iter().enumerate() {
let ctx = mgpu.device(dev_id);
if let Some(grad) = grad_stores[store_idx].get(gid) {
grad.storage.ensure_gpu();
let grad_buf = grad.storage.gpu_buffer();
let staging = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("allreduce_staging"),
size: byte_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = ctx.device.create_command_encoder(
&wgpu::CommandEncoderDescriptor::default(),
);
encoder.copy_buffer_to_buffer(&grad_buf, 0, &staging, 0, byte_size);
ctx.queue.submit(std::iter::once(encoder.finish()));
drop(grad_buf);
let (tx, rx) = std::sync::mpsc::sync_channel(1);
staging.slice(..).map_async(wgpu::MapMode::Read, move |r| {
let _ = tx.send(r);
});
staging_buffers.push(Some(staging));
map_receivers.push(Some(rx));
} else {
staging_buffers.push(None);
map_receivers.push(None);
}
}
for &dev_id in &self.device_ids {
mgpu.device(dev_id).device.poll(wgpu::Maintain::Wait);
}
let mut sum = vec![0.0f32; numel];
let mut count = 0usize;
for (staging, rx) in staging_buffers.iter().zip(map_receivers.iter()) {
if let (Some(staging_buf), Some(receiver)) = (staging, rx) {
receiver
.recv()
.expect("map callback not called")
.expect("buffer map failed");
let view = staging_buf.slice(..).get_mapped_range();
let f32_data: &[f32] = bytemuck::cast_slice(&view);
for (i, &v) in f32_data.iter().enumerate() {
if i < numel {
sum[i] += v;
}
}
drop(view);
staging_buf.unmap();
count += 1;
}
}
if count > 0 {
let divisor = count as f32;
for v in &mut sum {
*v /= divisor;
}
}
averaged
.accumulate(gid, Tensor::new(sum, param.tensor.shape().to_vec()))
.expect("AllReduce: accumulate failed");
}
averaged
}
}