mod error;
pub mod cpu;
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(feature = "vulkan")]
pub mod vulkan;
pub use error::BackendError;
use crate::tensor::{DType, Tensor};
pub type BackendResult<T> = Result<T, BackendError>;
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn is_available(&self) -> bool;
fn alloc(&self, shape: &[usize], dtype: DType) -> BackendResult<Tensor>;
fn copy_to(&self, tensor: &Tensor) -> BackendResult<Tensor>;
fn add(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn mul(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn scale(&self, a: &Tensor, scalar: f32, out: &mut Tensor) -> BackendResult<()>;
fn silu(&self, x: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn gelu(&self, x: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn softmax(&self, x: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn rms_norm(
&self,
x: &Tensor,
weight: &Tensor,
eps: f32,
out: &mut Tensor,
) -> BackendResult<()>;
fn matmul(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn matvec(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn vec_mat(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn dequantize(&self, src: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn matvec_q(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn vec_mat_q(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn rope(
&self,
q: &mut Tensor,
k: &mut Tensor,
pos: usize,
freq_base: f32,
freq_scale: f32,
use_neox: bool,
) -> BackendResult<()>;
fn attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
out: &mut Tensor,
scale: f32,
) -> BackendResult<()>;
fn flash_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
out: &mut Tensor,
scale: f32,
_causal: bool,
) -> BackendResult<()> {
self.attention(q, k, v, out, scale)
}
}
pub fn default_backend() -> Box<dyn Backend> {
Box::new(cpu::CpuBackend::new())
}