use serde::{Deserialize, Serialize};
#[cfg(feature = "trueno-integration")]
use trueno::{Matrix, Vector};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[allow(clippy::upper_case_acronyms)]
pub enum Backend {
Scalar,
SIMD,
GPU,
}
impl std::fmt::Display for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Backend::Scalar => write!(f, "Scalar"),
Backend::SIMD => write!(f, "SIMD"),
Backend::GPU => write!(f, "GPU"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum OpComplexity {
Low,
Medium,
High,
}
pub struct BackendSelector {
pcie_bandwidth: f64,
gpu_gflops: f64,
min_dispatch_ratio: f64,
}
impl Default for BackendSelector {
fn default() -> Self {
Self {
pcie_bandwidth: 32e9, gpu_gflops: 20e12, min_dispatch_ratio: 5.0, }
}
}
impl BackendSelector {
pub fn new() -> Self {
Self::default()
}
pub fn with_pcie_bandwidth(mut self, bandwidth: f64) -> Self {
assert!(bandwidth > 0.0, "PCIe bandwidth must be > 0");
self.pcie_bandwidth = bandwidth;
self
}
pub fn with_gpu_gflops(mut self, gflops: f64) -> Self {
assert!(gflops > 0.0, "GPU GFLOPS must be > 0");
self.gpu_gflops = gflops;
self
}
pub fn with_min_dispatch_ratio(mut self, ratio: f64) -> Self {
self.min_dispatch_ratio = ratio;
self
}
pub fn select_backend(&self, data_bytes: usize, flops: u64) -> Backend {
let transfer_s = data_bytes as f64 / self.pcie_bandwidth;
let compute_s = flops as f64 / self.gpu_gflops;
if compute_s > self.min_dispatch_ratio * transfer_s {
Backend::GPU
} else {
Backend::SIMD
}
}
pub fn select_for_matmul(&self, m: usize, n: usize, k: usize) -> Backend {
let data_bytes = (m * k + k * n + m * n) * 4;
let flops = (2 * m * n * k) as u64;
self.select_backend(data_bytes, flops)
}
pub fn select_for_vector_op(&self, n: usize, ops_per_element: u64) -> Backend {
let data_bytes = n * 3 * 4;
let flops = n as u64 * ops_per_element;
self.select_backend(data_bytes, flops)
}
pub fn select_for_elementwise(&self, n: usize) -> Backend {
if n > 1_000_000 {
Backend::SIMD
} else {
Backend::Scalar
}
}
pub fn select_with_moe(&self, complexity: OpComplexity, data_size: usize) -> Backend {
match complexity {
OpComplexity::Low => {
if data_size > 1_000_000 {
Backend::SIMD
} else {
Backend::Scalar
}
}
OpComplexity::Medium => {
if data_size > 100_000 {
Backend::GPU
} else if data_size > 10_000 {
Backend::SIMD
} else {
Backend::Scalar
}
}
OpComplexity::High => {
if data_size > 10_000 {
Backend::GPU
} else if data_size > 1_000 {
Backend::SIMD
} else {
Backend::Scalar
}
}
}
}
#[cfg(feature = "trueno-integration")]
pub fn to_trueno_backend(backend: Backend) -> trueno::Backend {
match backend {
Backend::Scalar => trueno::Backend::Scalar,
Backend::SIMD => trueno::Backend::Auto, Backend::GPU => trueno::Backend::GPU,
}
}
#[cfg(feature = "trueno-integration")]
pub fn vector_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, String> {
if a.len() != b.len() {
return Err("Vector lengths must match".to_string());
}
let _backend = self.select_with_moe(OpComplexity::Low, a.len());
let vec_a: Vector<f32> = Vector::from_slice(a);
let vec_b: Vector<f32> = Vector::from_slice(b);
match vec_a.add(&vec_b) {
Ok(result) => Ok(result.as_slice().to_vec()),
Err(e) => Err(format!("Trueno error: {}", e)),
}
}
#[cfg(feature = "trueno-integration")]
pub fn matrix_multiply(
&self,
a: &[f32],
b: &[f32],
m: usize,
n: usize,
k: usize,
) -> Result<Vec<f32>, String> {
if a.len() != m * k {
return Err(format!("Matrix A size mismatch: expected {}, got {}", m * k, a.len()));
}
if b.len() != k * n {
return Err(format!("Matrix B size mismatch: expected {}, got {}", k * n, b.len()));
}
let _backend = self.select_for_matmul(m, n, k);
let mat_a: Matrix<f32> = match Matrix::from_vec(m, k, a.to_vec()) {
Ok(m) => m,
Err(e) => return Err(format!("Trueno error creating matrix A: {}", e)),
};
let mat_b: Matrix<f32> = match Matrix::from_vec(k, n, b.to_vec()) {
Ok(m) => m,
Err(e) => return Err(format!("Trueno error creating matrix B: {}", e)),
};
match mat_a.matmul(&mat_b) {
Ok(result) => Ok(result.as_slice().to_vec()),
Err(e) => Err(format!("Trueno error in matmul: {}", e)),
}
}
}
#[cfg(test)]
#[path = "backend_tests.rs"]
mod tests;