#![cfg(all(feature = "mlx", target_os = "macos"))]
use std::sync::Arc;
use rlx_ir::Shape;
use rlx_mlx::MlxError;
use rlx_mlx::array::Array;
use rlx_mlx::op_registry::{MlxKernel, register_mlx_kernel};
use crate::ops::UMAP_KNN;
#[derive(Debug)]
struct KnnForwardMlx;
impl MlxKernel for KnnForwardMlx {
fn name(&self) -> &str {
UMAP_KNN
}
fn execute(
&self,
_inputs: &[&Array],
_output_shape: &Shape,
_attrs: &[u8],
) -> Result<Array, MlxError> {
Err(MlxError(
"umap.knn on MLX: use rlx_umap::session::cosine_knn_mlx \
(MLX pairwise + CPU k-NN) — host k-NN cannot run inside mlx::compile"
.into(),
))
}
}
pub fn register_mlx_kernels() {
register_mlx_kernel(Arc::new(KnnForwardMlx));
}