Skip to main content

oxibonsai_model/
multi_gpu.rs

1//! Multi-GPU / multi-device inference utilities.
2//!
3//! This module provides abstractions for device mesh partitioning and
4//! NCCL-style collective operations, implemented over rayon thread pools
5//! as a CPU simulation.  A real GPU backend would swap in NCCL/cuBLAS calls.
6//!
7//! ## Architecture
8//!
9//! ```text
10//!  ┌─────────────────────────────────────────────────┐
11//!  │                 DeviceMesh (tp × pp)             │
12//!  │  ┌──────────┐  ┌──────────┐  ┌──────────┐       │
13//!  │  │ Device 0 │  │ Device 1 │  │ Device 2 │  ...  │
14//!  │  │ (tp=0,   │  │ (tp=1,   │  │ (tp=0,   │       │
15//!  │  │  pp=0)   │  │  pp=0)   │  │  pp=1)   │       │
16//!  │  └──────────┘  └──────────┘  └──────────┘       │
17//!  └─────────────────────────────────────────────────┘
18//!
19//!   NcclCollectives  ─►  all_reduce_sum / all_gather / broadcast …
20//!   partition_weights_column / partition_weights_row  ─►  shards
21//! ```
22
23use rayon::prelude::*;
24
25// ─────────────────────────────────────────────────────────────────────────────
26// DeviceId
27// ─────────────────────────────────────────────────────────────────────────────
28
29/// A logical device identifier (CPU thread group simulating a GPU).
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub struct DeviceId(pub usize);
32
33// ─────────────────────────────────────────────────────────────────────────────
34// DeviceInfo
35// ─────────────────────────────────────────────────────────────────────────────
36
37/// Simulated device capabilities.
38#[derive(Debug, Clone)]
39pub struct DeviceInfo {
40    /// The logical device identifier.
41    pub id: DeviceId,
42    /// Simulated device memory in bytes.
43    pub memory_bytes: usize,
44    /// Simulated number of compute units (analogous to CUDA SMs).
45    pub compute_units: usize,
46    /// Human-readable device name (e.g. "SimGPU-0").
47    pub name: String,
48}
49
50impl DeviceInfo {
51    fn simulated(linear_id: usize) -> Self {
52        Self {
53            id: DeviceId(linear_id),
54            // Simulate 24 GiB per device.
55            memory_bytes: 24 * 1024 * 1024 * 1024,
56            // Simulate 108 SMs per device.
57            compute_units: 108,
58            name: format!("SimGPU-{linear_id}"),
59        }
60    }
61}
62
63// ─────────────────────────────────────────────────────────────────────────────
64// DeviceMesh
65// ─────────────────────────────────────────────────────────────────────────────
66
67/// A 2-D logical device mesh: tensor-parallel dimension × pipeline-parallel dimension.
68///
69/// Devices are stored in row-major order: device at `(tp_rank, pp_rank)` has
70/// linear index `tp_rank + pp_rank * tp_size`.
71pub struct DeviceMesh {
72    devices: Vec<DeviceInfo>,
73    tp_size: usize,
74    pp_size: usize,
75}
76
77impl DeviceMesh {
78    /// Create a 1-D tensor-parallel mesh of `n` simulated devices.
79    pub fn tensor_parallel(n: usize) -> Self {
80        Self::new(n, 1)
81    }
82
83    /// Create a 2-D (`tp_size` × `pp_size`) mesh.
84    ///
85    /// Total device count is `tp_size * pp_size`.
86    pub fn new(tp_size: usize, pp_size: usize) -> Self {
87        let total = tp_size * pp_size;
88        let devices = (0..total).map(DeviceInfo::simulated).collect();
89        Self {
90            devices,
91            tp_size,
92            pp_size,
93        }
94    }
95
96    /// Total number of devices in the mesh.
97    pub fn size(&self) -> usize {
98        self.devices.len()
99    }
100
101    /// Get the device at tensor-parallel rank `tp_rank` and pipeline-parallel rank `pp_rank`.
102    ///
103    /// Returns `None` if either rank is out of bounds.
104    pub fn get(&self, tp_rank: usize, pp_rank: usize) -> Option<&DeviceInfo> {
105        if tp_rank >= self.tp_size || pp_rank >= self.pp_size {
106            return None;
107        }
108        let idx = tp_rank + pp_rank * self.tp_size;
109        self.devices.get(idx)
110    }
111
112    /// All devices in the tensor-parallel group for a given `pp_rank`.
113    ///
114    /// Returns an empty vec if `pp_rank` is out of range.
115    pub fn tp_group(&self, pp_rank: usize) -> Vec<&DeviceInfo> {
116        if pp_rank >= self.pp_size {
117            return Vec::new();
118        }
119        (0..self.tp_size)
120            .filter_map(|tp| self.get(tp, pp_rank))
121            .collect()
122    }
123
124    /// All devices in the pipeline-parallel group for a given `tp_rank`.
125    ///
126    /// Returns an empty vec if `tp_rank` is out of range.
127    pub fn pp_group(&self, tp_rank: usize) -> Vec<&DeviceInfo> {
128        if tp_rank >= self.tp_size {
129            return Vec::new();
130        }
131        (0..self.pp_size)
132            .filter_map(|pp| self.get(tp_rank, pp))
133            .collect()
134    }
135}
136
137// ─────────────────────────────────────────────────────────────────────────────
138// CollectiveResult
139// ─────────────────────────────────────────────────────────────────────────────
140
141/// Result of a collective communication operation.
142#[derive(Debug, Clone)]
143pub struct CollectiveResult {
144    /// The reduced / gathered data.
145    pub data: Vec<f32>,
146    /// Number of devices that participated.
147    pub participating_devices: usize,
148    /// Name tag identifying the operation (e.g. `"all_reduce_sum"`).
149    pub op_name: &'static str,
150}
151
152// ─────────────────────────────────────────────────────────────────────────────
153// NcclCollectives
154// ─────────────────────────────────────────────────────────────────────────────
155
156/// NCCL-style collective operations simulated on the CPU via rayon.
157///
158/// In production these would be replaced by NCCL library calls.
159pub struct NcclCollectives;
160
161impl NcclCollectives {
162    /// All-reduce (sum): element-wise sum of tensors from all participating devices;
163    /// the result is the same on every device.
164    ///
165    /// All shards must have the same length.
166    pub fn all_reduce_sum(shards: &[Vec<f32>]) -> CollectiveResult {
167        let n = shards.first().map(|s| s.len()).unwrap_or(0);
168        let data: Vec<f32> = (0..n)
169            .into_par_iter()
170            .map(|i| shards.iter().map(|s| s[i]).sum::<f32>())
171            .collect();
172        CollectiveResult {
173            data,
174            participating_devices: shards.len(),
175            op_name: "all_reduce_sum",
176        }
177    }
178
179    /// All-reduce (max): element-wise maximum across all device tensors.
180    ///
181    /// All shards must have the same length.
182    pub fn all_reduce_max(shards: &[Vec<f32>]) -> CollectiveResult {
183        let n = shards.first().map(|s| s.len()).unwrap_or(0);
184        let data: Vec<f32> = (0..n)
185            .into_par_iter()
186            .map(|i| {
187                shards
188                    .iter()
189                    .map(|s| s[i])
190                    .fold(f32::NEG_INFINITY, f32::max)
191            })
192            .collect();
193        CollectiveResult {
194            data,
195            participating_devices: shards.len(),
196            op_name: "all_reduce_max",
197        }
198    }
199
200    /// All-gather: concatenate tensors from all devices in rank order.
201    pub fn all_gather(shards: &[Vec<f32>]) -> CollectiveResult {
202        let data: Vec<f32> = shards.iter().flat_map(|s| s.iter().copied()).collect();
203        CollectiveResult {
204            data,
205            participating_devices: shards.len(),
206            op_name: "all_gather",
207        }
208    }
209
210    /// Reduce-scatter: sum the global `data` across all ranks, then scatter
211    /// equal-sized shards back to each device.
212    ///
213    /// If `data.len()` is not evenly divisible by `world_size`, the last shard
214    /// will be shorter.
215    pub fn reduce_scatter(data: &[f32], world_size: usize) -> Vec<Vec<f32>> {
216        if world_size == 0 {
217            return Vec::new();
218        }
219        // Here "reduce-scatter" treats each device as already holding its own
220        // portion of the data, and after the reduce step each device gets its
221        // equal shard.  In the simulation we simply split the input:
222        let base = data.len() / world_size;
223        let remainder = data.len() % world_size;
224        (0..world_size)
225            .map(|rank| {
226                let start = rank * base + rank.min(remainder);
227                let end = start + base + if rank < remainder { 1 } else { 0 };
228                data[start..end.min(data.len())].to_vec()
229            })
230            .collect()
231    }
232
233    /// Broadcast: replicate `data` from device 0 to all `world_size` devices.
234    pub fn broadcast(data: &[f32], world_size: usize) -> Vec<Vec<f32>> {
235        (0..world_size).map(|_| data.to_vec()).collect()
236    }
237}
238
239// ─────────────────────────────────────────────────────────────────────────────
240// Weight partition helpers
241// ─────────────────────────────────────────────────────────────────────────────
242
243/// Partition a row-major weight matrix `[rows × cols]` into column-parallel shards.
244///
245/// Splits along the `cols` dimension, giving each device `cols / world_size`
246/// (or `cols / world_size + 1` for the first few devices if not evenly divisible).
247pub fn partition_weights_column(
248    weights: &[f32],
249    rows: usize,
250    cols: usize,
251    world_size: usize,
252) -> Vec<Vec<f32>> {
253    if world_size == 0 {
254        return Vec::new();
255    }
256    let base_cols = cols / world_size;
257    let remainder = cols % world_size;
258    (0..world_size)
259        .map(|rank| {
260            let col_start = rank * base_cols + rank.min(remainder);
261            let shard_cols = base_cols + if rank < remainder { 1 } else { 0 };
262            let mut shard = Vec::with_capacity(rows * shard_cols);
263            for row in 0..rows {
264                let row_base = row * cols;
265                shard.extend_from_slice(
266                    &weights[row_base + col_start..row_base + col_start + shard_cols],
267                );
268            }
269            shard
270        })
271        .collect()
272}
273
274/// Partition a row-major weight matrix `[rows × cols]` into row-parallel shards.
275///
276/// Splits along the `rows` dimension, giving each device a contiguous block of rows.
277pub fn partition_weights_row(
278    weights: &[f32],
279    rows: usize,
280    cols: usize,
281    world_size: usize,
282) -> Vec<Vec<f32>> {
283    if world_size == 0 {
284        return Vec::new();
285    }
286    let base_rows = rows / world_size;
287    let remainder = rows % world_size;
288    (0..world_size)
289        .map(|rank| {
290            let row_start = rank * base_rows + rank.min(remainder);
291            let shard_rows = base_rows + if rank < remainder { 1 } else { 0 };
292            weights[row_start * cols..(row_start + shard_rows) * cols].to_vec()
293        })
294        .collect()
295}
296
297/// Merge column-parallel shards back into a single `[rows × cols]` weight matrix.
298///
299/// Assumes shards are produced by [`partition_weights_column`] with the same `rows`.
300pub fn merge_column_shards(shards: &[Vec<f32>], rows: usize) -> Vec<f32> {
301    if shards.is_empty() || rows == 0 {
302        return Vec::new();
303    }
304    // Each shard: rows × (shard_cols)
305    let total_cols: usize = shards.iter().map(|s| s.len() / rows).sum();
306    let mut result = vec![0.0f32; rows * total_cols];
307
308    let mut col_offset = 0usize;
309    for shard in shards {
310        let shard_cols = shard.len() / rows;
311        for row in 0..rows {
312            let dst_start = row * total_cols + col_offset;
313            let src_start = row * shard_cols;
314            result[dst_start..dst_start + shard_cols]
315                .copy_from_slice(&shard[src_start..src_start + shard_cols]);
316        }
317        col_offset += shard_cols;
318    }
319    result
320}