use anyhow::Result;
use sapient_backends_cpu::kernels::{self, attention, layernorm, matmul, rope};
use sapient_core::error::SapientError;
use sapient_core::{DType, Shape, Tensor};
fn map_err<T>(result: std::result::Result<T, SapientError>) -> Result<T> {
result.map_err(|e| anyhow::anyhow!("{e}"))
}
pub fn embed_tokens(weight: &Tensor, input_ids: &[u32]) -> Result<Tensor> {
let hidden = weight.shape().dims()[1];
let seq_len = input_ids.len();
let w_cow = weight.to_f32_cow();
let w = w_cow.as_ref();
let mut out = vec![0.0f32; seq_len * hidden];
for (i, &id) in input_ids.iter().enumerate() {
let row = id as usize * hidden;
if row + hidden > w.len() {
anyhow::bail!("token id {id} out of vocab range");
}
out[i * hidden..(i + 1) * hidden].copy_from_slice(&w[row..row + hidden]);
}
Tensor::from_f32(&out, Shape::new([1, seq_len, hidden])).map_err(|e| anyhow::anyhow!("{e}"))
}
pub fn linear_3d(x: &Tensor, weight: &Tensor) -> Result<Tensor> {
let dims = x.shape().dims();
if dims.len() != 3 {
anyhow::bail!("linear_3d expects [batch, seq, hidden]");
}
let (batch, seq, in_dim) = (dims[0], dims[1], dims[2]);
let w_dims = weight.shape().dims();
if w_dims.len() != 2 {
anyhow::bail!("linear weight must be 2-D");
}
let out_dim = w_dims[0];
if w_dims[1] != in_dim {
anyhow::bail!("linear weight in_dim mismatch: {} vs {in_dim}", w_dims[1]);
}
let x2d = map_err(x.reshape(vec![batch * seq, in_dim]))?;
let y2d = map_err(matmul::matmul_nt(&x2d, weight))?;
map_err(y2d.reshape(vec![batch, seq, out_dim]))
}
pub fn split_heads(x: &Tensor, n_heads: usize, head_dim: usize) -> Result<Tensor> {
let seq = x.shape().dims()[1];
permute(
&map_err(x.reshape(vec![1, seq, n_heads, head_dim]))?,
&[0, 2, 1, 3],
)
}
pub fn merge_heads(x: &Tensor) -> Result<Tensor> {
let d = x.shape().dims();
let (n_heads, seq, head_dim) = (d[1], d[2], d[3]);
permute(x, &[0, 2, 1, 3])?
.reshape(vec![1, seq, n_heads * head_dim])
.map_err(|e| anyhow::anyhow!("{e}"))
}
pub fn permute(x: &Tensor, order: &[usize]) -> Result<Tensor> {
let dims = x.shape().dims();
if order.len() != dims.len() {
anyhow::bail!("permute rank mismatch");
}
let new_dims: Vec<usize> = order.iter().map(|&i| dims[i]).collect();
let src = x.as_f32_slice();
let mut out = vec![0.0f32; src.len()];
#[allow(clippy::too_many_arguments)]
fn recurse(
dims: &[usize],
order: &[usize],
src: &[f32],
out: &mut [f32],
src_strides: &[usize],
dst_strides: &[usize],
idx: &mut [usize],
depth: usize,
) {
if depth == dims.len() {
let src_off: usize = idx
.iter()
.zip(src_strides.iter())
.map(|(&i, &s)| i * s)
.sum();
let dst_off: usize = order
.iter()
.enumerate()
.map(|(dst_ax, &src_ax)| idx[src_ax] * dst_strides[dst_ax])
.sum();
out[dst_off] = src[src_off];
return;
}
for i in 0..dims[depth] {
idx[depth] = i;
recurse(
dims,
order,
src,
out,
src_strides,
dst_strides,
idx,
depth + 1,
);
}
}
let src_strides = strides_for(dims);
let dst_strides = strides_for(&new_dims);
let mut idx = vec![0usize; dims.len()];
recurse(
dims,
order,
src,
&mut out,
&src_strides,
&dst_strides,
&mut idx,
0,
);
Tensor::from_f32(&out, Shape::new(new_dims)).map_err(|e| anyhow::anyhow!("{e}"))
}
fn strides_for(dims: &[usize]) -> Vec<usize> {
let mut strides = vec![1usize; dims.len()];
for i in (0..dims.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
strides
}
#[inline]
fn quantize_f32_to_q8_0_block(data: &[f32]) -> [u8; 34] {
debug_assert_eq!(data.len(), 32, "Q8_0 block must have exactly 32 elements");
let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale = max_abs / 127.0;
let d = half::f16::from_f32(scale);
let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
let mut block = [0u8; 34];
block[0..2].copy_from_slice(&d.to_le_bytes());
for (i, &v) in data.iter().enumerate() {
block[2 + i] = (v * inv_scale).round().clamp(-127.0, 127.0) as i8 as u8;
}
block
}
pub fn update_kv_cache(
cache: &mut Tensor,
current_seq_len: usize,
new_k: &Tensor,
) -> Result<Tensor> {
let cd = cache.shape().dims().to_vec();
let nd = new_k.shape().dims().to_vec();
if cd.len() != 4 || nd.len() != 4 {
anyhow::bail!("update_kv_cache expects 4-D tensors");
}
if cd[0] != nd[0] || cd[1] != nd[1] || cd[3] != nd[3] {
anyhow::bail!("update_kv_cache shape mismatch");
}
if cache.dtype() == DType::Q8_0 {
return update_kv_cache_q8(cache, &cd, &nd, current_seq_len, new_k);
}
let max_seq = cd[2];
let new_seq = nd[2];
if new_seq > max_seq {
anyhow::bail!("new tokens {} exceeds max cache size {}", new_seq, max_seq);
}
let mut total_seq = current_seq_len + new_seq;
let shift = total_seq.saturating_sub(max_seq);
let (b_sz, h, hd) = (cd[0], cd[1], cd[3]);
let new_k_slice = new_k.as_f32_slice();
let cache_strides = cache.strides().to_vec();
{
let cache_slice = cache.as_f32_slice_mut()?;
if shift > 0 {
let keep_seq = current_seq_len - shift;
for bi in 0..b_sz {
for hi in 0..h {
let cache_base = bi * cache_strides[0] + hi * cache_strides[1];
for si in 0..keep_seq {
let src_idx = cache_base + (si + shift) * cache_strides[2];
let dst_idx = cache_base + si * cache_strides[2];
cache_slice.copy_within(src_idx..src_idx + hd, dst_idx);
}
}
}
}
let insert_pos = if shift > 0 {
current_seq_len - shift
} else {
current_seq_len
};
for bi in 0..b_sz {
for hi in 0..h {
let cache_base =
bi * cache_strides[0] + hi * cache_strides[1] + insert_pos * cache_strides[2];
let new_base = ((bi * h + hi) * new_seq) * hd;
for si in 0..new_seq {
let c_idx = cache_base + si * cache_strides[2];
let n_idx = new_base + si * hd;
cache_slice[c_idx..c_idx + hd].copy_from_slice(&new_k_slice[n_idx..n_idx + hd]);
}
}
}
}
if shift > 0 {
total_seq = max_seq;
}
cache
.slice_axis(2, 0, total_seq)
.map_err(|e| anyhow::anyhow!("{e}"))
}
fn update_kv_cache_q8(
cache: &mut Tensor,
cd: &[usize], nd: &[usize], current_seq_len: usize,
new_k: &Tensor,
) -> Result<Tensor> {
let (b_sz, h, max_seq, hd) = (cd[0], cd[1], cd[2], cd[3]);
let new_seq = nd[2];
if new_seq > max_seq {
anyhow::bail!("new tokens {} exceeds max cache size {}", new_seq, max_seq);
}
let blocks_per_head = hd / 32;
let bytes_per_pos = blocks_per_head * 34;
let mut total_seq = current_seq_len + new_seq;
let shift = total_seq.saturating_sub(max_seq);
let mut cache_bytes: Vec<u8> = cache.as_bytes().to_vec();
let new_k_f32 = new_k.to_f32_vec();
let pos_off = |bi: usize, hi: usize, si: usize| -> usize {
(bi * h * max_seq + hi * max_seq + si) * bytes_per_pos
};
if shift > 0 {
let keep_seq = current_seq_len - shift;
for bi in 0..b_sz {
for hi in 0..h {
for si in 0..keep_seq {
let src = pos_off(bi, hi, si + shift);
let dst = pos_off(bi, hi, si);
cache_bytes.copy_within(src..src + bytes_per_pos, dst);
}
}
}
}
let insert_pos = if shift > 0 {
current_seq_len - shift
} else {
current_seq_len
};
for bi in 0..b_sz {
for hi in 0..h {
for si in 0..new_seq {
let dst_start = pos_off(bi, hi, insert_pos + si);
let src_f32_start = (bi * h * new_seq + hi * new_seq + si) * hd;
let src_f32 = &new_k_f32[src_f32_start..src_f32_start + hd];
for blk in 0..blocks_per_head {
let f_blk = &src_f32[blk * 32..(blk + 1) * 32];
let encoded = quantize_f32_to_q8_0_block(f_blk);
let byte_off = dst_start + blk * 34;
cache_bytes[byte_off..byte_off + 34].copy_from_slice(&encoded);
}
}
}
}
if shift > 0 {
total_seq = max_seq;
}
*cache = Tensor::from_quant_bytes(&cache_bytes, vec![b_sz, h, max_seq, hd], DType::Q8_0)
.map_err(|e| anyhow::anyhow!("{e}"))?;
let out_numel = b_sz * h * total_seq * hd;
let mut out_f32 = vec![0.0f32; out_numel];
for bi in 0..b_sz {
for hi in 0..h {
for si in 0..total_seq {
let src_start = pos_off(bi, hi, si);
let dst_f32_start = (bi * h * total_seq + hi * total_seq + si) * hd;
let src_slice = &cache_bytes[src_start..src_start + bytes_per_pos];
for blk in 0..blocks_per_head {
let blk_bytes = &src_slice[blk * 34..(blk + 1) * 34];
let d = half::f16::from_le_bytes([blk_bytes[0], blk_bytes[1]]).to_f32();
for j in 0..32 {
out_f32[dst_f32_start + blk * 32 + j] = blk_bytes[2 + j] as i8 as f32 * d;
}
}
}
}
}
Tensor::from_f32(&out_f32, Shape::new(vec![b_sz, h, total_seq, hd]))
.map_err(|e| anyhow::anyhow!("{e}"))
}
pub fn apply_rope_positions(x: &Tensor, positions: &[usize], base: f32) -> Result<Tensor> {
map_err(rope::apply_rope(x, positions, base))
}
pub fn apply_rope_partial(
x: &Tensor,
positions: &[usize],
base: f32,
rotary_dim: usize,
) -> Result<Tensor> {
map_err(rope::apply_rope_partial(x, positions, base, rotary_dim))
}
pub fn add_bias_last_dim(y: &Tensor, bias: &Tensor) -> Result<Tensor> {
let dims = y.shape().dims().to_vec();
let n = *dims.last().ok_or_else(|| anyhow::anyhow!("empty tensor"))?;
let bias_cow = bias.to_f32_cow();
let b = bias_cow.as_ref();
if b.len() != n {
anyhow::bail!("bias length {} does not match last dim {n}", b.len());
}
let mut data = y.as_f32_slice().to_vec();
for (i, v) in data.iter_mut().enumerate() {
*v += b[i % n];
}
map_err(Tensor::from_f32(&data, Shape::new(dims)))
}
pub fn rms_norm(x: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
map_err(layernorm::rms_norm(x, Some(weight), eps))
}
pub fn layer_norm(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>, eps: f32) -> Result<Tensor> {
map_err(layernorm::layer_norm(x, Some(weight), bias, -1, eps))
}
pub fn silu(x: &Tensor) -> Result<Tensor> {
map_err(kernels::elementwise::silu(x))
}
pub fn gelu(x: &Tensor) -> Result<Tensor> {
map_err(kernels::elementwise::gelu(x))
}
pub fn add(a: &Tensor, b: &Tensor) -> Result<Tensor> {
map_err(kernels::elementwise::add(a, b))
}
pub fn mul(a: &Tensor, b: &Tensor) -> Result<Tensor> {
map_err(kernels::elementwise::mul(a, b))
}
pub fn gqa_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
n_kv_heads: usize,
causal: bool,
) -> Result<Tensor> {
let mask = if causal {
let sq = q.shape().dims()[2];
let sk = k.shape().dims()[2];
Some(attention::causal_mask(sq, sk))
} else {
None
};
map_err(attention::scaled_dot_product_attention(
q,
k,
v,
mask.as_ref(),
None,
n_kv_heads,
))
}
pub fn all_logits_from_hidden(hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<Vec<f32>>> {
let dims = hidden.shape().dims();
let hidden_size = dims[2];
let seq = dims[1];
let vocab_size = lm_head.shape().dims()[0];
let h = hidden.as_f32_slice();
let h_all =
Tensor::from_f32(h, Shape::new([seq, hidden_size])).map_err(|e| anyhow::anyhow!("{e}"))?;
let logits_flat = map_err(matmul::matmul_nt(&h_all, lm_head))?;
let flat = logits_flat.as_f32_slice();
let mut all = Vec::with_capacity(seq);
for i in 0..seq {
all.push(flat[i * vocab_size..(i + 1) * vocab_size].to_vec());
}
Ok(all)
}
pub fn logits_from_hidden(hidden: &Tensor, lm_head: &Tensor) -> Result<Vec<f32>> {
let dims = hidden.shape().dims();
let hidden_size = dims[2];
let seq = dims[1];
let h = hidden.as_f32_slice();
let last = &h[(seq - 1) * hidden_size..seq * hidden_size];
let h_last =
Tensor::from_f32(last, Shape::new([1, hidden_size])).map_err(|e| anyhow::anyhow!("{e}"))?;
let logits = map_err(matmul::matmul_nt(&h_last, lm_head))?;
Ok(logits.as_f32_slice().to_vec())
}
pub fn mean_pool_hidden(hidden: &Tensor) -> Result<Vec<f32>> {
let dims = hidden.shape().dims();
let (seq, hidden_size) = (dims[1], dims[2]);
let h = hidden.as_f32_slice();
let mut out = vec![0.0f32; hidden_size];
for t in 0..seq {
for i in 0..hidden_size {
out[i] += h[t * hidden_size + i];
}
}
let n = seq as f32;
for v in &mut out {
*v /= n;
}
Ok(out)
}