oxibonsai_model/multi_gpu.rs
1//! Multi-GPU / multi-device inference utilities.
2//!
3//! This module provides abstractions for device mesh partitioning and
4//! NCCL-style collective operations, implemented over rayon thread pools
5//! as a CPU simulation. A real GPU backend would swap in NCCL/cuBLAS calls.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────┐
11//! │ DeviceMesh (tp × pp) │
12//! │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
13//! │ │ Device 0 │ │ Device 1 │ │ Device 2 │ ... │
14//! │ │ (tp=0, │ │ (tp=1, │ │ (tp=0, │ │
15//! │ │ pp=0) │ │ pp=0) │ │ pp=1) │ │
16//! │ └──────────┘ └──────────┘ └──────────┘ │
17//! └─────────────────────────────────────────────────┘
18//!
19//! NcclCollectives ─► all_reduce_sum / all_gather / broadcast …
20//! partition_weights_column / partition_weights_row ─► shards
21//! ```
22
23use rayon::prelude::*;
24
25// ─────────────────────────────────────────────────────────────────────────────
26// DeviceId
27// ─────────────────────────────────────────────────────────────────────────────
28
29/// A logical device identifier (CPU thread group simulating a GPU).
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub struct DeviceId(pub usize);
32
33// ─────────────────────────────────────────────────────────────────────────────
34// DeviceInfo
35// ─────────────────────────────────────────────────────────────────────────────
36
37/// Simulated device capabilities.
38#[derive(Debug, Clone)]
39pub struct DeviceInfo {
40 /// The logical device identifier.
41 pub id: DeviceId,
42 /// Simulated device memory in bytes.
43 pub memory_bytes: usize,
44 /// Simulated number of compute units (analogous to CUDA SMs).
45 pub compute_units: usize,
46 /// Human-readable device name (e.g. "SimGPU-0").
47 pub name: String,
48}
49
50impl DeviceInfo {
51 fn simulated(linear_id: usize) -> Self {
52 Self {
53 id: DeviceId(linear_id),
54 // Simulate 24 GiB per device.
55 memory_bytes: 24 * 1024 * 1024 * 1024,
56 // Simulate 108 SMs per device.
57 compute_units: 108,
58 name: format!("SimGPU-{linear_id}"),
59 }
60 }
61}
62
63// ─────────────────────────────────────────────────────────────────────────────
64// DeviceMesh
65// ─────────────────────────────────────────────────────────────────────────────
66
67/// A 2-D logical device mesh: tensor-parallel dimension × pipeline-parallel dimension.
68///
69/// Devices are stored in row-major order: device at `(tp_rank, pp_rank)` has
70/// linear index `tp_rank + pp_rank * tp_size`.
71pub struct DeviceMesh {
72 devices: Vec<DeviceInfo>,
73 tp_size: usize,
74 pp_size: usize,
75}
76
77impl DeviceMesh {
78 /// Create a 1-D tensor-parallel mesh of `n` simulated devices.
79 pub fn tensor_parallel(n: usize) -> Self {
80 Self::new(n, 1)
81 }
82
83 /// Create a 2-D (`tp_size` × `pp_size`) mesh.
84 ///
85 /// Total device count is `tp_size * pp_size`.
86 pub fn new(tp_size: usize, pp_size: usize) -> Self {
87 let total = tp_size * pp_size;
88 let devices = (0..total).map(DeviceInfo::simulated).collect();
89 Self {
90 devices,
91 tp_size,
92 pp_size,
93 }
94 }
95
96 /// Total number of devices in the mesh.
97 pub fn size(&self) -> usize {
98 self.devices.len()
99 }
100
101 /// Get the device at tensor-parallel rank `tp_rank` and pipeline-parallel rank `pp_rank`.
102 ///
103 /// Returns `None` if either rank is out of bounds.
104 pub fn get(&self, tp_rank: usize, pp_rank: usize) -> Option<&DeviceInfo> {
105 if tp_rank >= self.tp_size || pp_rank >= self.pp_size {
106 return None;
107 }
108 let idx = tp_rank + pp_rank * self.tp_size;
109 self.devices.get(idx)
110 }
111
112 /// All devices in the tensor-parallel group for a given `pp_rank`.
113 ///
114 /// Returns an empty vec if `pp_rank` is out of range.
115 pub fn tp_group(&self, pp_rank: usize) -> Vec<&DeviceInfo> {
116 if pp_rank >= self.pp_size {
117 return Vec::new();
118 }
119 (0..self.tp_size)
120 .filter_map(|tp| self.get(tp, pp_rank))
121 .collect()
122 }
123
124 /// All devices in the pipeline-parallel group for a given `tp_rank`.
125 ///
126 /// Returns an empty vec if `tp_rank` is out of range.
127 pub fn pp_group(&self, tp_rank: usize) -> Vec<&DeviceInfo> {
128 if tp_rank >= self.tp_size {
129 return Vec::new();
130 }
131 (0..self.pp_size)
132 .filter_map(|pp| self.get(tp_rank, pp))
133 .collect()
134 }
135}
136
137// ─────────────────────────────────────────────────────────────────────────────
138// CollectiveResult
139// ─────────────────────────────────────────────────────────────────────────────
140
141/// Result of a collective communication operation.
142#[derive(Debug, Clone)]
143pub struct CollectiveResult {
144 /// The reduced / gathered data.
145 pub data: Vec<f32>,
146 /// Number of devices that participated.
147 pub participating_devices: usize,
148 /// Name tag identifying the operation (e.g. `"all_reduce_sum"`).
149 pub op_name: &'static str,
150}
151
152// ─────────────────────────────────────────────────────────────────────────────
153// NcclCollectives
154// ─────────────────────────────────────────────────────────────────────────────
155
156/// NCCL-style collective operations simulated on the CPU via rayon.
157///
158/// In production these would be replaced by NCCL library calls.
159pub struct NcclCollectives;
160
161impl NcclCollectives {
162 /// All-reduce (sum): element-wise sum of tensors from all participating devices;
163 /// the result is the same on every device.
164 ///
165 /// All shards must have the same length.
166 pub fn all_reduce_sum(shards: &[Vec<f32>]) -> CollectiveResult {
167 let n = shards.first().map(|s| s.len()).unwrap_or(0);
168 let data: Vec<f32> = (0..n)
169 .into_par_iter()
170 .map(|i| shards.iter().map(|s| s[i]).sum::<f32>())
171 .collect();
172 CollectiveResult {
173 data,
174 participating_devices: shards.len(),
175 op_name: "all_reduce_sum",
176 }
177 }
178
179 /// All-reduce (max): element-wise maximum across all device tensors.
180 ///
181 /// All shards must have the same length.
182 pub fn all_reduce_max(shards: &[Vec<f32>]) -> CollectiveResult {
183 let n = shards.first().map(|s| s.len()).unwrap_or(0);
184 let data: Vec<f32> = (0..n)
185 .into_par_iter()
186 .map(|i| {
187 shards
188 .iter()
189 .map(|s| s[i])
190 .fold(f32::NEG_INFINITY, f32::max)
191 })
192 .collect();
193 CollectiveResult {
194 data,
195 participating_devices: shards.len(),
196 op_name: "all_reduce_max",
197 }
198 }
199
200 /// All-gather: concatenate tensors from all devices in rank order.
201 pub fn all_gather(shards: &[Vec<f32>]) -> CollectiveResult {
202 let data: Vec<f32> = shards.iter().flat_map(|s| s.iter().copied()).collect();
203 CollectiveResult {
204 data,
205 participating_devices: shards.len(),
206 op_name: "all_gather",
207 }
208 }
209
210 /// Reduce-scatter: sum the global `data` across all ranks, then scatter
211 /// equal-sized shards back to each device.
212 ///
213 /// If `data.len()` is not evenly divisible by `world_size`, the last shard
214 /// will be shorter.
215 pub fn reduce_scatter(data: &[f32], world_size: usize) -> Vec<Vec<f32>> {
216 if world_size == 0 {
217 return Vec::new();
218 }
219 // Here "reduce-scatter" treats each device as already holding its own
220 // portion of the data, and after the reduce step each device gets its
221 // equal shard. In the simulation we simply split the input:
222 let base = data.len() / world_size;
223 let remainder = data.len() % world_size;
224 (0..world_size)
225 .map(|rank| {
226 let start = rank * base + rank.min(remainder);
227 let end = start + base + if rank < remainder { 1 } else { 0 };
228 data[start..end.min(data.len())].to_vec()
229 })
230 .collect()
231 }
232
233 /// Broadcast: replicate `data` from device 0 to all `world_size` devices.
234 pub fn broadcast(data: &[f32], world_size: usize) -> Vec<Vec<f32>> {
235 (0..world_size).map(|_| data.to_vec()).collect()
236 }
237}
238
239// ─────────────────────────────────────────────────────────────────────────────
240// Weight partition helpers
241// ─────────────────────────────────────────────────────────────────────────────
242
243/// Partition a row-major weight matrix `[rows × cols]` into column-parallel shards.
244///
245/// Splits along the `cols` dimension, giving each device `cols / world_size`
246/// (or `cols / world_size + 1` for the first few devices if not evenly divisible).
247pub fn partition_weights_column(
248 weights: &[f32],
249 rows: usize,
250 cols: usize,
251 world_size: usize,
252) -> Vec<Vec<f32>> {
253 if world_size == 0 {
254 return Vec::new();
255 }
256 let base_cols = cols / world_size;
257 let remainder = cols % world_size;
258 (0..world_size)
259 .map(|rank| {
260 let col_start = rank * base_cols + rank.min(remainder);
261 let shard_cols = base_cols + if rank < remainder { 1 } else { 0 };
262 let mut shard = Vec::with_capacity(rows * shard_cols);
263 for row in 0..rows {
264 let row_base = row * cols;
265 shard.extend_from_slice(
266 &weights[row_base + col_start..row_base + col_start + shard_cols],
267 );
268 }
269 shard
270 })
271 .collect()
272}
273
274/// Partition a row-major weight matrix `[rows × cols]` into row-parallel shards.
275///
276/// Splits along the `rows` dimension, giving each device a contiguous block of rows.
277pub fn partition_weights_row(
278 weights: &[f32],
279 rows: usize,
280 cols: usize,
281 world_size: usize,
282) -> Vec<Vec<f32>> {
283 if world_size == 0 {
284 return Vec::new();
285 }
286 let base_rows = rows / world_size;
287 let remainder = rows % world_size;
288 (0..world_size)
289 .map(|rank| {
290 let row_start = rank * base_rows + rank.min(remainder);
291 let shard_rows = base_rows + if rank < remainder { 1 } else { 0 };
292 weights[row_start * cols..(row_start + shard_rows) * cols].to_vec()
293 })
294 .collect()
295}
296
297/// Merge column-parallel shards back into a single `[rows × cols]` weight matrix.
298///
299/// Assumes shards are produced by [`partition_weights_column`] with the same `rows`.
300pub fn merge_column_shards(shards: &[Vec<f32>], rows: usize) -> Vec<f32> {
301 if shards.is_empty() || rows == 0 {
302 return Vec::new();
303 }
304 // Each shard: rows × (shard_cols)
305 let total_cols: usize = shards.iter().map(|s| s.len() / rows).sum();
306 let mut result = vec![0.0f32; rows * total_cols];
307
308 let mut col_offset = 0usize;
309 for shard in shards {
310 let shard_cols = shard.len() / rows;
311 for row in 0..rows {
312 let dst_start = row * total_cols + col_offset;
313 let src_start = row * shard_cols;
314 result[dst_start..dst_start + shard_cols]
315 .copy_from_slice(&shard[src_start..src_start + shard_cols]);
316 }
317 col_offset += shard_cols;
318 }
319 result
320}