Skip to main content

rumus_distributed/
collective.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2//! Async collective operations: AllReduce via dedicated comm threads.
3
4use std::sync::mpsc;
5use std::sync::{Arc, Condvar, Mutex};
6
7use rumus::tensor::Tensor;
8
9// ---------------------------------------------------------------------------
10// CollectiveBarrier — reusable cross-rank synchronization
11// ---------------------------------------------------------------------------
12
13/// Cross-rank barrier for summing f32 vectors.
14///
15/// Shared by TP AllReduce, FSDP Reduce-Scatter, and PP gradient exchange.
16pub struct CollectiveBarrier {
17    pub world_size: usize,
18    state: Mutex<BarrierState>,
19    cvar: Condvar,
20}
21
22struct BarrierState {
23    buffers: Vec<Vec<f32>>,
24    result: Option<Vec<f32>>,
25    read_count: usize,
26}
27
28impl CollectiveBarrier {
29    pub fn new(world_size: usize) -> Self {
30        Self {
31            world_size,
32            state: Mutex::new(BarrierState {
33                buffers: Vec::new(),
34                result: None,
35                read_count: 0,
36            }),
37            cvar: Condvar::new(),
38        }
39    }
40
41    /// Push local data, wait for all ranks, return the reduced (summed + averaged) result.
42    pub fn reduce(&self, local: Vec<f32>) -> Vec<f32> {
43        let mut state = self.state.lock().unwrap();
44
45        state.buffers.push(local);
46
47        if state.buffers.len() == self.world_size {
48            // Last arrival: sum all buffers.
49            let len = state.buffers[0].len();
50            let mut summed = vec![0.0f32; len];
51            for buf in &state.buffers {
52                for (s, &v) in summed.iter_mut().zip(buf.iter()) {
53                    *s += v;
54                }
55            }
56            let n = self.world_size as f32;
57            for v in &mut summed {
58                *v /= n;
59            }
60            state.result = Some(summed);
61            state.read_count = 0;
62            self.cvar.notify_all();
63        } else {
64            state = self.cvar
65                .wait_while(state, |s| s.result.is_none())
66                .unwrap();
67        }
68
69        let result = state.result.as_ref().unwrap().clone();
70        state.read_count += 1;
71        if state.read_count == self.world_size {
72            state.buffers.clear();
73            state.result = None;
74            state.read_count = 0;
75        }
76
77        result
78    }
79}
80
81// ---------------------------------------------------------------------------
82// CommThread — dedicated background thread for async collectives
83// ---------------------------------------------------------------------------
84
85/// Request from the compute thread to the comm thread.
86pub struct CommRequest {
87    pub staging_buf: wgpu::Buffer,
88    pub dst_buf: wgpu::Buffer,
89    pub byte_size: u64,
90    pub barrier: Arc<CollectiveBarrier>,
91    pub response_tx: mpsc::SyncSender<()>,
92}
93
94/// Dedicated communication thread for async AllReduce.
95///
96/// Owns `Arc<Device>` and `Arc<Queue>` for its GPU — can `poll` and `write_buffer`.
97pub struct CommThread {
98    tx: mpsc::SyncSender<CommRequest>,
99    _handle: std::thread::JoinHandle<()>,
100}
101
102impl CommThread {
103    /// Spawn a comm thread for the given device/queue.
104    pub fn spawn(
105        device: Arc<wgpu::Device>,
106        queue: Arc<wgpu::Queue>,
107    ) -> Self {
108        let (tx, rx) = mpsc::sync_channel::<CommRequest>(16);
109
110        let handle = std::thread::spawn(move || {
111            while let Ok(req) = rx.recv() {
112                // Map the staging buffer (blocks this thread only, not compute).
113                let slice = req.staging_buf.slice(..);
114                let (map_tx, map_rx) = mpsc::sync_channel(1);
115                slice.map_async(wgpu::MapMode::Read, move |r| {
116                    let _ = map_tx.send(r);
117                });
118                device.poll(wgpu::Maintain::Wait);
119                map_rx.recv().unwrap().unwrap();
120
121                // Read the data.
122                let view = slice.get_mapped_range();
123                let local: Vec<f32> = bytemuck::cast_slice(&view).to_vec();
124                drop(view);
125                req.staging_buf.unmap();
126
127                // Barrier: sum with all ranks.
128                let reduced = req.barrier.reduce(local);
129
130                // Upload reduced result to the destination buffer.
131                queue.write_buffer(&req.dst_buf, 0, bytemuck::cast_slice(&reduced));
132
133                // Signal completion.
134                let _ = req.response_tx.send(());
135            }
136        });
137
138        Self { tx, _handle: handle }
139    }
140
141    /// Submit a non-blocking AllReduce request.
142    pub fn submit(&self, req: CommRequest) {
143        self.tx.send(req).expect("comm thread dead");
144    }
145}
146
147// ---------------------------------------------------------------------------
148// AsyncAllReduce — high-level non-blocking API
149// ---------------------------------------------------------------------------
150
151/// Non-blocking AllReduce handle.
152pub struct AllReduceHandle {
153    rx: mpsc::Receiver<()>,
154}
155
156impl AllReduceHandle {
157    /// Block until the AllReduce result is available in the destination buffer.
158    pub fn wait(self) {
159        let _ = self.rx.recv();
160    }
161}
162
163/// Submit a non-blocking AllReduce via the comm thread.
164///
165/// 1. Encodes copy from `src_buf` to a staging buffer.
166/// 2. Submits the copy command.
167/// 3. Sends the staging + dst to the comm thread.
168/// 4. Returns a handle the compute thread can `.wait()` on later.
169pub fn async_allreduce(
170    comm: &CommThread,
171    device: &wgpu::Device,
172    queue: &wgpu::Queue,
173    src_buf: &wgpu::Buffer,
174    dst_buf: wgpu::Buffer,
175    byte_size: u64,
176    barrier: &Arc<CollectiveBarrier>,
177) -> AllReduceHandle {
178    // Create staging buffer.
179    let staging = device.create_buffer(&wgpu::BufferDescriptor {
180        label: Some("allreduce_staging"),
181        size: byte_size,
182        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
183        mapped_at_creation: false,
184    });
185
186    // Encode + submit the GPU copy.
187    let mut enc = device.create_command_encoder(&Default::default());
188    enc.copy_buffer_to_buffer(src_buf, 0, &staging, 0, byte_size);
189    queue.submit(std::iter::once(enc.finish()));
190
191    // Send to comm thread (non-blocking from compute thread's perspective).
192    let (resp_tx, resp_rx) = mpsc::sync_channel(1);
193    comm.submit(CommRequest {
194        staging_buf: staging,
195        dst_buf,
196        byte_size,
197        barrier: Arc::clone(barrier),
198        response_tx: resp_tx,
199    });
200
201    AllReduceHandle { rx: resp_rx }
202}