use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BackendError {
Unsupported(String),
DeviceError(String),
InvalidArgument(String),
OutOfMemory,
NotInitialized,
}
impl fmt::Display for BackendError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unsupported(msg) => write!(f, "unsupported operation: {msg}"),
Self::DeviceError(msg) => write!(f, "device error: {msg}"),
Self::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
Self::OutOfMemory => write!(f, "out of device memory"),
Self::NotInitialized => write!(f, "backend not initialized"),
}
}
}
impl std::error::Error for BackendError {}
pub type BackendResult<T> = Result<T, BackendError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BackendTranspose {
NoTrans,
Trans,
ConjTrans,
}
impl fmt::Display for BackendTranspose {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoTrans => write!(f, "N"),
Self::Trans => write!(f, "T"),
Self::ConjTrans => write!(f, "C"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReduceOp {
Sum,
Max,
Min,
Mean,
}
impl fmt::Display for ReduceOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Sum => write!(f, "sum"),
Self::Max => write!(f, "max"),
Self::Min => write!(f, "min"),
Self::Mean => write!(f, "mean"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnaryOp {
Relu,
Sigmoid,
Tanh,
Exp,
Log,
Sqrt,
Abs,
Neg,
}
impl fmt::Display for UnaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Relu => write!(f, "relu"),
Self::Sigmoid => write!(f, "sigmoid"),
Self::Tanh => write!(f, "tanh"),
Self::Exp => write!(f, "exp"),
Self::Log => write!(f, "log"),
Self::Sqrt => write!(f, "sqrt"),
Self::Abs => write!(f, "abs"),
Self::Neg => write!(f, "neg"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BinaryOp {
Add,
Sub,
Mul,
Div,
Max,
Min,
}
impl fmt::Display for BinaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Add => write!(f, "add"),
Self::Sub => write!(f, "sub"),
Self::Mul => write!(f, "mul"),
Self::Div => write!(f, "div"),
Self::Max => write!(f, "max"),
Self::Min => write!(f, "min"),
}
}
}
pub trait ComputeBackend: Send + Sync + fmt::Debug {
fn name(&self) -> &str;
fn init(&mut self) -> BackendResult<()>;
fn is_initialized(&self) -> bool;
#[allow(clippy::too_many_arguments)]
fn gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
b_ptr: u64,
ldb: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
) -> BackendResult<()>;
#[allow(clippy::too_many_arguments)]
fn conv2d_forward(
&self,
input_ptr: u64,
input_shape: &[usize],
filter_ptr: u64,
filter_shape: &[usize],
output_ptr: u64,
output_shape: &[usize],
stride: &[usize],
padding: &[usize],
) -> BackendResult<()>;
#[allow(clippy::too_many_arguments)]
fn attention(
&self,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
o_ptr: u64,
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
scale: f64,
causal: bool,
) -> BackendResult<()>;
fn reduce(
&self,
op: ReduceOp,
input_ptr: u64,
output_ptr: u64,
shape: &[usize],
axis: usize,
) -> BackendResult<()>;
fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()>;
fn binary(
&self,
op: BinaryOp,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()>;
fn synchronize(&self) -> BackendResult<()>;
fn alloc(&self, bytes: usize) -> BackendResult<u64>;
fn free(&self, ptr: u64) -> BackendResult<()>;
fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>;
fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backend_error_display() {
assert_eq!(
BackendError::Unsupported("foo".into()).to_string(),
"unsupported operation: foo"
);
assert_eq!(
BackendError::DeviceError("bar".into()).to_string(),
"device error: bar"
);
assert_eq!(
BackendError::InvalidArgument("baz".into()).to_string(),
"invalid argument: baz"
);
assert_eq!(
BackendError::OutOfMemory.to_string(),
"out of device memory"
);
assert_eq!(
BackendError::NotInitialized.to_string(),
"backend not initialized"
);
}
#[test]
fn backend_error_is_std_error() {
let err: Box<dyn std::error::Error> = Box::new(BackendError::DeviceError("test".into()));
assert!(err.to_string().contains("test"));
}
#[test]
fn backend_transpose_display_and_values() {
assert_eq!(BackendTranspose::NoTrans.to_string(), "N");
assert_eq!(BackendTranspose::Trans.to_string(), "T");
assert_eq!(BackendTranspose::ConjTrans.to_string(), "C");
assert_eq!(BackendTranspose::NoTrans, BackendTranspose::NoTrans);
assert_ne!(BackendTranspose::NoTrans, BackendTranspose::Trans);
}
#[test]
fn reduce_op_display_and_coverage() {
let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
let names = ["sum", "max", "min", "mean"];
for (op, name) in ops.iter().zip(names.iter()) {
assert_eq!(op.to_string(), *name);
}
}
#[test]
fn unary_op_display_and_coverage() {
let ops = [
UnaryOp::Relu,
UnaryOp::Sigmoid,
UnaryOp::Tanh,
UnaryOp::Exp,
UnaryOp::Log,
UnaryOp::Sqrt,
UnaryOp::Abs,
UnaryOp::Neg,
];
let names = [
"relu", "sigmoid", "tanh", "exp", "log", "sqrt", "abs", "neg",
];
for (op, name) in ops.iter().zip(names.iter()) {
assert_eq!(op.to_string(), *name);
}
}
#[test]
fn binary_op_display_and_coverage() {
let ops = [
BinaryOp::Add,
BinaryOp::Sub,
BinaryOp::Mul,
BinaryOp::Div,
BinaryOp::Max,
BinaryOp::Min,
];
let names = ["add", "sub", "mul", "div", "max", "min"];
for (op, name) in ops.iter().zip(names.iter()) {
assert_eq!(op.to_string(), *name);
}
}
#[test]
fn enum_clone_and_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(ReduceOp::Sum);
set.insert(ReduceOp::Max);
assert!(set.contains(&ReduceOp::Sum));
assert!(!set.contains(&ReduceOp::Min));
let op = UnaryOp::Relu;
let cloned = op;
assert_eq!(op, cloned);
let bop = BinaryOp::Add;
let bcloned = bop;
assert_eq!(bop, bcloned);
let trans = BackendTranspose::ConjTrans;
let tcloned = trans;
assert_eq!(trans, tcloned);
}
}