Skip to main content

axonml_core/backends/
cpu.rs

1//! CPU Backend - Host Memory Operations
2//!
3//! # File
4//! `crates/axonml-core/src/backends/cpu.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use super::Backend;
18use crate::device::DeviceCapabilities;
19use crate::dtype::{Float, Numeric, Scalar};
20use rayon::prelude::*;
21use sysinfo::System;
22
23/// Threshold for using parallel processing (in elements)
24const PARALLEL_THRESHOLD: usize = 4096;
25
26// =============================================================================
27// CPU Backend Struct
28// =============================================================================
29
30/// CPU backend for tensor operations.
31#[derive(Debug, Clone, Copy, Default)]
32pub struct CpuBackend;
33
34impl CpuBackend {
35    /// Creates a new CPU backend.
36    #[must_use]
37    pub const fn new() -> Self {
38        Self
39    }
40}
41
42// =============================================================================
43// Backend Trait Implementation
44// =============================================================================
45
46impl Backend for CpuBackend {
47    fn name(&self) -> &'static str {
48        "cpu"
49    }
50
51    fn is_available(&self) -> bool {
52        true // CPU is always available
53    }
54
55    fn capabilities(&self) -> DeviceCapabilities {
56        DeviceCapabilities {
57            name: "CPU".to_string(),
58            total_memory: get_system_memory(),
59            available_memory: get_available_memory(),
60            supports_f16: true,
61            supports_f64: true,
62            max_threads_per_block: num_cpus(),
63            compute_capability: None,
64        }
65    }
66
67    fn allocate(&self, size: usize) -> *mut u8 {
68        if size == 0 {
69            return std::ptr::null_mut();
70        }
71        unsafe {
72            let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
73            std::alloc::alloc(layout)
74        }
75    }
76
77    fn deallocate(&self, ptr: *mut u8, size: usize) {
78        if ptr.is_null() || size == 0 {
79            return;
80        }
81        unsafe {
82            let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
83            std::alloc::dealloc(ptr, layout);
84        }
85    }
86
87    fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
88        // For CPU, this is just a memory copy
89        unsafe {
90            std::ptr::copy_nonoverlapping(src, dst, size);
91        }
92    }
93
94    fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
95        // For CPU, this is just a memory copy
96        unsafe {
97            std::ptr::copy_nonoverlapping(src, dst, size);
98        }
99    }
100
101    fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
102        // For CPU, this is just a memory copy
103        unsafe {
104            std::ptr::copy_nonoverlapping(src, dst, size);
105        }
106    }
107
108    fn synchronize(&self) {
109        // No-op for CPU - operations are synchronous
110    }
111}
112
113// =============================================================================
114// Helper Functions
115// =============================================================================
116
117/// Returns the total system memory in bytes.
118fn get_system_memory() -> usize {
119    let sys = System::new_all();
120    sys.total_memory() as usize
121}
122
123/// Returns the available system memory in bytes.
124fn get_available_memory() -> usize {
125    let sys = System::new_all();
126    sys.available_memory() as usize
127}
128
129/// Returns the number of CPU cores.
130fn num_cpus() -> usize {
131    std::thread::available_parallelism()
132        .map(std::num::NonZeroUsize::get)
133        .unwrap_or(1)
134}
135
136// =============================================================================
137// Element-wise Operations
138// =============================================================================
139
140impl CpuBackend {
141    /// Adds two slices element-wise with optional parallelization.
142    pub fn add<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
143        debug_assert_eq!(a.len(), b.len());
144        debug_assert_eq!(a.len(), dst.len());
145
146        if dst.len() >= PARALLEL_THRESHOLD {
147            dst.par_iter_mut()
148                .zip(a.par_iter().zip(b.par_iter()))
149                .for_each(|(d, (a_val, b_val))| {
150                    *d = *a_val + *b_val;
151                });
152        } else {
153            for i in 0..dst.len() {
154                dst[i] = a[i] + b[i];
155            }
156        }
157    }
158
159    /// Subtracts two slices element-wise with optional parallelization.
160    pub fn sub<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
161        debug_assert_eq!(a.len(), b.len());
162        debug_assert_eq!(a.len(), dst.len());
163
164        if dst.len() >= PARALLEL_THRESHOLD {
165            dst.par_iter_mut()
166                .zip(a.par_iter().zip(b.par_iter()))
167                .for_each(|(d, (a_val, b_val))| {
168                    *d = *a_val - *b_val;
169                });
170        } else {
171            for i in 0..dst.len() {
172                dst[i] = a[i] - b[i];
173            }
174        }
175    }
176
177    /// Multiplies two slices element-wise with optional parallelization.
178    pub fn mul<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
179        debug_assert_eq!(a.len(), b.len());
180        debug_assert_eq!(a.len(), dst.len());
181
182        if dst.len() >= PARALLEL_THRESHOLD {
183            dst.par_iter_mut()
184                .zip(a.par_iter().zip(b.par_iter()))
185                .for_each(|(d, (a_val, b_val))| {
186                    *d = *a_val * *b_val;
187                });
188        } else {
189            for i in 0..dst.len() {
190                dst[i] = a[i] * b[i];
191            }
192        }
193    }
194
195    /// Divides two slices element-wise with optional parallelization.
196    pub fn div<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
197        debug_assert_eq!(a.len(), b.len());
198        debug_assert_eq!(a.len(), dst.len());
199
200        if dst.len() >= PARALLEL_THRESHOLD {
201            dst.par_iter_mut()
202                .zip(a.par_iter().zip(b.par_iter()))
203                .for_each(|(d, (a_val, b_val))| {
204                    *d = *a_val / *b_val;
205                });
206        } else {
207            for i in 0..dst.len() {
208                dst[i] = a[i] / b[i];
209            }
210        }
211    }
212
213    /// Adds a scalar to each element with optional parallelization.
214    pub fn add_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
215        debug_assert_eq!(a.len(), dst.len());
216
217        if dst.len() >= PARALLEL_THRESHOLD {
218            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
219                *d = *a_val + scalar;
220            });
221        } else {
222            for i in 0..dst.len() {
223                dst[i] = a[i] + scalar;
224            }
225        }
226    }
227
228    /// Multiplies each element by a scalar with optional parallelization.
229    pub fn mul_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
230        debug_assert_eq!(a.len(), dst.len());
231
232        if dst.len() >= PARALLEL_THRESHOLD {
233            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
234                *d = *a_val * scalar;
235            });
236        } else {
237            for i in 0..dst.len() {
238                dst[i] = a[i] * scalar;
239            }
240        }
241    }
242
243    /// Negates each element with optional parallelization.
244    pub fn neg<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
245        debug_assert_eq!(a.len(), dst.len());
246
247        if dst.len() >= PARALLEL_THRESHOLD {
248            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
249                *d = T::zero() - *a_val;
250            });
251        } else {
252            for i in 0..dst.len() {
253                dst[i] = T::zero() - a[i];
254            }
255        }
256    }
257
258    /// Computes absolute value of each element with optional parallelization.
259    pub fn abs<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
260        debug_assert_eq!(a.len(), dst.len());
261
262        if dst.len() >= PARALLEL_THRESHOLD {
263            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
264                *d = if *a_val < T::zero() {
265                    T::zero() - *a_val
266                } else {
267                    *a_val
268                };
269            });
270        } else {
271            for i in 0..dst.len() {
272                dst[i] = if a[i] < T::zero() {
273                    T::zero() - a[i]
274                } else {
275                    a[i]
276                };
277            }
278        }
279    }
280}
281
282// =============================================================================
283// Activation Functions
284// =============================================================================
285
286impl CpuBackend {
287    /// Applies `ReLU` activation: max(0, x) with optional parallelization.
288    pub fn relu<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
289        debug_assert_eq!(a.len(), dst.len());
290
291        if dst.len() >= PARALLEL_THRESHOLD {
292            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
293                *d = if *a_val > T::zero() {
294                    *a_val
295                } else {
296                    T::zero()
297                };
298            });
299        } else {
300            for i in 0..dst.len() {
301                dst[i] = if a[i] > T::zero() { a[i] } else { T::zero() };
302            }
303        }
304    }
305
306    /// Applies sigmoid activation: 1 / (1 + exp(-x)) with optional parallelization.
307    pub fn sigmoid<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
308        debug_assert_eq!(a.len(), dst.len());
309
310        if dst.len() >= PARALLEL_THRESHOLD {
311            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
312                *d = T::one() / (T::one() + (-*a_val).exp_value());
313            });
314        } else {
315            for i in 0..dst.len() {
316                dst[i] = T::one() / (T::one() + (-a[i]).exp_value());
317            }
318        }
319    }
320
321    /// Applies tanh activation with optional parallelization.
322    pub fn tanh<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
323        debug_assert_eq!(a.len(), dst.len());
324
325        if dst.len() >= PARALLEL_THRESHOLD {
326            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
327                *d = a_val.tanh_value();
328            });
329        } else {
330            for i in 0..dst.len() {
331                dst[i] = a[i].tanh_value();
332            }
333        }
334    }
335
336    /// Applies exponential function with optional parallelization.
337    pub fn exp<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
338        debug_assert_eq!(a.len(), dst.len());
339
340        if dst.len() >= PARALLEL_THRESHOLD {
341            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
342                *d = a_val.exp_value();
343            });
344        } else {
345            for i in 0..dst.len() {
346                dst[i] = a[i].exp_value();
347            }
348        }
349    }
350
351    /// Applies natural logarithm with optional parallelization.
352    pub fn ln<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
353        debug_assert_eq!(a.len(), dst.len());
354
355        if dst.len() >= PARALLEL_THRESHOLD {
356            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
357                *d = a_val.ln_value();
358            });
359        } else {
360            for i in 0..dst.len() {
361                dst[i] = a[i].ln_value();
362            }
363        }
364    }
365
366    /// Applies square root with optional parallelization.
367    pub fn sqrt<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
368        debug_assert_eq!(a.len(), dst.len());
369
370        if dst.len() >= PARALLEL_THRESHOLD {
371            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
372                *d = a_val.sqrt_value();
373            });
374        } else {
375            for i in 0..dst.len() {
376                dst[i] = a[i].sqrt_value();
377            }
378        }
379    }
380
381    /// Squares each element with optional parallelization.
382    pub fn square<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
383        debug_assert_eq!(a.len(), dst.len());
384
385        if dst.len() >= PARALLEL_THRESHOLD {
386            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
387                *d = *a_val * *a_val;
388            });
389        } else {
390            for i in 0..dst.len() {
391                dst[i] = a[i] * a[i];
392            }
393        }
394    }
395}
396
397// =============================================================================
398// Reduction Operations
399// =============================================================================
400
401impl CpuBackend {
402    /// Computes the sum of all elements.
403    pub fn sum<T: Numeric>(a: &[T]) -> T {
404        let mut result = T::zero();
405        for &val in a {
406            result = result + val;
407        }
408        result
409    }
410
411    /// Computes the product of all elements.
412    pub fn prod<T: Numeric>(a: &[T]) -> T {
413        let mut result = T::one();
414        for &val in a {
415            result = result * val;
416        }
417        result
418    }
419
420    /// Finds the maximum element.
421    pub fn max<T: Numeric>(a: &[T]) -> Option<T> {
422        if a.is_empty() {
423            return None;
424        }
425
426        let mut result = a[0];
427        for &val in &a[1..] {
428            if val > result {
429                result = val;
430            }
431        }
432        Some(result)
433    }
434
435    /// Finds the minimum element.
436    pub fn min<T: Numeric>(a: &[T]) -> Option<T> {
437        if a.is_empty() {
438            return None;
439        }
440
441        let mut result = a[0];
442        for &val in &a[1..] {
443            if val < result {
444                result = val;
445            }
446        }
447        Some(result)
448    }
449
450    /// Computes the mean of all elements.
451    pub fn mean<T: Float>(a: &[T]) -> Option<T> {
452        if a.is_empty() {
453            return None;
454        }
455
456        let sum = Self::sum(a);
457        let len = T::from(a.len()).unwrap_or(T::one());
458        Some(sum / len)
459    }
460
461    /// Finds the index of the maximum element.
462    pub fn argmax<T: Numeric>(a: &[T]) -> Option<usize> {
463        if a.is_empty() {
464            return None;
465        }
466
467        let mut max_idx = 0;
468        let mut max_val = a[0];
469        for (i, &val) in a.iter().enumerate().skip(1) {
470            if val > max_val {
471                max_val = val;
472                max_idx = i;
473            }
474        }
475        Some(max_idx)
476    }
477
478    /// Finds the index of the minimum element.
479    pub fn argmin<T: Numeric>(a: &[T]) -> Option<usize> {
480        if a.is_empty() {
481            return None;
482        }
483
484        let mut min_idx = 0;
485        let mut min_val = a[0];
486        for (i, &val) in a.iter().enumerate().skip(1) {
487            if val < min_val {
488                min_val = val;
489                min_idx = i;
490            }
491        }
492        Some(min_idx)
493    }
494}
495
496// =============================================================================
497// Matrix Operations
498// =============================================================================
499
500impl CpuBackend {
501    /// Performs matrix multiplication: C = A @ B.
502    ///
503    /// A is (m x k), B is (k x n), C is (m x n).
504    /// Uses optimized GEMM from matrixmultiply crate for f32/f64,
505    /// falls back to cache-efficient tiled implementation for other types.
506    pub fn matmul<T: Numeric>(c: &mut [T], a: &[T], b: &[T], m: usize, n: usize, k: usize) {
507        debug_assert_eq!(a.len(), m * k);
508        debug_assert_eq!(b.len(), k * n);
509        debug_assert_eq!(c.len(), m * n);
510
511        // Use optimized BLAS routines for f32 and f64
512        use std::any::TypeId;
513        if TypeId::of::<T>() == TypeId::of::<f32>() {
514            // SAFETY: We verified T is f32, so the casts are safe
515            unsafe {
516                let a_f32: &[f32] = &*(std::ptr::from_ref::<[T]>(a) as *const [f32]);
517                let b_f32: &[f32] = &*(std::ptr::from_ref::<[T]>(b) as *const [f32]);
518                let c_f32: &mut [f32] = &mut *(std::ptr::from_mut::<[T]>(c) as *mut [f32]);
519                Self::matmul_f32(c_f32, a_f32, b_f32, m, n, k);
520            }
521            return;
522        }
523
524        if TypeId::of::<T>() == TypeId::of::<f64>() {
525            // SAFETY: We verified T is f64, so the casts are safe
526            unsafe {
527                let a_f64: &[f64] = &*(std::ptr::from_ref::<[T]>(a) as *const [f64]);
528                let b_f64: &[f64] = &*(std::ptr::from_ref::<[T]>(b) as *const [f64]);
529                let c_f64: &mut [f64] = &mut *(std::ptr::from_mut::<[T]>(c) as *mut [f64]);
530                Self::matmul_f64(c_f64, a_f64, b_f64, m, n, k);
531            }
532            return;
533        }
534
535        // Fallback: Use cache-efficient tiled matrix multiplication
536        // Block size chosen for typical L1 cache (32KB)
537        const BLOCK_SIZE: usize = 64;
538
539        // Initialize C to zero
540        for val in c.iter_mut() {
541            *val = T::zero();
542        }
543
544        // Tiled matrix multiplication for better cache locality
545        for i0 in (0..m).step_by(BLOCK_SIZE) {
546            let i_end = (i0 + BLOCK_SIZE).min(m);
547            for p0 in (0..k).step_by(BLOCK_SIZE) {
548                let p_end = (p0 + BLOCK_SIZE).min(k);
549                for j0 in (0..n).step_by(BLOCK_SIZE) {
550                    let j_end = (j0 + BLOCK_SIZE).min(n);
551
552                    // Compute block C[i0:i_end, j0:j_end] += A[i0:i_end, p0:p_end] @ B[p0:p_end, j0:j_end]
553                    for i in i0..i_end {
554                        for p in p0..p_end {
555                            let a_val = a[i * k + p];
556                            for j in j0..j_end {
557                                c[i * n + j] = c[i * n + j] + a_val * b[p * n + j];
558                            }
559                        }
560                    }
561                }
562            }
563        }
564    }
565
566    /// Performs optimized f32 matrix multiplication using matrixmultiply crate.
567    ///
568    /// C = alpha * A @ B + beta * C
569    pub fn sgemm(
570        c: &mut [f32],
571        a: &[f32],
572        b: &[f32],
573        m: usize,
574        n: usize,
575        k: usize,
576        alpha: f32,
577        beta: f32,
578    ) {
579        debug_assert_eq!(a.len(), m * k);
580        debug_assert_eq!(b.len(), k * n);
581        debug_assert_eq!(c.len(), m * n);
582
583        unsafe {
584            matrixmultiply::sgemm(
585                m,
586                k,
587                n,
588                alpha,
589                a.as_ptr(),
590                k as isize,
591                1, // A: row-major (m x k)
592                b.as_ptr(),
593                n as isize,
594                1, // B: row-major (k x n)
595                beta,
596                c.as_mut_ptr(),
597                n as isize,
598                1, // C: row-major (m x n)
599            );
600        }
601    }
602
603    /// Performs optimized f64 matrix multiplication using matrixmultiply crate.
604    ///
605    /// C = alpha * A @ B + beta * C
606    pub fn dgemm(
607        c: &mut [f64],
608        a: &[f64],
609        b: &[f64],
610        m: usize,
611        n: usize,
612        k: usize,
613        alpha: f64,
614        beta: f64,
615    ) {
616        debug_assert_eq!(a.len(), m * k);
617        debug_assert_eq!(b.len(), k * n);
618        debug_assert_eq!(c.len(), m * n);
619
620        unsafe {
621            matrixmultiply::dgemm(
622                m,
623                k,
624                n,
625                alpha,
626                a.as_ptr(),
627                k as isize,
628                1, // A: row-major (m x k)
629                b.as_ptr(),
630                n as isize,
631                1, // B: row-major (k x n)
632                beta,
633                c.as_mut_ptr(),
634                n as isize,
635                1, // C: row-major (m x n)
636            );
637        }
638    }
639
640    /// Performs f32 matrix multiplication: C = A @ B using optimized GEMM.
641    pub fn matmul_f32(c: &mut [f32], a: &[f32], b: &[f32], m: usize, n: usize, k: usize) {
642        Self::sgemm(c, a, b, m, n, k, 1.0, 0.0);
643    }
644
645    /// Performs f64 matrix multiplication: C = A @ B using optimized GEMM.
646    pub fn matmul_f64(c: &mut [f64], a: &[f64], b: &[f64], m: usize, n: usize, k: usize) {
647        Self::dgemm(c, a, b, m, n, k, 1.0, 0.0);
648    }
649
650    /// Transposes a matrix.
651    ///
652    /// A is (rows x cols), B is (cols x rows).
653    pub fn transpose<T: Scalar>(dst: &mut [T], src: &[T], rows: usize, cols: usize) {
654        debug_assert_eq!(src.len(), rows * cols);
655        debug_assert_eq!(dst.len(), rows * cols);
656
657        for i in 0..rows {
658            for j in 0..cols {
659                dst[j * rows + i] = src[i * cols + j];
660            }
661        }
662    }
663
664    /// Computes dot product of two vectors.
665    pub fn dot<T: Numeric>(a: &[T], b: &[T]) -> T {
666        debug_assert_eq!(a.len(), b.len());
667
668        let mut sum = T::zero();
669        for i in 0..a.len() {
670            sum = sum + a[i] * b[i];
671        }
672        sum
673    }
674}
675
676// =============================================================================
677// Comparison Operations
678// =============================================================================
679
680impl CpuBackend {
681    /// Element-wise equality comparison.
682    pub fn eq<T: Scalar + PartialEq>(dst: &mut [bool], a: &[T], b: &[T]) {
683        debug_assert_eq!(a.len(), b.len());
684        debug_assert_eq!(a.len(), dst.len());
685
686        for i in 0..dst.len() {
687            dst[i] = a[i] == b[i];
688        }
689    }
690
691    /// Element-wise less-than comparison.
692    pub fn lt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
693        debug_assert_eq!(a.len(), b.len());
694        debug_assert_eq!(a.len(), dst.len());
695
696        for i in 0..dst.len() {
697            dst[i] = a[i] < b[i];
698        }
699    }
700
701    /// Element-wise greater-than comparison.
702    pub fn gt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
703        debug_assert_eq!(a.len(), b.len());
704        debug_assert_eq!(a.len(), dst.len());
705
706        for i in 0..dst.len() {
707            dst[i] = a[i] > b[i];
708        }
709    }
710}
711
712// =============================================================================
713// Fill Operations
714// =============================================================================
715
716impl CpuBackend {
717    /// Fills a slice with a value.
718    pub fn fill<T: Scalar>(dst: &mut [T], value: T) {
719        for elem in dst.iter_mut() {
720            *elem = value;
721        }
722    }
723
724    /// Fills a slice with zeros.
725    pub fn fill_zeros<T: Scalar>(dst: &mut [T]) {
726        Self::fill(dst, T::zeroed());
727    }
728
729    /// Copies from source to destination.
730    pub fn copy<T: Scalar>(dst: &mut [T], src: &[T]) {
731        debug_assert_eq!(dst.len(), src.len());
732        dst.copy_from_slice(src);
733    }
734}
735
736// =============================================================================
737// Tests
738// =============================================================================
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743
744    #[test]
745    fn test_add() {
746        let a = [1.0_f32, 2.0, 3.0];
747        let b = [4.0_f32, 5.0, 6.0];
748        let mut c = [0.0_f32; 3];
749
750        CpuBackend::add(&mut c, &a, &b);
751        assert_eq!(c, [5.0, 7.0, 9.0]);
752    }
753
754    #[test]
755    fn test_mul() {
756        let a = [2.0_f32, 3.0, 4.0];
757        let b = [2.0_f32, 2.0, 2.0];
758        let mut c = [0.0_f32; 3];
759
760        CpuBackend::mul(&mut c, &a, &b);
761        assert_eq!(c, [4.0, 6.0, 8.0]);
762    }
763
764    #[test]
765    fn test_relu() {
766        let a = [-1.0_f32, 0.0, 1.0, 2.0];
767        let mut b = [0.0_f32; 4];
768
769        CpuBackend::relu(&mut b, &a);
770        assert_eq!(b, [0.0, 0.0, 1.0, 2.0]);
771    }
772
773    #[test]
774    fn test_sum() {
775        let a = [1.0_f32, 2.0, 3.0, 4.0];
776        assert_eq!(CpuBackend::sum(&a), 10.0);
777    }
778
779    #[test]
780    fn test_max_min() {
781        let a = [1.0_f32, 4.0, 2.0, 3.0];
782        assert_eq!(CpuBackend::max(&a), Some(4.0));
783        assert_eq!(CpuBackend::min(&a), Some(1.0));
784    }
785
786    #[test]
787    fn test_argmax() {
788        let a = [1.0_f32, 4.0, 2.0, 3.0];
789        assert_eq!(CpuBackend::argmax(&a), Some(1));
790    }
791
792    #[test]
793    fn test_matmul() {
794        // A = [[1, 2], [3, 4]] (2x2)
795        // B = [[5, 6], [7, 8]] (2x2)
796        // C = [[19, 22], [43, 50]]
797        let a = [1.0_f32, 2.0, 3.0, 4.0];
798        let b = [5.0_f32, 6.0, 7.0, 8.0];
799        let mut c = [0.0_f32; 4];
800
801        CpuBackend::matmul(&mut c, &a, &b, 2, 2, 2);
802        assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
803    }
804
805    #[test]
806    fn test_transpose() {
807        // A = [[1, 2, 3], [4, 5, 6]] (2x3)
808        // B = [[1, 4], [2, 5], [3, 6]] (3x2)
809        let a = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
810        let mut b = [0.0_f32; 6];
811
812        CpuBackend::transpose(&mut b, &a, 2, 3);
813        assert_eq!(b, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
814    }
815
816    #[test]
817    fn test_dot() {
818        let a = [1.0_f32, 2.0, 3.0];
819        let b = [4.0_f32, 5.0, 6.0];
820        assert_eq!(CpuBackend::dot(&a, &b), 32.0);
821    }
822
823    #[test]
824    fn test_fill() {
825        let mut a = [0.0_f32; 5];
826        CpuBackend::fill(&mut a, 42.0);
827        assert_eq!(a, [42.0; 5]);
828    }
829}