use super::Backend;
use crate::device::DeviceCapabilities;
use crate::dtype::{Float, Numeric, Scalar};
use rayon::prelude::*;
use sysinfo::System;
const PARALLEL_THRESHOLD: usize = 4096;
#[derive(Debug, Clone, Copy, Default)]
pub struct CpuBackend;
impl CpuBackend {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl Backend for CpuBackend {
fn name(&self) -> &'static str {
"cpu"
}
fn is_available(&self) -> bool {
true }
fn capabilities(&self) -> DeviceCapabilities {
DeviceCapabilities {
name: "CPU".to_string(),
total_memory: get_system_memory(),
available_memory: get_available_memory(),
supports_f16: true,
supports_f64: true,
max_threads_per_block: num_cpus(),
compute_capability: None,
}
}
fn allocate(&self, size: usize) -> *mut u8 {
if size == 0 {
return std::ptr::null_mut();
}
let aligned_size = (size + 63) & !63; unsafe {
let layout = std::alloc::Layout::from_size_align_unchecked(aligned_size, 64);
std::alloc::alloc(layout)
}
}
fn deallocate(&self, ptr: *mut u8, size: usize) {
if ptr.is_null() || size == 0 {
return;
}
let aligned_size = (size + 63) & !63;
unsafe {
let layout = std::alloc::Layout::from_size_align_unchecked(aligned_size, 64);
std::alloc::dealloc(ptr, layout);
}
}
fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
unsafe {
std::ptr::copy_nonoverlapping(src, dst, size);
}
}
fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
unsafe {
std::ptr::copy_nonoverlapping(src, dst, size);
}
}
fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
unsafe {
std::ptr::copy_nonoverlapping(src, dst, size);
}
}
fn synchronize(&self) {
}
}
fn get_system_memory() -> usize {
let sys = System::new_all();
sys.total_memory() as usize
}
fn get_available_memory() -> usize {
let sys = System::new_all();
sys.available_memory() as usize
}
fn num_cpus() -> usize {
std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(1)
}
impl CpuBackend {
pub fn add<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut()
.zip(a.par_iter().zip(b.par_iter()))
.for_each(|(d, (a_val, b_val))| {
*d = *a_val + *b_val;
});
} else {
for i in 0..dst.len() {
dst[i] = a[i] + b[i];
}
}
}
pub fn sub<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut()
.zip(a.par_iter().zip(b.par_iter()))
.for_each(|(d, (a_val, b_val))| {
*d = *a_val - *b_val;
});
} else {
for i in 0..dst.len() {
dst[i] = a[i] - b[i];
}
}
}
pub fn mul<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut()
.zip(a.par_iter().zip(b.par_iter()))
.for_each(|(d, (a_val, b_val))| {
*d = *a_val * *b_val;
});
} else {
for i in 0..dst.len() {
dst[i] = a[i] * b[i];
}
}
}
pub fn div<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], b: &[T]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut()
.zip(a.par_iter().zip(b.par_iter()))
.for_each(|(d, (a_val, b_val))| {
*d = *a_val / *b_val;
});
} else {
for i in 0..dst.len() {
dst[i] = a[i] / b[i];
}
}
}
pub fn add_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = *a_val + scalar;
});
} else {
for i in 0..dst.len() {
dst[i] = a[i] + scalar;
}
}
}
pub fn mul_scalar<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T], scalar: T) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = *a_val * scalar;
});
} else {
for i in 0..dst.len() {
dst[i] = a[i] * scalar;
}
}
}
pub fn neg<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = T::zero() - *a_val;
});
} else {
for i in 0..dst.len() {
dst[i] = T::zero() - a[i];
}
}
}
pub fn abs<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = if *a_val < T::zero() {
T::zero() - *a_val
} else {
*a_val
};
});
} else {
for i in 0..dst.len() {
dst[i] = if a[i] < T::zero() {
T::zero() - a[i]
} else {
a[i]
};
}
}
}
}
impl CpuBackend {
pub fn relu<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = if *a_val > T::zero() {
*a_val
} else {
T::zero()
};
});
} else {
for i in 0..dst.len() {
dst[i] = if a[i] > T::zero() { a[i] } else { T::zero() };
}
}
}
pub fn sigmoid<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = T::one() / (T::one() + (-*a_val).exp_value());
});
} else {
for i in 0..dst.len() {
dst[i] = T::one() / (T::one() + (-a[i]).exp_value());
}
}
}
pub fn tanh<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = a_val.tanh_value();
});
} else {
for i in 0..dst.len() {
dst[i] = a[i].tanh_value();
}
}
}
pub fn exp<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = a_val.exp_value();
});
} else {
for i in 0..dst.len() {
dst[i] = a[i].exp_value();
}
}
}
pub fn ln<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = a_val.ln_value();
});
} else {
for i in 0..dst.len() {
dst[i] = a[i].ln_value();
}
}
}
pub fn sqrt<T: Float + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = a_val.sqrt_value();
});
} else {
for i in 0..dst.len() {
dst[i] = a[i].sqrt_value();
}
}
}
pub fn square<T: Numeric + Sync + Send>(dst: &mut [T], a: &[T]) {
debug_assert_eq!(a.len(), dst.len());
if dst.len() >= PARALLEL_THRESHOLD {
dst.par_iter_mut().zip(a.par_iter()).for_each(|(d, a_val)| {
*d = *a_val * *a_val;
});
} else {
for i in 0..dst.len() {
dst[i] = a[i] * a[i];
}
}
}
}
impl CpuBackend {
pub fn sum<T: Numeric>(a: &[T]) -> T {
let mut result = T::zero();
for &val in a {
result = result + val;
}
result
}
pub fn prod<T: Numeric>(a: &[T]) -> T {
let mut result = T::one();
for &val in a {
result = result * val;
}
result
}
pub fn max<T: Numeric>(a: &[T]) -> Option<T> {
if a.is_empty() {
return None;
}
let mut result = a[0];
for &val in &a[1..] {
if val > result {
result = val;
}
}
Some(result)
}
pub fn min<T: Numeric>(a: &[T]) -> Option<T> {
if a.is_empty() {
return None;
}
let mut result = a[0];
for &val in &a[1..] {
if val < result {
result = val;
}
}
Some(result)
}
pub fn mean<T: Float>(a: &[T]) -> Option<T> {
if a.is_empty() {
return None;
}
let sum = Self::sum(a);
let len = T::from(a.len()).unwrap_or(T::one());
Some(sum / len)
}
pub fn argmax<T: Numeric>(a: &[T]) -> Option<usize> {
if a.is_empty() {
return None;
}
let mut max_idx = 0;
let mut max_val = a[0];
for (i, &val) in a.iter().enumerate().skip(1) {
if val > max_val {
max_val = val;
max_idx = i;
}
}
Some(max_idx)
}
pub fn argmin<T: Numeric>(a: &[T]) -> Option<usize> {
if a.is_empty() {
return None;
}
let mut min_idx = 0;
let mut min_val = a[0];
for (i, &val) in a.iter().enumerate().skip(1) {
if val < min_val {
min_val = val;
min_idx = i;
}
}
Some(min_idx)
}
}
impl CpuBackend {
pub fn matmul<T: Numeric>(c: &mut [T], a: &[T], b: &[T], m: usize, n: usize, k: usize) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(c.len(), m * n);
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f32>() {
unsafe {
let a_f32: &[f32] = &*(std::ptr::from_ref::<[T]>(a) as *const [f32]);
let b_f32: &[f32] = &*(std::ptr::from_ref::<[T]>(b) as *const [f32]);
let c_f32: &mut [f32] = &mut *(std::ptr::from_mut::<[T]>(c) as *mut [f32]);
Self::matmul_f32(c_f32, a_f32, b_f32, m, n, k);
}
return;
}
if TypeId::of::<T>() == TypeId::of::<f64>() {
unsafe {
let a_f64: &[f64] = &*(std::ptr::from_ref::<[T]>(a) as *const [f64]);
let b_f64: &[f64] = &*(std::ptr::from_ref::<[T]>(b) as *const [f64]);
let c_f64: &mut [f64] = &mut *(std::ptr::from_mut::<[T]>(c) as *mut [f64]);
Self::matmul_f64(c_f64, a_f64, b_f64, m, n, k);
}
return;
}
const BLOCK_SIZE: usize = 64;
for val in c.iter_mut() {
*val = T::zero();
}
for i0 in (0..m).step_by(BLOCK_SIZE) {
let i_end = (i0 + BLOCK_SIZE).min(m);
for p0 in (0..k).step_by(BLOCK_SIZE) {
let p_end = (p0 + BLOCK_SIZE).min(k);
for j0 in (0..n).step_by(BLOCK_SIZE) {
let j_end = (j0 + BLOCK_SIZE).min(n);
for i in i0..i_end {
for p in p0..p_end {
let a_val = a[i * k + p];
for j in j0..j_end {
c[i * n + j] = c[i * n + j] + a_val * b[p * n + j];
}
}
}
}
}
}
}
pub fn sgemm(
c: &mut [f32],
a: &[f32],
b: &[f32],
m: usize,
n: usize,
k: usize,
alpha: f32,
beta: f32,
) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(c.len(), m * n);
unsafe {
matrixmultiply::sgemm(
m,
k,
n,
alpha,
a.as_ptr(),
k as isize,
1, b.as_ptr(),
n as isize,
1, beta,
c.as_mut_ptr(),
n as isize,
1, );
}
}
pub fn dgemm(
c: &mut [f64],
a: &[f64],
b: &[f64],
m: usize,
n: usize,
k: usize,
alpha: f64,
beta: f64,
) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(c.len(), m * n);
unsafe {
matrixmultiply::dgemm(
m,
k,
n,
alpha,
a.as_ptr(),
k as isize,
1, b.as_ptr(),
n as isize,
1, beta,
c.as_mut_ptr(),
n as isize,
1, );
}
}
pub fn matmul_f32(c: &mut [f32], a: &[f32], b: &[f32], m: usize, n: usize, k: usize) {
Self::sgemm(c, a, b, m, n, k, 1.0, 0.0);
}
pub fn matmul_f64(c: &mut [f64], a: &[f64], b: &[f64], m: usize, n: usize, k: usize) {
Self::dgemm(c, a, b, m, n, k, 1.0, 0.0);
}
pub fn transpose<T: Scalar>(dst: &mut [T], src: &[T], rows: usize, cols: usize) {
debug_assert_eq!(src.len(), rows * cols);
debug_assert_eq!(dst.len(), rows * cols);
for i in 0..rows {
for j in 0..cols {
dst[j * rows + i] = src[i * cols + j];
}
}
}
pub fn dot<T: Numeric>(a: &[T], b: &[T]) -> T {
debug_assert_eq!(a.len(), b.len());
let mut sum = T::zero();
for i in 0..a.len() {
sum = sum + a[i] * b[i];
}
sum
}
}
impl CpuBackend {
pub fn eq<T: Scalar + PartialEq>(dst: &mut [bool], a: &[T], b: &[T]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), dst.len());
for i in 0..dst.len() {
dst[i] = a[i] == b[i];
}
}
pub fn lt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), dst.len());
for i in 0..dst.len() {
dst[i] = a[i] < b[i];
}
}
pub fn gt<T: Numeric>(dst: &mut [bool], a: &[T], b: &[T]) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), dst.len());
for i in 0..dst.len() {
dst[i] = a[i] > b[i];
}
}
}
impl CpuBackend {
pub fn fill<T: Scalar>(dst: &mut [T], value: T) {
for elem in dst.iter_mut() {
*elem = value;
}
}
pub fn fill_zeros<T: Scalar>(dst: &mut [T]) {
Self::fill(dst, T::zeroed());
}
pub fn copy<T: Scalar>(dst: &mut [T], src: &[T]) {
debug_assert_eq!(dst.len(), src.len());
dst.copy_from_slice(src);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add() {
let a = [1.0_f32, 2.0, 3.0];
let b = [4.0_f32, 5.0, 6.0];
let mut c = [0.0_f32; 3];
CpuBackend::add(&mut c, &a, &b);
assert_eq!(c, [5.0, 7.0, 9.0]);
}
#[test]
fn test_mul() {
let a = [2.0_f32, 3.0, 4.0];
let b = [2.0_f32, 2.0, 2.0];
let mut c = [0.0_f32; 3];
CpuBackend::mul(&mut c, &a, &b);
assert_eq!(c, [4.0, 6.0, 8.0]);
}
#[test]
fn test_relu() {
let a = [-1.0_f32, 0.0, 1.0, 2.0];
let mut b = [0.0_f32; 4];
CpuBackend::relu(&mut b, &a);
assert_eq!(b, [0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_sum() {
let a = [1.0_f32, 2.0, 3.0, 4.0];
assert_eq!(CpuBackend::sum(&a), 10.0);
}
#[test]
fn test_max_min() {
let a = [1.0_f32, 4.0, 2.0, 3.0];
assert_eq!(CpuBackend::max(&a), Some(4.0));
assert_eq!(CpuBackend::min(&a), Some(1.0));
}
#[test]
fn test_argmax() {
let a = [1.0_f32, 4.0, 2.0, 3.0];
assert_eq!(CpuBackend::argmax(&a), Some(1));
}
#[test]
fn test_matmul() {
let a = [1.0_f32, 2.0, 3.0, 4.0];
let b = [5.0_f32, 6.0, 7.0, 8.0];
let mut c = [0.0_f32; 4];
CpuBackend::matmul(&mut c, &a, &b, 2, 2, 2);
assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_transpose() {
let a = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut b = [0.0_f32; 6];
CpuBackend::transpose(&mut b, &a, 2, 3);
assert_eq!(b, [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_dot() {
let a = [1.0_f32, 2.0, 3.0];
let b = [4.0_f32, 5.0, 6.0];
assert_eq!(CpuBackend::dot(&a, &b), 32.0);
}
#[test]
fn test_fill() {
let mut a = [0.0_f32; 5];
CpuBackend::fill(&mut a, 42.0);
assert_eq!(a, [42.0; 5]);
}
}