numrs/backend/
dispatch.rs

1//! Kernel dispatch table - Zero-cost runtime dispatch system
2//!
3//! Este módulo implementa un sistema de dispatch que:
4//! 1. Se inicializa UNA VEZ al startup
5//! 2. Valida qué backends están disponibles (incluyendo WebGPU)
6//! 3. Crea function pointers para cada operación
7//! 4. El hot-path solo hace `dispatch_table.matmul(a, b)` sin branches
8//!
9//! Arquitectura:
10//! ```text
11//! Startup:
12//!   - Detectar capabilities (SIMD, GPU, BLAS, WebGPU)
13//!   - Validar cada backend (GPU probe, BLAS test call)
14//!   - Elegir mejor implementación por operación
15//!   - Guardar function pointers en DispatchTable
16//!
17//! Runtime:
18//!   - get_dispatch_table() → &'static DispatchTable
19//!   - table.add(a, b) → llama directamente al kernel elegido
20//!   - ZERO overhead (direct call, no match/if)
21//! ```
22
23use crate::array::Array;
24use crate::llo::reduction::ReductionKind;
25use crate::llo::ElementwiseKind;
26use anyhow::Result;
27use once_cell::sync::OnceCell;
28use std::fmt;
29use std::sync::RwLock;
30
31// ============================================================================
32// Runtime Capabilities (migrado desde runtime.rs)
33// ============================================================================
34
35/// Runtime-detected capabilities (features realmente disponibles en el sistema)
36#[derive(Debug, Clone, Copy)]
37pub struct RuntimeCapabilities {
38    pub has_simd: bool,
39    pub has_gpu: bool,
40    pub has_blas: bool,
41    pub has_threads: bool,
42    pub has_wasm_simd: bool,
43    pub has_webgpu: bool,
44}
45
46// ============================================================================
47// Kernel Function Signatures
48// ============================================================================
49
50/// Signature para kernels elementwise (add, mul, etc)
51pub type ElementwiseFn = fn(&Array, &Array, ElementwiseKind) -> Result<Array>;
52
53/// Signature para kernels de reducción (sum, mean, max, min, etc)
54pub type ReductionFn = fn(&Array, Option<usize>, ReductionKind) -> Result<Array>;
55
56/// Signature para matmul
57pub type MatmulFn = fn(&Array, &Array) -> Result<Array>;
58
59/// Signature para dot product (retorna scalar)
60pub type DotFn = fn(&Array, &Array) -> Result<f32>;
61
62// ============================================================================
63// Dispatch Table - Almacena function pointers
64// ============================================================================
65
66/// Dispatch table con function pointers seleccionados al startup.
67/// Todos los campos son public para acceso directo sin getter overhead.
68#[derive(Clone, Copy)]
69pub struct DispatchTable {
70    /// Kernel para operaciones elementwise (add, mul, div, etc)
71    pub elementwise: ElementwiseFn,
72
73    /// Kernel para reducciones (sum, mean, etc)
74    pub reduction: ReductionFn,
75
76    /// Kernel para matrix multiplication
77    pub matmul: MatmulFn,
78
79    /// Kernel para dot product
80    pub dot: DotFn,
81
82    /// Metadata: nombre del backend usado para elementwise
83    pub elementwise_backend: &'static str,
84
85    /// Metadata: nombre del backend usado para reduction
86    pub reduction_backend: &'static str,
87
88    /// Metadata: nombre del backend usado para matmul
89    pub matmul_backend: &'static str,
90
91    /// Metadata: nombre del backend usado para dot
92    pub dot_backend: &'static str,
93}
94
95impl fmt::Debug for DispatchTable {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        f.debug_struct("DispatchTable")
98            .field("elementwise_backend", &self.elementwise_backend)
99            .field("reduction_backend", &self.reduction_backend)
100            .field("matmul_backend", &self.matmul_backend)
101            .field("dot_backend", &self.dot_backend)
102            .finish()
103    }
104}
105
106// Global dispatch table (inicializado una vez al startup)
107static DISPATCH_TABLE: OnceCell<DispatchTable> = OnceCell::new();
108
109// Global adaptive lookup tables (pobladas por microbenchmarks)
110static MATMUL_LOOKUP: OnceCell<crate::backend::microbench::AdaptiveLookupTable<MatmulFn>> =
111    OnceCell::new();
112static ELEMENTWISE_LOOKUP: OnceCell<
113    crate::backend::microbench::AdaptiveLookupTable<ElementwiseFn>,
114> = OnceCell::new();
115static REDUCTION_LOOKUP: OnceCell<crate::backend::microbench::AdaptiveLookupTable<ReductionFn>> =
116    OnceCell::new();
117
118// Backend override para benchmarking (fuerza un kernel específico)
119static BACKEND_OVERRIDE: RwLock<Option<&'static str>> = RwLock::new(None);
120
121// ============================================================================
122// Backend Validation - Verifica que cada backend realmente funciona
123// ============================================================================
124
125/// Resultados de validación de backends
126#[derive(Debug, Clone)]
127pub struct BackendValidation {
128    pub simd_available: bool,
129    pub simd_validated: bool,
130    pub blas_available: bool,
131    pub blas_validated: bool,
132    pub gpu_available: bool,
133    pub gpu_validated: bool,
134    pub webgpu_available: bool,
135    pub webgpu_validated: bool,
136    pub metal_available: bool,
137    pub metal_validated: bool,
138}
139
140/// Valida que cada backend realmente funciona (no solo que está compilado)
141pub fn validate_backends() -> BackendValidation {
142    let mut validation = BackendValidation {
143        simd_available: false,
144        simd_validated: false,
145        blas_available: false,
146        blas_validated: false,
147        gpu_available: false,
148        gpu_validated: false,
149        webgpu_available: false,
150        webgpu_validated: false,
151        metal_available: false,
152        metal_validated: false,
153    };
154
155    // 1. SIMD Validation
156    validation.simd_available = cfg!(numrs_kernel_elementwise_simd)
157        || crate::backend::cpu::simd::elementwise_simd_supported();
158
159    if validation.simd_available {
160        // Test rápido: crear arrays pequeños y ejecutar add
161        let a = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
162        let b = Array::new(vec![4], vec![1.0, 1.0, 1.0, 1.0]);
163
164        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
165        {
166            // Test WASM SIMD specific
167            match crate::backend::cpu::simd::elementwise_simd(&a, &b, ElementwiseKind::Add) {
168                Ok(result) => {
169                    // Verificar resultado correcto
170                    validation.simd_validated =
171                        result.data.len() == 4 && (result.data[0] - 2.0).abs() < 0.001;
172                }
173                Err(_) => validation.simd_validated = false,
174            }
175        }
176
177        #[cfg(numrs_kernel_elementwise_simd)]
178        {
179            match crate::backend::cpu::simd::elementwise_simd(&a, &b, ElementwiseKind::Add) {
180                Ok(result) => {
181                    // Verificar resultado correcto
182                    validation.simd_validated =
183                        result.data.len() == 4 && (result.data[0] - 2.0).abs() < 0.001;
184                }
185                Err(_) => validation.simd_validated = false,
186            }
187        }
188    }
189
190    // 2. BLAS Validation
191    // CAMBIO: siempre intentar validar BLAS incluso si cfg dice que no está
192    // Esto soluciona problemas de propagación de features entre crates
193    validation.blas_available = cfg!(numrs_has_blas);
194
195    // Intento agresivo: llamar BLAS directamente sin depender solo de cfg
196    let _a = Array::new(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
197    let _b = Array::new(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]);
198
199    #[cfg(numrs_has_blas)]
200    {
201        // matmul_blas devuelve Array, no Result
202        let result = crate::backend::blas::matmul_blas(&_a, &_b);
203        // Verificar resultado: debe ser igual a 'a' (multiplicar por identidad)
204        validation.blas_validated = result.data.len() == 4
205            && (result.data[0] - 1.0).abs() < 0.001
206            && (result.data[3] - 4.0).abs() < 0.001;
207
208        if validation.blas_validated {
209            validation.blas_available = true; // Forzar available si validó exitosamente
210        }
211    }
212
213    // 3. WebGPU Validation (más complejo - requires async probe)
214    #[cfg(target_arch = "wasm32")]
215    {
216        // IMPORTANT: Never auto-validate WebGPU on WASM at startup!
217        // GPU operations require async JS coordination and will block if called synchronously.
218        // pollster::block_on() doesn't work in WASM - it blocks the main thread forever.
219        // WebGPU is only available after explicit async JS initialization (future work).
220        validation.webgpu_available = false;
221        validation.webgpu_validated = false;
222
223        eprintln!("[numrs-dispatch] WebGPU disabled for WASM (async arch required for GPU ops)");
224    }
225
226    #[cfg(not(target_arch = "wasm32"))]
227    {
228        validation.webgpu_available = cfg!(numrs_kernel_elementwise_gpu);
229
230        if validation.webgpu_available {
231            // Usar el probe existente (cached)
232            validation.webgpu_validated = crate::backend::webgpu::is_available_cached();
233
234            // Si el probe dice que está disponible, intentar operación real
235            if validation.webgpu_validated {
236                // TODO: agregar test funcional de WebGPU cuando el backend esté completo
237                // Por ahora confiamos en el probe
238                #[cfg(debug_assertions)]
239                eprintln!("[numrs-dispatch] WebGPU detected and validated via probe");
240            }
241        }
242    }
243
244    // 4. Metal Validation (macOS only)
245    validation.metal_available = cfg!(target_os = "macos");
246
247    if validation.metal_available {
248        // Usar el probe de Metal (cached)
249        validation.metal_validated = crate::backend::metal::is_available_cached();
250
251        if validation.metal_validated {
252            eprintln!("[numrs-dispatch] Metal detected and validated via probe");
253        }
254    }
255
256    // 5. GPU genérico (CUDA placeholder)
257    validation.gpu_available = cfg!(numrs_kernel_matmul_gpu);
258    // GPU validation pendiente hasta que se implemente CUDA
259    validation.gpu_validated = false;
260
261    validation
262}
263
264// ============================================================================
265// Kernel Selection - Elige la mejor implementación por operación
266// ============================================================================
267
268/// Estrategia de selección basada en validation + benchmarks opcionales
269pub fn select_kernels(validation: &BackendValidation) -> DispatchTable {
270    // Definir todas las implementaciones disponibles
271
272    // --- MATMUL ---
273    // SIEMPRE usar kernel adaptativo que decide basado en tamaño
274    // Si probing está disabled, el kernel usa heurística estática
275    // Si probing está enabled, consulta la lookup table con microbenchmarks
276    let (matmul, mm_backend) = (kernel_matmul_adaptive as MatmulFn, "adaptive");
277
278    // --- ELEMENTWISE ---
279    // SIEMPRE adaptativo
280    let (elementwise, elem_backend) = (kernel_elementwise_adaptive as ElementwiseFn, "adaptive");
281
282    // --- REDUCTION ---
283    // SIEMPRE adaptativo
284    let (reduction, red_backend) = (kernel_reduction_adaptive as ReductionFn, "adaptive");
285
286    // --- DOT PRODUCT ---
287    let (dot, dot_backend) = {
288        #[cfg(feature = "blas-backend")]
289        {
290            if validation.blas_validated {
291                // BLAS sdot es la mejor opción (5-10x más rápido)
292                (kernel_dot_blas as DotFn, "blas")
293            } else if validation.simd_validated {
294                // SIMD con FMA (2-3x más rápido)
295                (kernel_dot_simd as DotFn, "cpu-simd")
296            } else {
297                // Scalar fallback
298                (kernel_dot_scalar as DotFn, "cpu-scalar")
299            }
300        }
301        #[cfg(not(feature = "blas-backend"))]
302        {
303            if validation.simd_validated {
304                // SIMD con FMA (2-3x más rápido)
305                (kernel_dot_simd as DotFn, "cpu-simd")
306            } else {
307                // Scalar fallback
308                (kernel_dot_scalar as DotFn, "cpu-scalar")
309            }
310        }
311    };
312
313    // Si probing está habilitado, refinar selección con microbenchmarks
314    // Si NO, inicializar lookup tables con heurística estática
315    let config = crate::backend::microbench::BenchConfig::from_env();
316
317    if config.enabled {
318        let (
319            elementwise,
320            elem_backend,
321            reduction,
322            red_backend,
323            matmul,
324            mm_backend,
325            dot,
326            dot_backend,
327        ) = refine_with_probing(
328            validation,
329            elementwise,
330            elem_backend,
331            reduction,
332            red_backend,
333            matmul,
334            mm_backend,
335            dot,
336            dot_backend,
337        );
338
339        DispatchTable {
340            elementwise,
341            reduction,
342            matmul,
343            dot,
344            elementwise_backend: elem_backend,
345            reduction_backend: red_backend,
346            matmul_backend: mm_backend,
347            dot_backend,
348        }
349    } else {
350        // Inicializar lookup tables con heurística (sin microbenchmarks)
351        #[cfg(debug_assertions)]
352        eprintln!("[numrs-dispatch] Initializing adaptive lookup tables (heuristic mode)");
353
354        let matmul_table = crate::backend::microbench::benchmark_matmul(validation, &config);
355        let elem_table = crate::backend::microbench::benchmark_elementwise(validation, &config);
356        let red_table = crate::backend::microbench::benchmark_reduction(validation, &config);
357
358        let _ = MATMUL_LOOKUP.set(matmul_table);
359        let _ = ELEMENTWISE_LOOKUP.set(elem_table);
360        let _ = REDUCTION_LOOKUP.set(red_table);
361
362        DispatchTable {
363            elementwise,
364            reduction,
365            matmul,
366            dot,
367            elementwise_backend: elem_backend,
368            reduction_backend: red_backend,
369            matmul_backend: mm_backend,
370            dot_backend,
371        }
372    }
373}
374
375/// Refina la selección ejecutando microbenchmarks entre candidatos
376/// Puebla las lookup tables globales para dispatch adaptativo
377#[allow(clippy::type_complexity)]
378#[allow(unused_variables)]
379fn refine_with_probing(
380    validation: &BackendValidation,
381    elementwise: ElementwiseFn,
382    elem_backend: &'static str,
383    reduction: ReductionFn,
384    red_backend: &'static str,
385    matmul: MatmulFn,
386    mm_backend: &'static str,
387    dot: DotFn,
388    dot_backend: &'static str,
389) -> (
390    ElementwiseFn,
391    &'static str,
392    ReductionFn,
393    &'static str,
394    MatmulFn,
395    &'static str,
396    DotFn,
397    &'static str,
398) {
399    eprintln!("[numrs-dispatch] Running microbenchmarks for adaptive kernel selection...");
400
401    let config = crate::backend::microbench::BenchConfig::from_env();
402
403    // Ejecutar microbenchmarks y crear lookup tables
404    let matmul_table = crate::backend::microbench::benchmark_matmul(validation, &config);
405    let elem_table = crate::backend::microbench::benchmark_elementwise(validation, &config);
406    let red_table = crate::backend::microbench::benchmark_reduction(validation, &config);
407
408    // Guardar las lookup tables en globals
409    let _ = MATMUL_LOOKUP.set(matmul_table);
410    let _ = ELEMENTWISE_LOOKUP.set(elem_table);
411    let _ = REDUCTION_LOOKUP.set(red_table);
412
413    eprintln!("[numrs-dispatch] Adaptive lookup tables created");
414    eprintln!("[numrs-dispatch] Kernels will select backend dynamically based on input size");
415
416    // Usar kernels adaptativos que consultan las lookup tables
417    (
418        kernel_elementwise_adaptive as ElementwiseFn,
419        "adaptive",
420        kernel_reduction_adaptive as ReductionFn,
421        "adaptive",
422        kernel_matmul_adaptive as MatmulFn,
423        "adaptive",
424        dot,
425        dot_backend,
426    )
427}
428
429// ============================================================================
430// Adaptive Kernels - Consultan lookup tables en runtime
431// ============================================================================
432
433/// Kernel adaptativo de matmul - consulta MATMUL_LOOKUP y llama al kernel apropiado
434fn kernel_matmul_adaptive(a: &Array, b: &Array) -> Result<Array> {
435    // Check for backend override (for benchmarking)
436    if let Ok(guard) = BACKEND_OVERRIDE.read() {
437        if let Some(backend) = *guard {
438            return match backend {
439                "scalar" => kernel_matmul_scalar(a, b),
440                "simd" => kernel_matmul_simd(a, b),
441                "blas" => kernel_matmul_blas_direct(a, b),
442                "webgpu" => kernel_matmul_webgpu(a, b),
443                "metal" => kernel_matmul_metal(a, b),
444                _ => kernel_matmul_blas_direct(a, b),
445            };
446        }
447    }
448
449    let size = a.shape[0] * b.shape[1]; // output matrix size
450
451    if let Some(lookup) = MATMUL_LOOKUP.get() {
452        // Usar lookup table del microbenchmark
453        let kernel = lookup.lookup(size);
454        return kernel(a, b);
455    }
456
457    // Fallback si no hay lookup table (no debería pasar)
458    kernel_matmul_blas_direct(a, b)
459}
460
461/// Kernel adaptativo de elementwise - consulta ELEMENTWISE_LOOKUP
462fn kernel_elementwise_adaptive(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
463    // Check for backend override (for benchmarking)
464    if let Ok(guard) = BACKEND_OVERRIDE.read() {
465        if let Some(backend) = *guard {
466            return match backend {
467                "scalar" => kernel_elementwise_scalar(a, b, kind),
468                "simd" => kernel_elementwise_simd(a, b, kind),
469                "webgpu" => kernel_elementwise_webgpu(a, b, kind),
470                "metal" => kernel_elementwise_metal(a, b, kind),
471                _ => kernel_elementwise_simd(a, b, kind),
472            };
473        }
474    }
475
476    let size = a.data.len();
477
478    if let Some(lookup) = ELEMENTWISE_LOOKUP.get() {
479        let kernel = lookup.lookup(size);
480        return kernel(a, b, kind);
481    }
482
483    // Fallback
484    kernel_elementwise_simd(a, b, kind)
485}
486
487/// Kernel adaptativo de reduction - consulta REDUCTION_LOOKUP
488fn kernel_reduction_adaptive(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
489    // Check for backend override (for benchmarking)
490    if let Ok(guard) = BACKEND_OVERRIDE.read() {
491        if let Some(backend) = *guard {
492            return match backend {
493                "scalar" => kernel_reduction_scalar(a, axis, kind),
494                "simd" => kernel_reduction_simd(a, axis, kind),
495                "blas" => kernel_reduction_blas(a, axis, kind),
496                _ => kernel_reduction_simd(a, axis, kind),
497            };
498        }
499    }
500
501    let size = a.data.len();
502
503    if let Some(lookup) = REDUCTION_LOOKUP.get() {
504        let kernel = lookup.lookup(size);
505        return kernel(a, axis, kind);
506    }
507
508    // Fallback
509    kernel_reduction_simd(a, axis, kind)
510}
511
512// ============================================================================
513// Kernel Implementations - Wrappers adaptativos que deciden internamente
514// ============================================================================
515
516/// Kernel elementwise usando Metal (macOS only)
517fn kernel_elementwise_metal(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
518    #[cfg(target_os = "macos")]
519    {
520        crate::backend::metal::elementwise_metal(a, b, kind)
521    }
522
523    #[cfg(not(target_os = "macos"))]
524    {
525        // Fallback si no es macOS
526        kernel_elementwise_webgpu(a, b, kind)
527    }
528}
529
530/// Kernel elementwise usando WebGPU
531fn kernel_elementwise_webgpu(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
532    #[cfg(numrs_kernel_elementwise_gpu)]
533    {
534        crate::backend::webgpu::elementwise_webgpu(a, b, kind)
535    }
536
537    #[cfg(not(numrs_kernel_elementwise_gpu))]
538    {
539        // Fallback si no está compilado
540        kernel_elementwise_scalar(a, b, kind)
541    }
542}
543
544/// Kernel elementwise usando SIMD
545pub fn kernel_elementwise_simd(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
546    #[cfg(numrs_kernel_elementwise_simd)]
547    {
548        crate::backend::cpu::simd::elementwise_simd(a, b, kind)
549    }
550
551    #[cfg(not(numrs_kernel_elementwise_simd))]
552    {
553        kernel_elementwise_scalar(a, b, kind)
554    }
555}
556
557/// Kernel elementwise usando scalar (siempre disponible)
558fn kernel_elementwise_scalar(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
559    crate::backend::cpu::scalar::elementwise_scalar(a, b, kind)
560}
561
562/// Kernel reduction usando BLAS
563fn kernel_reduction_blas(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
564    #[cfg(numrs_has_blas)]
565    {
566        // TODO: implementar reduce con BLAS cuando esté disponible
567        // Por ahora fallback a SIMD
568        kernel_reduction_simd(a, axis, kind)
569    }
570
571    #[cfg(not(numrs_has_blas))]
572    {
573        kernel_reduction_simd(a, axis, kind)
574    }
575}
576
577/// Kernel reduction usando SIMD
578pub fn kernel_reduction_simd(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
579    #[cfg(numrs_kernel_sum_simd)]
580    {
581        crate::backend::cpu::simd::reduce_simd(a, axis, kind)
582    }
583
584    #[cfg(not(numrs_kernel_sum_simd))]
585    {
586        kernel_reduction_scalar(a, axis, kind)
587    }
588}
589
590/// Kernel reduction usando scalar
591fn kernel_reduction_scalar(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
592    crate::backend::cpu::scalar::reduce_scalar(a, axis, kind)
593}
594
595/// Kernel matmul usando BLAS directo (sin decisiones internas)
596pub fn kernel_matmul_blas_direct(a: &Array, b: &Array) -> Result<Array> {
597    #[cfg(numrs_has_blas)]
598    {
599        // BLAS/MKL ya tiene paralelización interna optimizada
600        Ok(crate::backend::blas::matmul_blas(a, b))
601    }
602
603    #[cfg(not(numrs_has_blas))]
604    {
605        kernel_matmul_simd(a, b)
606    }
607}
608
609/// Kernel matmul usando Metal (macOS only)
610pub fn kernel_matmul_metal(a: &Array, b: &Array) -> Result<Array> {
611    #[cfg(target_os = "macos")]
612    {
613        crate::backend::metal::matmul_metal(a, b)
614    }
615
616    #[cfg(not(target_os = "macos"))]
617    {
618        // Si no es macOS, intentar SIMD antes de scalar
619        kernel_matmul_simd(a, b)
620    }
621}
622
623/// Kernel matmul usando WebGPU
624pub fn kernel_matmul_webgpu(a: &Array, b: &Array) -> Result<Array> {
625    #[cfg(numrs_kernel_matmul_gpu)]
626    {
627        Ok(crate::backend::webgpu::matmul_webgpu(a, b))
628    }
629
630    #[cfg(not(numrs_kernel_matmul_gpu))]
631    {
632        // Fallback a SIMD primero si está disponible
633        kernel_matmul_simd(a, b)
634    }
635}
636
637/// Kernel matmul usando SIMD
638pub fn kernel_matmul_simd(a: &Array, b: &Array) -> Result<Array> {
639    #[cfg(numrs_kernel_matmul_simd)]
640    {
641        Ok(crate::backend::cpu::simd::matmul_simd(a, b))
642    }
643
644    #[cfg(not(numrs_kernel_matmul_simd))]
645    {
646        kernel_matmul_scalar(a, b)
647    }
648}
649
650/// Kernel matmul usando scalar (siempre disponible)
651pub fn kernel_matmul_scalar(a: &Array, b: &Array) -> Result<Array> {
652    // Usar implementación scalar paralela (con Rayon pero sin SIMD)
653    Ok(crate::backend::cpu::matmul_scalar_parallel(a, b))
654}
655
656// --- DOT PRODUCT KERNELS ---
657
658/// Kernel dot usando BLAS sdot (máximo rendimiento)
659#[cfg(feature = "blas-backend")]
660fn kernel_dot_blas(a: &Array, b: &Array) -> Result<f32> {
661    crate::backend::blas::dot_blas(a, b)
662}
663
664/// Kernel dot usando SIMD con FMA
665fn kernel_dot_simd(a: &Array, b: &Array) -> Result<f32> {
666    crate::backend::cpu::simd::dot_simd(a, b)
667}
668
669/// Kernel dot usando scalar (siempre disponible)
670fn kernel_dot_scalar(a: &Array, b: &Array) -> Result<f32> {
671    crate::backend::cpu::scalar::dot_scalar(a, b)
672}
673
674// ============================================================================
675// Public API - Inicialización y acceso al dispatch table
676// ============================================================================
677
678/// Inicializa el dispatch table (llamar al startup)
679pub fn init_dispatch_table() -> &'static DispatchTable {
680    DISPATCH_TABLE.get_or_init(|| {
681        #[cfg(debug_assertions)]
682        eprintln!("[numrs-dispatch] Initializing dispatch table...");
683
684        // 1. Validar backends
685        let validation = validate_backends();
686        #[cfg(debug_assertions)]
687        eprintln!("[numrs-dispatch] Validation results: {:?}", validation);
688
689        // 2. Seleccionar kernels
690        let table = select_kernels(&validation);
691
692        #[cfg(debug_assertions)]
693        {
694            eprintln!("[numrs-dispatch] Selected kernels:");
695            eprintln!("  - elementwise: {}", table.elementwise_backend);
696            eprintln!("  - reduction:   {}", table.reduction_backend);
697            eprintln!(
698                "  - matmul:      {} (validates: blas={}, metal={}, webgpu={}, simd={})",
699                table.matmul_backend,
700                validation.blas_validated,
701                validation.metal_validated,
702                validation.webgpu_validated,
703                validation.simd_validated
704            );
705            eprintln!("  - dot:         {}", table.dot_backend);
706        }
707
708        table
709    })
710}
711
712/// Obtiene el dispatch table (inicializa si es necesario)
713pub fn get_dispatch_table() -> &'static DispatchTable {
714    DISPATCH_TABLE.get_or_init(|| {
715        // Validar backends
716        let validation = validate_backends();
717
718        // Seleccionar kernels (internamente decide si hacer probing o no)
719        select_kernels(&validation)
720    })
721}
722
723/// Force reinitialize the dispatch table (WASM only - for WebGPU JS integration)
724#[cfg(target_arch = "wasm32")]
725pub fn force_reinitialize_dispatch() {
726    // For WASM, we need to recreate the dispatch table after JavaScript
727    // has initialized WebGPU and set the availability flag
728    unsafe {
729        // SAFETY: In WASM there's no true multi-threading, so this is safe
730        // We're using this to allow JavaScript to signal WebGPU is ready
731        let ptr = &DISPATCH_TABLE as *const OnceCell<DispatchTable> as *mut OnceCell<DispatchTable>;
732        (*ptr).take(); // Clear the existing table
733    }
734    // Next call to get_dispatch_table() will reinitialize with WebGPU available
735}
736
737/// Re-exportar para uso público
738pub use get_dispatch_table as table;
739
740/// Forzar un backend específico (para benchmarking)
741/// Valores válidos: "scalar", "simd", "blas", "webgpu", "metal"
742/// Usar `None` para restaurar comportamiento adaptativo
743pub fn set_backend_override(backend: Option<&'static str>) {
744    if let Ok(mut guard) = BACKEND_OVERRIDE.write() {
745        *guard = backend;
746    }
747}
748
749/// Obtener el backend override actual
750pub fn get_backend_override() -> Option<&'static str> {
751    BACKEND_OVERRIDE.read().ok().and_then(|guard| *guard)
752}
753
754// ============================================================================
755// Generic Dispatch Functions - Despacho basado en tipo en runtime
756// ============================================================================
757
758/// Dispatch genérico para operaciones elementwise que mantiene el tipo nativo
759///
760/// Esta función hace dispatch basado en el tipo T en runtime, llamando al
761/// kernel apropiado y manteniendo los datos en su tipo nativo (Vec<T>).
762///
763/// NOTA: Devuelve Array<f32> con dtype configurado para mantener compatibilidad
764/// con la API existente. Los datos se convierten al final.
765#[inline]
766pub fn dispatch_elementwise_generic<T>(
767    a: &Array<T>,
768    b: &Array<T>,
769    kind: ElementwiseKind,
770) -> Result<Array<T>>
771where
772    T: crate::array::DTypeValue,
773{
774    use std::any::TypeId;
775
776    // **ZERO-COPY OPTIMIZATION**: Materializar solo si necesario
777    // Los kernels CPU pueden trabajar con strides directamente
778    // Los kernels GPU/SIMD necesitan arrays contiguos
779    let needs_contiguous = should_materialize_for_backend(a, b);
780
781    let a_ref = if needs_contiguous && !a.is_contiguous() {
782        &a.to_contiguous()
783    } else {
784        a
785    };
786
787    let b_ref = if needs_contiguous && !b.is_contiguous() {
788        &b.to_contiguous()
789    } else {
790        b
791    };
792
793    // Despacho basado en tipo
794    if TypeId::of::<T>() == TypeId::of::<f32>() {
795        // f32: usar dispatch table legacy
796        let a_f32 = unsafe { &*(a_ref as *const Array<T> as *const Array<f32>) };
797        let b_f32 = unsafe { &*(b_ref as *const Array<T> as *const Array<f32>) };
798        let table = get_dispatch_table();
799        let result = (table.elementwise)(a_f32, b_f32, kind)?;
800        return Ok(unsafe { std::mem::transmute::<Array<f32>, Array<T>>(result) });
801    }
802
803    if TypeId::of::<T>() == TypeId::of::<f64>() {
804        // f64: ejecutar operación nativa en f64
805        let a_f64 = unsafe { &*(a_ref as *const Array<T> as *const Array<f64>) };
806        let b_f64 = unsafe { &*(b_ref as *const Array<T> as *const Array<f64>) };
807        let result_f64 = elementwise_f64_native(a_f64, b_f64, kind)?;
808        return Ok(unsafe { std::mem::transmute::<Array<f64>, Array<T>>(result_f64) });
809    }
810
811    if TypeId::of::<T>() == TypeId::of::<i32>() {
812        // i32: ejecutar operación nativa en i32
813        let a_i32 = unsafe { &*(a_ref as *const Array<T> as *const Array<i32>) };
814        let b_i32 = unsafe { &*(b_ref as *const Array<T> as *const Array<i32>) };
815        let result_i32 = elementwise_i32_native(a_i32, b_i32, kind)?;
816        return Ok(unsafe { std::mem::transmute::<Array<i32>, Array<T>>(result_i32) });
817    }
818
819    // Para otros tipos: fallback usando conversión a f32
820    let a_data: Vec<f32> = a_ref
821        .data
822        .iter()
823        .map(|&x| crate::array::DTypeValue::to_f32(x))
824        .collect();
825    let b_data: Vec<f32> = b_ref
826        .data
827        .iter()
828        .map(|&x| crate::array::DTypeValue::to_f32(x))
829        .collect();
830    let a_temp = Array::new(a_ref.shape.clone(), a_data);
831    let b_temp = Array::new(b_ref.shape.clone(), b_data);
832
833    let table = get_dispatch_table();
834    let result = (table.elementwise)(&a_temp, &b_temp, kind)?;
835    Ok(unsafe { std::mem::transmute::<Array<f32>, Array<T>>(result) })
836}
837
838/// Determina si los arrays deben materializarse para el backend actual
839///
840/// **FILOSOFÍA**: Materializar lo menos posible. Los kernels CPU stride-aware
841/// son suficientemente rápidos, y SIMD también puede trabajar con views.
842/// Solo materializar cuando GPU lo NECESITA (contiguous memory requirement).
843///
844/// Retorna true si:
845/// - Se va a usar GPU Y el array es MUY grande (>1M elementos)
846///
847/// Retorna false si:
848/// - CPU/SIMD: pueden trabajar con strides eficientemente
849/// - Arrays ya contiguos: no hay beneficio
850/// - Arrays pequeños/medianos: overhead de copia > beneficio
851#[inline]
852fn should_materialize_for_backend<T>(a: &Array<T>, b: &Array<T>) -> bool
853where
854    T: crate::array::DTypeValue,
855{
856    // Si ambos son contiguos, no hay nada que hacer
857    if a.is_contiguous() && b.is_contiguous() {
858        return false;
859    }
860
861    // GPU solo para arrays MUY grandes (>1M elementos)
862    #[cfg(feature = "webgpu")]
863    if crate::backend::webgpu::is_available_cached() {
864        let size: usize = a.shape.iter().product();
865        if size > 1_000_000 {
866            return true; // Materializar solo si es REALMENTE grande
867        }
868    }
869
870    // CPU/SIMD: trabajar con strides es eficiente, NO materializar
871    false
872}
873
874/// Operación elementwise nativa para f64 (sin conversión a f32)
875///
876/// **ZERO-COPY**: Esta función trabaja directamente con strides si están presentes,
877/// evitando materialización innecesaria para operaciones CPU.
878#[inline]
879fn elementwise_f64_native(
880    a: &Array<f64>,
881    b: &Array<f64>,
882    kind: ElementwiseKind,
883) -> Result<Array<f64>> {
884    let size: usize = a.shape.iter().product();
885    let mut result_data = Vec::with_capacity(size);
886
887    // Fast path: ambos arrays contiguos (sin strides)
888    if a.is_contiguous() && b.is_contiguous() {
889        match kind {
890            ElementwiseKind::Add => {
891                for i in 0..a.data.len() {
892                    result_data.push(a.data[i] + b.data[i]);
893                }
894            }
895            ElementwiseKind::Sub => {
896                for i in 0..a.data.len() {
897                    result_data.push(a.data[i] - b.data[i]);
898                }
899            }
900            ElementwiseKind::Mul => {
901                for i in 0..a.data.len() {
902                    result_data.push(a.data[i] * b.data[i]);
903                }
904            }
905            ElementwiseKind::Div => {
906                for i in 0..a.data.len() {
907                    result_data.push(a.data[i] / b.data[i]);
908                }
909            }
910            ElementwiseKind::Pow => {
911                for i in 0..a.data.len() {
912                    result_data.push(a.data[i].powf(b.data[i]));
913                }
914            }
915            _ => anyhow::bail!("Unsupported elementwise operation for f64: {:?}", kind),
916        }
917    } else {
918        // Stride-aware path: indexación con strides
919        let a_strides = a.get_strides();
920        let b_strides = b.get_strides();
921
922        let mut indices = vec![0usize; a.shape.len()];
923
924        for _ in 0..size {
925            // Calcular offsets con strides
926            let mut a_idx = a.offset as isize;
927            let mut b_idx = b.offset as isize;
928            for (i, &idx) in indices.iter().enumerate() {
929                a_idx += idx as isize * a_strides[i];
930                b_idx += idx as isize * b_strides[i];
931            }
932
933            // Bounds check
934            let a_idx_u = (a_idx as usize).min(a.data.len().saturating_sub(1));
935            let b_idx_u = (b_idx as usize).min(b.data.len().saturating_sub(1));
936
937            let val = match kind {
938                ElementwiseKind::Add => a.data[a_idx_u] + b.data[b_idx_u],
939                ElementwiseKind::Sub => a.data[a_idx_u] - b.data[b_idx_u],
940                ElementwiseKind::Mul => a.data[a_idx_u] * b.data[b_idx_u],
941                ElementwiseKind::Div => a.data[a_idx_u] / b.data[b_idx_u],
942                ElementwiseKind::Pow => a.data[a_idx_u].powf(b.data[b_idx_u]),
943                _ => anyhow::bail!("Unsupported elementwise operation for f64: {:?}", kind),
944            };
945            result_data.push(val);
946
947            // Incrementar índices (orden C)
948            for i in (0..a.shape.len()).rev() {
949                indices[i] += 1;
950                if indices[i] < a.shape[i] {
951                    break;
952                }
953                indices[i] = 0;
954            }
955        }
956    }
957
958    let mut result = Array::new(a.shape.clone(), result_data);
959    result.dtype = crate::array::DType::F64;
960    Ok(result)
961}
962
963/// Operación elementwise nativa para i32 (sin conversión a f32)
964///
965/// **ZERO-COPY**: Esta función trabaja directamente con strides si están presentes,
966/// evitando materialización innecesaria para operaciones CPU.
967#[inline]
968fn elementwise_i32_native(
969    a: &Array<i32>,
970    b: &Array<i32>,
971    kind: ElementwiseKind,
972) -> Result<Array<i32>> {
973    let size: usize = a.shape.iter().product();
974    let mut result_data = Vec::with_capacity(size);
975
976    // Fast path: ambos arrays contiguos (sin strides)
977    if a.is_contiguous() && b.is_contiguous() {
978        match kind {
979            ElementwiseKind::Add => {
980                for i in 0..a.data.len() {
981                    result_data.push(a.data[i] + b.data[i]);
982                }
983            }
984            ElementwiseKind::Sub => {
985                for i in 0..a.data.len() {
986                    result_data.push(a.data[i] - b.data[i]);
987                }
988            }
989            ElementwiseKind::Mul => {
990                for i in 0..a.data.len() {
991                    result_data.push(a.data[i] * b.data[i]);
992                }
993            }
994            ElementwiseKind::Div => {
995                for i in 0..a.data.len() {
996                    result_data.push(a.data[i] / b.data[i]);
997                }
998            }
999            ElementwiseKind::Pow => {
1000                for i in 0..a.data.len() {
1001                    result_data.push(a.data[i].pow(b.data[i] as u32));
1002                }
1003            }
1004            _ => anyhow::bail!("Unsupported elementwise operation for i32: {:?}", kind),
1005        }
1006    } else {
1007        // Stride-aware path: indexación con strides
1008        let a_strides = a.get_strides();
1009        let b_strides = b.get_strides();
1010
1011        let mut indices = vec![0usize; a.shape.len()];
1012
1013        for _ in 0..size {
1014            // Calcular offsets con strides
1015            let mut a_idx = a.offset as isize;
1016            let mut b_idx = b.offset as isize;
1017            for (i, &idx) in indices.iter().enumerate() {
1018                a_idx += idx as isize * a_strides[i];
1019                b_idx += idx as isize * b_strides[i];
1020            }
1021
1022            // Bounds check
1023            let a_idx_u = (a_idx as usize).min(a.data.len().saturating_sub(1));
1024            let b_idx_u = (b_idx as usize).min(b.data.len().saturating_sub(1));
1025
1026            let val = match kind {
1027                ElementwiseKind::Add => a.data[a_idx_u] + b.data[b_idx_u],
1028                ElementwiseKind::Sub => a.data[a_idx_u] - b.data[b_idx_u],
1029                ElementwiseKind::Mul => a.data[a_idx_u] * b.data[b_idx_u],
1030                ElementwiseKind::Div => a.data[a_idx_u] / b.data[b_idx_u],
1031                ElementwiseKind::Pow => a.data[a_idx_u].pow(b.data[b_idx_u] as u32),
1032                _ => anyhow::bail!("Unsupported elementwise operation for i32: {:?}", kind),
1033            };
1034            result_data.push(val);
1035
1036            // Incrementar índices (orden C)
1037            for i in (0..a.shape.len()).rev() {
1038                indices[i] += 1;
1039                if indices[i] < a.shape[i] {
1040                    break;
1041                }
1042                indices[i] = 0;
1043            }
1044        }
1045    }
1046
1047    let mut result = Array::new(a.shape.clone(), result_data);
1048    result.dtype = crate::array::DType::I32;
1049    Ok(result)
1050}
1051
1052// ============================================================================
1053
1054// ============================================================================
1055// Tests
1056// ============================================================================
1057
1058#[cfg(test)]
1059mod tests {
1060    use super::*;
1061
1062    #[test]
1063    fn test_dispatch_table_initialization() {
1064        let table = init_dispatch_table();
1065
1066        // Verificar que todos los campos están asignados
1067        assert!(!table.elementwise_backend.is_empty());
1068        assert!(!table.reduction_backend.is_empty());
1069        assert!(!table.matmul_backend.is_empty());
1070
1071        println!("Dispatch table: {:?}", table);
1072    }
1073
1074    #[test]
1075    fn test_backend_validation() {
1076        let validation = validate_backends();
1077
1078        println!("Backend validation: {:?}", validation);
1079
1080        // Al menos scalar debe estar disponible
1081        assert!(validation.simd_available || validation.blas_available || true);
1082    }
1083
1084    #[test]
1085    fn test_elementwise_dispatch() {
1086        let table = get_dispatch_table();
1087
1088        let a = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
1089        let b = Array::new(vec![4], vec![1.0, 1.0, 1.0, 1.0]);
1090
1091        let result = (table.elementwise)(&a, &b, ElementwiseKind::Add);
1092
1093        assert!(result.is_ok());
1094        let result = result.unwrap();
1095        assert_eq!(result.data, vec![2.0, 3.0, 4.0, 5.0]);
1096
1097        println!(
1098            "Elementwise test passed using: {}",
1099            table.elementwise_backend
1100        );
1101    }
1102}