#[cfg(feature = "cpu-accelerate")]
extern crate accelerate_src;
#[cfg(feature = "cpu")]
extern crate blas_src;
#[cfg(all(feature = "cpu", not(feature = "cpu-accelerate")))]
extern crate openblas_src;
use std::path::Path;
use std::sync::Arc;
use safetensors::SafeTensors;
use super::{BatchInputs, Driver};
use crate::backend::Encoding;
use crate::backend::arch::classic_bert::{
ClassicBertArch, ClassicBertLayerWeights, ClassicBertWeights,
};
use crate::backend::arch::modern_bert::{
ModernBertArch, ModernBertLayerWeights, ModernBertWeights, RopeCache,
};
pub enum MmapTensor {
Owned(Vec<f32>),
Mapped {
mmap: Arc<memmap2::Mmap>,
offset: usize,
len: usize,
},
}
impl std::ops::Deref for MmapTensor {
type Target = [f32];
fn deref(&self) -> &[f32] {
match self {
Self::Owned(v) => v,
#[expect(unsafe_code, reason = "reinterpret aligned mmap bytes as f32 slice")]
Self::Mapped { mmap, offset, len } => {
let bytes = &mmap[*offset..*offset + *len * 4];
#[expect(
clippy::cast_ptr_alignment,
reason = "alignment verified in load_tensor_mmap before constructing Mapped"
)]
unsafe {
std::slice::from_raw_parts(bytes.as_ptr().cast::<f32>(), *len)
}
}
}
}
}
impl std::ops::DerefMut for MmapTensor {
fn deref_mut(&mut self) -> &mut [f32] {
match self {
Self::Owned(v) => v,
Self::Mapped { .. } => {
panic!("cannot mutate a memory-mapped tensor — only Owned tensors are mutable")
}
}
}
}
impl MmapTensor {
fn resize(&mut self, new_len: usize, value: f32) {
match self {
Self::Owned(v) => v.resize(new_len, value),
Self::Mapped { .. } => panic!("cannot resize a memory-mapped tensor"),
}
}
fn clear(&mut self) {
match self {
Self::Owned(v) => v.clear(),
Self::Mapped { .. } => panic!("cannot clear a memory-mapped tensor"),
}
}
fn extend<I: IntoIterator<Item = f32>>(&mut self, iter: I) {
match self {
Self::Owned(v) => v.extend(iter),
Self::Mapped { .. } => panic!("cannot extend a memory-mapped tensor"),
}
}
}
impl From<Vec<f32>> for MmapTensor {
fn from(v: Vec<f32>) -> Self {
Self::Owned(v)
}
}
#[expect(unsafe_code, reason = "MmapTensor components are all Send + Sync")]
unsafe impl Send for MmapTensor {}
#[expect(unsafe_code, reason = "MmapTensor components are all Send + Sync")]
unsafe impl Sync for MmapTensor {}
pub struct CpuDriver;
impl CpuDriver {
pub fn new() -> crate::Result<Self> {
Ok(Self)
}
}
#[expect(unsafe_code, reason = "FFI call to system BLAS thread control API")]
pub fn force_single_threaded_blas() {
#[cfg(target_os = "macos")]
{
const BLAS_THREADING_SINGLE_THREADED: std::ffi::c_uint = 1;
type BLASSetThreadingFn = unsafe extern "C" fn(std::ffi::c_uint) -> std::ffi::c_int;
let sym = unsafe { libc::dlsym(libc::RTLD_DEFAULT, c"BLASSetThreading".as_ptr()) };
if !sym.is_null() {
let func: BLASSetThreadingFn = unsafe { std::mem::transmute(sym) };
unsafe { func(BLAS_THREADING_SINGLE_THREADED) };
}
}
#[cfg(all(not(target_os = "macos"), feature = "cpu"))]
{
unsafe extern "C" {
fn openblas_set_num_threads(num: std::ffi::c_int);
}
unsafe {
openblas_set_num_threads(1);
}
}
}
pub struct ClassicBertConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub layer_norm_eps: f32,
pub max_position_embeddings: usize,
pub vocab_size: usize,
}
impl ClassicBertConfig {
#[expect(
clippy::cast_possible_truncation,
reason = "config ints are small positive values"
)]
pub fn from_json(json: &serde_json::Value) -> crate::Result<Self> {
let get_usize = |key: &str| -> crate::Result<usize> {
json.get(key)
.and_then(serde_json::Value::as_u64)
.map(|v| v as usize)
.ok_or_else(|| crate::Error::Cpu(format!("config.json missing or invalid: {key}")))
};
let get_f64 = |key: &str| -> crate::Result<f64> {
json.get(key)
.and_then(serde_json::Value::as_f64)
.ok_or_else(|| crate::Error::Cpu(format!("config.json missing or invalid: {key}")))
};
Ok(Self {
hidden_size: get_usize("hidden_size")?,
intermediate_size: get_usize("intermediate_size")?,
num_hidden_layers: get_usize("num_hidden_layers")?,
num_attention_heads: get_usize("num_attention_heads")?,
layer_norm_eps: get_f64("layer_norm_eps")
.or_else(|_| get_f64("layer_norm_epsilon"))
.unwrap_or(1e-12) as f32,
max_position_embeddings: get_usize("max_position_embeddings").unwrap_or(512),
vocab_size: get_usize("vocab_size")?,
})
}
}
fn load_tensor_flat(tensors: &SafeTensors<'_>, name: &str) -> crate::Result<Vec<f32>> {
let tensor = tensors
.tensor(name)
.map_err(|_| crate::Error::Cpu(format!("missing weight: {name}")))?;
let data: Vec<f32> = tensor
.data()
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
Ok(data)
}
fn load_tensor_mmap(
tensors: &SafeTensors<'_>,
name: &str,
mmap: &Arc<memmap2::Mmap>,
) -> crate::Result<MmapTensor> {
let tensor = tensors
.tensor(name)
.map_err(|_| crate::Error::Cpu(format!("missing weight: {name}")))?;
let data = tensor.data();
let ptr = data.as_ptr();
let byte_len = data.len();
let float_len = byte_len / 4;
if (ptr as usize).is_multiple_of(4) {
let mmap_start = mmap.as_ptr() as usize;
let tensor_start = ptr as usize;
let offset = tensor_start - mmap_start;
Ok(MmapTensor::Mapped {
mmap: Arc::clone(mmap),
offset,
len: float_len,
})
} else {
let data: Vec<f32> = data
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
Ok(MmapTensor::Owned(data))
}
}
impl CpuDriver {
#[expect(unsafe_code, reason = "memmap2::Mmap::map requires unsafe")]
#[expect(
clippy::too_many_lines,
reason = "weight loading with QKV fusion is inherently sequential"
)]
pub fn load_classic_bert_weights(
&self,
weights_path: &Path,
config: &ClassicBertConfig,
) -> crate::Result<(ClassicBertArch<MmapTensor>, Arc<memmap2::Mmap>)> {
let file = std::fs::File::open(weights_path).map_err(|e| crate::Error::Io {
path: weights_path.display().to_string(),
source: e,
})?;
let mmap =
Arc::new(
unsafe { memmap2::Mmap::map(&file) }.map_err(|e| crate::Error::Io {
path: weights_path.display().to_string(),
source: e,
})?,
);
let tensors = SafeTensors::deserialize(&mmap)
.map_err(|e| crate::Error::Cpu(format!("safetensors parse: {e}")))?;
let hidden = config.hidden_size;
let num_layers = config.num_hidden_layers;
let num_heads = config.num_attention_heads;
let head_dim = hidden / num_heads;
let intermediate = config.intermediate_size;
let mut layers = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let prefix = format!("encoder.layer.{i}");
let q_w = load_tensor_flat(&tensors, &format!("{prefix}.attention.self.query.weight"))?;
let k_w = load_tensor_flat(&tensors, &format!("{prefix}.attention.self.key.weight"))?;
let v_w = load_tensor_flat(&tensors, &format!("{prefix}.attention.self.value.weight"))?;
let q_b = load_tensor_flat(&tensors, &format!("{prefix}.attention.self.query.bias"))?;
let k_b = load_tensor_flat(&tensors, &format!("{prefix}.attention.self.key.bias"))?;
let v_b = load_tensor_flat(&tensors, &format!("{prefix}.attention.self.value.bias"))?;
let mut fused_qkv_w = Vec::with_capacity(3 * hidden * hidden);
fused_qkv_w.extend_from_slice(&q_w);
fused_qkv_w.extend_from_slice(&k_w);
fused_qkv_w.extend_from_slice(&v_w);
let mut fused_qkv_b = Vec::with_capacity(3 * hidden);
fused_qkv_b.extend_from_slice(&q_b);
fused_qkv_b.extend_from_slice(&k_b);
fused_qkv_b.extend_from_slice(&v_b);
layers.push(ClassicBertLayerWeights {
qkv_weight: MmapTensor::Owned(fused_qkv_w),
qkv_bias: MmapTensor::Owned(fused_qkv_b),
output_weight: load_tensor_mmap(
&tensors,
&format!("{prefix}.attention.output.dense.weight"),
&mmap,
)?,
output_bias: load_tensor_mmap(
&tensors,
&format!("{prefix}.attention.output.dense.bias"),
&mmap,
)?,
output_ln_weight: load_tensor_mmap(
&tensors,
&format!("{prefix}.attention.output.LayerNorm.weight"),
&mmap,
)?,
output_ln_bias: load_tensor_mmap(
&tensors,
&format!("{prefix}.attention.output.LayerNorm.bias"),
&mmap,
)?,
ffn_inter_weight: load_tensor_mmap(
&tensors,
&format!("{prefix}.intermediate.dense.weight"),
&mmap,
)?,
ffn_inter_bias: load_tensor_mmap(
&tensors,
&format!("{prefix}.intermediate.dense.bias"),
&mmap,
)?,
ffn_out_weight: load_tensor_mmap(
&tensors,
&format!("{prefix}.output.dense.weight"),
&mmap,
)?,
ffn_out_bias: load_tensor_mmap(
&tensors,
&format!("{prefix}.output.dense.bias"),
&mmap,
)?,
ffn_ln_weight: load_tensor_mmap(
&tensors,
&format!("{prefix}.output.LayerNorm.weight"),
&mmap,
)?,
ffn_ln_bias: load_tensor_mmap(
&tensors,
&format!("{prefix}.output.LayerNorm.bias"),
&mmap,
)?,
});
}
let weights = ClassicBertWeights {
word_embeddings: load_tensor_mmap(
&tensors,
"embeddings.word_embeddings.weight",
&mmap,
)?,
position_embeddings: load_tensor_mmap(
&tensors,
"embeddings.position_embeddings.weight",
&mmap,
)?,
token_type_embeddings: load_tensor_mmap(
&tensors,
"embeddings.token_type_embeddings.weight",
&mmap,
)?,
emb_ln_weight: load_tensor_mmap(&tensors, "embeddings.LayerNorm.weight", &mmap)?,
emb_ln_bias: load_tensor_mmap(&tensors, "embeddings.LayerNorm.bias", &mmap)?,
layers,
num_heads,
head_dim,
hidden_dim: hidden,
intermediate_dim: intermediate,
layer_norm_eps: config.layer_norm_eps,
};
Ok((ClassicBertArch { weights }, mmap))
}
pub fn load_modern_bert_weights(
&self,
weights_path: &Path,
config: &ModernBertConfig,
) -> crate::Result<(ModernBertArch<MmapTensor>, Arc<memmap2::Mmap>)> {
let file = std::fs::File::open(weights_path).map_err(|e| crate::Error::Io {
path: weights_path.display().to_string(),
source: e,
})?;
#[expect(unsafe_code, reason = "memmap2 requires unsafe for mmap")]
let mmap =
Arc::new(
unsafe { memmap2::Mmap::map(&file) }.map_err(|e| crate::Error::Io {
path: weights_path.display().to_string(),
source: e,
})?,
);
let tensors = SafeTensors::deserialize(&mmap)
.map_err(|e| crate::Error::Cpu(format!("safetensors parse: {e}")))?;
let hidden = config.hidden_size;
let num_layers = config.num_hidden_layers;
let num_heads = config.num_attention_heads;
let head_dim = hidden / num_heads;
let intermediate = config.intermediate_size;
let global_attn_every_n = config.global_attn_every_n_layers;
let mut layers = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let qkv_weight =
load_tensor_mmap(&tensors, &format!("layers.{i}.attn.Wqkv.weight"), &mmap)?;
let output_weight =
load_tensor_mmap(&tensors, &format!("layers.{i}.attn.Wo.weight"), &mmap)?;
let attn_norm_weight = if i == 0 {
None
} else {
Some(load_tensor_mmap(
&tensors,
&format!("layers.{i}.attn_norm.weight"),
&mmap,
)?)
};
let mlp_wi_weight =
load_tensor_mmap(&tensors, &format!("layers.{i}.mlp.Wi.weight"), &mmap)?;
let mlp_wo_weight =
load_tensor_mmap(&tensors, &format!("layers.{i}.mlp.Wo.weight"), &mmap)?;
let mlp_norm_weight =
load_tensor_mmap(&tensors, &format!("layers.{i}.mlp_norm.weight"), &mmap)?;
let is_global = i % global_attn_every_n == 0;
layers.push(ModernBertLayerWeights {
qkv_weight,
output_weight,
attn_norm_weight,
mlp_wi_weight,
mlp_wo_weight,
mlp_norm_weight,
is_global,
});
}
let tok_embeddings = load_tensor_mmap(&tensors, "embeddings.tok_embeddings.weight", &mmap)?;
let emb_norm_weight = load_tensor_mmap(&tensors, "embeddings.norm.weight", &mmap)?;
let final_norm_weight = load_tensor_mmap(&tensors, "final_norm.weight", &mmap)?;
let zero_bias = MmapTensor::Owned(vec![0.0f32; hidden]);
let weights = ModernBertWeights {
tok_embeddings,
emb_norm_weight,
final_norm_weight,
zero_bias,
layers,
num_heads,
head_dim,
hidden_dim: hidden,
intermediate_dim: intermediate,
layer_norm_eps: config.norm_eps,
local_window: config.local_attention,
};
let max_seq = config.max_position_embeddings;
let global_rope = build_rope_cache_cpu(head_dim, max_seq, config.global_rope_theta);
let local_rope = build_rope_cache_cpu(head_dim, max_seq, config.local_rope_theta);
let arch = ModernBertArch {
weights,
global_rope,
local_rope,
};
Ok((arch, mmap))
}
}
pub struct ModernBertConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub global_attn_every_n_layers: usize,
pub local_attention: usize,
pub global_rope_theta: f32,
pub local_rope_theta: f32,
pub norm_eps: f32,
pub max_position_embeddings: usize,
pub vocab_size: usize,
}
impl ModernBertConfig {
#[expect(
clippy::cast_possible_truncation,
reason = "HuggingFace config ints always fit in usize; f64 rope/eps values fit in f32"
)]
pub fn from_json(json: &serde_json::Value) -> crate::Result<Self> {
let get_usize = |key: &str| -> crate::Result<usize> {
json.get(key)
.and_then(serde_json::Value::as_u64)
.map(|v| v as usize)
.ok_or_else(|| crate::Error::Cpu(format!("config.json missing or invalid: {key}")))
};
let get_f64 = |key: &str| -> crate::Result<f64> {
json.get(key)
.and_then(serde_json::Value::as_f64)
.ok_or_else(|| crate::Error::Cpu(format!("config.json missing or invalid: {key}")))
};
Ok(Self {
hidden_size: get_usize("hidden_size")?,
intermediate_size: get_usize("intermediate_size")?,
num_hidden_layers: get_usize("num_hidden_layers")?,
num_attention_heads: get_usize("num_attention_heads")?,
global_attn_every_n_layers: get_usize("global_attn_every_n_layers")?,
local_attention: get_usize("local_attention")?,
global_rope_theta: get_f64("global_rope_theta")? as f32,
local_rope_theta: get_f64("local_rope_theta")? as f32,
norm_eps: get_f64("norm_eps").unwrap_or(1e-5) as f32,
max_position_embeddings: get_usize("max_position_embeddings")?,
vocab_size: get_usize("vocab_size")?,
})
}
}
fn build_rope_cache_cpu(head_dim: usize, max_seq: usize, theta: f32) -> RopeCache<MmapTensor> {
let half_dim = head_dim / 2;
let n = max_seq * half_dim;
let mut cos_data = Vec::with_capacity(n);
let mut sin_data = Vec::with_capacity(n);
for pos in 0..max_seq {
for d in 0..half_dim {
let freq = (pos as f32) / theta.powf(2.0 * d as f32 / head_dim as f32);
cos_data.push(freq.cos());
sin_data.push(freq.sin());
}
}
RopeCache {
cos: MmapTensor::Owned(cos_data),
sin: MmapTensor::Owned(sin_data),
}
}
fn gelu_scalar(x: f32) -> f32 {
x * 0.5 * (1.0 + ((2.0 / std::f32::consts::PI).sqrt() * (x + 0.044_715 * x * x * x)).tanh())
}
fn softmax_inplace(vals: &mut [f32]) {
let max = vals.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0_f32;
for v in vals.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
let inv_sum = 1.0 / sum;
for v in vals.iter_mut() {
*v *= inv_sum;
}
}
#[expect(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "dimension values and token IDs are small positive integers"
)]
impl Driver for CpuDriver {
type Tensor = MmapTensor;
fn new_for_clone() -> crate::Result<Self> {
Ok(Self)
}
fn alloc_zeros(&self, n: usize) -> crate::Result<MmapTensor> {
Ok(MmapTensor::Owned(vec![0.0; n]))
}
fn clone_tensor(&self, tensor: &MmapTensor, _n: usize) -> crate::Result<MmapTensor> {
Ok(MmapTensor::Owned(tensor.to_vec()))
}
fn prepare_batch(
&self,
encodings: &[Encoding],
max_seq: usize,
) -> crate::Result<BatchInputs<MmapTensor>> {
let batch = encodings.len();
let total = batch * max_seq;
let mut input_ids = vec![0.0_f32; total];
let mut token_type_ids = vec![0.0_f32; total];
let mut position_ids = vec![0.0_f32; total];
let mut attn_mask_int = vec![0.0_f32; total];
for (b, enc) in encodings.iter().enumerate() {
let seq_len = enc.input_ids.len();
let offset = b * max_seq;
for (i, &id) in enc.input_ids.iter().enumerate() {
input_ids[offset + i] = id as f32;
}
for (i, &id) in enc.token_type_ids.iter().enumerate() {
token_type_ids[offset + i] = id as f32;
}
for i in 0..seq_len {
position_ids[offset + i] = i as f32;
}
for (i, &m) in enc.attention_mask.iter().enumerate() {
attn_mask_int[offset + i] = m as f32;
}
}
let float_mask: Vec<f32> = attn_mask_int
.iter()
.map(|&m| if m > 0.5 { 0.0 } else { -1e9 })
.collect();
let pooling_mask: Vec<f32> = attn_mask_int
.iter()
.map(|&m| if m > 0.5 { 1.0 } else { 0.0 })
.collect();
let seq_lengths: Vec<usize> = encodings.iter().map(|e| e.input_ids.len()).collect();
let total_tokens: usize = seq_lengths.iter().sum();
Ok(BatchInputs {
input_ids: MmapTensor::Owned(input_ids),
attention_mask: MmapTensor::Owned(attn_mask_int),
token_type_ids: MmapTensor::Owned(token_type_ids),
position_ids: MmapTensor::Owned(position_ids),
float_mask: MmapTensor::Owned(float_mask),
pooling_mask: MmapTensor::Owned(pooling_mask),
batch,
max_seq,
total_tokens,
seq_lengths,
cu_seqlens: None, })
}
fn pad_to_batch(
&self,
flat: &MmapTensor,
padded: &mut MmapTensor,
seq_lengths: &[usize],
max_seq: usize,
dim: usize,
) -> crate::Result<()> {
let batch = seq_lengths.len();
padded.resize(batch * max_seq * dim, 0.0);
padded.fill(0.0);
let mut offset = 0;
for (b, &len) in seq_lengths.iter().enumerate() {
for t in 0..len {
let src = (offset + t) * dim;
let dst = (b * max_seq + t) * dim;
padded[dst..dst + dim].copy_from_slice(&flat[src..src + dim]);
}
offset += len;
}
Ok(())
}
fn unpad_from_batch(
&self,
padded: &MmapTensor,
flat: &mut MmapTensor,
seq_lengths: &[usize],
max_seq: usize,
dim: usize,
) -> crate::Result<()> {
let total_tokens: usize = seq_lengths.iter().sum();
flat.resize(total_tokens * dim, 0.0);
let mut offset = 0;
for (b, &len) in seq_lengths.iter().enumerate() {
for t in 0..len {
let src = (b * max_seq + t) * dim;
let dst = (offset + t) * dim;
flat[dst..dst + dim].copy_from_slice(&padded[src..src + dim]);
}
offset += len;
}
Ok(())
}
fn embedding_lookup(
&self,
word_ids: &MmapTensor,
embedding_table: &MmapTensor,
seq_len: usize,
hidden: usize,
) -> crate::Result<MmapTensor> {
let mut output = vec![0.0; seq_len * hidden];
for (i, &wid) in word_ids.iter().enumerate().take(seq_len) {
let id = wid as usize;
let src_start = id * hidden;
let dst_start = i * hidden;
output[dst_start..dst_start + hidden]
.copy_from_slice(&embedding_table[src_start..src_start + hidden]);
}
Ok(MmapTensor::Owned(output))
}
fn add_embeddings(
&self,
hidden: &mut MmapTensor,
table: &MmapTensor,
ids: &MmapTensor,
seq_len: usize,
hidden_dim: usize,
) -> crate::Result<()> {
for (i, &id_f) in ids.iter().enumerate().take(seq_len) {
let id = id_f as usize;
let tbl_start = id * hidden_dim;
let hid_start = i * hidden_dim;
for j in 0..hidden_dim {
hidden[hid_start + j] += table[tbl_start + j];
}
}
Ok(())
}
fn layer_norm(
&self,
output: &mut MmapTensor,
input: &MmapTensor,
weight: &MmapTensor,
bias: &MmapTensor,
rows: usize,
cols: usize,
eps: f32,
) -> crate::Result<()> {
output.resize(rows * cols, 0.0);
for r in 0..rows {
let base = r * cols;
let row = &input[base..base + cols];
let mean: f32 = row.iter().sum::<f32>() / cols as f32;
let var: f32 = row.iter().map(|x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
let inv_std = 1.0 / (var + eps).sqrt();
for c in 0..cols {
output[base + c] = (row[c] - mean) * inv_std * weight[c] + bias[c];
}
}
Ok(())
}
#[expect(
clippy::many_single_char_names,
reason = "a, b, m, n, k are standard GEMM parameter names from BLAS"
)]
fn gemm(
&self,
a: &MmapTensor,
b: &MmapTensor,
output: &mut MmapTensor,
m: usize,
n: usize,
k: usize,
transpose_b: bool,
) -> crate::Result<()> {
output.resize(m * n, 0.0);
let a_view = ndarray::ArrayView2::from_shape((m, k), a)
.map_err(|e| crate::Error::Cpu(format!("GEMM a shape error: {e}")))?;
if transpose_b {
let b_view = ndarray::ArrayView2::from_shape((n, k), b)
.map_err(|e| crate::Error::Cpu(format!("GEMM b shape error: {e}")))?;
let bt = b_view.t();
let mut c = ndarray::Array2::zeros((m, n));
ndarray::linalg::general_mat_mul(1.0, &a_view, &bt, 0.0, &mut c);
output.clear();
output.extend(c.iter().copied());
} else {
let b_view = ndarray::ArrayView2::from_shape((k, n), b)
.map_err(|e| crate::Error::Cpu(format!("GEMM b shape error: {e}")))?;
let mut c = ndarray::Array2::zeros((m, n));
ndarray::linalg::general_mat_mul(1.0, &a_view, &b_view, 0.0, &mut c);
output.clear();
output.extend(c.iter().copied());
}
Ok(())
}
#[expect(
clippy::many_single_char_names,
reason = "a, b, m, n, k are standard GEMM parameter names from BLAS"
)]
fn gemm_batched(
&self,
a: &MmapTensor,
b: &MmapTensor,
output: &mut MmapTensor,
m: usize,
n: usize,
k: usize,
transpose_b: bool,
stride_a: usize,
stride_b: usize,
stride_c: usize,
batch_count: usize,
) -> crate::Result<()> {
output.resize(batch_count * stride_c, 0.0);
for batch in 0..batch_count {
let a_off = batch * stride_a;
let b_off = batch * stride_b;
let c_off = batch * stride_c;
let a_slice = &a[a_off..a_off + m * k];
let b_slice = if transpose_b {
&b[b_off..b_off + n * k]
} else {
&b[b_off..b_off + k * n]
};
let a_view = ndarray::ArrayView2::from_shape((m, k), a_slice)
.map_err(|e| crate::Error::Cpu(format!("batched GEMM a shape: {e}")))?;
let mut c = ndarray::Array2::zeros((m, n));
if transpose_b {
let b_view = ndarray::ArrayView2::from_shape((n, k), b_slice)
.map_err(|e| crate::Error::Cpu(format!("batched GEMM b shape: {e}")))?;
ndarray::linalg::general_mat_mul(1.0, &a_view, &b_view.t(), 0.0, &mut c);
} else {
let b_view = ndarray::ArrayView2::from_shape((k, n), b_slice)
.map_err(|e| crate::Error::Cpu(format!("batched GEMM b shape: {e}")))?;
ndarray::linalg::general_mat_mul(1.0, &a_view, &b_view, 0.0, &mut c);
}
output[c_off..c_off + m * n].copy_from_slice(c.as_slice().unwrap());
}
Ok(())
}
fn fused_scale_mask_softmax(
&self,
scores: &mut MmapTensor,
mask: &MmapTensor,
batch: usize,
num_heads: usize,
seq_len: usize,
scale: f32,
) -> crate::Result<()> {
for b in 0..batch {
for h in 0..num_heads {
for q in 0..seq_len {
let row_off = ((b * num_heads + h) * seq_len + q) * seq_len;
let row = &mut scores[row_off..row_off + seq_len];
for kk in 0..seq_len {
row[kk] = row[kk] * scale + mask[b * seq_len + kk];
}
softmax_inplace(row);
}
}
}
Ok(())
}
fn fused_scale_mask_softmax_windowed(
&self,
scores: &mut MmapTensor,
mask: &MmapTensor,
batch: usize,
num_heads: usize,
seq_len: usize,
scale: f32,
window_size: usize,
) -> crate::Result<()> {
let half_window = window_size / 2;
for b in 0..batch {
for h in 0..num_heads {
for q in 0..seq_len {
let row_off = ((b * num_heads + h) * seq_len + q) * seq_len;
let row = &mut scores[row_off..row_off + seq_len];
for kk in 0..seq_len {
let dist = q.abs_diff(kk);
let window_mask = if dist > half_window { -1e9 } else { 0.0 };
row[kk] = row[kk] * scale + mask[b * seq_len + kk] + window_mask;
}
softmax_inplace(row);
}
}
}
Ok(())
}
fn build_attn_mask(
&self,
output: &mut MmapTensor,
int_mask: &MmapTensor,
n: usize,
) -> crate::Result<()> {
output.resize(n, 0.0);
for i in 0..n {
output[i] = if int_mask[i] > 0.5 { 0.0 } else { -1e9 };
}
Ok(())
}
fn qkv_split(
&self,
q: &mut MmapTensor,
k: &mut MmapTensor,
v: &mut MmapTensor,
qkv: &MmapTensor,
batch: usize,
seq: usize,
hidden: usize,
num_heads: usize,
head_dim: usize,
) -> crate::Result<()> {
let total_head = batch * num_heads * seq * head_dim;
q.resize(total_head, 0.0);
k.resize(total_head, 0.0);
v.resize(total_head, 0.0);
for b in 0..batch {
for s in 0..seq {
let src_row = (b * seq + s) * 3 * hidden;
for h in 0..num_heads {
for d in 0..head_dim {
let src_q = src_row + h * head_dim + d;
let src_k = src_row + hidden + h * head_dim + d;
let src_v = src_row + 2 * hidden + h * head_dim + d;
let dst = (b * num_heads + h) * seq * head_dim + s * head_dim + d;
q[dst] = qkv[src_q];
k[dst] = qkv[src_k];
v[dst] = qkv[src_v];
}
}
}
}
Ok(())
}
fn attn_reshape(
&self,
output: &mut MmapTensor,
input: &MmapTensor,
batch: usize,
seq: usize,
num_heads: usize,
head_dim: usize,
) -> crate::Result<()> {
let hidden = num_heads * head_dim;
let total = batch * seq * hidden;
output.resize(total, 0.0);
for b in 0..batch {
for s in 0..seq {
for h in 0..num_heads {
let src_off = (b * num_heads + h) * seq * head_dim + s * head_dim;
let dst_off = (b * seq + s) * hidden + h * head_dim;
output[dst_off..dst_off + head_dim]
.copy_from_slice(&input[src_off..src_off + head_dim]);
}
}
}
Ok(())
}
fn apply_rope(
&self,
qk: &mut MmapTensor,
cos: &MmapTensor,
sin: &MmapTensor,
num_rows: usize,
seq_len: usize,
head_dim: usize,
_num_heads: usize,
) -> crate::Result<()> {
let half = head_dim / 2;
for row_idx in 0..num_rows {
let pos = row_idx % seq_len;
let base = row_idx * head_dim;
let cache_base = pos * half;
for d in 0..half {
let first = qk[base + d];
let second = qk[base + d + half];
let c = cos[cache_base + d];
let sn = sin[cache_base + d];
qk[base + d] = first * c - second * sn;
qk[base + d + half] = first * sn + second * c;
}
}
Ok(())
}
fn split_gate_value(
&self,
first: &mut MmapTensor,
second: &mut MmapTensor,
input: &MmapTensor,
rows: usize,
cols: usize,
) -> crate::Result<()> {
first.resize(rows * cols, 0.0);
second.resize(rows * cols, 0.0);
for r in 0..rows {
let src = r * 2 * cols;
let dst = r * cols;
first[dst..dst + cols].copy_from_slice(&input[src..src + cols]);
second[dst..dst + cols].copy_from_slice(&input[src + cols..src + 2 * cols]);
}
Ok(())
}
fn gelu(&self, x: &mut MmapTensor, n: usize) -> crate::Result<()> {
for v in x.iter_mut().take(n) {
*v = gelu_scalar(*v);
}
Ok(())
}
fn swiglu(
&self,
value: &MmapTensor,
gate: &MmapTensor,
output: &mut MmapTensor,
n: usize,
) -> crate::Result<()> {
output.resize(n, 0.0);
for i in 0..n {
let g = gate[i];
let silu = g / (1.0 + (-g).exp());
output[i] = value[i] * silu;
}
Ok(())
}
fn geglu(
&self,
value: &MmapTensor,
gate: &MmapTensor,
output: &mut MmapTensor,
n: usize,
) -> crate::Result<()> {
output.resize(n, 0.0);
for i in 0..n {
output[i] = gelu_scalar(value[i]) * gate[i];
}
Ok(())
}
fn fused_bias_gelu(
&self,
x: &mut MmapTensor,
bias: &MmapTensor,
rows: usize,
cols: usize,
) -> crate::Result<()> {
for r in 0..rows {
let base = r * cols;
for c in 0..cols {
x[base + c] = gelu_scalar(x[base + c] + bias[c]);
}
}
Ok(())
}
fn fused_bias_residual(
&self,
output: &mut MmapTensor,
input: &MmapTensor,
bias: &MmapTensor,
residual: &MmapTensor,
n: usize,
cols: usize,
) -> crate::Result<()> {
output.resize(n, 0.0);
for i in 0..n {
output[i] = input[i] + bias[i % cols] + residual[i];
}
Ok(())
}
fn fused_residual_layernorm(
&self,
output: &mut MmapTensor,
hidden: &MmapTensor,
residual: &MmapTensor,
weight: &MmapTensor,
bias: &MmapTensor,
rows: usize,
cols: usize,
eps: f32,
) -> crate::Result<()> {
output.resize(rows * cols, 0.0);
for r in 0..rows {
let base = r * cols;
let mean: f32 = (0..cols)
.map(|c| hidden[base + c] + residual[base + c])
.sum::<f32>()
/ cols as f32;
let var: f32 = (0..cols)
.map(|c| {
let v = hidden[base + c] + residual[base + c] - mean;
v * v
})
.sum::<f32>()
/ cols as f32;
let inv_std = 1.0 / (var + eps).sqrt();
for c in 0..cols {
let v = hidden[base + c] + residual[base + c];
output[base + c] = (v - mean) * inv_std * weight[c] + bias[c];
}
}
Ok(())
}
fn residual_add(
&self,
output: &mut MmapTensor,
hidden: &MmapTensor,
residual: &MmapTensor,
n: usize,
) -> crate::Result<()> {
output.resize(n, 0.0);
for i in 0..n {
output[i] = hidden[i] + residual[i];
}
Ok(())
}
fn add_bias(
&self,
x: &mut MmapTensor,
bias: &MmapTensor,
rows: usize,
cols: usize,
) -> crate::Result<()> {
for r in 0..rows {
let base = r * cols;
for c in 0..cols {
x[base + c] += bias[c];
}
}
Ok(())
}
fn cls_pool(
&self,
output: &mut MmapTensor,
hidden: &MmapTensor,
batch: usize,
seq: usize,
hidden_dim: usize,
) -> crate::Result<()> {
output.resize(batch * hidden_dim, 0.0);
for b in 0..batch {
let src = b * seq * hidden_dim;
let dst = b * hidden_dim;
output[dst..dst + hidden_dim].copy_from_slice(&hidden[src..src + hidden_dim]);
}
Ok(())
}
fn mean_pool(
&self,
output: &mut MmapTensor,
hidden: &MmapTensor,
mask: &MmapTensor,
batch: usize,
seq: usize,
hidden_dim: usize,
) -> crate::Result<()> {
output.resize(batch * hidden_dim, 0.0);
for b in 0..batch {
let mask_sum: f32 = (0..seq).map(|s| mask[b * seq + s]).sum();
let inv_sum = if mask_sum > 0.0 { 1.0 / mask_sum } else { 0.0 };
for d in 0..hidden_dim {
let mut sum = 0.0_f32;
for s in 0..seq {
sum += hidden[(b * seq + s) * hidden_dim + d] * mask[b * seq + s];
}
output[b * hidden_dim + d] = sum * inv_sum;
}
}
Ok(())
}
fn l2_normalize(&self, data: &mut MmapTensor, rows: usize, cols: usize) -> crate::Result<()> {
for r in 0..rows {
let base = r * cols;
let row = &data[base..base + cols];
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
let inv_norm = if norm > 0.0 { 1.0 / norm } else { 0.0 };
for c in 0..cols {
data[base + c] *= inv_norm;
}
}
Ok(())
}
#[expect(
clippy::cast_possible_wrap,
reason = "seq/window indices are small ML dimensions that fit in isize"
)]
fn banded_qk(
&self,
q: &MmapTensor,
k: &MmapTensor,
scores: &mut MmapTensor,
batch_heads: usize,
seq: usize,
head_dim: usize,
window: usize,
stride_qk: usize,
stride_scores: usize,
) -> crate::Result<()> {
let half_w = window / 2;
for h in 0..batch_heads {
for i in 0..seq {
for w in 0..window {
let k_pos = i as isize - half_w as isize + w as isize;
if k_pos < 0 || k_pos >= seq as isize {
scores[h * stride_scores + i * window + w] = -1e9;
} else {
let mut dot = 0.0_f32;
for d in 0..head_dim {
dot += q[h * stride_qk + i * head_dim + d]
* k[h * stride_qk + k_pos as usize * head_dim + d];
}
scores[h * stride_scores + i * window + w] = dot;
}
}
}
}
Ok(())
}
#[expect(
clippy::cast_possible_wrap,
reason = "seq/window indices are small ML dimensions that fit in isize"
)]
fn banded_sv(
&self,
scores: &MmapTensor,
v: &MmapTensor,
output: &mut MmapTensor,
batch_heads: usize,
seq: usize,
head_dim: usize,
window: usize,
stride_scores: usize,
stride_v: usize,
stride_out: usize,
) -> crate::Result<()> {
let half_w = window / 2;
for h in 0..batch_heads {
for i in 0..seq {
for d in 0..head_dim {
let mut sum = 0.0_f32;
for w in 0..window {
let v_pos = i as isize - half_w as isize + w as isize;
if v_pos >= 0 && v_pos < seq as isize {
sum += scores[h * stride_scores + i * window + w]
* v[h * stride_v + v_pos as usize * head_dim + d];
}
}
output[h * stride_out + i * head_dim + d] = sum;
}
}
}
Ok(())
}
fn banded_softmax(
&self,
scores: &mut MmapTensor,
total_rows: usize,
window: usize,
scale: f32,
) -> crate::Result<()> {
for r in 0..total_rows {
let row = &mut scores[r * window..(r + 1) * window];
for v in row.iter_mut() {
*v *= scale;
}
let max = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0_f32;
for v in row.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
let inv = 1.0 / sum.max(1e-12);
for v in row.iter_mut() {
*v *= inv;
}
}
Ok(())
}
fn to_host(
&self,
tensor: &MmapTensor,
batch: usize,
dim: usize,
) -> crate::Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(batch);
for b in 0..batch {
results.push(tensor[b * dim..(b + 1) * dim].to_vec());
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_driver_creates() {
let _driver = CpuDriver::new().unwrap();
}
#[test]
fn driver_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CpuDriver>();
}
#[test]
fn gelu_smoke_test() {
let driver = CpuDriver::new().unwrap();
let mut tensor: MmapTensor = vec![0.0_f32, 1.0, -1.0, 2.0].into();
driver.gelu(&mut tensor, 4).unwrap();
assert!(
tensor[0].abs() < 1e-4,
"GELU(0) should be ~0, got {}",
tensor[0]
);
assert!(
(tensor[1] - 0.8412).abs() < 0.01,
"GELU(1) should be ~0.8412, got {}",
tensor[1]
);
assert!(
(tensor[2] - (-0.1588)).abs() < 0.01,
"GELU(-1) should be ~-0.1588, got {}",
tensor[2]
);
}
#[test]
fn layer_norm_smoke_test() {
let driver = CpuDriver::new().unwrap();
let input: MmapTensor = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0].into(); let weight: MmapTensor = vec![1.0, 1.0, 1.0].into();
let bias: MmapTensor = vec![0.0, 0.0, 0.0].into();
let mut output: MmapTensor = vec![].into();
driver
.layer_norm(&mut output, &input, &weight, &bias, 2, 3, 1e-5)
.unwrap();
let row0_mean: f32 = output[0..3].iter().sum::<f32>() / 3.0;
assert!(
row0_mean.abs() < 1e-5,
"layer norm row mean should be ~0, got {row0_mean}"
);
}
#[test]
fn gemm_smoke_test() {
let driver = CpuDriver::new().unwrap();
let a: MmapTensor = vec![1.0, 2.0, 3.0, 4.0].into();
let b: MmapTensor = vec![5.0, 6.0, 7.0, 8.0].into();
let mut output: MmapTensor = vec![].into();
driver.gemm(&a, &b, &mut output, 2, 2, 2, false).unwrap();
assert!((output[0] - 19.0).abs() < 1e-4, "got {}", output[0]);
assert!((output[1] - 22.0).abs() < 1e-4, "got {}", output[1]);
assert!((output[2] - 43.0).abs() < 1e-4, "got {}", output[2]);
assert!((output[3] - 50.0).abs() < 1e-4, "got {}", output[3]);
}
#[test]
fn gemm_transpose_b_test() {
let driver = CpuDriver::new().unwrap();
let a: MmapTensor = vec![1.0, 2.0, 3.0, 4.0].into();
let b: MmapTensor = vec![5.0, 6.0, 7.0, 8.0].into(); let mut output: MmapTensor = vec![].into();
driver.gemm(&a, &b, &mut output, 2, 2, 2, true).unwrap();
assert!((output[0] - 17.0).abs() < 1e-4, "got {}", output[0]);
assert!((output[1] - 23.0).abs() < 1e-4, "got {}", output[1]);
assert!((output[2] - 39.0).abs() < 1e-4, "got {}", output[2]);
assert!((output[3] - 53.0).abs() < 1e-4, "got {}", output[3]);
}
#[test]
fn embedding_lookup_test() {
let driver = CpuDriver::new().unwrap();
let table: MmapTensor = vec![
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, ]
.into();
let ids: MmapTensor = vec![2.0, 0.0, 4.0].into();
let output = driver.embedding_lookup(&ids, &table, 3, 2).unwrap();
assert_eq!(&*output, &[0.5, 0.6, 0.1, 0.2, 0.9, 1.0]);
}
#[test]
fn cls_pool_test() {
let driver = CpuDriver::new().unwrap();
let hidden: MmapTensor = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ]
.into();
let mut output: MmapTensor = vec![].into();
driver.cls_pool(&mut output, &hidden, 2, 3, 2).unwrap();
assert_eq!(&*output, &[1.0, 2.0, 7.0, 8.0]);
}
}