Skip to main content

ferrotorch_core/
dispatch.rs

1//! Multi-dispatch key system for composable tensor backends. CL-397.
2//!
3//! Mirrors PyTorch's `DispatchKey` / `DispatchKeySet` / `Dispatcher`
4//! architecture: every tensor carries a set of active dispatch keys
5//! (e.g. `Autograd`, `Quantized`, `Sparse`, `CPU`, `CUDA`), and when
6//! an op is invoked the dispatcher picks the kernel registered for
7//! the **highest-priority** active key.
8//!
9//! This enables layered semantics without hard-coding each
10//! combination in every op:
11//!
12//! - `Autograd` kernels record a backward node and forward to the
13//!   next layer.
14//! - `Quantized` kernels dequantize, forward, and re-quantize.
15//! - `Sparse` kernels call the sparse backend when the tensor is a
16//!   sparse view.
17//! - `CPU` / `CUDA` are the terminal "backend" keys that actually
18//!   run the math.
19//!
20//! The dispatcher walks the set from highest to lowest priority,
21//! picks the first registered kernel, and runs it. The kernel can
22//! mask its own key off and call the dispatcher again to delegate
23//! to the next layer ("redispatch" in PyTorch terminology).
24//!
25//! # Example
26//!
27//! ```ignore
28//! use ferrotorch_core::dispatch::{DispatchKey, DispatchKeySet, Dispatcher};
29//!
30//! let mut dispatcher = Dispatcher::<f32>::new();
31//!
32//! // Register a CPU kernel for the "add" op.
33//! dispatcher.register("add", DispatchKey::Cpu, |inputs, _keyset, _disp| {
34//!     // Actually do the addition...
35//!     Ok(inputs[0].clone())
36//! });
37//!
38//! // Layer an autograd kernel on top that records a backward node
39//! // and redispatches with Autograd masked off.
40//! dispatcher.register("add", DispatchKey::Autograd, |inputs, keyset, disp| {
41//!     // ... record backward ...
42//!     let remaining = keyset.remove(DispatchKey::Autograd);
43//!     disp.call("add", inputs, remaining)
44//! });
45//!
46//! // Call the op with a keyset that has both Autograd and CPU set.
47//! // The dispatcher picks Autograd first (higher priority), which
48//! // then redispatches to Cpu.
49//! let keyset = DispatchKeySet::from([DispatchKey::Autograd, DispatchKey::Cpu]);
50//! let result = dispatcher.call("add", &[tensor], keyset).unwrap();
51//! ```
52
53use crate::dtype::Float;
54use crate::error::{FerrotorchError, FerrotorchResult};
55use crate::tensor::Tensor;
56
57use std::collections::HashMap;
58
59/// One of the 16 possible dispatch keys, ordered from lowest to
60/// highest priority. The `u8` repr matches the bit position in
61/// [`DispatchKeySet`]'s internal `u16` bitmask, so the priority
62/// ordering is both the enum declaration order and the numeric
63/// order of the discriminants.
64///
65/// Keys are resolved highest-priority-first: the dispatcher walks
66/// from the largest discriminant down and picks the first key that
67/// has a registered kernel for the op.
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
69#[repr(u8)]
70pub enum DispatchKey {
71    /// Backend: CPU — terminal key for CPU kernels.
72    Cpu = 0,
73    /// Backend: CUDA — terminal key for CUDA kernels.
74    Cuda = 1,
75    /// Backend: Meta device — shape-only dry runs, no data.
76    Meta = 2,
77    /// Tensor contains sparse data. Sparse kernels intercept ops
78    /// and either call a sparse-specific backend or densify and
79    /// redispatch.
80    Sparse = 3,
81    /// Tensor contains quantized values. Quantized kernels
82    /// dequantize, redispatch, and requantize (for ops without
83    /// native quantized kernels).
84    Quantized = 4,
85    /// Tensor is a nested/jagged tensor. Nested kernels iterate
86    /// per-component and redispatch to the backend.
87    Nested = 5,
88    /// Auto-mixed-precision: cast inputs to the autocast dtype
89    /// before redispatching. Higher priority than Quantized so
90    /// AMP happens before quantization layering.
91    Autocast = 6,
92    /// Autograd: record a backward node and redispatch with
93    /// Autograd masked off. Highest-priority non-profiling key so
94    /// the backward graph sees the post-dispatch view of each op.
95    Autograd = 7,
96    /// Vmap (batched tensor): intercept ops and apply them over
97    /// the batch dimension. Stacks above Autograd so batched
98    /// forwards still see autograd semantics.
99    Vmap = 8,
100    /// Profiler: record an entry in the active profiler before
101    /// redispatching. Sits above Vmap so the profiler sees the
102    /// outer call exactly once regardless of batching.
103    Profiler = 9,
104    /// Tracer: emit an IR node into the active JIT trace.
105    /// Highest priority so tracing happens before any other
106    /// layering transforms the op.
107    Tracer = 10,
108}
109
110impl DispatchKey {
111    /// The numeric priority of this key. Larger = higher priority.
112    #[inline]
113    pub fn priority(self) -> u8 {
114        self as u8
115    }
116
117    /// All 11 defined keys, in priority order (lowest to highest).
118    /// Useful for iterating the full set.
119    pub const ALL: [DispatchKey; 11] = [
120        DispatchKey::Cpu,
121        DispatchKey::Cuda,
122        DispatchKey::Meta,
123        DispatchKey::Sparse,
124        DispatchKey::Quantized,
125        DispatchKey::Nested,
126        DispatchKey::Autocast,
127        DispatchKey::Autograd,
128        DispatchKey::Vmap,
129        DispatchKey::Profiler,
130        DispatchKey::Tracer,
131    ];
132}
133
134/// A set of active [`DispatchKey`]s, stored as a `u16` bitmask for
135/// constant-time membership testing and iteration.
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
137pub struct DispatchKeySet {
138    bits: u16,
139}
140
141impl DispatchKeySet {
142    /// An empty set.
143    #[inline]
144    pub const fn empty() -> Self {
145        Self { bits: 0 }
146    }
147
148    /// A set containing every defined key.
149    pub fn all() -> Self {
150        let mut set = Self::empty();
151        for &k in &DispatchKey::ALL {
152            set = set.insert(k);
153        }
154        set
155    }
156
157    /// Construct a set from an iterable of keys.
158    pub fn from_iter<I: IntoIterator<Item = DispatchKey>>(keys: I) -> Self {
159        let mut set = Self::empty();
160        for k in keys {
161            set = set.insert(k);
162        }
163        set
164    }
165
166    /// Returns true if `key` is in this set.
167    #[inline]
168    pub fn contains(self, key: DispatchKey) -> bool {
169        (self.bits >> key.priority()) & 1 != 0
170    }
171
172    /// Returns a new set with `key` added.
173    #[inline]
174    #[must_use]
175    pub fn insert(self, key: DispatchKey) -> Self {
176        Self {
177            bits: self.bits | (1 << key.priority()),
178        }
179    }
180
181    /// Returns a new set with `key` removed.
182    #[inline]
183    #[must_use]
184    pub fn remove(self, key: DispatchKey) -> Self {
185        Self {
186            bits: self.bits & !(1 << key.priority()),
187        }
188    }
189
190    /// Union of two sets.
191    #[inline]
192    #[must_use]
193    pub fn union(self, other: Self) -> Self {
194        Self {
195            bits: self.bits | other.bits,
196        }
197    }
198
199    /// Intersection of two sets.
200    #[inline]
201    #[must_use]
202    pub fn intersection(self, other: Self) -> Self {
203        Self {
204            bits: self.bits & other.bits,
205        }
206    }
207
208    /// Returns true if this set has no keys.
209    #[inline]
210    pub fn is_empty(self) -> bool {
211        self.bits == 0
212    }
213
214    /// Number of keys in this set.
215    #[inline]
216    pub fn len(self) -> usize {
217        self.bits.count_ones() as usize
218    }
219
220    /// Highest-priority key in this set, or `None` if empty. This
221    /// is the "next" key the dispatcher will resolve.
222    pub fn highest(self) -> Option<DispatchKey> {
223        if self.bits == 0 {
224            return None;
225        }
226        // Walk keys from highest to lowest discriminant and return
227        // the first one present.
228        for &k in DispatchKey::ALL.iter().rev() {
229            if self.contains(k) {
230                return Some(k);
231            }
232        }
233        None
234    }
235
236    /// Returns an iterator over all keys in the set, in
237    /// **descending** priority order (highest first).
238    pub fn iter_desc(self) -> impl Iterator<Item = DispatchKey> {
239        let mut bits = self.bits;
240        std::iter::from_fn(move || {
241            if bits == 0 {
242                return None;
243            }
244            // Find the highest set bit.
245            let top = 15 - bits.leading_zeros() as u8;
246            bits &= !(1 << top);
247            // Map bit position back to a DispatchKey if valid.
248            DispatchKey::ALL.iter().find(|k| k.priority() == top).copied()
249        })
250    }
251}
252
253impl Default for DispatchKeySet {
254    fn default() -> Self {
255        Self::empty()
256    }
257}
258
259impl<const N: usize> From<[DispatchKey; N]> for DispatchKeySet {
260    fn from(arr: [DispatchKey; N]) -> Self {
261        Self::from_iter(arr)
262    }
263}
264
265// ---------------------------------------------------------------------------
266// Kernel type and Dispatcher
267// ---------------------------------------------------------------------------
268
269/// A dispatched kernel: takes the op's input tensors, the
270/// currently-active keyset (after all higher-priority keys have
271/// been resolved), and a reference to the dispatcher so the kernel
272/// can redispatch to a lower-priority key.
273///
274/// Kernels return a single output tensor. Ops with multiple
275/// outputs are not yet supported by this dispatcher — they'd need
276/// a separate `KernelMulti` variant.
277pub type Kernel<T> = Box<
278    dyn Fn(&[Tensor<T>], DispatchKeySet, &Dispatcher<T>) -> FerrotorchResult<Tensor<T>>
279        + Send
280        + Sync,
281>;
282
283/// A kernel registration table keyed by `(op_name, dispatch_key)`.
284/// Looking up a kernel is a single HashMap probe.
285///
286/// `T` is the scalar dtype the dispatcher operates on (f32 / f64).
287/// Different dispatchers are typically held per-dtype.
288pub struct Dispatcher<T: Float> {
289    kernels: HashMap<(String, DispatchKey), Kernel<T>>,
290}
291
292impl<T: Float> Dispatcher<T> {
293    /// Create an empty dispatcher with no registered kernels.
294    pub fn new() -> Self {
295        Self {
296            kernels: HashMap::new(),
297        }
298    }
299
300    /// Register a kernel for `(op_name, key)`. Overwrites any
301    /// existing registration for the same pair.
302    pub fn register<F>(&mut self, op_name: impl Into<String>, key: DispatchKey, kernel: F)
303    where
304        F: Fn(&[Tensor<T>], DispatchKeySet, &Dispatcher<T>) -> FerrotorchResult<Tensor<T>>
305            + Send
306            + Sync
307            + 'static,
308    {
309        self.kernels.insert((op_name.into(), key), Box::new(kernel));
310    }
311
312    /// Returns true if a kernel is registered for `(op_name, key)`.
313    pub fn has_kernel(&self, op_name: &str, key: DispatchKey) -> bool {
314        self.kernels.contains_key(&(op_name.to_string(), key))
315    }
316
317    /// Number of registered kernels.
318    pub fn kernel_count(&self) -> usize {
319        self.kernels.len()
320    }
321
322    /// Call `op_name` with `inputs` and the given active keyset.
323    /// Walks the keyset in descending priority order, picks the
324    /// first key that has a kernel registered for the op, and runs
325    /// it. The kernel receives the full `keyset` (not just its
326    /// own key) so it can decide which keys to mask off before
327    /// redispatching.
328    ///
329    /// # Errors
330    ///
331    /// Returns [`FerrotorchError::InvalidArgument`] if no kernel
332    /// is registered for any active key in the set, or if the set
333    /// is empty.
334    pub fn call(
335        &self,
336        op_name: &str,
337        inputs: &[Tensor<T>],
338        keyset: DispatchKeySet,
339    ) -> FerrotorchResult<Tensor<T>> {
340        if keyset.is_empty() {
341            return Err(FerrotorchError::InvalidArgument {
342                message: format!(
343                    "Dispatcher::call({op_name}): empty keyset — no backend to run on"
344                ),
345            });
346        }
347        for key in keyset.iter_desc() {
348            if let Some(kernel) = self.kernels.get(&(op_name.to_string(), key)) {
349                return kernel(inputs, keyset, self);
350            }
351        }
352        Err(FerrotorchError::InvalidArgument {
353            message: format!(
354                "Dispatcher::call({op_name}): no kernel registered for any key in {keyset:?}"
355            ),
356        })
357    }
358
359    /// Call `op_name` with the kernel for a specific `key`,
360    /// bypassing priority resolution. Returns an error if no
361    /// kernel is registered for that key.
362    ///
363    /// Primarily useful for testing and for kernels that want to
364    /// forward directly to a specific lower-priority layer.
365    pub fn call_direct(
366        &self,
367        op_name: &str,
368        inputs: &[Tensor<T>],
369        keyset: DispatchKeySet,
370        key: DispatchKey,
371    ) -> FerrotorchResult<Tensor<T>> {
372        match self.kernels.get(&(op_name.to_string(), key)) {
373            Some(kernel) => kernel(inputs, keyset, self),
374            None => Err(FerrotorchError::InvalidArgument {
375                message: format!(
376                    "Dispatcher::call_direct({op_name}, {key:?}): no kernel registered"
377                ),
378            }),
379        }
380    }
381}
382
383impl<T: Float> Default for Dispatcher<T> {
384    fn default() -> Self {
385        Self::new()
386    }
387}
388
389impl<T: Float> std::fmt::Debug for Dispatcher<T> {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391        f.debug_struct("Dispatcher")
392            .field("kernel_count", &self.kernels.len())
393            .finish()
394    }
395}
396
397// ---------------------------------------------------------------------------
398// Tests
399// ---------------------------------------------------------------------------
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::storage::TensorStorage;
405
406    fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
407        Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
408    }
409
410    // ── DispatchKey priority ────────────────────────────────────────
411
412    #[test]
413    fn dispatch_key_priority_ordering() {
414        assert!(DispatchKey::Tracer.priority() > DispatchKey::Autograd.priority());
415        assert!(DispatchKey::Autograd.priority() > DispatchKey::Autocast.priority());
416        assert!(DispatchKey::Autocast.priority() > DispatchKey::Cpu.priority());
417        assert!(DispatchKey::Cuda.priority() > DispatchKey::Cpu.priority());
418    }
419
420    #[test]
421    fn dispatch_key_all_contains_every_key() {
422        assert_eq!(DispatchKey::ALL.len(), 11);
423        // Each key appears exactly once.
424        for k in &DispatchKey::ALL {
425            let count = DispatchKey::ALL.iter().filter(|&other| other == k).count();
426            assert_eq!(count, 1, "duplicate key {k:?}");
427        }
428    }
429
430    // ── DispatchKeySet membership ───────────────────────────────────
431
432    #[test]
433    fn dispatch_key_set_empty() {
434        let set = DispatchKeySet::empty();
435        assert!(set.is_empty());
436        assert_eq!(set.len(), 0);
437        assert_eq!(set.highest(), None);
438        assert!(!set.contains(DispatchKey::Cpu));
439    }
440
441    #[test]
442    fn dispatch_key_set_insert_and_contains() {
443        let set = DispatchKeySet::empty()
444            .insert(DispatchKey::Cpu)
445            .insert(DispatchKey::Autograd);
446        assert_eq!(set.len(), 2);
447        assert!(set.contains(DispatchKey::Cpu));
448        assert!(set.contains(DispatchKey::Autograd));
449        assert!(!set.contains(DispatchKey::Cuda));
450    }
451
452    #[test]
453    fn dispatch_key_set_remove() {
454        let set = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
455        let without_autograd = set.remove(DispatchKey::Autograd);
456        assert_eq!(without_autograd.len(), 1);
457        assert!(without_autograd.contains(DispatchKey::Cpu));
458        assert!(!without_autograd.contains(DispatchKey::Autograd));
459    }
460
461    #[test]
462    fn dispatch_key_set_highest() {
463        let set = DispatchKeySet::from([
464            DispatchKey::Cpu,
465            DispatchKey::Autograd,
466            DispatchKey::Profiler,
467        ]);
468        assert_eq!(set.highest(), Some(DispatchKey::Profiler));
469    }
470
471    #[test]
472    fn dispatch_key_set_iter_desc_gives_priority_order() {
473        let set = DispatchKeySet::from([
474            DispatchKey::Cpu,
475            DispatchKey::Tracer,
476            DispatchKey::Autograd,
477            DispatchKey::Cuda,
478        ]);
479        let order: Vec<_> = set.iter_desc().collect();
480        assert_eq!(
481            order,
482            vec![
483                DispatchKey::Tracer,
484                DispatchKey::Autograd,
485                DispatchKey::Cuda,
486                DispatchKey::Cpu,
487            ]
488        );
489    }
490
491    #[test]
492    fn dispatch_key_set_union_and_intersection() {
493        let a = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
494        let b = DispatchKeySet::from([DispatchKey::Autograd, DispatchKey::Quantized]);
495        let u = a.union(b);
496        assert_eq!(u.len(), 3);
497        assert!(u.contains(DispatchKey::Cpu));
498        assert!(u.contains(DispatchKey::Autograd));
499        assert!(u.contains(DispatchKey::Quantized));
500
501        let i = a.intersection(b);
502        assert_eq!(i.len(), 1);
503        assert!(i.contains(DispatchKey::Autograd));
504    }
505
506    #[test]
507    fn dispatch_key_set_all_contains_every_key() {
508        let set = DispatchKeySet::all();
509        assert_eq!(set.len(), 11);
510        for &k in &DispatchKey::ALL {
511            assert!(set.contains(k));
512        }
513    }
514
515    #[test]
516    fn dispatch_key_set_from_array_literal() {
517        let set = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Cuda]);
518        assert_eq!(set.len(), 2);
519    }
520
521    // ── Dispatcher registration and lookup ──────────────────────────
522
523    #[test]
524    fn dispatcher_register_and_has_kernel() {
525        let mut d = Dispatcher::<f32>::new();
526        assert_eq!(d.kernel_count(), 0);
527        assert!(!d.has_kernel("add", DispatchKey::Cpu));
528
529        d.register("add", DispatchKey::Cpu, |inputs, _, _| Ok(inputs[0].clone()));
530        assert_eq!(d.kernel_count(), 1);
531        assert!(d.has_kernel("add", DispatchKey::Cpu));
532        assert!(!d.has_kernel("add", DispatchKey::Cuda));
533        assert!(!d.has_kernel("sub", DispatchKey::Cpu));
534    }
535
536    #[test]
537    fn dispatcher_call_empty_keyset_errors() {
538        let d = Dispatcher::<f32>::new();
539        let t = make_tensor(vec![1.0], vec![1]);
540        let result = d.call("add", &[t], DispatchKeySet::empty());
541        assert!(result.is_err());
542        assert!(format!("{}", result.unwrap_err()).contains("empty keyset"));
543    }
544
545    #[test]
546    fn dispatcher_call_no_kernel_errors() {
547        let d = Dispatcher::<f32>::new();
548        let t = make_tensor(vec![1.0], vec![1]);
549        let keyset = DispatchKeySet::from([DispatchKey::Cpu]);
550        let result = d.call("add", &[t], keyset);
551        assert!(result.is_err());
552        assert!(format!("{}", result.unwrap_err()).contains("no kernel registered"));
553    }
554
555    #[test]
556    fn dispatcher_call_picks_highest_priority_key() {
557        use std::sync::atomic::{AtomicUsize, Ordering};
558        use std::sync::Arc;
559
560        // Track which kernel was called by name.
561        let cpu_count = Arc::new(AtomicUsize::new(0));
562        let autograd_count = Arc::new(AtomicUsize::new(0));
563
564        let mut d = Dispatcher::<f32>::new();
565        let cpu_c = Arc::clone(&cpu_count);
566        d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
567            cpu_c.fetch_add(1, Ordering::Relaxed);
568            Ok(inputs[0].clone())
569        });
570        let ag_c = Arc::clone(&autograd_count);
571        d.register("add", DispatchKey::Autograd, move |inputs, _, _| {
572            ag_c.fetch_add(1, Ordering::Relaxed);
573            Ok(inputs[0].clone())
574        });
575
576        let t = make_tensor(vec![1.0], vec![1]);
577        let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
578        d.call("add", &[t], keyset).unwrap();
579
580        // Autograd is higher priority, so it should be called.
581        assert_eq!(autograd_count.load(Ordering::Relaxed), 1);
582        assert_eq!(cpu_count.load(Ordering::Relaxed), 0);
583    }
584
585    #[test]
586    fn dispatcher_redispatch_chains_through_keys() {
587        // Autograd kernel masks itself off and calls down to Cpu.
588        use std::sync::atomic::{AtomicUsize, Ordering};
589        use std::sync::Arc;
590
591        let cpu_count = Arc::new(AtomicUsize::new(0));
592        let autograd_count = Arc::new(AtomicUsize::new(0));
593
594        let mut d = Dispatcher::<f32>::new();
595        let cpu_c = Arc::clone(&cpu_count);
596        d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
597            cpu_c.fetch_add(1, Ordering::Relaxed);
598            Ok(inputs[0].clone())
599        });
600        let ag_c = Arc::clone(&autograd_count);
601        d.register("add", DispatchKey::Autograd, move |inputs, keyset, disp| {
602            ag_c.fetch_add(1, Ordering::Relaxed);
603            // Mask off autograd and redispatch.
604            let rest = keyset.remove(DispatchKey::Autograd);
605            disp.call("add", inputs, rest)
606        });
607
608        let t = make_tensor(vec![1.0], vec![1]);
609        let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Autograd]);
610        d.call("add", &[t], keyset).unwrap();
611
612        assert_eq!(autograd_count.load(Ordering::Relaxed), 1);
613        assert_eq!(cpu_count.load(Ordering::Relaxed), 1);
614    }
615
616    #[test]
617    fn dispatcher_skips_keys_without_kernel() {
618        // Register Cpu only. A keyset that includes Autograd + Cpu
619        // should still resolve because Autograd has no kernel but
620        // Cpu does.
621        let mut d = Dispatcher::<f32>::new();
622        d.register("add", DispatchKey::Cpu, |inputs, _, _| Ok(inputs[0].clone()));
623
624        let t = make_tensor(vec![1.0, 2.0], vec![2]);
625        let keyset = DispatchKeySet::from([DispatchKey::Autograd, DispatchKey::Cpu]);
626        let result = d.call("add", &[t], keyset).unwrap();
627        assert_eq!(result.shape(), &[2]);
628    }
629
630    #[test]
631    fn dispatcher_call_direct_bypasses_priority() {
632        use std::sync::atomic::{AtomicUsize, Ordering};
633        use std::sync::Arc;
634
635        let cpu_count = Arc::new(AtomicUsize::new(0));
636        let cuda_count = Arc::new(AtomicUsize::new(0));
637
638        let mut d = Dispatcher::<f32>::new();
639        let cpu_c = Arc::clone(&cpu_count);
640        d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
641            cpu_c.fetch_add(1, Ordering::Relaxed);
642            Ok(inputs[0].clone())
643        });
644        let cuda_c = Arc::clone(&cuda_count);
645        d.register("add", DispatchKey::Cuda, move |inputs, _, _| {
646            cuda_c.fetch_add(1, Ordering::Relaxed);
647            Ok(inputs[0].clone())
648        });
649
650        // call() with both keys → Cuda (higher priority).
651        let t = make_tensor(vec![1.0], vec![1]);
652        let keyset = DispatchKeySet::from([DispatchKey::Cpu, DispatchKey::Cuda]);
653        d.call("add", &[t.clone()], keyset).unwrap();
654        assert_eq!(cuda_count.load(Ordering::Relaxed), 1);
655        assert_eq!(cpu_count.load(Ordering::Relaxed), 0);
656
657        // call_direct(Cpu) → forces Cpu kernel.
658        d.call_direct("add", &[t], keyset, DispatchKey::Cpu).unwrap();
659        assert_eq!(cpu_count.load(Ordering::Relaxed), 1);
660        assert_eq!(cuda_count.load(Ordering::Relaxed), 1);
661    }
662
663    #[test]
664    fn dispatcher_call_direct_missing_kernel_errors() {
665        let d = Dispatcher::<f32>::new();
666        let t = make_tensor(vec![1.0], vec![1]);
667        let keyset = DispatchKeySet::from([DispatchKey::Cpu]);
668        let result = d.call_direct("add", &[t], keyset, DispatchKey::Cpu);
669        assert!(result.is_err());
670    }
671
672    #[test]
673    fn dispatcher_full_three_layer_stack() {
674        // Realistic chain: Tracer → Autograd → Cpu.
675        // Tracer emits an IR node marker and redispatches.
676        // Autograd records a backward marker and redispatches.
677        // Cpu does the actual math.
678        use std::sync::Mutex;
679        use std::sync::Arc;
680
681        let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(Vec::new()));
682
683        let mut d = Dispatcher::<f32>::new();
684
685        let log_c = Arc::clone(&log);
686        d.register("add", DispatchKey::Cpu, move |inputs, _, _| {
687            log_c.lock().unwrap().push("cpu");
688            Ok(inputs[0].clone())
689        });
690
691        let log_a = Arc::clone(&log);
692        d.register("add", DispatchKey::Autograd, move |inputs, keyset, disp| {
693            log_a.lock().unwrap().push("autograd");
694            let rest = keyset.remove(DispatchKey::Autograd);
695            disp.call("add", inputs, rest)
696        });
697
698        let log_t = Arc::clone(&log);
699        d.register("add", DispatchKey::Tracer, move |inputs, keyset, disp| {
700            log_t.lock().unwrap().push("tracer");
701            let rest = keyset.remove(DispatchKey::Tracer);
702            disp.call("add", inputs, rest)
703        });
704
705        let t = make_tensor(vec![1.0, 2.0], vec![2]);
706        let keyset = DispatchKeySet::from([
707            DispatchKey::Tracer,
708            DispatchKey::Autograd,
709            DispatchKey::Cpu,
710        ]);
711        d.call("add", &[t], keyset).unwrap();
712
713        let final_log = log.lock().unwrap();
714        assert_eq!(*final_log, vec!["tracer", "autograd", "cpu"]);
715    }
716}