#![cfg(target_os = "macos")]
use std::ffi::CString;
use std::sync::Arc;
use rlx_ir::Shape;
use crate::array::{Array, MlxError, check};
use crate::ffi::{self, MlxDtype, mlx_array_t};
use crate::op_registry::{MlxKernel, register_mlx_kernel};
const KERNEL_MSL: &str = include_str!("../cpp/kernels/batched_lu_solve.metal");
pub const KERNEL_NAME: &str = "rlx_linalg_batched_dense_solve_metal_f32";
pub struct BatchedLuSolveMetal;
impl MlxKernel for BatchedLuSolveMetal {
fn name(&self) -> &str {
KERNEL_NAME
}
fn execute(
&self,
inputs: &[&Array],
output_shape: &Shape,
_attrs: &[u8],
) -> Result<Array, MlxError> {
if inputs.len() != 2 {
return Err(MlxError(format!(
"{KERNEL_NAME}: expected 2 inputs (A, b), got {}",
inputs.len()
)));
}
let out_dims = output_shape.dims();
if out_dims.len() != 2 {
return Err(MlxError(format!(
"{KERNEL_NAME}: expected rank-2 output [B, n], got rank {}",
out_dims.len()
)));
}
let b_dim = out_dims[0].unwrap_static();
let n_dim = out_dims[1].unwrap_static();
let nmax = next_pow2(n_dim).max(8);
if nmax > 128 {
return Err(MlxError(format!(
"{KERNEL_NAME}: n={n_dim} exceeds NMAX cap of 128 \
(threadgroup memory bound at f32). Lowering should \
fall back to MLX-CPU `ops::solve` for n in this range."
)));
}
let header = format!("#define NMAX {nmax}\n");
let name_c = CString::new(KERNEL_NAME).unwrap();
let source_c = CString::new(KERNEL_MSL).unwrap();
let header_c = CString::new(header).unwrap();
let in_a_c = CString::new("A").unwrap();
let in_b_c = CString::new("b").unwrap();
let out_c = CString::new("x").unwrap();
let in_name_ptrs: [*const std::os::raw::c_char; 2] = [in_a_c.as_ptr(), in_b_c.as_ptr()];
let in_array_ptrs: [*mut mlx_array_t; 2] = [inputs[0].ptr, inputs[1].ptr];
let out_shape_i32: [std::os::raw::c_int; 2] =
[b_dim as std::os::raw::c_int, n_dim as std::os::raw::c_int];
let mut out_handle: *mut mlx_array_t = std::ptr::null_mut();
let rc = unsafe {
ffi::rlx_mlx_op_metal_kernel_dispatch(
name_c.as_ptr(),
source_c.as_ptr(),
header_c.as_ptr(),
in_name_ptrs.as_ptr(),
2,
out_c.as_ptr(),
in_array_ptrs.as_ptr(),
out_shape_i32.as_ptr(),
2,
MlxDtype::F32,
(b_dim * n_dim) as std::os::raw::c_int,
1,
1,
n_dim as std::os::raw::c_int,
1,
1,
&mut out_handle,
)
};
check(rc)?;
Ok(Array::from_raw(out_handle))
}
}
fn next_pow2(n: usize) -> usize {
if n <= 1 {
return 1;
}
let mut p = 1usize;
while p < n {
p <<= 1;
}
p
}
pub fn register() {
register_mlx_kernel(Arc::new(BatchedLuSolveMetal));
}
#[cfg(test)]
mod tests {
use super::*;
use crate::op_registry::lookup_mlx_kernel;
#[test]
fn registration_round_trips() {
register();
let k = lookup_mlx_kernel(KERNEL_NAME).expect("kernel must be findable after register()");
assert_eq!(k.name(), KERNEL_NAME);
}
#[test]
fn msl_source_present() {
assert!(KERNEL_MSL.contains("threadgroup float Aloc"));
assert!(KERNEL_MSL.contains("threadgroup_position_in_grid"));
assert!(KERNEL_MSL.contains("Doolittle LU"));
}
}