use std::sync::mpsc;
use std::sync::{Arc, Condvar, Mutex};
use rumus::tensor::Tensor;
pub struct CollectiveBarrier {
pub world_size: usize,
state: Mutex<BarrierState>,
cvar: Condvar,
}
struct BarrierState {
buffers: Vec<Vec<f32>>,
result: Option<Vec<f32>>,
read_count: usize,
}
impl CollectiveBarrier {
pub fn new(world_size: usize) -> Self {
Self {
world_size,
state: Mutex::new(BarrierState {
buffers: Vec::new(),
result: None,
read_count: 0,
}),
cvar: Condvar::new(),
}
}
pub fn reduce(&self, local: Vec<f32>) -> Vec<f32> {
let mut state = self.state.lock().unwrap();
state.buffers.push(local);
if state.buffers.len() == self.world_size {
let len = state.buffers[0].len();
let mut summed = vec![0.0f32; len];
for buf in &state.buffers {
for (s, &v) in summed.iter_mut().zip(buf.iter()) {
*s += v;
}
}
let n = self.world_size as f32;
for v in &mut summed {
*v /= n;
}
state.result = Some(summed);
state.read_count = 0;
self.cvar.notify_all();
} else {
state = self.cvar
.wait_while(state, |s| s.result.is_none())
.unwrap();
}
let result = state.result.as_ref().unwrap().clone();
state.read_count += 1;
if state.read_count == self.world_size {
state.buffers.clear();
state.result = None;
state.read_count = 0;
}
result
}
}
pub struct CommRequest {
pub staging_buf: wgpu::Buffer,
pub dst_buf: wgpu::Buffer,
pub byte_size: u64,
pub barrier: Arc<CollectiveBarrier>,
pub response_tx: mpsc::SyncSender<()>,
}
pub struct CommThread {
tx: mpsc::SyncSender<CommRequest>,
_handle: std::thread::JoinHandle<()>,
}
impl CommThread {
pub fn spawn(
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
) -> Self {
let (tx, rx) = mpsc::sync_channel::<CommRequest>(16);
let handle = std::thread::spawn(move || {
while let Ok(req) = rx.recv() {
let slice = req.staging_buf.slice(..);
let (map_tx, map_rx) = mpsc::sync_channel(1);
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = map_tx.send(r);
});
device.poll(wgpu::Maintain::Wait);
map_rx.recv().unwrap().unwrap();
let view = slice.get_mapped_range();
let local: Vec<f32> = bytemuck::cast_slice(&view).to_vec();
drop(view);
req.staging_buf.unmap();
let reduced = req.barrier.reduce(local);
queue.write_buffer(&req.dst_buf, 0, bytemuck::cast_slice(&reduced));
let _ = req.response_tx.send(());
}
});
Self { tx, _handle: handle }
}
pub fn submit(&self, req: CommRequest) {
self.tx.send(req).expect("comm thread dead");
}
}
pub struct AllReduceHandle {
rx: mpsc::Receiver<()>,
}
impl AllReduceHandle {
pub fn wait(self) {
let _ = self.rx.recv();
}
}
pub fn async_allreduce(
comm: &CommThread,
device: &wgpu::Device,
queue: &wgpu::Queue,
src_buf: &wgpu::Buffer,
dst_buf: wgpu::Buffer,
byte_size: u64,
barrier: &Arc<CollectiveBarrier>,
) -> AllReduceHandle {
let staging = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("allreduce_staging"),
size: byte_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut enc = device.create_command_encoder(&Default::default());
enc.copy_buffer_to_buffer(src_buf, 0, &staging, 0, byte_size);
queue.submit(std::iter::once(enc.finish()));
let (resp_tx, resp_rx) = mpsc::sync_channel(1);
comm.submit(CommRequest {
staging_buf: staging,
dst_buf,
byte_size,
barrier: Arc::clone(barrier),
response_tx: resp_tx,
});
AllReduceHandle { rx: resp_rx }
}