use std::collections::HashMap;
use std::ffi::c_void;
use std::os::raw::c_int;
use std::ptr;
use rlx_ir::{Graph, NodeId};
use crate::array::{Array, MlxError, check};
use crate::ffi::{self, RLX_MLX_OK, mlx_array_t, mlx_compiled_t};
use crate::lower::{self, LeafKey, lower_with_env};
struct CompiledState {
graph: Graph,
leaf_order: Vec<(NodeId, LeafKey)>,
}
pub struct CompiledFn {
handle: *mut mlx_compiled_t,
_state: Box<CompiledState>,
leaf_order: Vec<(NodeId, LeafKey)>,
}
unsafe impl Send for CompiledFn {}
impl CompiledFn {
pub fn compile(graph: Graph) -> Result<Self, MlxError> {
let leaf_order = lower::leaf_order(&graph);
let state = Box::new(CompiledState {
graph,
leaf_order: leaf_order.clone(),
});
let mut handle: *mut mlx_compiled_t = ptr::null_mut();
let rc = unsafe {
ffi::rlx_mlx_compile(
lower_callback,
state.as_ref() as *const CompiledState as *mut c_void,
0,
&mut handle,
)
};
check(rc)?;
Ok(Self {
handle,
_state: state,
leaf_order,
})
}
pub fn leaf_order(&self) -> &[(NodeId, LeafKey)] {
&self.leaf_order
}
pub fn invoke(&self, inputs: &[Array]) -> Result<Vec<Array>, MlxError> {
if inputs.len() != self.leaf_order.len() {
return Err(MlxError(format!(
"CompiledFn: expected {} leaves, got {}",
self.leaf_order.len(),
inputs.len()
)));
}
let in_handles: Vec<*mut mlx_array_t> = inputs.iter().map(|a| a.ptr).collect();
const CAP: usize = 64;
let mut out_handles: Vec<*mut mlx_array_t> = vec![ptr::null_mut(); CAP];
let mut n_out: usize = 0;
let rc = unsafe {
ffi::rlx_mlx_compiled_call(
self.handle,
in_handles.as_ptr(),
in_handles.len(),
out_handles.as_mut_ptr(),
CAP,
&mut n_out,
)
};
check(rc)?;
out_handles.truncate(n_out);
Ok(out_handles.into_iter().map(Array::from_raw).collect())
}
}
impl Drop for CompiledFn {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe {
ffi::rlx_mlx_compiled_free(self.handle);
}
self.handle = ptr::null_mut();
}
}
}
unsafe extern "C" fn lower_callback(
ud: *mut c_void,
inputs: *const *mut mlx_array_t,
n_inputs: usize,
out_outputs: *mut *mut mlx_array_t,
cap: usize,
out_n_outputs: *mut usize,
) -> c_int {
let state: &CompiledState = unsafe { &*(ud as *const CompiledState) };
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
run_callback(state, inputs, n_inputs, out_outputs, cap, out_n_outputs)
}));
match result {
Ok(Ok(())) => RLX_MLX_OK,
Ok(Err(e)) => {
let msg = format!("rlx-mlx compile callback: {e}");
if let Ok(c) = std::ffi::CString::new(msg) {
unsafe {
ffi::rlx_mlx_set_last_error(c.as_ptr());
}
}
1
}
Err(_) => {
let c = std::ffi::CString::new("rlx-mlx compile callback panicked").unwrap();
unsafe {
ffi::rlx_mlx_set_last_error(c.as_ptr());
}
1
}
}
}
fn run_callback(
state: &CompiledState,
inputs: *const *mut mlx_array_t,
n_inputs: usize,
out_outputs: *mut *mut mlx_array_t,
cap: usize,
out_n_outputs: *mut usize,
) -> Result<(), MlxError> {
if n_inputs != state.leaf_order.len() {
return Err(MlxError(format!(
"compile callback: leaf count mismatch ({} vs {})",
n_inputs,
state.leaf_order.len()
)));
}
let in_slice = unsafe { std::slice::from_raw_parts(inputs, n_inputs) };
let mut env: HashMap<NodeId, Array> = HashMap::with_capacity(state.graph.nodes().len());
for ((id, _key), &ptr) in state.leaf_order.iter().zip(in_slice) {
env.insert(*id, Array::from_raw(ptr));
}
let empty_params: HashMap<String, Vec<f32>> = HashMap::new();
let empty_typed: HashMap<String, (Vec<u8>, rlx_ir::DType)> = HashMap::new();
let outs = lower_with_env(&state.graph, env, &empty_params, &empty_typed)?;
if outs.len() > cap {
return Err(MlxError(format!(
"compile callback: {} outputs exceeds cap {}",
outs.len(),
cap
)));
}
let count = outs.len();
for (i, arr) in outs.into_iter().enumerate() {
unsafe {
*out_outputs.add(i) = arr.ptr;
}
std::mem::forget(arr);
}
unsafe {
*out_n_outputs = count;
}
Ok(())
}