use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use rlx_ir::{DType, Shape};
macro_rules! dtype_variants {
(
$(
$variant:ident => $rust_ty:ty,
$as_method:ident, $as_mut_method:ident,
$expect_method:ident, $expect_mut_method:ident,
)*
) => {
pub enum CpuTensorRef<'a> {
$(
$variant { data: &'a [$rust_ty], shape: &'a Shape },
)*
}
pub enum CpuTensorMut<'a> {
$(
$variant { data: &'a mut [$rust_ty], shape: &'a Shape },
)*
}
impl<'a> CpuTensorRef<'a> {
pub fn shape(&self) -> &Shape {
match self {
$( Self::$variant { shape, .. } => shape, )*
}
}
pub fn dtype(&self) -> DType { self.shape().dtype() }
$(
pub fn $as_method(&self) -> Option<&[$rust_ty]> {
if let Self::$variant { data, .. } = self { Some(data) } else { None }
}
pub fn $expect_method(&self, role: &str) -> Result<&[$rust_ty], String> {
self.$as_method().ok_or_else(|| format!(
"{role}: expected {:?}, got {:?}",
DType::$variant, self.dtype()))
}
)*
}
impl<'a> CpuTensorMut<'a> {
pub fn shape(&self) -> &Shape {
match self {
$( Self::$variant { shape, .. } => shape, )*
}
}
pub fn dtype(&self) -> DType { self.shape().dtype() }
$(
pub fn $as_mut_method(self) -> Option<&'a mut [$rust_ty]> {
if let Self::$variant { data, .. } = self { Some(data) } else { None }
}
pub fn $expect_mut_method(self, role: &str) -> Result<&'a mut [$rust_ty], String> {
let dt = self.dtype();
self.$as_mut_method().ok_or_else(|| format!(
"{role}: expected {:?}, got {dt:?}", DType::$variant))
}
)*
}
};
}
dtype_variants! {
F32 => f32, as_f32, as_f32_mut, expect_f32, expect_f32_mut,
F64 => f64, as_f64, as_f64_mut, expect_f64, expect_f64_mut,
F16 => half::f16, as_f16, as_f16_mut, expect_f16, expect_f16_mut,
BF16 => half::bf16, as_bf16, as_bf16_mut, expect_bf16, expect_bf16_mut,
I8 => i8, as_i8, as_i8_mut, expect_i8, expect_i8_mut,
I16 => i16, as_i16, as_i16_mut, expect_i16, expect_i16_mut,
I32 => i32, as_i32, as_i32_mut, expect_i32, expect_i32_mut,
I64 => i64, as_i64, as_i64_mut, expect_i64, expect_i64_mut,
U8 => u8, as_u8, as_u8_mut, expect_u8, expect_u8_mut,
U32 => u32, as_u32, as_u32_mut, expect_u32, expect_u32_mut,
Bool => u8, as_bool, as_bool_mut, expect_bool, expect_bool_mut,
}
pub trait CpuKernel: Send + Sync {
fn name(&self) -> &str;
fn execute(
&self,
inputs: &[CpuTensorRef<'_>],
output: CpuTensorMut<'_>,
attrs: &[u8],
) -> Result<(), String>;
}
pub struct CpuKernelRegistry {
kernels: RwLock<HashMap<String, Arc<dyn CpuKernel>>>,
}
impl CpuKernelRegistry {
pub fn new() -> Self {
Self {
kernels: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, k: Arc<dyn CpuKernel>) {
let name = k.name().to_string();
let mut g = self.kernels.write().unwrap();
if g.contains_key(&name) {
eprintln!(
"rlx-cpu: CpuKernel '{name}' was already registered — \
replacing the previous entry"
);
}
g.insert(name, k);
}
pub fn lookup(&self, name: &str) -> Option<Arc<dyn CpuKernel>> {
self.kernels.read().unwrap().get(name).cloned()
}
}
impl Default for CpuKernelRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn global_cpu_kernels() -> &'static CpuKernelRegistry {
static R: OnceLock<CpuKernelRegistry> = OnceLock::new();
R.get_or_init(CpuKernelRegistry::new)
}
pub fn register_cpu_kernel(k: Arc<dyn CpuKernel>) {
global_cpu_kernels().register(k);
}
pub fn lookup_cpu_kernel(name: &str) -> Option<Arc<dyn CpuKernel>> {
global_cpu_kernels().lookup(name)
}