Skip to main content

entrenar/finetune/
ring_allreduce.rs

1//! Ring AllReduce over TCP for distributed gradient averaging.
2//!
3//! Implements the bandwidth-optimal ring AllReduce algorithm:
4//! 1. **Scatter-reduce**: N-1 rounds, each worker sends 1/N chunk to right neighbor
5//! 2. **All-gather**: N-1 rounds, each worker broadcasts its reduced chunk
6//!
7//! Total data per worker: `2 * (N-1)/N * D * 4` bytes — matches theoretical lower bound.
8//!
9//! # Contract
10//!
11//! C-RING-001: Output equals arithmetic mean of all inputs across workers.
12//! C-RING-001: All workers produce identical output.
13//!
14//! # Two-Worker Degeneration
15//!
16//! For N=2, this degenerates to a simple send-and-average:
17//! - Each worker sends its full vector to the other
18//! - Each worker averages its local vector with the received one
19//!   This is correct but not bandwidth-optimal for N=2 (ring is equivalent).
20
21use std::io::{Read, Write};
22use std::net::TcpStream;
23
24/// A participant in the ring AllReduce.
25///
26/// Each worker holds a TCP connection to its right neighbor (send)
27/// and left neighbor (recv) in the ring topology.
28pub struct RingAllReduceWorker {
29    /// This worker's rank
30    rank: usize,
31    /// Total number of workers
32    world_size: usize,
33    /// TCP stream to right neighbor (rank + 1) % N — for sending
34    send_stream: TcpStream,
35    /// TCP stream from left neighbor (rank - 1) % N — for receiving
36    recv_stream: TcpStream,
37}
38
39impl RingAllReduceWorker {
40    /// Create a new ring worker with pre-established TCP connections.
41    ///
42    /// # Arguments
43    /// * `rank` - This worker's rank in [0, world_size)
44    /// * `world_size` - Total number of workers (N >= 2)
45    /// * `send_stream` - TCP connection to worker (rank + 1) % N
46    /// * `recv_stream` - TCP connection from worker (rank - 1) % N
47    pub fn new(
48        rank: usize,
49        world_size: usize,
50        send_stream: TcpStream,
51        recv_stream: TcpStream,
52    ) -> Self {
53        assert!(world_size >= 2, "ring AllReduce requires >= 2 workers");
54        assert!(rank < world_size, "rank must be < world_size");
55        Self { rank, world_size, send_stream, recv_stream }
56    }
57
58    /// Perform AllReduce on `data`, returning the averaged result.
59    ///
60    /// After this call, `data` contains `(1/N) * sum(all workers' input)`.
61    ///
62    /// # Contract (C-RING-001)
63    ///
64    /// - Output is the arithmetic mean of all inputs
65    /// - All workers produce identical output
66    /// - Data length must be the same on all workers
67    ///
68    /// # Errors
69    ///
70    /// Returns `Err` on TCP I/O failure.
71    pub fn allreduce(&mut self, data: &mut [f32]) -> Result<(), String> {
72        contract_pre_gradient_allreduce!();
73        let n = self.world_size;
74        let d = data.len();
75
76        // Compute chunk boundaries: chunks[i] = (start, len)
77        let chunk_size = d / n;
78        let remainder = d % n;
79        let chunks: Vec<(usize, usize)> = (0..n)
80            .map(|i| {
81                let start = i * chunk_size + i.min(remainder);
82                let len = chunk_size + usize::from(i < remainder);
83                (start, len)
84            })
85            .collect();
86
87        // Find the maximum chunk size for buffer allocation
88        let max_chunk_len = chunks.iter().map(|(_, len)| *len).max().unwrap_or(0);
89        let mut send_buf = vec![0u8; max_chunk_len * 4];
90        let mut recv_buf = vec![0u8; max_chunk_len * 4];
91
92        // ── Phase 1: Scatter-reduce ──────────────────────────────────────
93        // After N-1 rounds, worker w holds sum of chunk w from all workers.
94        for round in 0..(n - 1) {
95            // Chunk index to send (we send the chunk we just reduced)
96            let send_chunk_idx = (self.rank + n - round) % n;
97            let (send_start, send_len) = chunks[send_chunk_idx];
98
99            // Chunk index to receive
100            let recv_chunk_idx = (self.rank + n - round - 1) % n;
101            let (recv_start, recv_len) = chunks[recv_chunk_idx];
102
103            // Serialize chunk to send
104            f32_slice_to_bytes(&data[send_start..send_start + send_len], &mut send_buf);
105
106            // Send and receive simultaneously
107            // (TCP is full-duplex, so this won't deadlock)
108            self.send_stream
109                .write_all(&send_buf[..send_len * 4])
110                .map_err(|e| format!("ring send error (round {round}): {e}"))?;
111            self.recv_stream
112                .read_exact(&mut recv_buf[..recv_len * 4])
113                .map_err(|e| format!("ring recv error (round {round}): {e}"))?;
114
115            // Element-wise add received chunk into local chunk
116            for i in 0..recv_len {
117                let received =
118                    f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
119                data[recv_start + i] += received;
120            }
121        }
122
123        // ── Phase 2: All-gather ──────────────────────────────────────────
124        // After N-1 rounds, all workers hold all reduced chunks.
125        for round in 0..(n - 1) {
126            let send_chunk_idx = (self.rank + n - round + 1) % n;
127            let (send_start, send_len) = chunks[send_chunk_idx];
128
129            let recv_chunk_idx = (self.rank + n - round) % n;
130            let (recv_start, recv_len) = chunks[recv_chunk_idx];
131
132            // Serialize reduced chunk to send
133            f32_slice_to_bytes(&data[send_start..send_start + send_len], &mut send_buf);
134
135            self.send_stream
136                .write_all(&send_buf[..send_len * 4])
137                .map_err(|e| format!("ring allgather send error (round {round}): {e}"))?;
138            self.recv_stream
139                .read_exact(&mut recv_buf[..recv_len * 4])
140                .map_err(|e| format!("ring allgather recv error (round {round}): {e}"))?;
141
142            // Copy received chunk into place (no addition — just replace)
143            for i in 0..recv_len {
144                data[recv_start + i] =
145                    f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
146            }
147        }
148
149        // ── Phase 3: Divide by N to get mean ────────────────────────────
150        let inv_n = 1.0 / n as f32;
151        for x in data.iter_mut() {
152            *x *= inv_n;
153        }
154
155        Ok(())
156    }
157}
158
159/// Convert f32 slice to little-endian bytes.
160fn f32_slice_to_bytes(src: &[f32], dst: &mut [u8]) {
161    for (i, &val) in src.iter().enumerate() {
162        dst[i * 4..(i + 1) * 4].copy_from_slice(&val.to_le_bytes());
163    }
164}
165
166/// Simple AllReduce for exactly 2 workers using direct exchange.
167///
168/// Simpler than the ring algorithm for the N=2 case: each worker sends
169/// its entire vector, receives the other's, and averages.
170///
171/// # Contract (C-RING-001)
172///
173/// Output = (local + remote) / 2
174pub fn allreduce_pair(
175    data: &mut [f32],
176    send_stream: &mut TcpStream,
177    recv_stream: &mut TcpStream,
178) -> Result<(), String> {
179    let byte_len = data.len() * 4;
180    let mut send_buf = vec![0u8; byte_len];
181    let mut recv_buf = vec![0u8; byte_len];
182
183    f32_slice_to_bytes(data, &mut send_buf);
184
185    send_stream.write_all(&send_buf).map_err(|e| format!("pair send error: {e}"))?;
186    recv_stream.read_exact(&mut recv_buf).map_err(|e| format!("pair recv error: {e}"))?;
187
188    for i in 0..data.len() {
189        let remote = f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
190        data[i] = (data[i] + remote) * 0.5;
191    }
192
193    Ok(())
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use std::net::TcpListener;
200    use std::thread;
201
202    /// Set up a ring of N workers connected via TCP loopback.
203    /// Returns N `RingAllReduceWorker` instances.
204    fn setup_ring(n: usize) -> Vec<RingAllReduceWorker> {
205        // Create N listeners
206        let listeners: Vec<TcpListener> =
207            (0..n).map(|_| TcpListener::bind("127.0.0.1:0").expect("bind")).collect();
208        let addrs: Vec<_> = listeners.iter().map(|l| l.local_addr().expect("addr")).collect();
209
210        // Each worker connects to its right neighbor
211        // Worker w sends to (w+1)%N, receives from (w-1)%N
212        let mut send_streams = Vec::with_capacity(n);
213        let mut recv_streams = Vec::with_capacity(n);
214
215        // Spawn accept threads for each listener
216        let accept_handles: Vec<_> = listeners
217            .into_iter()
218            .map(|listener| {
219                thread::spawn(move || {
220                    let (stream, _) = listener.accept().expect("accept");
221                    stream
222                })
223            })
224            .collect();
225
226        // Connect: worker w connects to listener (w+1)%N
227        for w in 0..n {
228            let right = (w + 1) % n;
229            let stream = TcpStream::connect(addrs[right]).expect("connect");
230            stream.set_nodelay(true).ok();
231            send_streams.push(stream);
232        }
233
234        // Collect accepted connections (each listener accepted from (w-1)%N)
235        for handle in accept_handles {
236            let stream = handle.join().expect("accept thread");
237            stream.set_nodelay(true).ok();
238            recv_streams.push(stream);
239        }
240
241        // Build workers
242        let mut workers = Vec::with_capacity(n);
243        for w in 0..n {
244            workers.push(RingAllReduceWorker::new(
245                w,
246                n,
247                send_streams.remove(0),
248                recv_streams.remove(0),
249            ));
250        }
251        workers
252    }
253
254    #[test]
255    fn test_ring_allreduce_2_workers_identical() {
256        let mut workers = setup_ring(2);
257
258        let data0 = vec![1.0f32, 2.0, 3.0];
259        let data1 = vec![1.0f32, 2.0, 3.0];
260
261        let mut d0 = data0.clone();
262        let mut w1 = workers.pop().unwrap();
263        let mut d1 = data1.clone();
264
265        let h1 = thread::spawn(move || {
266            w1.allreduce(&mut d1).expect("allreduce w1");
267            d1
268        });
269
270        workers[0].allreduce(&mut d0).expect("allreduce w0");
271        let result1 = h1.join().expect("join w1");
272
273        // Identical inputs → output equals input
274        for (&v, &expected) in d0.iter().zip(&[1.0, 2.0, 3.0]) {
275            assert!((v - expected).abs() < 1e-6, "w0: {v} != {expected}");
276        }
277        for (&v, &expected) in result1.iter().zip(&[1.0, 2.0, 3.0]) {
278            assert!((v - expected).abs() < 1e-6, "w1: {v} != {expected}");
279        }
280    }
281
282    #[test]
283    fn test_ring_allreduce_2_workers_distinct() {
284        let mut workers = setup_ring(2);
285
286        let mut d0 = vec![2.0f32, 4.0, 6.0];
287        let mut d1 = vec![8.0f32, 6.0, 4.0];
288        // Expected: (2+8)/2=5, (4+6)/2=5, (6+4)/2=5
289
290        let mut w1 = workers.pop().unwrap();
291
292        let h1 = thread::spawn(move || {
293            w1.allreduce(&mut d1).expect("allreduce w1");
294            d1
295        });
296
297        workers[0].allreduce(&mut d0).expect("allreduce w0");
298        let result1 = h1.join().expect("join w1");
299
300        for &v in &d0 {
301            assert!((v - 5.0).abs() < 1e-6, "w0: {v} != 5.0");
302        }
303        for &v in &result1 {
304            assert!((v - 5.0).abs() < 1e-6, "w1: {v} != 5.0");
305        }
306    }
307
308    #[test]
309    fn test_ring_allreduce_3_workers() {
310        let mut workers = setup_ring(3);
311
312        let mut d0 = vec![1.0f32, 0.0, 0.0];
313        let mut d1 = vec![0.0f32, 1.0, 0.0];
314        let mut d2 = vec![0.0f32, 0.0, 1.0];
315        // Expected: [1/3, 1/3, 1/3]
316
317        let mut w2 = workers.pop().unwrap();
318        let mut w1 = workers.pop().unwrap();
319
320        let h2 = thread::spawn(move || {
321            w2.allreduce(&mut d2).expect("allreduce w2");
322            d2
323        });
324        let h1 = thread::spawn(move || {
325            w1.allreduce(&mut d1).expect("allreduce w1");
326            d1
327        });
328
329        workers[0].allreduce(&mut d0).expect("allreduce w0");
330        let r1 = h1.join().expect("join w1");
331        let r2 = h2.join().expect("join w2");
332
333        let expected = 1.0 / 3.0;
334        for &v in &d0 {
335            assert!((v - expected).abs() < 1e-5, "w0: {v} != {expected}");
336        }
337        for &v in &r1 {
338            assert!((v - expected).abs() < 1e-5, "w1: {v} != {expected}");
339        }
340        for &v in &r2 {
341            assert!((v - expected).abs() < 1e-5, "w2: {v} != {expected}");
342        }
343    }
344
345    #[test]
346    fn test_ring_allreduce_non_divisible_length() {
347        // D=7, N=3 → chunks of [3, 2, 2]
348        let mut workers = setup_ring(3);
349
350        let mut d0 = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
351        let mut d1 = vec![7.0f32, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
352        let mut d2 = vec![0.0f32; 7];
353
354        let mut w2 = workers.pop().unwrap();
355        let mut w1 = workers.pop().unwrap();
356
357        let h2 = thread::spawn(move || {
358            w2.allreduce(&mut d2).expect("allreduce");
359            d2
360        });
361        let h1 = thread::spawn(move || {
362            w1.allreduce(&mut d1).expect("allreduce");
363            d1
364        });
365        workers[0].allreduce(&mut d0).expect("allreduce");
366        let r1 = h1.join().expect("join");
367        let r2 = h2.join().expect("join");
368
369        // Expected: (d0 + d1 + d2) / 3
370        let expected: Vec<f32> =
371            vec![8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0];
372        for (i, (&v, &e)) in d0.iter().zip(&expected).enumerate() {
373            assert!((v - e).abs() < 1e-5, "w0[{i}]: {v} != {e}");
374        }
375        assert_eq!(d0, r1, "w0 == w1");
376        assert_eq!(d0, r2, "w0 == w2");
377    }
378
379    #[test]
380    fn test_ring_allreduce_large_vector() {
381        let mut workers = setup_ring(2);
382        let d = 100_000;
383        let mut d0: Vec<f32> = (0..d).map(|i| i as f32).collect();
384        let mut d1: Vec<f32> = (0..d).map(|i| (d - 1 - i) as f32).collect();
385        // d0[i] + d1[i] = d-1 for all i, so average = (d-1)/2
386
387        let mut w1 = workers.pop().unwrap();
388
389        let h1 = thread::spawn(move || {
390            w1.allreduce(&mut d1).expect("allreduce");
391            d1
392        });
393        workers[0].allreduce(&mut d0).expect("allreduce");
394        let r1 = h1.join().expect("join");
395
396        let expected = (d as f32 - 1.0) / 2.0;
397        for (i, &v) in d0.iter().enumerate() {
398            assert!((v - expected).abs() < 1e-2, "w0[{i}]: {v} != {expected}");
399        }
400        assert_eq!(d0, r1, "results must be identical");
401    }
402
403    #[test]
404    fn test_allreduce_pair() {
405        let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
406        let addr = listener.local_addr().expect("addr");
407
408        let h = thread::spawn(move || {
409            let (recv, _) = listener.accept().expect("accept");
410            let send = TcpStream::connect(addr).expect("connect");
411            // This won't work — need bidirectional setup
412            // For the pair test, use a simpler setup
413            (recv, send)
414        });
415
416        // Simplified pair test using same connection both ways
417        let listener_a = TcpListener::bind("127.0.0.1:0").expect("bind");
418        let listener_b = TcpListener::bind("127.0.0.1:0").expect("bind");
419        let addr_a = listener_a.local_addr().expect("addr");
420        let addr_b = listener_b.local_addr().expect("addr");
421        drop(h);
422
423        let ha = thread::spawn(move || {
424            let send = TcpStream::connect(addr_b).expect("connect to b");
425            let (recv, _) = listener_a.accept().expect("accept from b");
426            (send, recv)
427        });
428
429        let send_b = TcpStream::connect(addr_a).expect("connect to a");
430        let (recv_b, _) = listener_b.accept().expect("accept from a");
431
432        let (mut send_a, mut recv_a) = ha.join().expect("join");
433        let mut send_b = send_b;
434        let mut recv_b = recv_b;
435
436        let mut d_a = vec![10.0f32, 20.0, 30.0];
437        let mut d_b = vec![30.0f32, 20.0, 10.0];
438
439        let hb = thread::spawn(move || {
440            allreduce_pair(&mut d_b, &mut send_b, &mut recv_b).expect("pair b");
441            d_b
442        });
443
444        allreduce_pair(&mut d_a, &mut send_a, &mut recv_a).expect("pair a");
445        let result_b = hb.join().expect("join");
446
447        assert_eq!(d_a, vec![20.0, 20.0, 20.0]);
448        assert_eq!(result_b, vec![20.0, 20.0, 20.0]);
449    }
450}