1use crate::device_mesh::{DType, DeviceMesh, MeshComm, ReduceOp};
7use std::sync::Arc;
8
9pub type DistResult<T> = Result<T, String>;
11
12pub struct DistributedOps<'a> {
14 mesh: &'a DeviceMesh,
15}
16
17impl<'a> DistributedOps<'a> {
18 pub fn new(mesh: &'a DeviceMesh) -> Self {
20 Self { mesh }
21 }
22
23 fn comm(&self) -> DistResult<&Arc<dyn MeshComm + Send + Sync>> {
25 self.mesh
26 .comm
27 .as_ref()
28 .ok_or_else(|| "No communication backend configured".to_string())
29 }
30
31 pub fn all_reduce_f32(&self, data: &mut [f32], op: ReduceOp, group: &str) -> DistResult<()> {
35 let comm = self.comm()?;
36
37 let byte_slice =
39 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
40
41 comm.all_reduce(byte_slice, DType::Float32, op, group)
42 }
43
44 pub fn all_reduce_i32(&self, data: &mut [i32], op: ReduceOp, group: &str) -> DistResult<()> {
46 let comm = self.comm()?;
47
48 let byte_slice =
49 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
50
51 comm.all_reduce(byte_slice, DType::Int32, op, group)
52 }
53
54 pub fn broadcast_f32(&self, data: &mut [f32], root_rank: usize, group: &str) -> DistResult<()> {
56 let comm = self.comm()?;
57
58 let byte_slice =
59 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
60
61 comm.broadcast(byte_slice, root_rank, group)
62 }
63
64 pub fn all_gather_f32(&self, local: &[f32], output: &mut [f32], group: &str) -> DistResult<()> {
69 let comm = self.comm()?;
70
71 let local_bytes =
72 unsafe { std::slice::from_raw_parts(local.as_ptr() as *const u8, local.len() * 4) };
73
74 let output_bytes = unsafe {
75 std::slice::from_raw_parts_mut(output.as_mut_ptr() as *mut u8, output.len() * 4)
76 };
77
78 comm.all_gather(local_bytes, output_bytes, DType::Float32, group)
79 }
80
81 pub fn scatter_f32(&self, data: &[f32], chunk: &mut [f32], root_rank: usize) -> DistResult<()> {
86 let comm = self.comm()?;
87
88 if self.mesh.rank == root_rank {
90 let data_bytes =
92 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4) };
93
94 let mut broadcast_buf = data_bytes.to_vec();
96 comm.broadcast(&mut broadcast_buf, root_rank, "world")?;
97
98 let chunk_size = chunk.len();
100 let offset = self.mesh.rank * chunk_size;
101 for (i, val) in chunk.iter_mut().enumerate() {
102 let src_idx = offset + i;
103 if src_idx < data.len() {
104 *val = data[src_idx];
105 }
106 }
107 } else {
108 let total_size = data.len();
110 let mut broadcast_buf = vec![0u8; total_size * 4];
111 comm.broadcast(&mut broadcast_buf, root_rank, "world")?;
112
113 let chunk_size = chunk.len();
115 let offset = self.mesh.rank * chunk_size;
116 for (i, val) in chunk.iter_mut().enumerate() {
117 let idx = (offset + i) * 4;
118 if idx + 4 <= broadcast_buf.len() {
119 let bytes: [u8; 4] = broadcast_buf[idx..idx + 4].try_into().unwrap();
120 *val = f32::from_le_bytes(bytes);
121 }
122 }
123 }
124
125 Ok(())
126 }
127
128 pub fn gather_f32(
133 &self,
134 local: &[f32],
135 output: &mut [f32],
136 root_rank: usize,
137 ) -> DistResult<()> {
138 self.all_gather_f32(local, output, "world")?;
140
141 if self.mesh.rank != root_rank {
144 }
147
148 Ok(())
149 }
150
151 pub fn reduce_scatter_f32(
156 &self,
157 data: &mut [f32],
158 output: &mut [f32],
159 op: ReduceOp,
160 group: &str,
161 ) -> DistResult<()> {
162 let comm = self.comm()?;
163
164 let data_bytes =
165 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
166
167 let output_bytes = unsafe {
168 std::slice::from_raw_parts_mut(output.as_mut_ptr() as *mut u8, output.len() * 4)
169 };
170
171 comm.reduce_scatter(data_bytes, output_bytes, op, group)
172 }
173
174 pub fn barrier(&self, group: &str) -> DistResult<()> {
176 let comm = self.comm()?;
177 comm.barrier(group)
178 }
179
180 pub fn send_f32(&self, data: &[f32], dest_rank: usize) -> DistResult<()> {
182 let comm = self.comm()?;
183 let bytes =
184 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4) };
185 comm.send(bytes, dest_rank)
186 }
187
188 pub fn recv_f32(&self, data: &mut [f32], src_rank: usize) -> DistResult<()> {
190 let comm = self.comm()?;
191 let bytes =
192 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
193 comm.recv(bytes, src_rank)
194 }
195}
196
197pub fn dist_ops(mesh: &DeviceMesh) -> DistributedOps<'_> {
199 DistributedOps::new(mesh)
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::device::{Device, DeviceBackend};
206 use crate::device_mesh::DeviceMesh;
207
208 fn create_test_devices(count: usize) -> Vec<Device> {
209 (0..count)
210 .map(|i| Device {
211 id: i,
212 name: format!("GPU_{}", i),
213 backend: DeviceBackend::Cuda,
214 memory_mb: 16000,
215 compute_units: 80,
216 pci_bus_id: None,
217 partition_id: None,
218 driver_version: None,
219 compute_capability: None,
220 utilization_gpu_pct: None,
221 temperature_c: None,
222 supports_fp16: true,
223 supports_int8: true,
224 cuda_version: Some("12.0".to_string()),
225 })
226 .collect()
227 }
228
229 #[test]
230 fn test_all_reduce_f32_single_rank() {
231 let devices = create_test_devices(1);
232 let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
233
234 let ops = dist_ops(&mesh);
235 let mut data = vec![1.0f32, 2.0, 3.0, 4.0];
236
237 let result = ops.all_reduce_f32(&mut data, ReduceOp::Sum, "world");
238 assert!(result.is_ok());
239
240 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
242 }
243
244 #[test]
245 fn test_broadcast_f32() {
246 let devices = create_test_devices(2);
247 let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
248
249 let ops = dist_ops(&mesh);
250 let mut data = vec![42.0f32, 24.0];
251
252 let result = ops.broadcast_f32(&mut data, 0, "world");
253 assert!(result.is_ok());
254 }
255
256 #[test]
257 fn test_barrier() {
258 let devices = create_test_devices(4);
259 let mesh = DeviceMesh::new_with_mock_comm(devices, 0);
260
261 let ops = dist_ops(&mesh);
262 let result = ops.barrier("world");
263 assert!(result.is_ok());
264 }
265
266 #[test]
267 fn test_send_recv_f32() {
268 use crate::mock_comm::MockComm;
269 use std::sync::{Arc, RwLock};
270
271 let state = Arc::new(RwLock::new(crate::mock_comm::MockCommState::new(2)));
273
274 let devices = create_test_devices(2);
276 let mut mesh0 = DeviceMesh::new(devices.clone());
277 mesh0.rank = 0;
278 mesh0.comm = Some(Arc::new(MockComm::with_shared_state(0, state.clone())));
279
280 let mut mesh1 = DeviceMesh::new(devices);
281 mesh1.rank = 1;
282 mesh1.comm = Some(Arc::new(MockComm::with_shared_state(1, state)));
283
284 let ops0 = dist_ops(&mesh0);
286 let send_data = vec![1.0f32, 2.0, 3.0];
287 ops0.send_f32(&send_data, 1).unwrap();
288
289 let ops1 = dist_ops(&mesh1);
291 let mut recv_data = vec![0.0f32; 3];
292 ops1.recv_f32(&mut recv_data, 0).unwrap();
293
294 assert_eq!(recv_data, send_data);
295 }
296
297 #[test]
298 fn test_no_comm_backend_error() {
299 let devices = create_test_devices(2);
300 let mesh = DeviceMesh::new(devices); let ops = dist_ops(&mesh);
303 let mut data = vec![1.0f32];
304
305 let result = ops.all_reduce_f32(&mut data, ReduceOp::Sum, "world");
306 assert!(result.is_err());
307 assert!(result.unwrap_err().contains("No communication backend"));
308 }
309}