Skip to main content

axonml_core/backends/
cpu.rs

1//! CPU Backend - Host Memory Operations
2//!
3//! Provides the CPU implementation for tensor operations using host memory.
4//! This is the default backend that is always available.
5//!
6//! # Key Features
7//! - SIMD-optimized operations where possible
8//! - Multi-threaded execution via rayon
9//! - matrixmultiply crate for optimized GEMM operations
10//!
11//! @version 0.1.0
12//! @author `AutomataNexus` Development Team
13
14use super::Backend;
15use crate::device::DeviceCapabilities;
16use crate::dtype::{Float, Numeric, Scalar};
17use rayon::prelude::*;
18use sysinfo::System;
19
20/// Threshold for using parallel processing (in elements)
21const PARALLEL_THRESHOLD: usize = 4096;
22
23// =============================================================================
24// CPU Backend Struct
25// =============================================================================
26
27/// CPU backend for tensor operations.
28#[derive(Debug, Clone, Copy, Default)]
29pub struct CpuBackend;
30
31impl CpuBackend {
32    /// Creates a new CPU backend.
33    #[must_use]
34    pub const fn new() -> Self {
35        Self
36    }
37}
38
39// =============================================================================
40// Backend Trait Implementation
41// =============================================================================
42
43impl Backend for CpuBackend {
44    fn name(&self) -> &'static str {
45        "cpu"
46    }
47
48    fn is_available(&self) -> bool {
49        true // CPU is always available
50    }
51
52    fn capabilities(&self) -> DeviceCapabilities {
53        DeviceCapabilities {
54            name: "CPU".to_string(),
55            total_memory: get_system_memory(),
56            available_memory: get_available_memory(),
57            supports_f16: true,
58            supports_f64: true,
59            max_threads_per_block: num_cpus(),
60            compute_capability: None,
61        }
62    }
63
64    fn allocate(&self, size: usize) -> *mut u8 {
65        if size == 0 {
66            return std::ptr::null_mut();
67        }
68        unsafe {
69            let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
70            std::alloc::alloc(layout)
71        }
72    }
73
74    fn deallocate(&self, ptr: *mut u8, size: usize) {
75        if ptr.is_null() || size == 0 {
76            return;
77        }
78        unsafe {
79            let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
80            std::alloc::dealloc(ptr, layout);
81        }
82    }
83
84    fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
85        // For CPU, this is just a memory copy
86        unsafe {
87            std::ptr::copy_nonoverlapping(src, dst, size);
88        }
89    }
90
91    fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
92        // For CPU, this is just a memory copy
93        unsafe {
94            std::ptr::copy_nonoverlapping(src, dst, size);
95        }
96    }
97
98    fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
99        // For CPU, this is just a memory copy
100        unsafe {
101            std::ptr::copy_nonoverlapping(src, dst, size);
102        }
103    }
104
105    fn synchronize(&self) {
106        // No-op for CPU - operations are synchronous
107    }
108}
109
110// =============================================================================
111// Helper Functions
112// =============================================================================
113
114/// Returns the total system memory in bytes.
115fn get_system_memory() -> usize {
116    let sys = System::new_all();
117    sys.total_memory() as usize
118}
119
120/// Returns the available system memory in bytes.
121fn get_available_memory() -> usize {
122    let sys = System::new_all();
123    sys.available_memory() as usize
124}
125
126/// Returns the number of CPU cores.
127fn num_cpus() -> usize {
128    std::thread::available_parallelism()
129        .map(std::num::NonZeroUsize::get)
130        .unwrap_or(1)
131}
132
133// =============================================================================
134// Element-wise Operations
135// =============================================================================
136
137impl CpuBackend {
138    /// Adds two slices element-wise with optional parallelization.
139    pub fn add<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
140        debug_assert_eq!(a.len(), b.len());
141        debug_assert_eq!(a.len(), dst.len());
142
143        if dst.len() >= PARALLEL_THRESHOLD {
144            dst.par_iter_mut()
145                .zip(a.par_iter().zip(b.par_iter()))
146                .for_each(|(d, (a_val, b_val))| {
147                    *d = *a_val + *b_val;
148                });
149        } else {
150            for i in 0..dst.len() {
151                dst[i] = a[i] + b[i];
152            }
153        }
154    }
155
156    /// Subtracts two slices element-wise with optional parallelization.
157    pub fn sub<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
158        debug_assert_eq!(a.len(), b.len());
159        debug_assert_eq!(a.len(), dst.len());
160
161        if dst.len() >= PARALLEL_THRESHOLD {
162            dst.par_iter_mut()
163                .zip(a.par_iter().zip(b.par_iter()))
164                .for_each(|(d, (a_val, b_val))| {
165                    *d = *a_val - *b_val;
166                });
167        } else {
168            for i in 0..dst.len() {
169                dst[i] = a[i] - b[i];
170            }
171        }
172    }
173
174    /// Multiplies two slices element-wise with optional parallelization.
175    pub fn mul<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
176        debug_assert_eq!(a.len(), b.len());
177        debug_assert_eq!(a.len(), dst.len());
178
179        if dst.len() >= PARALLEL_THRESHOLD {
180            dst.par_iter_mut()
181                .zip(a.par_iter().zip(b.par_iter()))
182                .for_each(|(d, (a_val, b_val))| {
183                    *d = *a_val * *b_val;
184                });
185        } else {
186            for i in 0..dst.len() {
187                dst[i] = a[i] * b[i];
188            }
189        }
190    }
191
192    /// Divides two slices element-wise with optional parallelization.
193    pub fn div<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
194        debug_assert_eq!(a.len(), b.len());
195        debug_assert_eq!(a.len(), dst.len());
196
197        if dst.len() >= PARALLEL_THRESHOLD {
198            dst.par_iter_mut()
199                .zip(a.par_iter().zip(b.par_iter()))
200                .for_each(|(d, (a_val, b_val))| {
201                    *d = *a_val / *b_val;
202                });
203        } else {
204            for i in 0..dst.len() {
205                dst[i] = a[i] / b[i];
206            }
207        }
208    }
209
210    /// Adds a scalar to each element with optional parallelization.
211    pub fn add_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
212        debug_assert_eq!(a.len(), dst.len());
213
214        if dst.len() >= PARALLEL_THRESHOLD {
215            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
216                *d = *a_val + scalar;
217            });
218        } else {
219            for i in 0..dst.len() {
220                dst[i] = a[i] + scalar;
221            }
222        }
223    }
224
225    /// Multiplies each element by a scalar with optional parallelization.
226    pub fn mul_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
227        debug_assert_eq!(a.len(), dst.len());
228
229        if dst.len() >= PARALLEL_THRESHOLD {
230            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
231                *d = *a_val * scalar;
232            });
233        } else {
234            for i in 0..dst.len() {
235                dst[i] = a[i] * scalar;
236            }
237        }
238    }
239
240    /// Negates each element with optional parallelization.
241    pub fn neg<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
242        debug_assert_eq!(a.len(), dst.len());
243
244        if dst.len() >= PARALLEL_THRESHOLD {
245            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
246                *d = T::zero() - *a_val;
247            });
248        } else {
249            for i in 0..dst.len() {
250                dst[i] = T::zero() - a[i];
251            }
252        }
253    }
254
255    /// Computes absolute value of each element with optional parallelization.
256    pub fn abs<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
257        debug_assert_eq!(a.len(), dst.len());
258
259        if dst.len() >= PARALLEL_THRESHOLD {
260            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
261                *d = if *a_val < T::zero() {
262                    T::zero() - *a_val
263                } else {
264                    *a_val
265                };
266            });
267        } else {
268            for i in 0..dst.len() {
269                dst[i] = if a[i] < T::zero() {
270                    T::zero() - a[i]
271                } else {
272                    a[i]
273                };
274            }
275        }
276    }
277}
278
279// =============================================================================
280// Activation Functions
281// =============================================================================
282
283impl CpuBackend {
284    /// Applies `ReLU` activation: max(0, x) with optional parallelization.
285    pub fn relu<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
286        debug_assert_eq!(a.len(), dst.len());
287
288        if dst.len() >= PARALLEL_THRESHOLD {
289            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
290                *d = if *a_val > T::zero() {
291                    *a_val
292                } else {
293                    T::zero()
294                };
295            });
296        } else {
297            for i in 0..dst.len() {
298                dst[i] = if a[i] > T::zero() { a[i] } else { T::zero() };
299            }
300        }
301    }
302
303    /// Applies sigmoid activation: 1 / (1 + exp(-x)) with optional parallelization.
304    pub fn sigmoid<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
305        debug_assert_eq!(a.len(), dst.len());
306
307        if dst.len() >= PARALLEL_THRESHOLD {
308            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
309                *d = T::one() / (T::one() + (-*a_val).exp_value());
310            });
311        } else {
312            for i in 0..dst.len() {
313                dst[i] = T::one() / (T::one() + (-a[i]).exp_value());
314            }
315        }
316    }
317
318    /// Applies tanh activation with optional parallelization.
319    pub fn tanh<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
320        debug_assert_eq!(a.len(), dst.len());
321
322        if dst.len() >= PARALLEL_THRESHOLD {
323            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
324                *d = a_val.tanh_value();
325            });
326        } else {
327            for i in 0..dst.len() {
328                dst[i] = a[i].tanh_value();
329            }
330        }
331    }
332
333    /// Applies exponential function with optional parallelization.
334    pub fn exp<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
335        debug_assert_eq!(a.len(), dst.len());
336
337        if dst.len() >= PARALLEL_THRESHOLD {
338            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
339                *d = a_val.exp_value();
340            });
341        } else {
342            for i in 0..dst.len() {
343                dst[i] = a[i].exp_value();
344            }
345        }
346    }
347
348    /// Applies natural logarithm with optional parallelization.
349    pub fn ln<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
350        debug_assert_eq!(a.len(), dst.len());
351
352        if dst.len() >= PARALLEL_THRESHOLD {
353            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
354                *d = a_val.ln_value();
355            });
356        } else {
357            for i in 0..dst.len() {
358                dst[i] = a[i].ln_value();
359            }
360        }
361    }
362
363    /// Applies square root with optional parallelization.
364    pub fn sqrt<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
365        debug_assert_eq!(a.len(), dst.len());
366
367        if dst.len() >= PARALLEL_THRESHOLD {
368            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
369                *d = a_val.sqrt_value();
370            });
371        } else {
372            for i in 0..dst.len() {
373                dst[i] = a[i].sqrt_value();
374            }
375        }
376    }
377
378    /// Squares each element with optional parallelization.
379    pub fn square<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
380        debug_assert_eq!(a.len(), dst.len());
381
382        if dst.len() >= PARALLEL_THRESHOLD {
383            dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
384                *d = *a_val * *a_val;
385            });
386        } else {
387            for i in 0..dst.len() {
388                dst[i] = a[i] * a[i];
389            }
390        }
391    }
392}
393
394// =============================================================================
395// Reduction Operations
396// =============================================================================
397
398impl CpuBackend {
399    /// Computes the sum of all elements.
400    pub fn sum<T: Numeric>(a: &[T]) -> T {
401        let mut result = T::zero();
402        for &val in a {
403            result = result + val;
404        }
405        result
406    }
407
408    /// Computes the product of all elements.
409    pub fn prod<T: Numeric>(a: &[T]) -> T {
410        let mut result = T::one();
411        for &val in a {
412            result = result * val;
413        }
414        result
415    }
416
417    /// Finds the maximum element.
418    pub fn max<T: Numeric>(a: &[T]) -> Option<T> {
419        if a.is_empty() {
420            return None;
421        }
422
423        let mut result = a[0];
424        for &val in &a[1..] {
425            if val > result {
426                result = val;
427            }
428        }
429        Some(result)
430    }
431
432    /// Finds the minimum element.
433    pub fn min<T: Numeric>(a: &[T]) -> Option<T> {
434        if a.is_empty() {
435            return None;
436        }
437
438        let mut result = a[0];
439        for &val in &a[1..] {
440            if val < result {
441                result = val;
442            }
443        }
444        Some(result)
445    }
446
447    /// Computes the mean of all elements.
448    pub fn mean<T: Float>(a: &[T]) -> Option<T> {
449        if a.is_empty() {
450            return None;
451        }
452
453        let sum = Self::sum(a);
454        let len = T::from(a.len()).unwrap_or(T::one());
455        Some(sum / len)
456    }
457
458    /// Finds the index of the maximum element.
459    pub fn argmax<T: Numeric>(a: &[T]) -> Option<usize> {
460        if a.is_empty() {
461            return None;
462        }
463
464        let mut max_idx = 0;
465        let mut max_val = a[0];
466        for (i, &val) in a.iter().enumerate().skip(1) {
467            if val > max_val {
468                max_val = val;
469                max_idx = i;
470            }
471        }
472        Some(max_idx)
473    }
474
475    /// Finds the index of the minimum element.
476    pub fn argmin<T: Numeric>(a: &[T]) -> Option<usize> {
477        if a.is_empty() {
478            return None;
479        }
480
481        let mut min_idx = 0;
482        let mut min_val = a[0];
483        for (i, &val) in a.iter().enumerate().skip(1) {
484            if val < min_val {
485                min_val = val;
486                min_idx = i;
487            }
488        }
489        Some(min_idx)
490    }
491}
492
493// =============================================================================
494// Matrix Operations
495// =============================================================================
496
497impl CpuBackend {
498    /// Performs matrix multiplication: C = A @ B.
499    ///
500    /// A is (m x k), B is (k x n), C is (m x n).
501    /// Uses optimized GEMM from matrixmultiply crate for f32/f64,
502    /// falls back to cache-efficient tiled implementation for other types.
503    pub fn matmul<T: Numeric>(c: &mut [T], a: &[T], b: &[T], m: usize, n: usize, k: usize) {
504        debug_assert_eq!(a.len(), m * k);
505        debug_assert_eq!(b.len(), k * n);
506        debug_assert_eq!(c.len(), m * n);
507
508        // Use optimized BLAS routines for f32 and f64
509        use std::any::TypeId;
510        if TypeId::of::<T>() == TypeId::of::<f32>() {
511            // SAFETY: We verified T is f32, so the casts are safe
512            unsafe {
513                let a_f32: &[f32] = &*(a as *const [T] as *const [f32]);
514                let b_f32: &[f32] = &*(b as *const [T] as *const [f32]);
515                let c_f32: &mut [f32] = &mut *(c as *mut [T] as *mut [f32]);
516                Self::matmul_f32(c_f32, a_f32, b_f32, m, n, k);
517            }
518            return;
519        }
520
521        if TypeId::of::<T>() == TypeId::of::<f64>() {
522            // SAFETY: We verified T is f64, so the casts are safe
523            unsafe {
524                let a_f64: &[f64] = &*(a as *const [T] as *const [f64]);
525                let b_f64: &[f64] = &*(b as *const [T] as *const [f64]);
526                let c_f64: &mut [f64] = &mut *(c as *mut [T] as *mut [f64]);
527                Self::matmul_f64(c_f64, a_f64, b_f64, m, n, k);
528            }
529            return;
530        }
531
532        // Fallback: Use cache-efficient tiled matrix multiplication
533        // Block size chosen for typical L1 cache (32KB)
534        const BLOCK_SIZE: usize = 64;
535
536        // Initialize C to zero
537        for val in c.iter_mut() {
538            *val = T::zero();
539        }
540
541        // Tiled matrix multiplication for better cache locality
542        for i0 in (0..m).step_by(BLOCK_SIZE) {
543            let i_end = (i0 + BLOCK_SIZE).min(m);
544            for p0 in (0..k).step_by(BLOCK_SIZE) {
545                let p_end = (p0 + BLOCK_SIZE).min(k);
546                for j0 in (0..n).step_by(BLOCK_SIZE) {
547                    let j_end = (j0 + BLOCK_SIZE).min(n);
548
549                    // Compute block C[i0:i_end, j0:j_end] += A[i0:i_end, p0:p_end] @ B[p0:p_end, j0:j_end]
550                    for i in i0..i_end {
551                        for p in p0..p_end {
552                            let a_val = a[i * k + p];
553                            for j in j0..j_end {
554                                c[i * n + j] = c[i * n + j] + a_val * b[p * n + j];
555                            }
556                        }
557                    }
558                }
559            }
560        }
561    }
562
563    /// Performs optimized f32 matrix multiplication using matrixmultiply crate.
564    ///
565    /// C = alpha * A @ B + beta * C
566    pub fn sgemm(
567        c: &mut [f32],
568        a: &[f32],
569        b: &[f32],
570        m: usize,
571        n: usize,
572        k: usize,
573        alpha: f32,
574        beta: f32,
575    ) {
576        debug_assert_eq!(a.len(), m * k);
577        debug_assert_eq!(b.len(), k * n);
578        debug_assert_eq!(c.len(), m * n);
579
580        unsafe {
581            matrixmultiply::sgemm(
582                m,
583                k,
584                n,
585                alpha,
586                a.as_ptr(),
587                k as isize,
588                1, // A: row-major (m x k)
589                b.as_ptr(),
590                n as isize,
591                1, // B: row-major (k x n)
592                beta,
593                c.as_mut_ptr(),
594                n as isize,
595                1, // C: row-major (m x n)
596            );
597        }
598    }
599
600    /// Performs optimized f64 matrix multiplication using matrixmultiply crate.
601    ///
602    /// C = alpha * A @ B + beta * C
603    pub fn dgemm(
604        c: &mut [f64],
605        a: &[f64],
606        b: &[f64],
607        m: usize,
608        n: usize,
609        k: usize,
610        alpha: f64,
611        beta: f64,
612    ) {
613        debug_assert_eq!(a.len(), m * k);
614        debug_assert_eq!(b.len(), k * n);
615        debug_assert_eq!(c.len(), m * n);
616
617        unsafe {
618            matrixmultiply::dgemm(
619                m,
620                k,
621                n,
622                alpha,
623                a.as_ptr(),
624                k as isize,
625                1, // A: row-major (m x k)
626                b.as_ptr(),
627                n as isize,
628                1, // B: row-major (k x n)
629                beta,
630                c.as_mut_ptr(),
631                n as isize,
632                1, // C: row-major (m x n)
633            );
634        }
635    }
636
637    /// Performs f32 matrix multiplication: C = A @ B using optimized GEMM.
638    pub fn matmul_f32(c: &mut [f32], a: &[f32], b: &[f32], m: usize, n: usize, k: usize) {
639        Self::sgemm(c, a, b, m, n, k, 1.0, 0.0);
640    }
641
642    /// Performs f64 matrix multiplication: C = A @ B using optimized GEMM.
643    pub fn matmul_f64(c: &mut [f64], a: &[f64], b: &[f64], m: usize, n: usize, k: usize) {
644        Self::dgemm(c, a, b, m, n, k, 1.0, 0.0);
645    }
646
647    /// Transposes a matrix.
648    ///
649    /// A is (rows x cols), B is (cols x rows).
650    pub fn transpose<T: Scalar>(dst: &mut [T], src: &[T], rows: usize, cols: usize) {
651        debug_assert_eq!(src.len(), rows * cols);
652        debug_assert_eq!(dst.len(), rows * cols);
653
654        for i in 0..rows {
655            for j in 0..cols {
656                dst[j * rows + i] = src[i * cols + j];
657            }
658        }
659    }
660
661    /// Computes dot product of two vectors.
662    pub fn dot<T: Numeric>(a: &[T], b: &[T]) -> T {
663        debug_assert_eq!(a.len(), b.len());
664
665        let mut sum = T::zero();
666        for i in 0..a.len() {
667            sum = sum + a[i] * b[i];
668        }
669        sum
670    }
671}
672
673// =============================================================================
674// Comparison Operations
675// =============================================================================
676
677impl CpuBackend {
678    /// Element-wise equality comparison.
679    pub fn eq<T: Scalar + PartialEq>(dst: &mut [bool], a: &[T], b: &[T]) {
680        debug_assert_eq!(a.len(), b.len());
681        debug_assert_eq!(a.len(), dst.len());
682
683        for i in 0..dst.len() {
684            dst[i] = a[i] == b[i];
685        }
686    }
687
688    /// Element-wise less-than comparison.
689    pub fn lt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
690        debug_assert_eq!(a.len(), b.len());
691        debug_assert_eq!(a.len(), dst.len());
692
693        for i in 0..dst.len() {
694            dst[i] = a[i] < b[i];
695        }
696    }
697
698    /// Element-wise greater-than comparison.
699    pub fn gt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
700        debug_assert_eq!(a.len(), b.len());
701        debug_assert_eq!(a.len(), dst.len());
702
703        for i in 0..dst.len() {
704            dst[i] = a[i] > b[i];
705        }
706    }
707}
708
709// =============================================================================
710// Fill Operations
711// =============================================================================
712
713impl CpuBackend {
714    /// Fills a slice with a value.
715    pub fn fill<T: Scalar>(dst: &mut [T], value: T) {
716        for elem in dst.iter_mut() {
717            *elem = value;
718        }
719    }
720
721    /// Fills a slice with zeros.
722    pub fn fill_zeros<T: Scalar>(dst: &mut [T]) {
723        Self::fill(dst, T::zeroed());
724    }
725
726    /// Copies from source to destination.
727    pub fn copy<T: Scalar>(dst: &mut [T], src: &[T]) {
728        debug_assert_eq!(dst.len(), src.len());
729        dst.copy_from_slice(src);
730    }
731}
732
733// =============================================================================
734// Tests
735// =============================================================================
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740
741    #[test]
742    fn test_add() {
743        let a = [1.0_f32, 2.0, 3.0];
744        let b = [4.0_f32, 5.0, 6.0];
745        let mut c = [0.0_f32; 3];
746
747        CpuBackend::add(&mut c, &a, &b);
748        assert_eq!(c, [5.0, 7.0, 9.0]);
749    }
750
751    #[test]
752    fn test_mul() {
753        let a = [2.0_f32, 3.0, 4.0];
754        let b = [2.0_f32, 2.0, 2.0];
755        let mut c = [0.0_f32; 3];
756
757        CpuBackend::mul(&mut c, &a, &b);
758        assert_eq!(c, [4.0, 6.0, 8.0]);
759    }
760
761    #[test]
762    fn test_relu() {
763        let a = [-1.0_f32, 0.0, 1.0, 2.0];
764        let mut b = [0.0_f32; 4];
765
766        CpuBackend::relu(&mut b, &a);
767        assert_eq!(b, [0.0, 0.0, 1.0, 2.0]);
768    }
769
770    #[test]
771    fn test_sum() {
772        let a = [1.0_f32, 2.0, 3.0, 4.0];
773        assert_eq!(CpuBackend::sum(&a), 10.0);
774    }
775
776    #[test]
777    fn test_max_min() {
778        let a = [1.0_f32, 4.0, 2.0, 3.0];
779        assert_eq!(CpuBackend::max(&a), Some(4.0));
780        assert_eq!(CpuBackend::min(&a), Some(1.0));
781    }
782
783    #[test]
784    fn test_argmax() {
785        let a = [1.0_f32, 4.0, 2.0, 3.0];
786        assert_eq!(CpuBackend::argmax(&a), Some(1));
787    }
788
789    #[test]
790    fn test_matmul() {
791        // A = [[1, 2], [3, 4]] (2x2)
792        // B = [[5, 6], [7, 8]] (2x2)
793        // C = [[19, 22], [43, 50]]
794        let a = [1.0_f32, 2.0, 3.0, 4.0];
795        let b = [5.0_f32, 6.0, 7.0, 8.0];
796        let mut c = [0.0_f32; 4];
797
798        CpuBackend::matmul(&mut c, &a, &b, 2, 2, 2);
799        assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
800    }
801
802    #[test]
803    fn test_transpose() {
804        // A = [[1, 2, 3], [4, 5, 6]] (2x3)
805        // B = [[1, 4], [2, 5], [3, 6]] (3x2)
806        let a = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
807        let mut b = [0.0_f32; 6];
808
809        CpuBackend::transpose(&mut b, &a, 2, 3);
810        assert_eq!(b, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
811    }
812
813    #[test]
814    fn test_dot() {
815        let a = [1.0_f32, 2.0, 3.0];
816        let b = [4.0_f32, 5.0, 6.0];
817        assert_eq!(CpuBackend::dot(&a, &b), 32.0);
818    }
819
820    #[test]
821    fn test_fill() {
822        let mut a = [0.0_f32; 5];
823        CpuBackend::fill(&mut a, 42.0);
824        assert_eq!(a, [42.0; 5]);
825    }
826}