Skip to main content

kapsl_hal/
distributed_ops.rs

1//! Distributed Operations for Tensor Computations
2//!
3//! High-level distributed tensor operations that use the `MeshComm` trait.
4//! These operations enable distributed inference across multiple devices.
5
6use crate::device_mesh::{DType, DeviceMesh, MeshComm, ReduceOp};
7use std::sync::Arc;
8
9/// Result type for distributed operations
10pub type DistResult<T> = Result<T, String>;
11
12/// Distributed tensor operations
13pub struct DistributedOps<'a> {
14    mesh: &'a DeviceMesh,
15}
16
17impl<'a> DistributedOps<'a> {
18    /// Create a new distributed operations context
19    pub fn new(mesh: &'a DeviceMesh) -> Self {
20        Self { mesh }
21    }
22
23    /// Get the communication backend
24    fn comm(&self) -> DistResult<&Arc<dyn MeshComm + Send + Sync>> {
25        self.mesh
26            .comm
27            .as_ref()
28            .ok_or_else(|| "No communication backend configured".to_string())
29    }
30
31    /// All-reduce a f32 tensor across all ranks in a group
32    ///
33    /// After this operation, all ranks will have the same reduced values.
34    pub fn all_reduce_f32(&self, data: &mut [f32], op: ReduceOp, group: &str) -> DistResult<()> {
35        let comm = self.comm()?;
36
37        // Convert f32 slice to bytes for the communication layer
38        let byte_slice =
39            unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
40
41        comm.all_reduce(byte_slice, DType::Float32, op, group)
42    }
43
44    /// All-reduce a i32 tensor across all ranks in a group
45    pub fn all_reduce_i32(&self, data: &mut [i32], op: ReduceOp, group: &str) -> DistResult<()> {
46        let comm = self.comm()?;
47
48        let byte_slice =
49            unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
50
51        comm.all_reduce(byte_slice, DType::Int32, op, group)
52    }
53
54    /// Broadcast f32 tensor from root rank to all ranks in group
55    pub fn broadcast_f32(&self, data: &mut [f32], root_rank: usize, group: &str) -> DistResult<()> {
56        let comm = self.comm()?;
57
58        let byte_slice =
59            unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
60
61        comm.broadcast(byte_slice, root_rank, group)
62    }
63
64    /// All-gather f32 tensors from all ranks
65    ///
66    /// Each rank contributes `local.len()` elements, and receives
67    /// `local.len() * world_size` elements in the output.
68    pub fn all_gather_f32(&self, local: &[f32], output: &mut [f32], group: &str) -> DistResult<()> {
69        let comm = self.comm()?;
70
71        let local_bytes =
72            unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, local.len() * 4) };
73
74        let output_bytes = unsafe {
75            std::slice::from_raw_parts_mut(output.as_mut_ptr() as *mut u8, output.len() * 4)
76        };
77
78        comm.all_gather(local_bytes, output_bytes, DType::Float32, group)
79    }
80
81    /// Scatter a tensor: divide data among ranks
82    ///
83    /// Only the root rank's `data` is used for input. After this operation,
84    /// each rank's `chunk` will contain its portion of the data.
85    pub fn scatter_f32(&self, data: &[f32], chunk: &mut [f32], root_rank: usize) -> DistResult<()> {
86        let comm = self.comm()?;
87
88        // Root broadcasts, then each rank picks its chunk
89        if self.mesh.rank == root_rank {
90            // Convert to bytes for broadcast
91            let data_bytes =
92                unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4) };
93
94            // Broadcast the entire data
95            let mut broadcast_buf = data_bytes.to_vec();
96            comm.broadcast(&mut broadcast_buf, root_rank, "world")?;
97
98            // Extract our chunk
99            let chunk_size = chunk.len();
100            let offset = self.mesh.rank * chunk_size;
101            for (i, val) in chunk.iter_mut().enumerate() {
102                let src_idx = offset + i;
103                if src_idx < data.len() {
104                    *val = data[src_idx];
105                }
106            }
107        } else {
108            // Non-root receives broadcast and extracts chunk
109            let total_size = data.len();
110            let mut broadcast_buf = vec![0u8; total_size * 4];
111            comm.broadcast(&mut broadcast_buf, root_rank, "world")?;
112
113            // Convert bytes back to f32 and extract our chunk
114            let chunk_size = chunk.len();
115            let offset = self.mesh.rank * chunk_size;
116            for (i, val) in chunk.iter_mut().enumerate() {
117                let idx = (offset + i) * 4;
118                if idx + 4 <= broadcast_buf.len() {
119                    let bytes: [u8; 4] = broadcast_buf[idx..idx + 4].try_into().unwrap();
120                    *val = f32::from_le_bytes(bytes);
121                }
122            }
123        }
124
125        Ok(())
126    }
127
128    /// Gather tensors from all ranks to root
129    ///
130    /// Each rank's `local` data is gathered to the root rank's `output`.
131    /// Only the root rank's output will contain the complete gathered data.
132    pub fn gather_f32(
133        &self,
134        local: &[f32],
135        output: &mut [f32],
136        root_rank: usize,
137    ) -> DistResult<()> {
138        // Use all_gather first, then root takes the result
139        self.all_gather_f32(local, output, "world")?;
140
141        // For non-root ranks, output is partially filled but that's okay
142        // since only root is expected to use the complete result
143        if self.mesh.rank != root_rank {
144            // Clear output for non-root ranks (optional, for clarity)
145            // In practice, callers should only use root's output
146        }
147
148        Ok(())
149    }
150
151    /// Reduce-scatter: reduce and distribute results
152    ///
153    /// Combines reduction and scatter in one operation.
154    /// After this operation, each rank has a portion of the reduced result.
155    pub fn reduce_scatter_f32(
156        &self,
157        data: &mut [f32],
158        output: &mut [f32],
159        op: ReduceOp,
160        group: &str,
161    ) -> DistResult<()> {
162        let comm = self.comm()?;
163
164        let data_bytes =
165            unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
166
167        let output_bytes = unsafe {
168            std::slice::from_raw_parts_mut(output.as_mut_ptr() as *mut u8, output.len() * 4)
169        };
170
171        comm.reduce_scatter(data_bytes, output_bytes, op, group)
172    }
173
174    /// Barrier synchronization across all ranks in a group
175    pub fn barrier(&self, group: &str) -> DistResult<()> {
176        let comm = self.comm()?;
177        comm.barrier(group)
178    }
179
180    /// Point-to-point send of f32 tensor
181    pub fn send_f32(&self, data: &[f32], dest_rank: usize) -> DistResult<()> {
182        let comm = self.comm()?;
183        let bytes =
184            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4) };
185        comm.send(bytes, dest_rank)
186    }
187
188    /// Point-to-point receive of f32 tensor
189    pub fn recv_f32(&self, data: &mut [f32], src_rank: usize) -> DistResult<()> {
190        let comm = self.comm()?;
191        let bytes =
192            unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
193        comm.recv(bytes, src_rank)
194    }
195}
196
197/// Convenience function to create distributed ops from a mesh
198pub fn dist_ops(mesh: &DeviceMesh) -> DistributedOps<'_> {
199    DistributedOps::new(mesh)
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::device::{Device, DeviceBackend};
206    use crate::device_mesh::DeviceMesh;
207
208    fn create_test_devices(count: usize) -> Vec<Device> {
209        (0..count)
210            .map(|i| Device {
211                id: i,
212                name: format!("GPU_{}", i),
213                backend: DeviceBackend::Cuda,
214                memory_mb: 16000,
215                compute_units: 80,
216                pci_bus_id: None,
217                partition_id: None,
218                driver_version: None,
219                compute_capability: None,
220                utilization_gpu_pct: None,
221                temperature_c: None,
222                supports_fp16: true,
223                supports_int8: true,
224                cuda_version: Some("12.0".to_string()),
225            })
226            .collect()
227    }
228
229    #[test]
230    fn test_all_reduce_f32_single_rank() {
231        let devices = create_test_devices(1);
232        let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
233
234        let ops = dist_ops(&mesh);
235        let mut data = vec![1.0f32, 2.0, 3.0, 4.0];
236
237        let result = ops.all_reduce_f32(&mut data, ReduceOp::Sum, "world");
238        assert!(result.is_ok());
239
240        // Single rank, data unchanged
241        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
242    }
243
244    #[test]
245    fn test_broadcast_f32() {
246        let devices = create_test_devices(2);
247        let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
248
249        let ops = dist_ops(&mesh);
250        let mut data = vec![42.0f32, 24.0];
251
252        let result = ops.broadcast_f32(&mut data, 0, "world");
253        assert!(result.is_ok());
254    }
255
256    #[test]
257    fn test_barrier() {
258        let devices = create_test_devices(4);
259        let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
260
261        let ops = dist_ops(&mesh);
262        let result = ops.barrier("world");
263        assert!(result.is_ok());
264    }
265
266    #[test]
267    fn test_send_recv_f32() {
268        use crate::mock_comm::MockComm;
269        use std::sync::{Arc, RwLock};
270
271        // Create shared state for 2 ranks
272        let state = Arc::new(RwLock::new(crate::mock_comm::MockCommState::new(2)));
273
274        // Create two meshes with shared comm state
275        let devices = create_test_devices(2);
276        let mut mesh0 = DeviceMesh::new(devices.clone());
277        mesh0.rank = 0;
278        mesh0.comm = Some(Arc::new(MockComm::with_shared_state(0, state.clone())));
279
280        let mut mesh1 = DeviceMesh::new(devices);
281        mesh1.rank = 1;
282        mesh1.comm = Some(Arc::new(MockComm::with_shared_state(1, state)));
283
284        // Rank 0 sends to rank 1
285        let ops0 = dist_ops(&mesh0);
286        let send_data = vec![1.0f32, 2.0, 3.0];
287        ops0.send_f32(&send_data, 1).unwrap();
288
289        // Rank 1 receives from rank 0
290        let ops1 = dist_ops(&mesh1);
291        let mut recv_data = vec![0.0f32; 3];
292        ops1.recv_f32(&mut recv_data, 0).unwrap();
293
294        assert_eq!(recv_data, send_data);
295    }
296
297    #[test]
298    fn test_no_comm_backend_error() {
299        let devices = create_test_devices(2);
300        let mesh = DeviceMesh::new(devices); // No mock comm attached
301
302        let ops = dist_ops(&mesh);
303        let mut data = vec![1.0f32];
304
305        let result = ops.all_reduce_f32(&mut data, ReduceOp::Sum, "world");
306        assert!(result.is_err());
307        assert!(result.unwrap_err().contains("No communication backend"));
308    }
309}