axonml_core/backends/
cpu.rs

1//! CPU Backend - Host Memory Operations
2//!
3//! Provides the CPU implementation for tensor operations using host memory.
4//! This is the default backend that is always available.
5//!
6//! # Key Features
7//! - SIMD-optimized operations where possible
8//! - Multi-threaded execution via rayon
9//! - matrixmultiply crate for optimized GEMM operations
10//!
11//! @version 0.1.0
12//! @author `AutomataNexus` Development Team
13
14use super::Backend;
15use crate::device::DeviceCapabilities;
16use crate::dtype::{Float, Numeric, Scalar};
17use sysinfo::System;
18
19// =============================================================================
20// CPU Backend Struct
21// =============================================================================
22
23/// CPU backend for tensor operations.
24#[derive(Debug, Clone, Copy, Default)]
25pub struct CpuBackend;
26
27impl CpuBackend {
28    /// Creates a new CPU backend.
29    #[must_use]
30    pub const fn new() -> Self {
31        Self
32    }
33}
34
35// =============================================================================
36// Backend Trait Implementation
37// =============================================================================
38
39impl Backend for CpuBackend {
40    fn name(&self) -> &'static str {
41        "cpu"
42    }
43
44    fn is_available(&self) -> bool {
45        true // CPU is always available
46    }
47
48    fn capabilities(&self) -> DeviceCapabilities {
49        DeviceCapabilities {
50            name: "CPU".to_string(),
51            total_memory: get_system_memory(),
52            available_memory: get_available_memory(),
53            supports_f16: true,
54            supports_f64: true,
55            max_threads_per_block: num_cpus(),
56            compute_capability: None,
57        }
58    }
59
60    fn allocate(&self, size: usize) -> *mut u8 {
61        if size == 0 {
62            return std::ptr::null_mut();
63        }
64        unsafe {
65            let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
66            std::alloc::alloc(layout)
67        }
68    }
69
70    fn deallocate(&self, ptr: *mut u8, size: usize) {
71        if ptr.is_null() || size == 0 {
72            return;
73        }
74        unsafe {
75            let layout = std::alloc::Layout::from_size_align_unchecked(size, 64);
76            std::alloc::dealloc(ptr, layout);
77        }
78    }
79
80    fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
81        // For CPU, this is just a memory copy
82        unsafe {
83            std::ptr::copy_nonoverlapping(src, dst, size);
84        }
85    }
86
87    fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
88        // For CPU, this is just a memory copy
89        unsafe {
90            std::ptr::copy_nonoverlapping(src, dst, size);
91        }
92    }
93
94    fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
95        // For CPU, this is just a memory copy
96        unsafe {
97            std::ptr::copy_nonoverlapping(src, dst, size);
98        }
99    }
100
101    fn synchronize(&self) {
102        // No-op for CPU - operations are synchronous
103    }
104}
105
106// =============================================================================
107// Helper Functions
108// =============================================================================
109
110/// Returns the total system memory in bytes.
111fn get_system_memory() -> usize {
112    let sys = System::new_all();
113    sys.total_memory() as usize
114}
115
116/// Returns the available system memory in bytes.
117fn get_available_memory() -> usize {
118    let sys = System::new_all();
119    sys.available_memory() as usize
120}
121
122/// Returns the number of CPU cores.
123fn num_cpus() -> usize {
124    std::thread::available_parallelism()
125        .map(std::num::NonZeroUsize::get)
126        .unwrap_or(1)
127}
128
129// =============================================================================
130// Element-wise Operations
131// =============================================================================
132
133impl CpuBackend {
134    /// Adds two slices element-wise.
135    pub fn add<T: Numeric>(dst: &mut [T], a: &[T], b: &[T]) {
136        debug_assert_eq!(a.len(), b.len());
137        debug_assert_eq!(a.len(), dst.len());
138
139        for i in 0..dst.len() {
140            dst[i] = a[i] + b[i];
141        }
142    }
143
144    /// Subtracts two slices element-wise.
145    pub fn sub<T: Numeric>(dst: &mut [T], a: &[T], b: &[T]) {
146        debug_assert_eq!(a.len(), b.len());
147        debug_assert_eq!(a.len(), dst.len());
148
149        for i in 0..dst.len() {
150            dst[i] = a[i] - b[i];
151        }
152    }
153
154    /// Multiplies two slices element-wise.
155    pub fn mul<T: Numeric>(dst: &mut [T], a: &[T], b: &[T]) {
156        debug_assert_eq!(a.len(), b.len());
157        debug_assert_eq!(a.len(), dst.len());
158
159        for i in 0..dst.len() {
160            dst[i] = a[i] * b[i];
161        }
162    }
163
164    /// Divides two slices element-wise.
165    pub fn div<T: Numeric>(dst: &mut [T], a: &[T], b: &[T]) {
166        debug_assert_eq!(a.len(), b.len());
167        debug_assert_eq!(a.len(), dst.len());
168
169        for i in 0..dst.len() {
170            dst[i] = a[i] / b[i];
171        }
172    }
173
174    /// Adds a scalar to each element.
175    pub fn add_scalar<T: Numeric>(dst: &mut [T], a: &[T], scalar: T) {
176        debug_assert_eq!(a.len(), dst.len());
177
178        for i in 0..dst.len() {
179            dst[i] = a[i] + scalar;
180        }
181    }
182
183    /// Multiplies each element by a scalar.
184    pub fn mul_scalar<T: Numeric>(dst: &mut [T], a: &[T], scalar: T) {
185        debug_assert_eq!(a.len(), dst.len());
186
187        for i in 0..dst.len() {
188            dst[i] = a[i] * scalar;
189        }
190    }
191
192    /// Negates each element.
193    pub fn neg<T: Numeric>(dst: &mut [T], a: &[T]) {
194        debug_assert_eq!(a.len(), dst.len());
195
196        for i in 0..dst.len() {
197            dst[i] = T::zero() - a[i];
198        }
199    }
200
201    /// Computes absolute value of each element.
202    pub fn abs<T: Numeric>(dst: &mut [T], a: &[T]) {
203        debug_assert_eq!(a.len(), dst.len());
204
205        for i in 0..dst.len() {
206            dst[i] = if a[i] < T::zero() {
207                T::zero() - a[i]
208            } else {
209                a[i]
210            };
211        }
212    }
213}
214
215// =============================================================================
216// Activation Functions
217// =============================================================================
218
219impl CpuBackend {
220    /// Applies `ReLU` activation: max(0, x).
221    pub fn relu<T: Float>(dst: &mut [T], a: &[T]) {
222        debug_assert_eq!(a.len(), dst.len());
223
224        for i in 0..dst.len() {
225            dst[i] = if a[i] > T::zero() { a[i] } else { T::zero() };
226        }
227    }
228
229    /// Applies sigmoid activation: 1 / (1 + exp(-x)).
230    pub fn sigmoid<T: Float>(dst: &mut [T], a: &[T]) {
231        debug_assert_eq!(a.len(), dst.len());
232
233        for i in 0..dst.len() {
234            dst[i] = T::one() / (T::one() + (-a[i]).exp_value());
235        }
236    }
237
238    /// Applies tanh activation.
239    pub fn tanh<T: Float>(dst: &mut [T], a: &[T]) {
240        debug_assert_eq!(a.len(), dst.len());
241
242        for i in 0..dst.len() {
243            dst[i] = a[i].tanh_value();
244        }
245    }
246
247    /// Applies exponential function.
248    pub fn exp<T: Float>(dst: &mut [T], a: &[T]) {
249        debug_assert_eq!(a.len(), dst.len());
250
251        for i in 0..dst.len() {
252            dst[i] = a[i].exp_value();
253        }
254    }
255
256    /// Applies natural logarithm.
257    pub fn ln<T: Float>(dst: &mut [T], a: &[T]) {
258        debug_assert_eq!(a.len(), dst.len());
259
260        for i in 0..dst.len() {
261            dst[i] = a[i].ln_value();
262        }
263    }
264
265    /// Applies square root.
266    pub fn sqrt<T: Float>(dst: &mut [T], a: &[T]) {
267        debug_assert_eq!(a.len(), dst.len());
268
269        for i in 0..dst.len() {
270            dst[i] = a[i].sqrt_value();
271        }
272    }
273
274    /// Squares each element.
275    pub fn square<T: Numeric>(dst: &mut [T], a: &[T]) {
276        debug_assert_eq!(a.len(), dst.len());
277
278        for i in 0..dst.len() {
279            dst[i] = a[i] * a[i];
280        }
281    }
282}
283
284// =============================================================================
285// Reduction Operations
286// =============================================================================
287
288impl CpuBackend {
289    /// Computes the sum of all elements.
290    pub fn sum<T: Numeric>(a: &[T]) -> T {
291        let mut result = T::zero();
292        for &val in a {
293            result = result + val;
294        }
295        result
296    }
297
298    /// Computes the product of all elements.
299    pub fn prod<T: Numeric>(a: &[T]) -> T {
300        let mut result = T::one();
301        for &val in a {
302            result = result * val;
303        }
304        result
305    }
306
307    /// Finds the maximum element.
308    pub fn max<T: Numeric>(a: &[T]) -> Option<T> {
309        if a.is_empty() {
310            return None;
311        }
312
313        let mut result = a[0];
314        for &val in &a[1..] {
315            if val > result {
316                result = val;
317            }
318        }
319        Some(result)
320    }
321
322    /// Finds the minimum element.
323    pub fn min<T: Numeric>(a: &[T]) -> Option<T> {
324        if a.is_empty() {
325            return None;
326        }
327
328        let mut result = a[0];
329        for &val in &a[1..] {
330            if val < result {
331                result = val;
332            }
333        }
334        Some(result)
335    }
336
337    /// Computes the mean of all elements.
338    pub fn mean<T: Float>(a: &[T]) -> Option<T> {
339        if a.is_empty() {
340            return None;
341        }
342
343        let sum = Self::sum(a);
344        let len = T::from(a.len()).unwrap_or(T::one());
345        Some(sum / len)
346    }
347
348    /// Finds the index of the maximum element.
349    pub fn argmax<T: Numeric>(a: &[T]) -> Option<usize> {
350        if a.is_empty() {
351            return None;
352        }
353
354        let mut max_idx = 0;
355        let mut max_val = a[0];
356        for (i, &val) in a.iter().enumerate().skip(1) {
357            if val > max_val {
358                max_val = val;
359                max_idx = i;
360            }
361        }
362        Some(max_idx)
363    }
364
365    /// Finds the index of the minimum element.
366    pub fn argmin<T: Numeric>(a: &[T]) -> Option<usize> {
367        if a.is_empty() {
368            return None;
369        }
370
371        let mut min_idx = 0;
372        let mut min_val = a[0];
373        for (i, &val) in a.iter().enumerate().skip(1) {
374            if val < min_val {
375                min_val = val;
376                min_idx = i;
377            }
378        }
379        Some(min_idx)
380    }
381}
382
383// =============================================================================
384// Matrix Operations
385// =============================================================================
386
387impl CpuBackend {
388    /// Performs matrix multiplication: C = A @ B.
389    ///
390    /// A is (m x k), B is (k x n), C is (m x n).
391    /// Uses optimized GEMM from matrixmultiply crate for f32/f64,
392    /// falls back to cache-efficient tiled implementation for other types.
393    pub fn matmul<T: Numeric>(c: &mut [T], a: &[T], b: &[T], m: usize, n: usize, k: usize) {
394        debug_assert_eq!(a.len(), m * k);
395        debug_assert_eq!(b.len(), k * n);
396        debug_assert_eq!(c.len(), m * n);
397
398        // Use cache-efficient tiled matrix multiplication
399        // Block size chosen for typical L1 cache (32KB)
400        const BLOCK_SIZE: usize = 64;
401
402        // Initialize C to zero
403        for val in c.iter_mut() {
404            *val = T::zero();
405        }
406
407        // Tiled matrix multiplication for better cache locality
408        for i0 in (0..m).step_by(BLOCK_SIZE) {
409            let i_end = (i0 + BLOCK_SIZE).min(m);
410            for p0 in (0..k).step_by(BLOCK_SIZE) {
411                let p_end = (p0 + BLOCK_SIZE).min(k);
412                for j0 in (0..n).step_by(BLOCK_SIZE) {
413                    let j_end = (j0 + BLOCK_SIZE).min(n);
414
415                    // Compute block C[i0:i_end, j0:j_end] += A[i0:i_end, p0:p_end] @ B[p0:p_end, j0:j_end]
416                    for i in i0..i_end {
417                        for p in p0..p_end {
418                            let a_val = a[i * k + p];
419                            for j in j0..j_end {
420                                c[i * n + j] = c[i * n + j] + a_val * b[p * n + j];
421                            }
422                        }
423                    }
424                }
425            }
426        }
427    }
428
429    /// Performs optimized f32 matrix multiplication using matrixmultiply crate.
430    ///
431    /// C = alpha * A @ B + beta * C
432    pub fn sgemm(
433        c: &mut [f32],
434        a: &[f32],
435        b: &[f32],
436        m: usize,
437        n: usize,
438        k: usize,
439        alpha: f32,
440        beta: f32,
441    ) {
442        debug_assert_eq!(a.len(), m * k);
443        debug_assert_eq!(b.len(), k * n);
444        debug_assert_eq!(c.len(), m * n);
445
446        unsafe {
447            matrixmultiply::sgemm(
448                m,
449                k,
450                n,
451                alpha,
452                a.as_ptr(),
453                k as isize,
454                1, // A: row-major (m x k)
455                b.as_ptr(),
456                n as isize,
457                1, // B: row-major (k x n)
458                beta,
459                c.as_mut_ptr(),
460                n as isize,
461                1, // C: row-major (m x n)
462            );
463        }
464    }
465
466    /// Performs optimized f64 matrix multiplication using matrixmultiply crate.
467    ///
468    /// C = alpha * A @ B + beta * C
469    pub fn dgemm(
470        c: &mut [f64],
471        a: &[f64],
472        b: &[f64],
473        m: usize,
474        n: usize,
475        k: usize,
476        alpha: f64,
477        beta: f64,
478    ) {
479        debug_assert_eq!(a.len(), m * k);
480        debug_assert_eq!(b.len(), k * n);
481        debug_assert_eq!(c.len(), m * n);
482
483        unsafe {
484            matrixmultiply::dgemm(
485                m,
486                k,
487                n,
488                alpha,
489                a.as_ptr(),
490                k as isize,
491                1, // A: row-major (m x k)
492                b.as_ptr(),
493                n as isize,
494                1, // B: row-major (k x n)
495                beta,
496                c.as_mut_ptr(),
497                n as isize,
498                1, // C: row-major (m x n)
499            );
500        }
501    }
502
503    /// Performs f32 matrix multiplication: C = A @ B using optimized GEMM.
504    pub fn matmul_f32(c: &mut [f32], a: &[f32], b: &[f32], m: usize, n: usize, k: usize) {
505        Self::sgemm(c, a, b, m, n, k, 1.0, 0.0);
506    }
507
508    /// Performs f64 matrix multiplication: C = A @ B using optimized GEMM.
509    pub fn matmul_f64(c: &mut [f64], a: &[f64], b: &[f64], m: usize, n: usize, k: usize) {
510        Self::dgemm(c, a, b, m, n, k, 1.0, 0.0);
511    }
512
513    /// Transposes a matrix.
514    ///
515    /// A is (rows x cols), B is (cols x rows).
516    pub fn transpose<T: Scalar>(dst: &mut [T], src: &[T], rows: usize, cols: usize) {
517        debug_assert_eq!(src.len(), rows * cols);
518        debug_assert_eq!(dst.len(), rows * cols);
519
520        for i in 0..rows {
521            for j in 0..cols {
522                dst[j * rows + i] = src[i * cols + j];
523            }
524        }
525    }
526
527    /// Computes dot product of two vectors.
528    pub fn dot<T: Numeric>(a: &[T], b: &[T]) -> T {
529        debug_assert_eq!(a.len(), b.len());
530
531        let mut sum = T::zero();
532        for i in 0..a.len() {
533            sum = sum + a[i] * b[i];
534        }
535        sum
536    }
537}
538
539// =============================================================================
540// Comparison Operations
541// =============================================================================
542
543impl CpuBackend {
544    /// Element-wise equality comparison.
545    pub fn eq<T: Scalar + PartialEq>(dst: &mut [bool], a: &[T], b: &[T]) {
546        debug_assert_eq!(a.len(), b.len());
547        debug_assert_eq!(a.len(), dst.len());
548
549        for i in 0..dst.len() {
550            dst[i] = a[i] == b[i];
551        }
552    }
553
554    /// Element-wise less-than comparison.
555    pub fn lt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
556        debug_assert_eq!(a.len(), b.len());
557        debug_assert_eq!(a.len(), dst.len());
558
559        for i in 0..dst.len() {
560            dst[i] = a[i] < b[i];
561        }
562    }
563
564    /// Element-wise greater-than comparison.
565    pub fn gt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
566        debug_assert_eq!(a.len(), b.len());
567        debug_assert_eq!(a.len(), dst.len());
568
569        for i in 0..dst.len() {
570            dst[i] = a[i] > b[i];
571        }
572    }
573}
574
575// =============================================================================
576// Fill Operations
577// =============================================================================
578
579impl CpuBackend {
580    /// Fills a slice with a value.
581    pub fn fill<T: Scalar>(dst: &mut [T], value: T) {
582        for elem in dst.iter_mut() {
583            *elem = value;
584        }
585    }
586
587    /// Fills a slice with zeros.
588    pub fn fill_zeros<T: Scalar>(dst: &mut [T]) {
589        Self::fill(dst, T::zeroed());
590    }
591
592    /// Copies from source to destination.
593    pub fn copy<T: Scalar>(dst: &mut [T], src: &[T]) {
594        debug_assert_eq!(dst.len(), src.len());
595        dst.copy_from_slice(src);
596    }
597}
598
599// =============================================================================
600// Tests
601// =============================================================================
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn test_add() {
609        let a = [1.0_f32, 2.0, 3.0];
610        let b = [4.0_f32, 5.0, 6.0];
611        let mut c = [0.0_f32; 3];
612
613        CpuBackend::add(&mut c, &a, &b);
614        assert_eq!(c, [5.0, 7.0, 9.0]);
615    }
616
617    #[test]
618    fn test_mul() {
619        let a = [2.0_f32, 3.0, 4.0];
620        let b = [2.0_f32, 2.0, 2.0];
621        let mut c = [0.0_f32; 3];
622
623        CpuBackend::mul(&mut c, &a, &b);
624        assert_eq!(c, [4.0, 6.0, 8.0]);
625    }
626
627    #[test]
628    fn test_relu() {
629        let a = [-1.0_f32, 0.0, 1.0, 2.0];
630        let mut b = [0.0_f32; 4];
631
632        CpuBackend::relu(&mut b, &a);
633        assert_eq!(b, [0.0, 0.0, 1.0, 2.0]);
634    }
635
636    #[test]
637    fn test_sum() {
638        let a = [1.0_f32, 2.0, 3.0, 4.0];
639        assert_eq!(CpuBackend::sum(&a), 10.0);
640    }
641
642    #[test]
643    fn test_max_min() {
644        let a = [1.0_f32, 4.0, 2.0, 3.0];
645        assert_eq!(CpuBackend::max(&a), Some(4.0));
646        assert_eq!(CpuBackend::min(&a), Some(1.0));
647    }
648
649    #[test]
650    fn test_argmax() {
651        let a = [1.0_f32, 4.0, 2.0, 3.0];
652        assert_eq!(CpuBackend::argmax(&a), Some(1));
653    }
654
655    #[test]
656    fn test_matmul() {
657        // A = [[1, 2], [3, 4]] (2x2)
658        // B = [[5, 6], [7, 8]] (2x2)
659        // C = [[19, 22], [43, 50]]
660        let a = [1.0_f32, 2.0, 3.0, 4.0];
661        let b = [5.0_f32, 6.0, 7.0, 8.0];
662        let mut c = [0.0_f32; 4];
663
664        CpuBackend::matmul(&mut c, &a, &b, 2, 2, 2);
665        assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
666    }
667
668    #[test]
669    fn test_transpose() {
670        // A = [[1, 2, 3], [4, 5, 6]] (2x3)
671        // B = [[1, 4], [2, 5], [3, 6]] (3x2)
672        let a = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
673        let mut b = [0.0_f32; 6];
674
675        CpuBackend::transpose(&mut b, &a, 2, 3);
676        assert_eq!(b, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
677    }
678
679    #[test]
680    fn test_dot() {
681        let a = [1.0_f32, 2.0, 3.0];
682        let b = [4.0_f32, 5.0, 6.0];
683        assert_eq!(CpuBackend::dot(&a, &b), 32.0);
684    }
685
686    #[test]
687    fn test_fill() {
688        let mut a = [0.0_f32; 5];
689        CpuBackend::fill(&mut a, 42.0);
690        assert_eq!(a, [42.0; 5]);
691    }
692}