pub mod cpu;
#[cfg(feature = "cuda")]
pub mod cuda;
pub mod dx12;
mod error;
pub mod tensor_parallel;
pub mod metal;
#[cfg(feature = "vulkan")]
pub mod vulkan;
#[cfg(feature = "hailo")]
pub mod hailo;
pub use error::BackendError;
use crate::tensor::{DType, Tensor};
pub type BackendResult<T> = Result<T, BackendError>;
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn is_available(&self) -> bool;
fn alloc(&self, shape: &[usize], dtype: DType) -> BackendResult<Tensor>;
fn copy_to(&self, tensor: &Tensor) -> BackendResult<Tensor>;
fn add(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn mul(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn scale(&self, a: &Tensor, scalar: f32, out: &mut Tensor) -> BackendResult<()>;
fn silu(&self, x: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn gelu(&self, x: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn softmax(&self, x: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn rms_norm(
&self,
x: &Tensor,
weight: &Tensor,
eps: f32,
out: &mut Tensor,
) -> BackendResult<()>;
fn matmul(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn matvec(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn vec_mat(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn dequantize(&self, src: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn matvec_q(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn vec_mat_q(&self, a: &Tensor, b: &Tensor, out: &mut Tensor) -> BackendResult<()>;
fn rope(
&self,
q: &mut Tensor,
k: &mut Tensor,
pos: usize,
freq_base: f32,
freq_scale: f32,
use_neox: bool,
) -> BackendResult<()>;
fn attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
out: &mut Tensor,
scale: f32,
) -> BackendResult<()>;
fn flash_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
out: &mut Tensor,
scale: f32,
_causal: bool,
) -> BackendResult<()> {
self.attention(q, k, v, out, scale)
}
fn attention_cached(
&self,
q: &Tensor,
k_cache: &Tensor,
v_cache: &Tensor,
out: &mut Tensor,
scale: f32,
kv_len: usize,
) -> BackendResult<()> {
let num_kv_heads = k_cache.shape()[0];
let max_seq_len = k_cache.shape()[1];
let head_dim = k_cache.shape()[2];
let mut k_contig = Tensor::zeros(vec![num_kv_heads, kv_len, head_dim], DType::F32);
let mut v_contig = Tensor::zeros(vec![num_kv_heads, kv_len, head_dim], DType::F32);
{
let k_src = k_cache.as_f32()?;
let k_dst = k_contig.as_f32_mut()?;
for h in 0..num_kv_heads {
for p in 0..kv_len {
let src_off = h * max_seq_len * head_dim + p * head_dim;
let dst_off = h * kv_len * head_dim + p * head_dim;
k_dst[dst_off..dst_off + head_dim]
.copy_from_slice(&k_src[src_off..src_off + head_dim]);
}
}
}
{
let v_src = v_cache.as_f32()?;
let v_dst = v_contig.as_f32_mut()?;
for h in 0..num_kv_heads {
for p in 0..kv_len {
let src_off = h * max_seq_len * head_dim + p * head_dim;
let dst_off = h * kv_len * head_dim + p * head_dim;
v_dst[dst_off..dst_off + head_dim]
.copy_from_slice(&v_src[src_off..src_off + head_dim]);
}
}
}
self.attention(q, &k_contig, &v_contig, out, scale)
}
fn attention_turboquant(
&self,
queries: &[f32],
tq_cache: &crate::model::kv_turboquant::TurboQuantKVCache,
layer_idx: usize,
num_heads: usize,
scale: f32,
) -> BackendResult<Vec<f32>> {
Ok(tq_cache.attention_layer(layer_idx, queries, num_heads, scale))
}
}
pub fn default_backend() -> Box<dyn Backend> {
Box::new(cpu::CpuBackend::new())
}
pub trait GpuInference: Send {
fn forward(&mut self, token_id: u32) -> BackendResult<Vec<f32>>;
fn prefill_token(&mut self, token_id: u32) -> BackendResult<()>;
fn reset(&mut self);
fn position(&self) -> usize;
}
pub struct GpuModelWrapper<T: GpuInference> {
gpu: std::sync::Mutex<T>,
config: crate::model::ModelConfig,
architecture: crate::model::Architecture,
}
impl<T: GpuInference> GpuModelWrapper<T> {
pub fn new(
gpu: T,
config: crate::model::ModelConfig,
architecture: crate::model::Architecture,
) -> Self {
Self {
gpu: std::sync::Mutex::new(gpu),
config,
architecture,
}
}
}
impl<T: GpuInference + 'static> crate::model::Model for GpuModelWrapper<T> {
fn forward(
&self,
tokens: &[u32],
ctx: &mut crate::model::InferenceContext,
) -> crate::model::ModelResult<crate::tensor::Tensor> {
let mut gpu = self.gpu.lock().map_err(|e| {
crate::model::ModelError::ConfigError(format!(
"GPU inference lock poisoned: {}", e
))
})?;
if ctx.position == 0 && gpu.position() > 0 {
gpu.reset();
}
if tokens.is_empty() {
return Err(crate::model::ModelError::ConfigError(
"No tokens to process".into(),
));
}
let last_idx = tokens.len() - 1;
for &token in &tokens[..last_idx] {
gpu.prefill_token(token)?;
}
let logits_vec = gpu.forward(tokens[last_idx])?;
ctx.position += tokens.len();
ctx.kv_cache.seq_len = ctx.position;
crate::tensor::Tensor::from_f32(&logits_vec, vec![logits_vec.len()])
.map_err(|e| e.into())
}
fn config(&self) -> &crate::model::ModelConfig {
&self.config
}
fn architecture(&self) -> crate::model::Architecture {
self.architecture
}
}