use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::tensor::Tensor;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(u8)]
pub enum DispatchKey {
Cpu = 0,
Cuda = 1,
Meta = 2,
Sparse = 3,
Quantized = 4,
Nested = 5,
Autocast = 6,
Autograd = 7,
Vmap = 8,
Profiler = 9,
Tracer = 10,
}
impl DispatchKey {
#[inline]
pub fn priority(self) -> u8 {
self as u8
}
pub const ALL: [DispatchKey; 11] = [
DispatchKey::Cpu,
DispatchKey::Cuda,
DispatchKey::Meta,
DispatchKey::Sparse,
DispatchKey::Quantized,
DispatchKey::Nested,
DispatchKey::Autocast,
DispatchKey::Autograd,
DispatchKey::Vmap,
DispatchKey::Profiler,
DispatchKey::Tracer,
];
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DispatchKeySet {
bits: u16,
}
impl DispatchKeySet {
#[inline]
pub const fn empty() -> Self {
Self { bits: 0 }
}
pub fn all() -> Self {
let mut set = Self::empty();
for &k in &DispatchKey::ALL {
set = set.insert(k);
}
set
}
pub fn from_iter<I: IntoIterator<Item = DispatchKey>>(keys: I) -> Self {
let mut set = Self::empty();
for k in keys {
set = set.insert(k);
}
set
}
#[inline]
pub fn contains(self, key: DispatchKey) -> bool {
(self.bits >> key.priority()) & 1 != 0
}
#[inline]
#[must_use]
pub fn insert(self, key: DispatchKey) -> Self {
Self {
bits: self.bits | (1 << key.priority()),
}
}
#[inline]
#[must_use]
pub fn remove(self, key: DispatchKey) -> Self {
Self {
bits: self.bits & !(1 << key.priority()),
}
}
#[inline]
#[must_use]
pub fn union(self, other: Self) -> Self {
Self {
bits: self.bits | other.bits,
}
}
#[inline]
#[must_use]
pub fn intersection(self, other: Self) -> Self {
Self {
bits: self.bits & other.bits,
}
}
#[inline]
pub fn is_empty(self) -> bool {
self.bits == 0
}
#[inline]
pub fn len(self) -> usize {
self.bits.count_ones() as usize
}
pub fn highest(self) -> Option<DispatchKey> {
if self.bits == 0 {
return None;
}
for &k in DispatchKey::ALL.iter().rev() {
if self.contains(k) {
return Some(k);
}
}
None
}
pub fn iter_desc(self) -> impl Iterator<Item = DispatchKey> {
let mut bits = self.bits;
std::iter::from_fn(move || {
if bits == 0 {
return None;
}
let top = 15 - bits.leading_zeros() as u8;
bits &= !(1 << top);
DispatchKey::ALL.iter().find(|k| k.priority() == top).copied()
})
}
}
impl Default for DispatchKeySet {
fn default() -> Self {
Self::empty()
}
}
impl<const N: usize> From<[DispatchKey; N]> for DispatchKeySet {
fn from(arr: [DispatchKey; N]) -> Self {
Self::from_iter(arr)
}
}
pub type Kernel<T> = Box<
dyn Fn(&[Tensor<T>], DispatchKeySet, &Dispatcher<T>) -> FerrotorchResult<Tensor<T>>
+ Send
+ Sync,
>;
pub struct Dispatcher<T: Float> {
kernels: HashMap<(String, DispatchKey), Kernel<T>>,
}
impl<T: Float> Dispatcher<T> {
pub fn new() -> Self {
Self {
kernels: HashMap::new(),
}
}
pub fn register<F>(&mut self, op_name: impl Into<String>, key: DispatchKey, kernel: F)
where
F: Fn(&[Tensor<T>], DispatchKeySet, &Dispatcher<T>) -> FerrotorchResult<Tensor<T>>
+ Send
+ Sync
+ 'static,
{
self.kernels.insert((op_name.into(), key), Box::new(kernel));
}
pub fn has_kernel(&self, op_name: &str, key: DispatchKey) -> bool {
self.kernels.contains_key(&(op_name.to_string(), key))
}
pub fn kernel_count(&self) -> usize {
self.kernels.len()
}
pub fn call(
&self,
op_name: &str,
inputs: &[Tensor<T>],
keyset: DispatchKeySet,
) -> FerrotorchResult<Tensor<T>> {
if keyset.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Dispatcher::call({op_name}): empty keyset — no backend to run on"
),
});
}
for key in keyset.iter_desc() {
if let Some(kernel) = self.kernels.get(&(op_name.to_string(), key)) {
return kernel(inputs, keyset, self);
}
}
Err(FerrotorchError::InvalidArgument {
message: format!(
"Dispatcher::call({op_name}): no kernel registered for any key in {keyset:?}"
),
})
}
pub fn call_direct(
&self,
op_name: &str,
inputs: &[Tensor<T>],
keyset: DispatchKeySet,
key: DispatchKey,
) -> FerrotorchResult<Tensor<T>> {
match self.kernels.get(&(op_name.to_string(), key)) {
Some(kernel) => kernel(inputs, keyset, self),
None => Err(FerrotorchError::InvalidArgument {
message: format!(
"Dispatcher::call_direct({op_name}, {key:?}): no kernel registered"
),
}),
}
}
}
impl<T: Float> Default for Dispatcher<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> std::fmt::Debug for Dispatcher<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Dispatcher")
.field("kernel_count", &self.kernels.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
#[test]
fn dispatch_key_priority_ordering() {
assert!(DispatchKey::Tracer.priority() > DispatchKey::Autograd.priority());
assert!(DispatchKey::Autograd.priority() > DispatchKey::Autocast.priority());
assert!(DispatchKey::Autocast.priority() > DispatchKey::Cpu.priority());
assert!(DispatchKey::Cuda.priority() > DispatchKey::Cpu.priority());
}
#[test]
fn dispatch_key_all_contains_every_key() {
assert_eq!(DispatchKey::ALL.len(), 11);
for k in &DispatchKey::ALL {
let count = DispatchKey::ALL.iter().filter(|&other| other == k).count();
assert_eq!(count, 1, "duplicate key {k:?}");
}
}
#[test]
fn dispatch_key_set_empty() {
let set = DispatchKeySet::empty();
assert!(set.is_empty());
assert_eq!(set.len(), 0);
assert_eq!(set.highest(), None);
assert!(!set.contains(DispatchKey::Cpu));
}
#[test]
fn dispatch_key_set_insert_and_contains() {
let set = DispatchKeySet::empty()
.insert(DispatchKey::Cpu)
.insert(DispatchKey::Autograd);
assert_eq!(set.len(), 2);
assert!(set.contains(DispatchKey::Cpu));
assert!(set.contains(DispatchKey::Autograd));
assert!(!set.contains(DispatchKey::Cuda));
}
#[test]
fn dispatch_key_set_remove() {
let set = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
let without_autograd = set.remove(DispatchKey::Autograd);
assert_eq!(without_autograd.len(), 1);
assert!(without_autograd.contains(DispatchKey::Cpu));
assert!(!without_autograd.contains(DispatchKey::Autograd));
}
#[test]
fn dispatch_key_set_highest() {
let set = DispatchKeySet::from([
DispatchKey::Cpu,
DispatchKey::Autograd,
DispatchKey::Profiler,
]);
assert_eq!(set.highest(), Some(DispatchKey::Profiler));
}
#[test]
fn dispatch_key_set_iter_desc_gives_priority_order() {
let set = DispatchKeySet::from([
DispatchKey::Cpu,
DispatchKey::Tracer,
DispatchKey::Autograd,
DispatchKey::Cuda,
]);
let order: Vec<_> = set.iter_desc().collect();
assert_eq!(
order,
vec![
DispatchKey::Tracer,
DispatchKey::Autograd,
DispatchKey::Cuda,
DispatchKey::Cpu,
]
);
}
#[test]
fn dispatch_key_set_union_and_intersection() {
let a = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
let b = DispatchKeySet::from([DispatchKey::Autograd, DispatchKey::Quantized]);
let u = a.union(b);
assert_eq!(u.len(), 3);
assert!(u.contains(DispatchKey::Cpu));
assert!(u.contains(DispatchKey::Autograd));
assert!(u.contains(DispatchKey::Quantized));
let i = a.intersection(b);
assert_eq!(i.len(), 1);
assert!(i.contains(DispatchKey::Autograd));
}
#[test]
fn dispatch_key_set_all_contains_every_key() {
let set = DispatchKeySet::all();
assert_eq!(set.len(), 11);
for &k in &DispatchKey::ALL {
assert!(set.contains(k));
}
}
#[test]
fn dispatch_key_set_from_array_literal() {
let set = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Cuda]);
assert_eq!(set.len(), 2);
}
#[test]
fn dispatcher_register_and_has_kernel() {
let mut d = Dispatcher::<f32>::new();
assert_eq!(d.kernel_count(), 0);
assert!(!d.has_kernel("add", DispatchKey::Cpu));
d.register("add", DispatchKey::Cpu, |inputs, _, _| Ok(inputs[0].clone()));
assert_eq!(d.kernel_count(), 1);
assert!(d.has_kernel("add", DispatchKey::Cpu));
assert!(!d.has_kernel("add", DispatchKey::Cuda));
assert!(!d.has_kernel("sub", DispatchKey::Cpu));
}
#[test]
fn dispatcher_call_empty_keyset_errors() {
let d = Dispatcher::<f32>::new();
let t = make_tensor(vec![1.0], vec![1]);
let result = d.call("add", &[t], DispatchKeySet::empty());
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("empty keyset"));
}
#[test]
fn dispatcher_call_no_kernel_errors() {
let d = Dispatcher::<f32>::new();
let t = make_tensor(vec![1.0], vec![1]);
let keyset = DispatchKeySet::from([DispatchKey::Cpu]);
let result = d.call("add", &[t], keyset);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("no kernel registered"));
}
#[test]
fn dispatcher_call_picks_highest_priority_key() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let cpu_count = Arc::new(AtomicUsize::new(0));
let autograd_count = Arc::new(AtomicUsize::new(0));
let mut d = Dispatcher::<f32>::new();
let cpu_c = Arc::clone(&cpu_count);
d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
cpu_c.fetch_add(1, Ordering::Relaxed);
Ok(inputs[0].clone())
});
let ag_c = Arc::clone(&autograd_count);
d.register("add", DispatchKey::Autograd, move |inputs, _, _| {
ag_c.fetch_add(1, Ordering::Relaxed);
Ok(inputs[0].clone())
});
let t = make_tensor(vec![1.0], vec![1]);
let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
d.call("add", &[t], keyset).unwrap();
assert_eq!(autograd_count.load(Ordering::Relaxed), 1);
assert_eq!(cpu_count.load(Ordering::Relaxed), 0);
}
#[test]
fn dispatcher_redispatch_chains_through_keys() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let cpu_count = Arc::new(AtomicUsize::new(0));
let autograd_count = Arc::new(AtomicUsize::new(0));
let mut d = Dispatcher::<f32>::new();
let cpu_c = Arc::clone(&cpu_count);
d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
cpu_c.fetch_add(1, Ordering::Relaxed);
Ok(inputs[0].clone())
});
let ag_c = Arc::clone(&autograd_count);
d.register("add", DispatchKey::Autograd, move |inputs, keyset, disp| {
ag_c.fetch_add(1, Ordering::Relaxed);
let rest = keyset.remove(DispatchKey::Autograd);
disp.call("add", inputs, rest)
});
let t = make_tensor(vec![1.0], vec![1]);
let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
d.call("add", &[t], keyset).unwrap();
assert_eq!(autograd_count.load(Ordering::Relaxed), 1);
assert_eq!(cpu_count.load(Ordering::Relaxed), 1);
}
#[test]
fn dispatcher_skips_keys_without_kernel() {
let mut d = Dispatcher::<f32>::new();
d.register("add", DispatchKey::Cpu, |inputs, _, _| Ok(inputs[0].clone()));
let t = make_tensor(vec![1.0, 2.0], vec![2]);
let keyset = DispatchKeySet::from([DispatchKey::Autograd, DispatchKey::Cpu]);
let result = d.call("add", &[t], keyset).unwrap();
assert_eq!(result.shape(), &[2]);
}
#[test]
fn dispatcher_call_direct_bypasses_priority() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let cpu_count = Arc::new(AtomicUsize::new(0));
let cuda_count = Arc::new(AtomicUsize::new(0));
let mut d = Dispatcher::<f32>::new();
let cpu_c = Arc::clone(&cpu_count);
d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
cpu_c.fetch_add(1, Ordering::Relaxed);
Ok(inputs[0].clone())
});
let cuda_c = Arc::clone(&cuda_count);
d.register("add", DispatchKey::Cuda, move |inputs, _, _| {
cuda_c.fetch_add(1, Ordering::Relaxed);
Ok(inputs[0].clone())
});
let t = make_tensor(vec![1.0], vec![1]);
let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Cuda]);
d.call("add", &[t.clone()], keyset).unwrap();
assert_eq!(cuda_count.load(Ordering::Relaxed), 1);
assert_eq!(cpu_count.load(Ordering::Relaxed), 0);
d.call_direct("add", &[t], keyset, DispatchKey::Cpu).unwrap();
assert_eq!(cpu_count.load(Ordering::Relaxed), 1);
assert_eq!(cuda_count.load(Ordering::Relaxed), 1);
}
#[test]
fn dispatcher_call_direct_missing_kernel_errors() {
let d = Dispatcher::<f32>::new();
let t = make_tensor(vec![1.0], vec![1]);
let keyset = DispatchKeySet::from([DispatchKey::Cpu]);
let result = d.call_direct("add", &[t], keyset, DispatchKey::Cpu);
assert!(result.is_err());
}
#[test]
fn dispatcher_full_three_layer_stack() {
use std::sync::Mutex;
use std::sync::Arc;
let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(Vec::new()));
let mut d = Dispatcher::<f32>::new();
let log_c = Arc::clone(&log);
d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
log_c.lock().unwrap().push("cpu");
Ok(inputs[0].clone())
});
let log_a = Arc::clone(&log);
d.register("add", DispatchKey::Autograd, move |inputs, keyset, disp| {
log_a.lock().unwrap().push("autograd");
let rest = keyset.remove(DispatchKey::Autograd);
disp.call("add", inputs, rest)
});
let log_t = Arc::clone(&log);
d.register("add", DispatchKey::Tracer, move |inputs, keyset, disp| {
log_t.lock().unwrap().push("tracer");
let rest = keyset.remove(DispatchKey::Tracer);
disp.call("add", inputs, rest)
});
let t = make_tensor(vec![1.0, 2.0], vec![2]);
let keyset = DispatchKeySet::from([
DispatchKey::Tracer,
DispatchKey::Autograd,
DispatchKey::Cpu,
]);
d.call("add", &[t], keyset).unwrap();
let final_log = log.lock().unwrap();
assert_eq!(*final_log, vec!["tracer", "autograd", "cpu"]);
}
}