use crate::ffi;
use rlx_driver::{CollectiveError, Transport};
use std::ffi::{CStr, CString};
use std::os::raw::c_int;
fn last_error() -> String {
unsafe {
let p = ffi::rlx_mlx_last_error();
if p.is_null() {
"(null)".to_string()
} else {
CStr::from_ptr(p).to_string_lossy().into_owned()
}
}
}
fn terr(ctx: &str) -> CollectiveError {
CollectiveError::TransportError {
reason: format!("mlx {ctx}: {}", last_error()),
}
}
pub struct MlxTransport {
rank: u32,
world: u32,
}
impl MlxTransport {
pub fn is_available() -> bool {
let mut out: c_int = 0;
let rc = unsafe { ffi::rlx_mlx_dist_is_available(&mut out) };
rc == ffi::RLX_MLX_OK && out != 0
}
pub fn init(strict: bool, backend: &str) -> Result<Self, CollectiveError> {
let cb = CString::new(backend).map_err(|e| CollectiveError::TransportError {
reason: e.to_string(),
})?;
let mut rank: c_int = 0;
let mut size: c_int = 0;
let rc =
unsafe { ffi::rlx_mlx_dist_init(strict as c_int, cb.as_ptr(), &mut rank, &mut size) };
if rc != ffi::RLX_MLX_OK {
return Err(terr("init"));
}
Ok(Self {
rank: rank as u32,
world: size as u32,
})
}
pub fn all_sum(&self, data: &mut [f32]) -> Result<(), CollectiveError> {
if self.world <= 1 {
return Ok(());
}
let mut out = vec![0f32; data.len()];
let rc =
unsafe { ffi::rlx_mlx_dist_all_sum_f32(data.as_ptr(), out.as_mut_ptr(), data.len()) };
if rc != ffi::RLX_MLX_OK {
return Err(terr("all_sum"));
}
data.copy_from_slice(&out);
Ok(())
}
pub fn all_gather(&self, local: &[f32]) -> Result<Vec<f32>, CollectiveError> {
if self.world <= 1 {
return Ok(local.to_vec());
}
let total = local.len() * self.world as usize;
let mut out = vec![0f32; total];
let rc = unsafe {
ffi::rlx_mlx_dist_all_gather_f32(
local.as_ptr(),
local.len(),
out.as_mut_ptr(),
out.len(),
)
};
if rc != ffi::RLX_MLX_OK {
return Err(terr("all_gather"));
}
Ok(out)
}
fn raw_send(&self, dst: u32, data: &[f32]) -> Result<(), CollectiveError> {
let rc = unsafe { ffi::rlx_mlx_dist_send_f32(data.as_ptr(), data.len(), dst as c_int) };
if rc != ffi::RLX_MLX_OK {
return Err(terr("send"));
}
Ok(())
}
fn raw_recv(&self, src: u32, out: &mut [f32]) -> Result<(), CollectiveError> {
let rc = unsafe { ffi::rlx_mlx_dist_recv_f32(out.as_mut_ptr(), out.len(), src as c_int) };
if rc != ffi::RLX_MLX_OK {
return Err(terr("recv"));
}
Ok(())
}
}
impl Transport for MlxTransport {
fn rank(&self) -> u32 {
self.rank
}
fn world_size(&self) -> u32 {
self.world
}
fn send_bytes(&self, to: u32, _tag: u32, bytes: &[u8]) -> Result<(), CollectiveError> {
let len = bytes.len();
let hdr = [(len / 4096) as f32, (len % 4096) as f32];
self.raw_send(to, &hdr)?;
let nf = len.div_ceil(4);
if nf > 0 {
let mut buf = vec![0f32; nf];
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), buf.as_mut_ptr() as *mut u8, len);
}
self.raw_send(to, &buf)?;
}
Ok(())
}
fn recv_bytes(&self, from: u32, _tag: u32) -> Result<Vec<u8>, CollectiveError> {
let mut hdr = [0f32; 2];
self.raw_recv(from, &mut hdr)?;
let len = (hdr[0] as usize) * 4096 + (hdr[1] as usize);
let nf = len.div_ceil(4);
let mut out = vec![0u8; len];
if nf > 0 {
let mut buf = vec![0f32; nf];
self.raw_recv(from, &mut buf)?;
unsafe {
std::ptr::copy_nonoverlapping(buf.as_ptr() as *const u8, out.as_mut_ptr(), len);
}
}
Ok(out)
}
fn barrier(&self) -> Result<(), CollectiveError> {
if self.world <= 1 {
return Ok(());
}
let rc = unsafe { ffi::rlx_mlx_dist_barrier() };
if rc != ffi::RLX_MLX_OK {
return Err(terr("barrier"));
}
Ok(())
}
}
use crate::array::{Array, check};
use crate::op_registry::{MlxKernel, register_mlx_kernel};
use rlx_ir::Shape as IrShape;
use rlx_ir::op_registry::{OpExtension, register_op};
use std::sync::Arc;
pub const ALL_REDUCE: &str = "collective.all_reduce";
struct AllReduceExt;
impl OpExtension for AllReduceExt {
fn name(&self) -> &str {
ALL_REDUCE
}
fn num_inputs(&self) -> usize {
1
}
fn infer_shape(&self, inputs: &[&IrShape], _attrs: &[u8]) -> IrShape {
inputs[0].clone()
}
}
struct AllReduceMlx;
impl MlxKernel for AllReduceMlx {
fn name(&self) -> &str {
ALL_REDUCE
}
fn execute(
&self,
inputs: &[&Array],
_output_shape: &IrShape,
_attrs: &[u8],
) -> Result<Array, crate::array::MlxError> {
let mut out: *mut ffi::mlx_array_t = std::ptr::null_mut();
let rc = unsafe { ffi::rlx_mlx_dist_all_sum_array(inputs[0].ptr, &mut out) };
check(rc)?;
Ok(Array::from_raw(out))
}
}
pub fn register_collective() {
register_op(Arc::new(AllReduceExt));
register_mlx_kernel(Arc::new(AllReduceMlx));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn is_available_links_and_runs() {
let _ = MlxTransport::is_available();
}
#[test]
fn init_singleton_group_without_launcher() {
let t = MlxTransport::init(false, "any").expect("singleton init");
assert_eq!(t.world_size(), 1);
assert_eq!(t.rank(), 0);
let mut data = vec![1.0f32, 2.0, 3.0];
t.all_sum(&mut data).unwrap();
assert_eq!(
data,
vec![1.0, 2.0, 3.0],
"all_sum on singleton is identity"
);
let gathered = t.all_gather(&[7.0, 8.0]).unwrap();
assert_eq!(gathered, vec![7.0, 8.0]);
use rlx_driver::Transport as _;
assert_eq!(<MlxTransport as rlx_driver::Transport>::world_size(&t), 1);
}
#[test]
fn device_resident_all_reduce_runs_on_mlx() {
use crate::backend::MlxExecutable;
use rlx_ir::{DType, Graph, Shape};
register_collective();
let _ = MlxTransport::init(false, "any");
let mut g = Graph::new("device_all_reduce");
let x = g.input("x", Shape::new(&[4], DType::F32));
let y = g.custom_op(ALL_REDUCE, vec![], vec![x]);
g.set_outputs(vec![y]);
let mut exe = MlxExecutable::compile(g);
let out = exe.run(&[("x", &[1.0f32, 2.0, 3.0, 4.0])]);
assert_eq!(out.len(), 1);
assert_eq!(out[0], vec![1.0, 2.0, 3.0, 4.0]);
}
}