rumus_distributed/
collective.rs1use std::sync::mpsc;
5use std::sync::{Arc, Condvar, Mutex};
6
7use rumus::tensor::Tensor;
8
9pub struct CollectiveBarrier {
17 pub world_size: usize,
18 state: Mutex<BarrierState>,
19 cvar: Condvar,
20}
21
22struct BarrierState {
23 buffers: Vec<Vec<f32>>,
24 result: Option<Vec<f32>>,
25 read_count: usize,
26}
27
28impl CollectiveBarrier {
29 pub fn new(world_size: usize) -> Self {
30 Self {
31 world_size,
32 state: Mutex::new(BarrierState {
33 buffers: Vec::new(),
34 result: None,
35 read_count: 0,
36 }),
37 cvar: Condvar::new(),
38 }
39 }
40
41 pub fn reduce(&self, local: Vec<f32>) -> Vec<f32> {
43 let mut state = self.state.lock().unwrap();
44
45 state.buffers.push(local);
46
47 if state.buffers.len() == self.world_size {
48 let len = state.buffers[0].len();
50 let mut summed = vec![0.0f32; len];
51 for buf in &state.buffers {
52 for (s, &v) in summed.iter_mut().zip(buf.iter()) {
53 *s += v;
54 }
55 }
56 let n = self.world_size as f32;
57 for v in &mut summed {
58 *v /= n;
59 }
60 state.result = Some(summed);
61 state.read_count = 0;
62 self.cvar.notify_all();
63 } else {
64 state = self.cvar
65 .wait_while(state, |s| s.result.is_none())
66 .unwrap();
67 }
68
69 let result = state.result.as_ref().unwrap().clone();
70 state.read_count += 1;
71 if state.read_count == self.world_size {
72 state.buffers.clear();
73 state.result = None;
74 state.read_count = 0;
75 }
76
77 result
78 }
79}
80
81pub struct CommRequest {
87 pub staging_buf: wgpu::Buffer,
88 pub dst_buf: wgpu::Buffer,
89 pub byte_size: u64,
90 pub barrier: Arc<CollectiveBarrier>,
91 pub response_tx: mpsc::SyncSender<()>,
92}
93
94pub struct CommThread {
98 tx: mpsc::SyncSender<CommRequest>,
99 _handle: std::thread::JoinHandle<()>,
100}
101
102impl CommThread {
103 pub fn spawn(
105 device: Arc<wgpu::Device>,
106 queue: Arc<wgpu::Queue>,
107 ) -> Self {
108 let (tx, rx) = mpsc::sync_channel::<CommRequest>(16);
109
110 let handle = std::thread::spawn(move || {
111 while let Ok(req) = rx.recv() {
112 let slice = req.staging_buf.slice(..);
114 let (map_tx, map_rx) = mpsc::sync_channel(1);
115 slice.map_async(wgpu::MapMode::Read, move |r| {
116 let _ = map_tx.send(r);
117 });
118 device.poll(wgpu::Maintain::Wait);
119 map_rx.recv().unwrap().unwrap();
120
121 let view = slice.get_mapped_range();
123 let local: Vec<f32> = bytemuck::cast_slice(&view).to_vec();
124 drop(view);
125 req.staging_buf.unmap();
126
127 let reduced = req.barrier.reduce(local);
129
130 queue.write_buffer(&req.dst_buf, 0, bytemuck::cast_slice(&reduced));
132
133 let _ = req.response_tx.send(());
135 }
136 });
137
138 Self { tx, _handle: handle }
139 }
140
141 pub fn submit(&self, req: CommRequest) {
143 self.tx.send(req).expect("comm thread dead");
144 }
145}
146
147pub struct AllReduceHandle {
153 rx: mpsc::Receiver<()>,
154}
155
156impl AllReduceHandle {
157 pub fn wait(self) {
159 let _ = self.rx.recv();
160 }
161}
162
163pub fn async_allreduce(
170 comm: &CommThread,
171 device: &wgpu::Device,
172 queue: &wgpu::Queue,
173 src_buf: &wgpu::Buffer,
174 dst_buf: wgpu::Buffer,
175 byte_size: u64,
176 barrier: &Arc<CollectiveBarrier>,
177) -> AllReduceHandle {
178 let staging = device.create_buffer(&wgpu::BufferDescriptor {
180 label: Some("allreduce_staging"),
181 size: byte_size,
182 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
183 mapped_at_creation: false,
184 });
185
186 let mut enc = device.create_command_encoder(&Default::default());
188 enc.copy_buffer_to_buffer(src_buf, 0, &staging, 0, byte_size);
189 queue.submit(std::iter::once(enc.finish()));
190
191 let (resp_tx, resp_rx) = mpsc::sync_channel(1);
193 comm.submit(CommRequest {
194 staging_buf: staging,
195 dst_buf,
196 byte_size,
197 barrier: Arc::clone(barrier),
198 response_tx: resp_tx,
199 });
200
201 AllReduceHandle { rx: resp_rx }
202}