1use crate::dtype::Float;
54use crate::error::{FerrotorchError, FerrotorchResult};
55use crate::tensor::Tensor;
56
57use std::collections::HashMap;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
69#[repr(u8)]
70pub enum DispatchKey {
71 Cpu = 0,
73 Cuda = 1,
75 Meta = 2,
77 Sparse = 3,
81 Quantized = 4,
85 Nested = 5,
88 Autocast = 6,
92 Autograd = 7,
96 Vmap = 8,
100 Profiler = 9,
104 Tracer = 10,
108}
109
110impl DispatchKey {
111 #[inline]
113 pub fn priority(self) -> u8 {
114 self as u8
115 }
116
117 pub const ALL: [DispatchKey; 11] = [
120 DispatchKey::Cpu,
121 DispatchKey::Cuda,
122 DispatchKey::Meta,
123 DispatchKey::Sparse,
124 DispatchKey::Quantized,
125 DispatchKey::Nested,
126 DispatchKey::Autocast,
127 DispatchKey::Autograd,
128 DispatchKey::Vmap,
129 DispatchKey::Profiler,
130 DispatchKey::Tracer,
131 ];
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
137pub struct DispatchKeySet {
138 bits: u16,
139}
140
141impl DispatchKeySet {
142 #[inline]
144 pub const fn empty() -> Self {
145 Self { bits: 0 }
146 }
147
148 pub fn all() -> Self {
150 let mut set = Self::empty();
151 for &k in &DispatchKey::ALL {
152 set = set.insert(k);
153 }
154 set
155 }
156
157 pub fn from_iter<I: IntoIterator<Item = DispatchKey>>(keys: I) -> Self {
159 let mut set = Self::empty();
160 for k in keys {
161 set = set.insert(k);
162 }
163 set
164 }
165
166 #[inline]
168 pub fn contains(self, key: DispatchKey) -> bool {
169 (self.bits >> key.priority()) & 1 != 0
170 }
171
172 #[inline]
174 #[must_use]
175 pub fn insert(self, key: DispatchKey) -> Self {
176 Self {
177 bits: self.bits | (1 << key.priority()),
178 }
179 }
180
181 #[inline]
183 #[must_use]
184 pub fn remove(self, key: DispatchKey) -> Self {
185 Self {
186 bits: self.bits & !(1 << key.priority()),
187 }
188 }
189
190 #[inline]
192 #[must_use]
193 pub fn union(self, other: Self) -> Self {
194 Self {
195 bits: self.bits | other.bits,
196 }
197 }
198
199 #[inline]
201 #[must_use]
202 pub fn intersection(self, other: Self) -> Self {
203 Self {
204 bits: self.bits & other.bits,
205 }
206 }
207
208 #[inline]
210 pub fn is_empty(self) -> bool {
211 self.bits == 0
212 }
213
214 #[inline]
216 pub fn len(self) -> usize {
217 self.bits.count_ones() as usize
218 }
219
220 pub fn highest(self) -> Option<DispatchKey> {
223 if self.bits == 0 {
224 return None;
225 }
226 for &k in DispatchKey::ALL.iter().rev() {
229 if self.contains(k) {
230 return Some(k);
231 }
232 }
233 None
234 }
235
236 pub fn iter_desc(self) -> impl Iterator<Item = DispatchKey> {
239 let mut bits = self.bits;
240 std::iter::from_fn(move || {
241 if bits == 0 {
242 return None;
243 }
244 let top = 15 - bits.leading_zeros() as u8;
246 bits &= !(1 << top);
247 DispatchKey::ALL.iter().find(|k| k.priority() == top).copied()
249 })
250 }
251}
252
253impl Default for DispatchKeySet {
254 fn default() -> Self {
255 Self::empty()
256 }
257}
258
259impl<const N: usize> From<[DispatchKey; N]> for DispatchKeySet {
260 fn from(arr: [DispatchKey; N]) -> Self {
261 Self::from_iter(arr)
262 }
263}
264
265pub type Kernel<T> = Box<
278 dyn Fn(&[Tensor<T>], DispatchKeySet, &Dispatcher<T>) -> FerrotorchResult<Tensor<T>>
279 + Send
280 + Sync,
281>;
282
283pub struct Dispatcher<T: Float> {
289 kernels: HashMap<(String, DispatchKey), Kernel<T>>,
290}
291
292impl<T: Float> Dispatcher<T> {
293 pub fn new() -> Self {
295 Self {
296 kernels: HashMap::new(),
297 }
298 }
299
300 pub fn register<F>(&mut self, op_name: impl Into<String>, key: DispatchKey, kernel: F)
303 where
304 F: Fn(&[Tensor<T>], DispatchKeySet, &Dispatcher<T>) -> FerrotorchResult<Tensor<T>>
305 + Send
306 + Sync
307 + 'static,
308 {
309 self.kernels.insert((op_name.into(), key), Box::new(kernel));
310 }
311
312 pub fn has_kernel(&self, op_name: &str, key: DispatchKey) -> bool {
314 self.kernels.contains_key(&(op_name.to_string(), key))
315 }
316
317 pub fn kernel_count(&self) -> usize {
319 self.kernels.len()
320 }
321
322 pub fn call(
335 &self,
336 op_name: &str,
337 inputs: &[Tensor<T>],
338 keyset: DispatchKeySet,
339 ) -> FerrotorchResult<Tensor<T>> {
340 if keyset.is_empty() {
341 return Err(FerrotorchError::InvalidArgument {
342 message: format!(
343 "Dispatcher::call({op_name}): empty keyset — no backend to run on"
344 ),
345 });
346 }
347 for key in keyset.iter_desc() {
348 if let Some(kernel) = self.kernels.get(&(op_name.to_string(), key)) {
349 return kernel(inputs, keyset, self);
350 }
351 }
352 Err(FerrotorchError::InvalidArgument {
353 message: format!(
354 "Dispatcher::call({op_name}): no kernel registered for any key in {keyset:?}"
355 ),
356 })
357 }
358
359 pub fn call_direct(
366 &self,
367 op_name: &str,
368 inputs: &[Tensor<T>],
369 keyset: DispatchKeySet,
370 key: DispatchKey,
371 ) -> FerrotorchResult<Tensor<T>> {
372 match self.kernels.get(&(op_name.to_string(), key)) {
373 Some(kernel) => kernel(inputs, keyset, self),
374 None => Err(FerrotorchError::InvalidArgument {
375 message: format!(
376 "Dispatcher::call_direct({op_name}, {key:?}): no kernel registered"
377 ),
378 }),
379 }
380 }
381}
382
383impl<T: Float> Default for Dispatcher<T> {
384 fn default() -> Self {
385 Self::new()
386 }
387}
388
389impl<T: Float> std::fmt::Debug for Dispatcher<T> {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 f.debug_struct("Dispatcher")
392 .field("kernel_count", &self.kernels.len())
393 .finish()
394 }
395}
396
397#[cfg(test)]
402mod tests {
403 use super::*;
404 use crate::storage::TensorStorage;
405
406 fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
407 Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
408 }
409
410 #[test]
413 fn dispatch_key_priority_ordering() {
414 assert!(DispatchKey::Tracer.priority() > DispatchKey::Autograd.priority());
415 assert!(DispatchKey::Autograd.priority() > DispatchKey::Autocast.priority());
416 assert!(DispatchKey::Autocast.priority() > DispatchKey::Cpu.priority());
417 assert!(DispatchKey::Cuda.priority() > DispatchKey::Cpu.priority());
418 }
419
420 #[test]
421 fn dispatch_key_all_contains_every_key() {
422 assert_eq!(DispatchKey::ALL.len(), 11);
423 for k in &DispatchKey::ALL {
425 let count = DispatchKey::ALL.iter().filter(|&other| other == k).count();
426 assert_eq!(count, 1, "duplicate key {k:?}");
427 }
428 }
429
430 #[test]
433 fn dispatch_key_set_empty() {
434 let set = DispatchKeySet::empty();
435 assert!(set.is_empty());
436 assert_eq!(set.len(), 0);
437 assert_eq!(set.highest(), None);
438 assert!(!set.contains(DispatchKey::Cpu));
439 }
440
441 #[test]
442 fn dispatch_key_set_insert_and_contains() {
443 let set = DispatchKeySet::empty()
444 .insert(DispatchKey::Cpu)
445 .insert(DispatchKey::Autograd);
446 assert_eq!(set.len(), 2);
447 assert!(set.contains(DispatchKey::Cpu));
448 assert!(set.contains(DispatchKey::Autograd));
449 assert!(!set.contains(DispatchKey::Cuda));
450 }
451
452 #[test]
453 fn dispatch_key_set_remove() {
454 let set = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
455 let without_autograd = set.remove(DispatchKey::Autograd);
456 assert_eq!(without_autograd.len(), 1);
457 assert!(without_autograd.contains(DispatchKey::Cpu));
458 assert!(!without_autograd.contains(DispatchKey::Autograd));
459 }
460
461 #[test]
462 fn dispatch_key_set_highest() {
463 let set = DispatchKeySet::from([
464 DispatchKey::Cpu,
465 DispatchKey::Autograd,
466 DispatchKey::Profiler,
467 ]);
468 assert_eq!(set.highest(), Some(DispatchKey::Profiler));
469 }
470
471 #[test]
472 fn dispatch_key_set_iter_desc_gives_priority_order() {
473 let set = DispatchKeySet::from([
474 DispatchKey::Cpu,
475 DispatchKey::Tracer,
476 DispatchKey::Autograd,
477 DispatchKey::Cuda,
478 ]);
479 let order: Vec<_> = set.iter_desc().collect();
480 assert_eq!(
481 order,
482 vec![
483 DispatchKey::Tracer,
484 DispatchKey::Autograd,
485 DispatchKey::Cuda,
486 DispatchKey::Cpu,
487 ]
488 );
489 }
490
491 #[test]
492 fn dispatch_key_set_union_and_intersection() {
493 let a = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
494 let b = DispatchKeySet::from([DispatchKey::Autograd, DispatchKey::Quantized]);
495 let u = a.union(b);
496 assert_eq!(u.len(), 3);
497 assert!(u.contains(DispatchKey::Cpu));
498 assert!(u.contains(DispatchKey::Autograd));
499 assert!(u.contains(DispatchKey::Quantized));
500
501 let i = a.intersection(b);
502 assert_eq!(i.len(), 1);
503 assert!(i.contains(DispatchKey::Autograd));
504 }
505
506 #[test]
507 fn dispatch_key_set_all_contains_every_key() {
508 let set = DispatchKeySet::all();
509 assert_eq!(set.len(), 11);
510 for &k in &DispatchKey::ALL {
511 assert!(set.contains(k));
512 }
513 }
514
515 #[test]
516 fn dispatch_key_set_from_array_literal() {
517 let set = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Cuda]);
518 assert_eq!(set.len(), 2);
519 }
520
521 #[test]
524 fn dispatcher_register_and_has_kernel() {
525 let mut d = Dispatcher::<f32>::new();
526 assert_eq!(d.kernel_count(), 0);
527 assert!(!d.has_kernel("add", DispatchKey::Cpu));
528
529 d.register("add", DispatchKey::Cpu, |inputs, _, _| Ok(inputs[0].clone()));
530 assert_eq!(d.kernel_count(), 1);
531 assert!(d.has_kernel("add", DispatchKey::Cpu));
532 assert!(!d.has_kernel("add", DispatchKey::Cuda));
533 assert!(!d.has_kernel("sub", DispatchKey::Cpu));
534 }
535
536 #[test]
537 fn dispatcher_call_empty_keyset_errors() {
538 let d = Dispatcher::<f32>::new();
539 let t = make_tensor(vec![1.0], vec![1]);
540 let result = d.call("add", &[t], DispatchKeySet::empty());
541 assert!(result.is_err());
542 assert!(format!("{}", result.unwrap_err()).contains("empty keyset"));
543 }
544
545 #[test]
546 fn dispatcher_call_no_kernel_errors() {
547 let d = Dispatcher::<f32>::new();
548 let t = make_tensor(vec![1.0], vec![1]);
549 let keyset = DispatchKeySet::from([DispatchKey::Cpu]);
550 let result = d.call("add", &[t], keyset);
551 assert!(result.is_err());
552 assert!(format!("{}", result.unwrap_err()).contains("no kernel registered"));
553 }
554
555 #[test]
556 fn dispatcher_call_picks_highest_priority_key() {
557 use std::sync::atomic::{AtomicUsize, Ordering};
558 use std::sync::Arc;
559
560 let cpu_count = Arc::new(AtomicUsize::new(0));
562 let autograd_count = Arc::new(AtomicUsize::new(0));
563
564 let mut d = Dispatcher::<f32>::new();
565 let cpu_c = Arc::clone(&cpu_count);
566 d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
567 cpu_c.fetch_add(1, Ordering::Relaxed);
568 Ok(inputs[0].clone())
569 });
570 let ag_c = Arc::clone(&autograd_count);
571 d.register("add", DispatchKey::Autograd, move |inputs, _, _| {
572 ag_c.fetch_add(1, Ordering::Relaxed);
573 Ok(inputs[0].clone())
574 });
575
576 let t = make_tensor(vec![1.0], vec![1]);
577 let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
578 d.call("add", &[t], keyset).unwrap();
579
580 assert_eq!(autograd_count.load(Ordering::Relaxed), 1);
582 assert_eq!(cpu_count.load(Ordering::Relaxed), 0);
583 }
584
585 #[test]
586 fn dispatcher_redispatch_chains_through_keys() {
587 use std::sync::atomic::{AtomicUsize, Ordering};
589 use std::sync::Arc;
590
591 let cpu_count = Arc::new(AtomicUsize::new(0));
592 let autograd_count = Arc::new(AtomicUsize::new(0));
593
594 let mut d = Dispatcher::<f32>::new();
595 let cpu_c = Arc::clone(&cpu_count);
596 d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
597 cpu_c.fetch_add(1, Ordering::Relaxed);
598 Ok(inputs[0].clone())
599 });
600 let ag_c = Arc::clone(&autograd_count);
601 d.register("add", DispatchKey::Autograd, move |inputs, keyset, disp| {
602 ag_c.fetch_add(1, Ordering::Relaxed);
603 let rest = keyset.remove(DispatchKey::Autograd);
605 disp.call("add", inputs, rest)
606 });
607
608 let t = make_tensor(vec![1.0], vec![1]);
609 let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
610 d.call("add", &[t], keyset).unwrap();
611
612 assert_eq!(autograd_count.load(Ordering::Relaxed), 1);
613 assert_eq!(cpu_count.load(Ordering::Relaxed), 1);
614 }
615
616 #[test]
617 fn dispatcher_skips_keys_without_kernel() {
618 let mut d = Dispatcher::<f32>::new();
622 d.register("add", DispatchKey::Cpu, |inputs, _, _| Ok(inputs[0].clone()));
623
624 let t = make_tensor(vec![1.0, 2.0], vec![2]);
625 let keyset = DispatchKeySet::from([DispatchKey::Autograd, DispatchKey::Cpu]);
626 let result = d.call("add", &[t], keyset).unwrap();
627 assert_eq!(result.shape(), &[2]);
628 }
629
630 #[test]
631 fn dispatcher_call_direct_bypasses_priority() {
632 use std::sync::atomic::{AtomicUsize, Ordering};
633 use std::sync::Arc;
634
635 let cpu_count = Arc::new(AtomicUsize::new(0));
636 let cuda_count = Arc::new(AtomicUsize::new(0));
637
638 let mut d = Dispatcher::<f32>::new();
639 let cpu_c = Arc::clone(&cpu_count);
640 d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
641 cpu_c.fetch_add(1, Ordering::Relaxed);
642 Ok(inputs[0].clone())
643 });
644 let cuda_c = Arc::clone(&cuda_count);
645 d.register("add", DispatchKey::Cuda, move |inputs, _, _| {
646 cuda_c.fetch_add(1, Ordering::Relaxed);
647 Ok(inputs[0].clone())
648 });
649
650 let t = make_tensor(vec![1.0], vec![1]);
652 let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Cuda]);
653 d.call("add", &[t.clone()], keyset).unwrap();
654 assert_eq!(cuda_count.load(Ordering::Relaxed), 1);
655 assert_eq!(cpu_count.load(Ordering::Relaxed), 0);
656
657 d.call_direct("add", &[t], keyset, DispatchKey::Cpu).unwrap();
659 assert_eq!(cpu_count.load(Ordering::Relaxed), 1);
660 assert_eq!(cuda_count.load(Ordering::Relaxed), 1);
661 }
662
663 #[test]
664 fn dispatcher_call_direct_missing_kernel_errors() {
665 let d = Dispatcher::<f32>::new();
666 let t = make_tensor(vec![1.0], vec![1]);
667 let keyset = DispatchKeySet::from([DispatchKey::Cpu]);
668 let result = d.call_direct("add", &[t], keyset, DispatchKey::Cpu);
669 assert!(result.is_err());
670 }
671
672 #[test]
673 fn dispatcher_full_three_layer_stack() {
674 use std::sync::Mutex;
679 use std::sync::Arc;
680
681 let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(Vec::new()));
682
683 let mut d = Dispatcher::<f32>::new();
684
685 let log_c = Arc::clone(&log);
686 d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
687 log_c.lock().unwrap().push("cpu");
688 Ok(inputs[0].clone())
689 });
690
691 let log_a = Arc::clone(&log);
692 d.register("add", DispatchKey::Autograd, move |inputs, keyset, disp| {
693 log_a.lock().unwrap().push("autograd");
694 let rest = keyset.remove(DispatchKey::Autograd);
695 disp.call("add", inputs, rest)
696 });
697
698 let log_t = Arc::clone(&log);
699 d.register("add", DispatchKey::Tracer, move |inputs, keyset, disp| {
700 log_t.lock().unwrap().push("tracer");
701 let rest = keyset.remove(DispatchKey::Tracer);
702 disp.call("add", inputs, rest)
703 });
704
705 let t = make_tensor(vec![1.0, 2.0], vec![2]);
706 let keyset = DispatchKeySet::from([
707 DispatchKey::Tracer,
708 DispatchKey::Autograd,
709 DispatchKey::Cpu,
710 ]);
711 d.call("add", &[t], keyset).unwrap();
712
713 let final_log = log.lock().unwrap();
714 assert_eq!(*final_log, vec!["tracer", "autograd", "cpu"]);
715 }
716}