1use crate::array::Array;
24use crate::llo::reduction::ReductionKind;
25use crate::llo::ElementwiseKind;
26use anyhow::Result;
27use once_cell::sync::OnceCell;
28use std::fmt;
29use std::sync::RwLock;
30
31#[derive(Debug, Clone, Copy)]
37pub struct RuntimeCapabilities {
38 pub has_simd: bool,
39 pub has_gpu: bool,
40 pub has_blas: bool,
41 pub has_threads: bool,
42 pub has_wasm_simd: bool,
43 pub has_webgpu: bool,
44}
45
46pub type ElementwiseFn = fn(&Array, &Array, ElementwiseKind) -> Result<Array>;
52
53pub type ReductionFn = fn(&Array, Option<usize>, ReductionKind) -> Result<Array>;
55
56pub type MatmulFn = fn(&Array, &Array) -> Result<Array>;
58
59pub type DotFn = fn(&Array, &Array) -> Result<f32>;
61
62#[derive(Clone, Copy)]
69pub struct DispatchTable {
70 pub elementwise: ElementwiseFn,
72
73 pub reduction: ReductionFn,
75
76 pub matmul: MatmulFn,
78
79 pub dot: DotFn,
81
82 pub elementwise_backend: &'static str,
84
85 pub reduction_backend: &'static str,
87
88 pub matmul_backend: &'static str,
90
91 pub dot_backend: &'static str,
93}
94
95impl fmt::Debug for DispatchTable {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 f.debug_struct("DispatchTable")
98 .field("elementwise_backend", &self.elementwise_backend)
99 .field("reduction_backend", &self.reduction_backend)
100 .field("matmul_backend", &self.matmul_backend)
101 .field("dot_backend", &self.dot_backend)
102 .finish()
103 }
104}
105
106static DISPATCH_TABLE: OnceCell<DispatchTable> = OnceCell::new();
108
109static MATMUL_LOOKUP: OnceCell<crate::backend::microbench::AdaptiveLookupTable<MatmulFn>> =
111 OnceCell::new();
112static ELEMENTWISE_LOOKUP: OnceCell<
113 crate::backend::microbench::AdaptiveLookupTable<ElementwiseFn>,
114> = OnceCell::new();
115static REDUCTION_LOOKUP: OnceCell<crate::backend::microbench::AdaptiveLookupTable<ReductionFn>> =
116 OnceCell::new();
117
118static BACKEND_OVERRIDE: RwLock<Option<&'static str>> = RwLock::new(None);
120
121#[derive(Debug, Clone)]
127pub struct BackendValidation {
128 pub simd_available: bool,
129 pub simd_validated: bool,
130 pub blas_available: bool,
131 pub blas_validated: bool,
132 pub gpu_available: bool,
133 pub gpu_validated: bool,
134 pub webgpu_available: bool,
135 pub webgpu_validated: bool,
136 pub metal_available: bool,
137 pub metal_validated: bool,
138}
139
140pub fn validate_backends() -> BackendValidation {
142 let mut validation = BackendValidation {
143 simd_available: false,
144 simd_validated: false,
145 blas_available: false,
146 blas_validated: false,
147 gpu_available: false,
148 gpu_validated: false,
149 webgpu_available: false,
150 webgpu_validated: false,
151 metal_available: false,
152 metal_validated: false,
153 };
154
155 validation.simd_available = cfg!(numrs_kernel_elementwise_simd)
157 || crate::backend::cpu::simd::elementwise_simd_supported();
158
159 if validation.simd_available {
160 let a = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
162 let b = Array::new(vec![4], vec![1.0, 1.0, 1.0, 1.0]);
163
164 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
165 {
166 match crate::backend::cpu::simd::elementwise_simd(&a, &b, ElementwiseKind::Add) {
168 Ok(result) => {
169 validation.simd_validated =
171 result.data.len() == 4 && (result.data[0] - 2.0).abs() < 0.001;
172 }
173 Err(_) => validation.simd_validated = false,
174 }
175 }
176
177 #[cfg(numrs_kernel_elementwise_simd)]
178 {
179 match crate::backend::cpu::simd::elementwise_simd(&a, &b, ElementwiseKind::Add) {
180 Ok(result) => {
181 validation.simd_validated =
183 result.data.len() == 4 && (result.data[0] - 2.0).abs() < 0.001;
184 }
185 Err(_) => validation.simd_validated = false,
186 }
187 }
188 }
189
190 validation.blas_available = cfg!(numrs_has_blas);
194
195 let _a = Array::new(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
197 let _b = Array::new(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]);
198
199 #[cfg(numrs_has_blas)]
200 {
201 let result = crate::backend::blas::matmul_blas(&_a, &_b);
203 validation.blas_validated = result.data.len() == 4
205 && (result.data[0] - 1.0).abs() < 0.001
206 && (result.data[3] - 4.0).abs() < 0.001;
207
208 if validation.blas_validated {
209 validation.blas_available = true; }
211 }
212
213 #[cfg(target_arch = "wasm32")]
215 {
216 validation.webgpu_available = false;
221 validation.webgpu_validated = false;
222
223 eprintln!("[numrs-dispatch] WebGPU disabled for WASM (async arch required for GPU ops)");
224 }
225
226 #[cfg(not(target_arch = "wasm32"))]
227 {
228 validation.webgpu_available = cfg!(numrs_kernel_elementwise_gpu);
229
230 if validation.webgpu_available {
231 validation.webgpu_validated = crate::backend::webgpu::is_available_cached();
233
234 if validation.webgpu_validated {
236 #[cfg(debug_assertions)]
239 eprintln!("[numrs-dispatch] WebGPU detected and validated via probe");
240 }
241 }
242 }
243
244 validation.metal_available = cfg!(target_os = "macos");
246
247 if validation.metal_available {
248 validation.metal_validated = crate::backend::metal::is_available_cached();
250
251 if validation.metal_validated {
252 eprintln!("[numrs-dispatch] Metal detected and validated via probe");
253 }
254 }
255
256 validation.gpu_available = cfg!(numrs_kernel_matmul_gpu);
258 validation.gpu_validated = false;
260
261 validation
262}
263
264pub fn select_kernels(validation: &BackendValidation) -> DispatchTable {
270 let (matmul, mm_backend) = (kernel_matmul_adaptive as MatmulFn, "adaptive");
277
278 let (elementwise, elem_backend) = (kernel_elementwise_adaptive as ElementwiseFn, "adaptive");
281
282 let (reduction, red_backend) = (kernel_reduction_adaptive as ReductionFn, "adaptive");
285
286 let (dot, dot_backend) = {
288 #[cfg(feature = "blas-backend")]
289 {
290 if validation.blas_validated {
291 (kernel_dot_blas as DotFn, "blas")
293 } else if validation.simd_validated {
294 (kernel_dot_simd as DotFn, "cpu-simd")
296 } else {
297 (kernel_dot_scalar as DotFn, "cpu-scalar")
299 }
300 }
301 #[cfg(not(feature = "blas-backend"))]
302 {
303 if validation.simd_validated {
304 (kernel_dot_simd as DotFn, "cpu-simd")
306 } else {
307 (kernel_dot_scalar as DotFn, "cpu-scalar")
309 }
310 }
311 };
312
313 let config = crate::backend::microbench::BenchConfig::from_env();
316
317 if config.enabled {
318 let (
319 elementwise,
320 elem_backend,
321 reduction,
322 red_backend,
323 matmul,
324 mm_backend,
325 dot,
326 dot_backend,
327 ) = refine_with_probing(
328 validation,
329 elementwise,
330 elem_backend,
331 reduction,
332 red_backend,
333 matmul,
334 mm_backend,
335 dot,
336 dot_backend,
337 );
338
339 DispatchTable {
340 elementwise,
341 reduction,
342 matmul,
343 dot,
344 elementwise_backend: elem_backend,
345 reduction_backend: red_backend,
346 matmul_backend: mm_backend,
347 dot_backend,
348 }
349 } else {
350 #[cfg(debug_assertions)]
352 eprintln!("[numrs-dispatch] Initializing adaptive lookup tables (heuristic mode)");
353
354 let matmul_table = crate::backend::microbench::benchmark_matmul(validation, &config);
355 let elem_table = crate::backend::microbench::benchmark_elementwise(validation, &config);
356 let red_table = crate::backend::microbench::benchmark_reduction(validation, &config);
357
358 let _ = MATMUL_LOOKUP.set(matmul_table);
359 let _ = ELEMENTWISE_LOOKUP.set(elem_table);
360 let _ = REDUCTION_LOOKUP.set(red_table);
361
362 DispatchTable {
363 elementwise,
364 reduction,
365 matmul,
366 dot,
367 elementwise_backend: elem_backend,
368 reduction_backend: red_backend,
369 matmul_backend: mm_backend,
370 dot_backend,
371 }
372 }
373}
374
375#[allow(clippy::type_complexity)]
378#[allow(unused_variables)]
379fn refine_with_probing(
380 validation: &BackendValidation,
381 elementwise: ElementwiseFn,
382 elem_backend: &'static str,
383 reduction: ReductionFn,
384 red_backend: &'static str,
385 matmul: MatmulFn,
386 mm_backend: &'static str,
387 dot: DotFn,
388 dot_backend: &'static str,
389) -> (
390 ElementwiseFn,
391 &'static str,
392 ReductionFn,
393 &'static str,
394 MatmulFn,
395 &'static str,
396 DotFn,
397 &'static str,
398) {
399 eprintln!("[numrs-dispatch] Running microbenchmarks for adaptive kernel selection...");
400
401 let config = crate::backend::microbench::BenchConfig::from_env();
402
403 let matmul_table = crate::backend::microbench::benchmark_matmul(validation, &config);
405 let elem_table = crate::backend::microbench::benchmark_elementwise(validation, &config);
406 let red_table = crate::backend::microbench::benchmark_reduction(validation, &config);
407
408 let _ = MATMUL_LOOKUP.set(matmul_table);
410 let _ = ELEMENTWISE_LOOKUP.set(elem_table);
411 let _ = REDUCTION_LOOKUP.set(red_table);
412
413 eprintln!("[numrs-dispatch] Adaptive lookup tables created");
414 eprintln!("[numrs-dispatch] Kernels will select backend dynamically based on input size");
415
416 (
418 kernel_elementwise_adaptive as ElementwiseFn,
419 "adaptive",
420 kernel_reduction_adaptive as ReductionFn,
421 "adaptive",
422 kernel_matmul_adaptive as MatmulFn,
423 "adaptive",
424 dot,
425 dot_backend,
426 )
427}
428
429fn kernel_matmul_adaptive(a: &Array, b: &Array) -> Result<Array> {
435 if let Ok(guard) = BACKEND_OVERRIDE.read() {
437 if let Some(backend) = *guard {
438 return match backend {
439 "scalar" => kernel_matmul_scalar(a, b),
440 "simd" => kernel_matmul_simd(a, b),
441 "blas" => kernel_matmul_blas_direct(a, b),
442 "webgpu" => kernel_matmul_webgpu(a, b),
443 "metal" => kernel_matmul_metal(a, b),
444 _ => kernel_matmul_blas_direct(a, b),
445 };
446 }
447 }
448
449 let size = a.shape[0] * b.shape[1]; if let Some(lookup) = MATMUL_LOOKUP.get() {
452 let kernel = lookup.lookup(size);
454 return kernel(a, b);
455 }
456
457 kernel_matmul_blas_direct(a, b)
459}
460
461fn kernel_elementwise_adaptive(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
463 if let Ok(guard) = BACKEND_OVERRIDE.read() {
465 if let Some(backend) = *guard {
466 return match backend {
467 "scalar" => kernel_elementwise_scalar(a, b, kind),
468 "simd" => kernel_elementwise_simd(a, b, kind),
469 "webgpu" => kernel_elementwise_webgpu(a, b, kind),
470 "metal" => kernel_elementwise_metal(a, b, kind),
471 _ => kernel_elementwise_simd(a, b, kind),
472 };
473 }
474 }
475
476 let size = a.data.len();
477
478 if let Some(lookup) = ELEMENTWISE_LOOKUP.get() {
479 let kernel = lookup.lookup(size);
480 return kernel(a, b, kind);
481 }
482
483 kernel_elementwise_simd(a, b, kind)
485}
486
487fn kernel_reduction_adaptive(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
489 if let Ok(guard) = BACKEND_OVERRIDE.read() {
491 if let Some(backend) = *guard {
492 return match backend {
493 "scalar" => kernel_reduction_scalar(a, axis, kind),
494 "simd" => kernel_reduction_simd(a, axis, kind),
495 "blas" => kernel_reduction_blas(a, axis, kind),
496 _ => kernel_reduction_simd(a, axis, kind),
497 };
498 }
499 }
500
501 let size = a.data.len();
502
503 if let Some(lookup) = REDUCTION_LOOKUP.get() {
504 let kernel = lookup.lookup(size);
505 return kernel(a, axis, kind);
506 }
507
508 kernel_reduction_simd(a, axis, kind)
510}
511
512fn kernel_elementwise_metal(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
518 #[cfg(target_os = "macos")]
519 {
520 crate::backend::metal::elementwise_metal(a, b, kind)
521 }
522
523 #[cfg(not(target_os = "macos"))]
524 {
525 kernel_elementwise_webgpu(a, b, kind)
527 }
528}
529
530fn kernel_elementwise_webgpu(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
532 #[cfg(numrs_kernel_elementwise_gpu)]
533 {
534 crate::backend::webgpu::elementwise_webgpu(a, b, kind)
535 }
536
537 #[cfg(not(numrs_kernel_elementwise_gpu))]
538 {
539 kernel_elementwise_scalar(a, b, kind)
541 }
542}
543
544pub fn kernel_elementwise_simd(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
546 #[cfg(numrs_kernel_elementwise_simd)]
547 {
548 crate::backend::cpu::simd::elementwise_simd(a, b, kind)
549 }
550
551 #[cfg(not(numrs_kernel_elementwise_simd))]
552 {
553 kernel_elementwise_scalar(a, b, kind)
554 }
555}
556
557fn kernel_elementwise_scalar(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
559 crate::backend::cpu::scalar::elementwise_scalar(a, b, kind)
560}
561
562fn kernel_reduction_blas(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
564 #[cfg(numrs_has_blas)]
565 {
566 kernel_reduction_simd(a, axis, kind)
569 }
570
571 #[cfg(not(numrs_has_blas))]
572 {
573 kernel_reduction_simd(a, axis, kind)
574 }
575}
576
577pub fn kernel_reduction_simd(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
579 #[cfg(numrs_kernel_sum_simd)]
580 {
581 crate::backend::cpu::simd::reduce_simd(a, axis, kind)
582 }
583
584 #[cfg(not(numrs_kernel_sum_simd))]
585 {
586 kernel_reduction_scalar(a, axis, kind)
587 }
588}
589
590fn kernel_reduction_scalar(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
592 crate::backend::cpu::scalar::reduce_scalar(a, axis, kind)
593}
594
595pub fn kernel_matmul_blas_direct(a: &Array, b: &Array) -> Result<Array> {
597 #[cfg(numrs_has_blas)]
598 {
599 Ok(crate::backend::blas::matmul_blas(a, b))
601 }
602
603 #[cfg(not(numrs_has_blas))]
604 {
605 kernel_matmul_simd(a, b)
606 }
607}
608
609pub fn kernel_matmul_metal(a: &Array, b: &Array) -> Result<Array> {
611 #[cfg(target_os = "macos")]
612 {
613 crate::backend::metal::matmul_metal(a, b)
614 }
615
616 #[cfg(not(target_os = "macos"))]
617 {
618 kernel_matmul_simd(a, b)
620 }
621}
622
623pub fn kernel_matmul_webgpu(a: &Array, b: &Array) -> Result<Array> {
625 #[cfg(numrs_kernel_matmul_gpu)]
626 {
627 Ok(crate::backend::webgpu::matmul_webgpu(a, b))
628 }
629
630 #[cfg(not(numrs_kernel_matmul_gpu))]
631 {
632 kernel_matmul_simd(a, b)
634 }
635}
636
637pub fn kernel_matmul_simd(a: &Array, b: &Array) -> Result<Array> {
639 #[cfg(numrs_kernel_matmul_simd)]
640 {
641 Ok(crate::backend::cpu::simd::matmul_simd(a, b))
642 }
643
644 #[cfg(not(numrs_kernel_matmul_simd))]
645 {
646 kernel_matmul_scalar(a, b)
647 }
648}
649
650pub fn kernel_matmul_scalar(a: &Array, b: &Array) -> Result<Array> {
652 Ok(crate::backend::cpu::matmul_scalar_parallel(a, b))
654}
655
656#[cfg(feature = "blas-backend")]
660fn kernel_dot_blas(a: &Array, b: &Array) -> Result<f32> {
661 crate::backend::blas::dot_blas(a, b)
662}
663
664fn kernel_dot_simd(a: &Array, b: &Array) -> Result<f32> {
666 crate::backend::cpu::simd::dot_simd(a, b)
667}
668
669fn kernel_dot_scalar(a: &Array, b: &Array) -> Result<f32> {
671 crate::backend::cpu::scalar::dot_scalar(a, b)
672}
673
674pub fn init_dispatch_table() -> &'static DispatchTable {
680 DISPATCH_TABLE.get_or_init(|| {
681 #[cfg(debug_assertions)]
682 eprintln!("[numrs-dispatch] Initializing dispatch table...");
683
684 let validation = validate_backends();
686 #[cfg(debug_assertions)]
687 eprintln!("[numrs-dispatch] Validation results: {:?}", validation);
688
689 let table = select_kernels(&validation);
691
692 #[cfg(debug_assertions)]
693 {
694 eprintln!("[numrs-dispatch] Selected kernels:");
695 eprintln!(" - elementwise: {}", table.elementwise_backend);
696 eprintln!(" - reduction: {}", table.reduction_backend);
697 eprintln!(
698 " - matmul: {} (validates: blas={}, metal={}, webgpu={}, simd={})",
699 table.matmul_backend,
700 validation.blas_validated,
701 validation.metal_validated,
702 validation.webgpu_validated,
703 validation.simd_validated
704 );
705 eprintln!(" - dot: {}", table.dot_backend);
706 }
707
708 table
709 })
710}
711
712pub fn get_dispatch_table() -> &'static DispatchTable {
714 DISPATCH_TABLE.get_or_init(|| {
715 let validation = validate_backends();
717
718 select_kernels(&validation)
720 })
721}
722
723#[cfg(target_arch = "wasm32")]
725pub fn force_reinitialize_dispatch() {
726 unsafe {
729 let ptr = &DISPATCH_TABLE as *const OnceCell<DispatchTable> as *mut OnceCell<DispatchTable>;
732 (*ptr).take(); }
734 }
736
737pub use get_dispatch_table as table;
739
740pub fn set_backend_override(backend: Option<&'static str>) {
744 if let Ok(mut guard) = BACKEND_OVERRIDE.write() {
745 *guard = backend;
746 }
747}
748
749pub fn get_backend_override() -> Option<&'static str> {
751 BACKEND_OVERRIDE.read().ok().and_then(|guard| *guard)
752}
753
754#[inline]
766pub fn dispatch_elementwise_generic<T>(
767 a: &Array<T>,
768 b: &Array<T>,
769 kind: ElementwiseKind,
770) -> Result<Array<T>>
771where
772 T: crate::array::DTypeValue,
773{
774 use std::any::TypeId;
775
776 let needs_contiguous = should_materialize_for_backend(a, b);
780
781 let a_ref = if needs_contiguous && !a.is_contiguous() {
782 &a.to_contiguous()
783 } else {
784 a
785 };
786
787 let b_ref = if needs_contiguous && !b.is_contiguous() {
788 &b.to_contiguous()
789 } else {
790 b
791 };
792
793 if TypeId::of::<T>() == TypeId::of::<f32>() {
795 let a_f32 = unsafe { &*(a_ref as *const Array<T> as *const Array<f32>) };
797 let b_f32 = unsafe { &*(b_ref as *const Array<T> as *const Array<f32>) };
798 let table = get_dispatch_table();
799 let result = (table.elementwise)(a_f32, b_f32, kind)?;
800 return Ok(unsafe { std::mem::transmute::<Array<f32>, Array<T>>(result) });
801 }
802
803 if TypeId::of::<T>() == TypeId::of::<f64>() {
804 let a_f64 = unsafe { &*(a_ref as *const Array<T> as *const Array<f64>) };
806 let b_f64 = unsafe { &*(b_ref as *const Array<T> as *const Array<f64>) };
807 let result_f64 = elementwise_f64_native(a_f64, b_f64, kind)?;
808 return Ok(unsafe { std::mem::transmute::<Array<f64>, Array<T>>(result_f64) });
809 }
810
811 if TypeId::of::<T>() == TypeId::of::<i32>() {
812 let a_i32 = unsafe { &*(a_ref as *const Array<T> as *const Array<i32>) };
814 let b_i32 = unsafe { &*(b_ref as *const Array<T> as *const Array<i32>) };
815 let result_i32 = elementwise_i32_native(a_i32, b_i32, kind)?;
816 return Ok(unsafe { std::mem::transmute::<Array<i32>, Array<T>>(result_i32) });
817 }
818
819 let a_data: Vec<f32> = a_ref
821 .data
822 .iter()
823 .map(|&x| crate::array::DTypeValue::to_f32(x))
824 .collect();
825 let b_data: Vec<f32> = b_ref
826 .data
827 .iter()
828 .map(|&x| crate::array::DTypeValue::to_f32(x))
829 .collect();
830 let a_temp = Array::new(a_ref.shape.clone(), a_data);
831 let b_temp = Array::new(b_ref.shape.clone(), b_data);
832
833 let table = get_dispatch_table();
834 let result = (table.elementwise)(&a_temp, &b_temp, kind)?;
835 Ok(unsafe { std::mem::transmute::<Array<f32>, Array<T>>(result) })
836}
837
838#[inline]
852fn should_materialize_for_backend<T>(a: &Array<T>, b: &Array<T>) -> bool
853where
854 T: crate::array::DTypeValue,
855{
856 if a.is_contiguous() && b.is_contiguous() {
858 return false;
859 }
860
861 #[cfg(feature = "webgpu")]
863 if crate::backend::webgpu::is_available_cached() {
864 let size: usize = a.shape.iter().product();
865 if size > 1_000_000 {
866 return true; }
868 }
869
870 false
872}
873
874#[inline]
879fn elementwise_f64_native(
880 a: &Array<f64>,
881 b: &Array<f64>,
882 kind: ElementwiseKind,
883) -> Result<Array<f64>> {
884 let size: usize = a.shape.iter().product();
885 let mut result_data = Vec::with_capacity(size);
886
887 if a.is_contiguous() && b.is_contiguous() {
889 match kind {
890 ElementwiseKind::Add => {
891 for i in 0..a.data.len() {
892 result_data.push(a.data[i] + b.data[i]);
893 }
894 }
895 ElementwiseKind::Sub => {
896 for i in 0..a.data.len() {
897 result_data.push(a.data[i] - b.data[i]);
898 }
899 }
900 ElementwiseKind::Mul => {
901 for i in 0..a.data.len() {
902 result_data.push(a.data[i] * b.data[i]);
903 }
904 }
905 ElementwiseKind::Div => {
906 for i in 0..a.data.len() {
907 result_data.push(a.data[i] / b.data[i]);
908 }
909 }
910 ElementwiseKind::Pow => {
911 for i in 0..a.data.len() {
912 result_data.push(a.data[i].powf(b.data[i]));
913 }
914 }
915 _ => anyhow::bail!("Unsupported elementwise operation for f64: {:?}", kind),
916 }
917 } else {
918 let a_strides = a.get_strides();
920 let b_strides = b.get_strides();
921
922 let mut indices = vec![0usize; a.shape.len()];
923
924 for _ in 0..size {
925 let mut a_idx = a.offset as isize;
927 let mut b_idx = b.offset as isize;
928 for (i, &idx) in indices.iter().enumerate() {
929 a_idx += idx as isize * a_strides[i];
930 b_idx += idx as isize * b_strides[i];
931 }
932
933 let a_idx_u = (a_idx as usize).min(a.data.len().saturating_sub(1));
935 let b_idx_u = (b_idx as usize).min(b.data.len().saturating_sub(1));
936
937 let val = match kind {
938 ElementwiseKind::Add => a.data[a_idx_u] + b.data[b_idx_u],
939 ElementwiseKind::Sub => a.data[a_idx_u] - b.data[b_idx_u],
940 ElementwiseKind::Mul => a.data[a_idx_u] * b.data[b_idx_u],
941 ElementwiseKind::Div => a.data[a_idx_u] / b.data[b_idx_u],
942 ElementwiseKind::Pow => a.data[a_idx_u].powf(b.data[b_idx_u]),
943 _ => anyhow::bail!("Unsupported elementwise operation for f64: {:?}", kind),
944 };
945 result_data.push(val);
946
947 for i in (0..a.shape.len()).rev() {
949 indices[i] += 1;
950 if indices[i] < a.shape[i] {
951 break;
952 }
953 indices[i] = 0;
954 }
955 }
956 }
957
958 let mut result = Array::new(a.shape.clone(), result_data);
959 result.dtype = crate::array::DType::F64;
960 Ok(result)
961}
962
963#[inline]
968fn elementwise_i32_native(
969 a: &Array<i32>,
970 b: &Array<i32>,
971 kind: ElementwiseKind,
972) -> Result<Array<i32>> {
973 let size: usize = a.shape.iter().product();
974 let mut result_data = Vec::with_capacity(size);
975
976 if a.is_contiguous() && b.is_contiguous() {
978 match kind {
979 ElementwiseKind::Add => {
980 for i in 0..a.data.len() {
981 result_data.push(a.data[i] + b.data[i]);
982 }
983 }
984 ElementwiseKind::Sub => {
985 for i in 0..a.data.len() {
986 result_data.push(a.data[i] - b.data[i]);
987 }
988 }
989 ElementwiseKind::Mul => {
990 for i in 0..a.data.len() {
991 result_data.push(a.data[i] * b.data[i]);
992 }
993 }
994 ElementwiseKind::Div => {
995 for i in 0..a.data.len() {
996 result_data.push(a.data[i] / b.data[i]);
997 }
998 }
999 ElementwiseKind::Pow => {
1000 for i in 0..a.data.len() {
1001 result_data.push(a.data[i].pow(b.data[i] as u32));
1002 }
1003 }
1004 _ => anyhow::bail!("Unsupported elementwise operation for i32: {:?}", kind),
1005 }
1006 } else {
1007 let a_strides = a.get_strides();
1009 let b_strides = b.get_strides();
1010
1011 let mut indices = vec![0usize; a.shape.len()];
1012
1013 for _ in 0..size {
1014 let mut a_idx = a.offset as isize;
1016 let mut b_idx = b.offset as isize;
1017 for (i, &idx) in indices.iter().enumerate() {
1018 a_idx += idx as isize * a_strides[i];
1019 b_idx += idx as isize * b_strides[i];
1020 }
1021
1022 let a_idx_u = (a_idx as usize).min(a.data.len().saturating_sub(1));
1024 let b_idx_u = (b_idx as usize).min(b.data.len().saturating_sub(1));
1025
1026 let val = match kind {
1027 ElementwiseKind::Add => a.data[a_idx_u] + b.data[b_idx_u],
1028 ElementwiseKind::Sub => a.data[a_idx_u] - b.data[b_idx_u],
1029 ElementwiseKind::Mul => a.data[a_idx_u] * b.data[b_idx_u],
1030 ElementwiseKind::Div => a.data[a_idx_u] / b.data[b_idx_u],
1031 ElementwiseKind::Pow => a.data[a_idx_u].pow(b.data[b_idx_u] as u32),
1032 _ => anyhow::bail!("Unsupported elementwise operation for i32: {:?}", kind),
1033 };
1034 result_data.push(val);
1035
1036 for i in (0..a.shape.len()).rev() {
1038 indices[i] += 1;
1039 if indices[i] < a.shape[i] {
1040 break;
1041 }
1042 indices[i] = 0;
1043 }
1044 }
1045 }
1046
1047 let mut result = Array::new(a.shape.clone(), result_data);
1048 result.dtype = crate::array::DType::I32;
1049 Ok(result)
1050}
1051
1052#[cfg(test)]
1059mod tests {
1060 use super::*;
1061
1062 #[test]
1063 fn test_dispatch_table_initialization() {
1064 let table = init_dispatch_table();
1065
1066 assert!(!table.elementwise_backend.is_empty());
1068 assert!(!table.reduction_backend.is_empty());
1069 assert!(!table.matmul_backend.is_empty());
1070
1071 println!("Dispatch table: {:?}", table);
1072 }
1073
1074 #[test]
1075 fn test_backend_validation() {
1076 let validation = validate_backends();
1077
1078 println!("Backend validation: {:?}", validation);
1079
1080 assert!(validation.simd_available || validation.blas_available || true);
1082 }
1083
1084 #[test]
1085 fn test_elementwise_dispatch() {
1086 let table = get_dispatch_table();
1087
1088 let a = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
1089 let b = Array::new(vec![4], vec![1.0, 1.0, 1.0, 1.0]);
1090
1091 let result = (table.elementwise)(&a, &b, ElementwiseKind::Add);
1092
1093 assert!(result.is_ok());
1094 let result = result.unwrap();
1095 assert_eq!(result.data, vec![2.0, 3.0, 4.0, 5.0]);
1096
1097 println!(
1098 "Elementwise test passed using: {}",
1099 table.elementwise_backend
1100 );
1101 }
1102}