use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BackendError {
Unsupported(String),
DeviceError(String),
InvalidArgument(String),
OutOfMemory,
NotInitialized,
}
impl fmt::Display for BackendError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unsupported(msg) => write!(f, "unsupported operation: {msg}"),
Self::DeviceError(msg) => write!(f, "device error: {msg}"),
Self::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
Self::OutOfMemory => write!(f, "out of device memory"),
Self::NotInitialized => write!(f, "backend not initialized"),
}
}
}
impl std::error::Error for BackendError {}
pub type BackendResult<T> = Result<T, BackendError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BackendTranspose {
NoTrans,
Trans,
ConjTrans,
}
impl fmt::Display for BackendTranspose {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoTrans => write!(f, "N"),
Self::Trans => write!(f, "T"),
Self::ConjTrans => write!(f, "C"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReduceOp {
Sum,
Max,
Min,
Mean,
}
impl fmt::Display for ReduceOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Sum => write!(f, "sum"),
Self::Max => write!(f, "max"),
Self::Min => write!(f, "min"),
Self::Mean => write!(f, "mean"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnaryOp {
Relu,
Sigmoid,
Tanh,
Exp,
Log,
Sqrt,
Abs,
Neg,
}
impl fmt::Display for UnaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Relu => write!(f, "relu"),
Self::Sigmoid => write!(f, "sigmoid"),
Self::Tanh => write!(f, "tanh"),
Self::Exp => write!(f, "exp"),
Self::Log => write!(f, "log"),
Self::Sqrt => write!(f, "sqrt"),
Self::Abs => write!(f, "abs"),
Self::Neg => write!(f, "neg"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BinaryOp {
Add,
Sub,
Mul,
Div,
Max,
Min,
}
impl fmt::Display for BinaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Add => write!(f, "add"),
Self::Sub => write!(f, "sub"),
Self::Mul => write!(f, "mul"),
Self::Div => write!(f, "div"),
Self::Max => write!(f, "max"),
Self::Min => write!(f, "min"),
}
}
}
pub trait ComputeBackend: Send + Sync + fmt::Debug {
fn name(&self) -> &str;
fn init(&mut self) -> BackendResult<()>;
fn is_initialized(&self) -> bool;
#[allow(clippy::too_many_arguments)]
fn gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
b_ptr: u64,
ldb: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
) -> BackendResult<()>;
#[allow(clippy::too_many_arguments)]
fn conv2d_forward(
&self,
input_ptr: u64,
input_shape: &[usize],
filter_ptr: u64,
filter_shape: &[usize],
output_ptr: u64,
output_shape: &[usize],
stride: &[usize],
padding: &[usize],
) -> BackendResult<()>;
#[allow(clippy::too_many_arguments)]
fn attention(
&self,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
o_ptr: u64,
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
scale: f64,
causal: bool,
) -> BackendResult<()>;
fn reduce(
&self,
op: ReduceOp,
input_ptr: u64,
output_ptr: u64,
shape: &[usize],
axis: usize,
) -> BackendResult<()>;
fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()>;
fn binary(
&self,
op: BinaryOp,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()>;
#[allow(clippy::too_many_arguments)]
fn batched_gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
stride_a: usize,
b_ptr: u64,
ldb: usize,
stride_b: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
stride_c: usize,
batch_count: usize,
) -> BackendResult<()> {
let elem_bytes: u64 = 4; for b in 0..batch_count {
let b64 = b as u64;
self.gemm(
trans_a,
trans_b,
m,
n,
k,
alpha,
a_ptr + b64 * stride_a as u64 * elem_bytes,
lda,
b_ptr + b64 * stride_b as u64 * elem_bytes,
ldb,
beta,
c_ptr + b64 * stride_c as u64 * elem_bytes,
ldc,
)?;
}
Ok(())
}
fn synchronize(&self) -> BackendResult<()>;
fn alloc(&self, bytes: usize) -> BackendResult<u64>;
fn free(&self, ptr: u64) -> BackendResult<()>;
fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>;
fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backend_error_display() {
assert_eq!(
BackendError::Unsupported("foo".into()).to_string(),
"unsupported operation: foo"
);
assert_eq!(
BackendError::DeviceError("bar".into()).to_string(),
"device error: bar"
);
assert_eq!(
BackendError::InvalidArgument("baz".into()).to_string(),
"invalid argument: baz"
);
assert_eq!(
BackendError::OutOfMemory.to_string(),
"out of device memory"
);
assert_eq!(
BackendError::NotInitialized.to_string(),
"backend not initialized"
);
}
#[test]
fn backend_error_is_std_error() {
let err: Box<dyn std::error::Error> = Box::new(BackendError::DeviceError("test".into()));
assert!(err.to_string().contains("test"));
}
#[test]
fn backend_transpose_display_and_values() {
assert_eq!(BackendTranspose::NoTrans.to_string(), "N");
assert_eq!(BackendTranspose::Trans.to_string(), "T");
assert_eq!(BackendTranspose::ConjTrans.to_string(), "C");
assert_eq!(BackendTranspose::NoTrans, BackendTranspose::NoTrans);
assert_ne!(BackendTranspose::NoTrans, BackendTranspose::Trans);
}
#[test]
fn reduce_op_display_and_coverage() {
let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
let names = ["sum", "max", "min", "mean"];
for (op, name) in ops.iter().zip(names.iter()) {
assert_eq!(op.to_string(), *name);
}
}
#[test]
fn unary_op_display_and_coverage() {
let ops = [
UnaryOp::Relu,
UnaryOp::Sigmoid,
UnaryOp::Tanh,
UnaryOp::Exp,
UnaryOp::Log,
UnaryOp::Sqrt,
UnaryOp::Abs,
UnaryOp::Neg,
];
let names = [
"relu", "sigmoid", "tanh", "exp", "log", "sqrt", "abs", "neg",
];
for (op, name) in ops.iter().zip(names.iter()) {
assert_eq!(op.to_string(), *name);
}
}
#[test]
fn binary_op_display_and_coverage() {
let ops = [
BinaryOp::Add,
BinaryOp::Sub,
BinaryOp::Mul,
BinaryOp::Div,
BinaryOp::Max,
BinaryOp::Min,
];
let names = ["add", "sub", "mul", "div", "max", "min"];
for (op, name) in ops.iter().zip(names.iter()) {
assert_eq!(op.to_string(), *name);
}
}
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
struct MockBackend {
gemm_call_count: AtomicUsize,
}
impl MockBackend {
fn new() -> Self {
Self {
gemm_call_count: AtomicUsize::new(0),
}
}
}
impl ComputeBackend for MockBackend {
fn name(&self) -> &str {
"mock"
}
fn init(&mut self) -> BackendResult<()> {
Ok(())
}
fn is_initialized(&self) -> bool {
true
}
fn gemm(
&self,
_trans_a: BackendTranspose,
_trans_b: BackendTranspose,
_m: usize,
_n: usize,
_k: usize,
_alpha: f64,
_a_ptr: u64,
_lda: usize,
_b_ptr: u64,
_ldb: usize,
_beta: f64,
_c_ptr: u64,
_ldc: usize,
) -> BackendResult<()> {
self.gemm_call_count.fetch_add(1, Ordering::Relaxed);
Ok(())
}
fn conv2d_forward(
&self,
_: u64,
_: &[usize],
_: u64,
_: &[usize],
_: u64,
_: &[usize],
_: &[usize],
_: &[usize],
) -> BackendResult<()> {
Ok(())
}
fn attention(
&self,
_: u64,
_: u64,
_: u64,
_: u64,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: f64,
_: bool,
) -> BackendResult<()> {
Ok(())
}
fn reduce(&self, _: ReduceOp, _: u64, _: u64, _: &[usize], _: usize) -> BackendResult<()> {
Ok(())
}
fn unary(&self, _: UnaryOp, _: u64, _: u64, _: usize) -> BackendResult<()> {
Ok(())
}
fn binary(&self, _: BinaryOp, _: u64, _: u64, _: u64, _: usize) -> BackendResult<()> {
Ok(())
}
fn synchronize(&self) -> BackendResult<()> {
Ok(())
}
fn alloc(&self, _: usize) -> BackendResult<u64> {
Ok(0)
}
fn free(&self, _: u64) -> BackendResult<()> {
Ok(())
}
fn copy_htod(&self, _: u64, _: &[u8]) -> BackendResult<()> {
Ok(())
}
fn copy_dtoh(&self, _: &mut [u8], _: u64) -> BackendResult<()> {
Ok(())
}
}
#[test]
fn batched_gemm_zero_batch_is_noop() {
let backend = MockBackend::new();
let result = backend.batched_gemm(
BackendTranspose::NoTrans,
BackendTranspose::NoTrans,
4,
4,
4,
1.0,
0,
4,
16,
0,
4,
16,
0.0,
0,
4,
16,
0, );
assert!(result.is_ok());
assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), 0);
}
#[test]
fn batched_gemm_default_calls_gemm_n_times() {
let backend = MockBackend::new();
let batch_count = 7;
let result = backend.batched_gemm(
BackendTranspose::NoTrans,
BackendTranspose::Trans,
8,
8,
8,
1.0,
1000,
8,
64,
2000,
8,
64,
0.0,
3000,
8,
64,
batch_count,
);
assert!(result.is_ok());
assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), batch_count);
}
#[test]
fn batched_gemm_single_batch() {
let backend = MockBackend::new();
let result = backend.batched_gemm(
BackendTranspose::NoTrans,
BackendTranspose::NoTrans,
16,
16,
16,
1.0,
0,
16,
256,
0,
16,
256,
1.0,
0,
16,
256,
1,
);
assert!(result.is_ok());
assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), 1);
}
#[test]
fn enum_clone_and_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(ReduceOp::Sum);
set.insert(ReduceOp::Max);
assert!(set.contains(&ReduceOp::Sum));
assert!(!set.contains(&ReduceOp::Min));
let op = UnaryOp::Relu;
let cloned = op;
assert_eq!(op, cloned);
let bop = BinaryOp::Add;
let bcloned = bop;
assert_eq!(bop, bcloned);
let trans = BackendTranspose::ConjTrans;
let tcloned = trans;
assert_eq!(trans, tcloned);
}
}