#![cfg(target_os = "macos")]
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use rlx_ir::Shape;
use crate::array::{Array, MlxError};
pub trait MlxKernel: Send + Sync {
fn name(&self) -> &str;
fn execute(
&self,
inputs: &[&Array],
output_shape: &Shape,
attrs: &[u8],
) -> Result<Array, MlxError>;
}
pub struct MlxKernelRegistry {
kernels: RwLock<HashMap<String, Arc<dyn MlxKernel>>>,
}
impl MlxKernelRegistry {
pub fn new() -> Self {
Self {
kernels: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, k: Arc<dyn MlxKernel>) {
let name = k.name().to_string();
let mut g = self.kernels.write().unwrap();
if g.contains_key(&name) {
eprintln!(
"rlx-mlx: MlxKernel '{name}' was already registered — \
replacing the previous entry"
);
}
g.insert(name, k);
}
pub fn lookup(&self, name: &str) -> Option<Arc<dyn MlxKernel>> {
self.kernels.read().unwrap().get(name).cloned()
}
}
impl Default for MlxKernelRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn global_mlx_kernels() -> &'static MlxKernelRegistry {
static R: OnceLock<MlxKernelRegistry> = OnceLock::new();
R.get_or_init(MlxKernelRegistry::new)
}
pub fn register_mlx_kernel(k: Arc<dyn MlxKernel>) {
global_mlx_kernels().register(k);
}
pub fn lookup_mlx_kernel(name: &str) -> Option<Arc<dyn MlxKernel>> {
global_mlx_kernels().lookup(name)
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::DType;
struct StubKernel;
impl MlxKernel for StubKernel {
fn name(&self) -> &str {
"stub.mlx"
}
fn execute(
&self,
inputs: &[&Array],
_output_shape: &Shape,
_attrs: &[u8],
) -> Result<Array, MlxError> {
inputs[0].clone_handle()
}
}
#[test]
fn register_and_lookup_round_trips() {
let reg = MlxKernelRegistry::new();
reg.register(Arc::new(StubKernel));
let k = reg
.lookup("stub.mlx")
.expect("registered kernel must be findable");
assert_eq!(k.name(), "stub.mlx");
}
#[test]
fn execute_signature_compiles_and_runs() {
let k: Arc<dyn MlxKernel> = Arc::new(StubKernel);
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = Array::from_f32_slice(&data, &[4], DType::F32).expect("input array");
let out_shape = Shape::new(&[4], DType::F32);
let result = k
.execute(&[&input], &out_shape, &[])
.expect("stub kernel must succeed");
let result_data = result.to_f32().expect("readback");
assert_eq!(result_data, data, "stub clones input — values must match");
}
}