use super::compute_backend::{
ComputeBackend, ComputeBackendGeneric, DeviceType, Operation, PerformanceMetrics, ReduceOp,
TransferDirection,
};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
#[derive(Debug, Clone, Default)]
pub struct CpuSimdFeatures {
pub avx2: bool,
pub avx512f: bool,
pub fma: bool,
pub sse41: bool,
}
impl CpuSimdFeatures {
pub fn detect() -> Self {
#[cfg(target_arch = "x86_64")]
{
Self {
avx2: is_x86_feature_detected!("avx2"),
avx512f: is_x86_feature_detected!("avx512f"),
fma: is_x86_feature_detected!("fma"),
sse41: is_x86_feature_detected!("sse4.1"),
}
}
#[cfg(not(target_arch = "x86_64"))]
{
Self::default()
}
}
}
pub struct CpuBackend {
#[allow(dead_code)]
simd_features: CpuSimdFeatures,
last_metrics: Arc<Mutex<PerformanceMetrics>>,
config: Arc<Mutex<HashMap<String, Box<dyn Any + Send + Sync>>>>,
}
impl CpuBackend {
pub fn new() -> RusTorchResult<Self> {
let simd_features = CpuSimdFeatures::detect();
let default_metrics = PerformanceMetrics {
execution_time_ns: 0,
memory_bandwidth_gbps: 0.0,
device_utilization: 0.0,
memory_usage_bytes: 0,
};
Ok(Self {
simd_features,
last_metrics: Arc::new(Mutex::new(default_metrics)),
config: Arc::new(Mutex::new(HashMap::new())),
})
}
fn execute_add<T>(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>>
where
T: Float + Clone + Send + Sync + 'static,
{
let start = Instant::now();
if a.shape() != b.shape() {
return Err(RusTorchError::IncompatibleShapes(format!(
"Cannot add tensors with shapes {:?} and {:?}",
a.shape(),
b.shape()
)));
}
let result_data: Vec<T> = a
.as_slice()
.unwrap()
.iter()
.zip(b.as_slice().unwrap().iter())
.map(|(x, y)| *x + *y)
.collect();
let result = Tensor::from_vec(result_data, a.shape().to_vec());
let execution_time = start.elapsed();
let data_size = a.shape().iter().product::<usize>() * std::mem::size_of::<T>();
let mut metrics = self.last_metrics.lock().unwrap();
metrics.execution_time_ns = execution_time.as_nanos() as u64;
metrics.memory_bandwidth_gbps =
(data_size as f64 * 2.0) / (execution_time.as_secs_f64() * 1e9);
metrics.device_utilization = 100.0;
metrics.memory_usage_bytes = data_size * 3;
Ok(result)
}
fn execute_multiply<T>(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>>
where
T: Float + Clone + Send + Sync + 'static,
{
let start = Instant::now();
if a.shape() != b.shape() {
return Err(RusTorchError::IncompatibleShapes(format!(
"Cannot multiply tensors with shapes {:?} and {:?}",
a.shape(),
b.shape()
)));
}
let result_data: Vec<T> = a
.as_slice()
.unwrap()
.iter()
.zip(b.as_slice().unwrap().iter())
.map(|(x, y)| *x * *y)
.collect();
let result = Tensor::from_vec(result_data, a.shape().to_vec());
let execution_time = start.elapsed();
let data_size = a.shape().iter().product::<usize>() * std::mem::size_of::<T>();
let mut metrics = self.last_metrics.lock().unwrap();
metrics.execution_time_ns = execution_time.as_nanos() as u64;
metrics.memory_bandwidth_gbps =
(data_size as f64 * 2.0) / (execution_time.as_secs_f64() * 1e9);
metrics.device_utilization = 100.0;
metrics.memory_usage_bytes = data_size * 3;
Ok(result)
}
fn execute_matmul<T>(&self, a: &Tensor<T>, b: &Tensor<T>) -> RusTorchResult<Tensor<T>>
where
T: Float
+ Clone
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
let start = Instant::now();
let result = a.matmul(b).map_err(|e| {
RusTorchError::ComputationError(format!("Matrix multiplication failed: {:?}", e))
})?;
let execution_time = start.elapsed();
let flops = a.shape()[0] * a.shape()[1] * b.shape()[1] * 2;
let mut metrics = self.last_metrics.lock().unwrap();
metrics.execution_time_ns = execution_time.as_nanos() as u64;
metrics.memory_bandwidth_gbps = (flops as f64) / (execution_time.as_secs_f64() * 1e9);
metrics.device_utilization = 100.0;
metrics.memory_usage_bytes =
(a.numel() + b.numel() + result.numel()) * std::mem::size_of::<T>();
Ok(result)
}
fn execute_reduce<T>(
&self,
input: &Tensor<T>,
op: ReduceOp,
_axes: Option<Vec<usize>>,
) -> RusTorchResult<Tensor<T>>
where
T: Float + Clone + Send + Sync + 'static,
{
let start = Instant::now();
let result = match op {
ReduceOp::Sum => {
let sum_value = input
.as_slice()
.unwrap()
.iter()
.fold(T::zero(), |acc, &x| acc + x);
Tensor::from_vec(vec![sum_value], vec![1])
}
ReduceOp::Mean => {
let sum_value = input
.as_slice()
.unwrap()
.iter()
.fold(T::zero(), |acc, &x| acc + x);
let count = T::from(input.numel()).unwrap();
let mean_value = sum_value / count;
Tensor::from_vec(vec![mean_value], vec![1])
}
ReduceOp::Max => {
let max_value = input
.as_slice()
.unwrap()
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(T::zero());
Tensor::from_vec(vec![max_value], vec![1])
}
ReduceOp::Min => {
let min_value = input
.as_slice()
.unwrap()
.iter()
.copied()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(T::zero());
Tensor::from_vec(vec![min_value], vec![1])
}
ReduceOp::Product => {
let prod_value = input
.as_slice()
.unwrap()
.iter()
.fold(T::one(), |acc, &x| acc * x);
Tensor::from_vec(vec![prod_value], vec![1])
}
};
let execution_time = start.elapsed();
let data_size = input.numel() * std::mem::size_of::<T>();
let mut metrics = self.last_metrics.lock().unwrap();
metrics.execution_time_ns = execution_time.as_nanos() as u64;
metrics.memory_bandwidth_gbps = (data_size as f64) / (execution_time.as_secs_f64() * 1e9);
metrics.device_utilization = 100.0;
metrics.memory_usage_bytes = data_size + std::mem::size_of::<T>();
Ok(result)
}
}
impl ComputeBackendGeneric for CpuBackend {
fn execute_operation<T>(&self, operation: &Operation<T>) -> RusTorchResult<Tensor<T>>
where
T: Clone + Send + Sync + 'static + num_traits::Float,
{
match operation {
Operation::Add { a, b } => {
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_f32 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f32>>(a) };
let b_f32 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f32>>(b) };
let result = self.execute_add(a_f32, b_f32)?;
Ok(unsafe { std::mem::transmute::<Tensor<f32>, Tensor<T>>(result) })
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f64>>(a) };
let b_f64 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f64>>(b) };
let result = self.execute_add(a_f64, b_f64)?;
Ok(unsafe { std::mem::transmute::<Tensor<f64>, Tensor<T>>(result) })
} else {
Err(RusTorchError::UnsupportedOperation(
"Add operation only supports f32 and f64 types".to_string(),
))
}
}
Operation::Multiply { a, b } => {
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_f32 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f32>>(a) };
let b_f32 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f32>>(b) };
let result = self.execute_multiply(a_f32, b_f32)?;
Ok(unsafe { std::mem::transmute::<Tensor<f32>, Tensor<T>>(result) })
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f64>>(a) };
let b_f64 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f64>>(b) };
let result = self.execute_multiply(a_f64, b_f64)?;
Ok(unsafe { std::mem::transmute::<Tensor<f64>, Tensor<T>>(result) })
} else {
Err(RusTorchError::UnsupportedOperation(
"Multiply operation only supports f32 and f64 types".to_string(),
))
}
}
Operation::MatMul { a, b } => {
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let a_f32 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f32>>(a) };
let b_f32 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f32>>(b) };
let result = self.execute_matmul(a_f32, b_f32)?;
Ok(unsafe { std::mem::transmute::<Tensor<f32>, Tensor<T>>(result) })
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f64>>(a) };
let b_f64 = unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f64>>(b) };
let result = self.execute_matmul(a_f64, b_f64)?;
Ok(unsafe { std::mem::transmute::<Tensor<f64>, Tensor<T>>(result) })
} else {
Err(RusTorchError::UnsupportedOperation(
"MatMul operation only supports f32 and f64 types".to_string(),
))
}
}
Operation::Conv2D { .. } => Err(RusTorchError::UnsupportedOperation(
"Conv2D not implemented for unified CPU backend yet".to_string(),
)),
Operation::Reduce {
input,
operation: op,
axes,
} => {
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let input_f32 =
unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f32>>(input) };
let result = self.execute_reduce(input_f32, *op, axes.clone())?;
Ok(unsafe { std::mem::transmute::<Tensor<f32>, Tensor<T>>(result) })
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let input_f64 =
unsafe { std::mem::transmute::<&Tensor<T>, &Tensor<f64>>(input) };
let result = self.execute_reduce(input_f64, *op, axes.clone())?;
Ok(unsafe { std::mem::transmute::<Tensor<f64>, Tensor<T>>(result) })
} else {
Err(RusTorchError::UnsupportedOperation(
"Reduce operation only supports f32 and f64 types".to_string(),
))
}
}
}
}
fn memory_transfer<T>(
&self,
data: &[T],
_direction: TransferDirection,
) -> RusTorchResult<Vec<T>>
where
T: Clone + Send + Sync + 'static + num_traits::Float,
{
Ok(data.to_vec())
}
}
impl ComputeBackend for CpuBackend {
fn device_type(&self) -> DeviceType {
DeviceType::Cpu
}
fn is_available(&self) -> bool {
true
}
fn initialize(&mut self) -> RusTorchResult<()> {
Ok(())
}
fn get_metrics(&self) -> PerformanceMetrics {
self.last_metrics.lock().unwrap().clone()
}
fn get_info(&self) -> HashMap<String, Box<dyn Any + Send + Sync>> {
let mut info = HashMap::new();
info.insert(
"device_name".to_string(),
Box::new("CPU".to_string()) as Box<dyn Any + Send + Sync>,
);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
let simd_features = if std::arch::is_x86_feature_detected!("avx2") {
"AVX2"
} else if std::arch::is_x86_feature_detected!("avx") {
"AVX"
} else if std::arch::is_x86_feature_detected!("sse4.1") {
"SSE4.1"
} else {
"None"
};
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
let simd_features = "WASM";
info.insert(
"simd_features".to_string(),
Box::new(simd_features.to_string()) as Box<dyn Any + Send + Sync>,
);
info.insert(
"cores".to_string(),
Box::new(num_cpus::get()) as Box<dyn Any + Send + Sync>,
);
info
}
fn synchronize(&self) -> RusTorchResult<()> {
Ok(())
}
fn available_memory(&self) -> RusTorchResult<usize> {
Ok(usize::MAX / 2)
}
fn set_config(&mut self, key: &str, value: Box<dyn Any + Send + Sync>) -> RusTorchResult<()> {
self.config.lock().unwrap().insert(key.to_string(), value);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::compute_backend::Operation;
#[test]
fn test_unified_cpu_backend_creation() {
let backend = CpuBackend::new().unwrap();
assert_eq!(backend.device_type(), DeviceType::Cpu);
assert!(backend.is_available());
}
#[test]
fn test_simd_feature_detection() {
let features = CpuSimdFeatures::detect();
println!("SIMD features: {:?}", features);
}
#[test]
fn test_unified_cpu_add_operation() {
let backend = CpuBackend::new().unwrap();
let a = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], vec![3]);
let b = Tensor::from_vec(vec![4.0f32, 5.0, 6.0], vec![3]);
let operation = Operation::Add {
a: a.clone(),
b: b.clone(),
};
let result = backend.execute_operation(&operation).unwrap();
let expected = vec![5.0f32, 7.0, 9.0];
assert_eq!(result.as_slice().unwrap(), &expected);
let metrics = backend.get_metrics();
assert!(metrics.execution_time_ns > 0);
}
#[test]
fn test_unified_cpu_matmul_operation() {
let backend = CpuBackend::new().unwrap();
let a = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2]);
let b = Tensor::from_vec(vec![5.0f32, 6.0, 7.0, 8.0], vec![2, 2]);
let operation = Operation::MatMul {
a: a.clone(),
b: b.clone(),
};
let result = backend.execute_operation(&operation).unwrap();
assert_eq!(result.shape(), &[2, 2]);
let metrics = backend.get_metrics();
assert!(metrics.execution_time_ns > 0);
assert!(metrics.memory_usage_bytes > 0);
}
#[test]
fn test_unified_cpu_reduce_operation() {
let backend = CpuBackend::new().unwrap();
let input = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], vec![4]);
let operation = Operation::Reduce {
input: input.clone(),
operation: ReduceOp::Sum,
axes: None,
};
let result = backend.execute_operation(&operation).unwrap();
assert_eq!(result.as_slice().unwrap(), &[10.0f32]);
let mean_op = Operation::Reduce {
input: input.clone(),
operation: ReduceOp::Mean,
axes: None,
};
let mean_result = backend.execute_operation(&mean_op).unwrap();
assert_eq!(mean_result.as_slice().unwrap(), &[2.5f32]); }
}