use super::quant;
use rayon::prelude::*;
use sapient_core::error::{Result, SapientError};
use sapient_core::{DType, Shape, Tensor, Q4_0_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE};
pub fn matmul(a: &Tensor, b: &Tensor) -> Result<Tensor> {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.ndim() < 2 || b_shape.ndim() < 2 {
return Err(SapientError::RankMismatch {
expected: 2,
got: a_shape.ndim().min(b_shape.ndim()),
});
}
let a_rank = a_shape.ndim();
let b_rank = b_shape.ndim();
let m = a_shape.dims()[a_rank - 2];
let k = a_shape.dims()[a_rank - 1];
let k2 = b_shape.dims()[b_rank - 2];
let n = b_shape.dims()[b_rank - 1];
if k != k2 {
return Err(SapientError::ShapeMismatch {
expected: vec![m, k, n],
got: vec![m, k2, n],
});
}
let batch: usize = a_shape.dims()[..a_rank - 2].iter().product();
let a_cow = a.to_f32_cow();
let a_data = a_cow.as_ref();
let b_cow = b.to_f32_cow();
let b_data = b_cow.as_ref();
let out_numel = batch * m * n;
let mut out_data = vec![0.0f32; out_numel];
let a_stride = m * k;
let b_stride = k * n;
let c_stride = m * n;
for bi in 0..batch {
let a_off = bi * a_stride;
let b_off = bi * b_stride;
let c_off = bi * c_stride;
unsafe {
matrixmultiply::sgemm(
m,
k,
n,
1.0,
a_data[a_off..].as_ptr(),
k as isize,
1,
b_data[b_off..].as_ptr(),
n as isize,
1,
0.0,
out_data[c_off..].as_mut_ptr(),
n as isize,
1,
);
}
}
let mut out_dims: Vec<usize> = if a_rank > 2 {
a_shape.dims()[..a_rank - 2].to_vec()
} else {
vec![]
};
out_dims.push(m);
out_dims.push(n);
Tensor::from_f32(&out_data, Shape::new(out_dims))
}
pub fn matmul_nt(x: &Tensor, w: &Tensor) -> Result<Tensor> {
let xd = x.shape().dims();
let wd = w.shape().dims();
if xd.len() != 2 || wd.len() != 2 {
return Err(SapientError::internal("matmul_nt expects 2-D tensors"));
}
let (m, k) = (xd[0], xd[1]);
let (n, k2) = (wd[0], wd[1]);
if k != k2 {
return Err(SapientError::ShapeMismatch {
expected: vec![m, k],
got: vec![n, k2],
});
}
match w.dtype() {
DType::Q4_0 => matmul_nt_q4_0(x, w, m, k, n),
DType::Q8_0 => matmul_nt_q8_0(x, w, m, k, n),
_ => matmul_nt_float(x, w, m, k, n),
}
}
fn matmul_nt_float(x: &Tensor, w: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
let x_cow = x.to_f32_cow();
let w_cow = w.to_f32_cow();
let x_data = x_cow.as_ref();
let w_data = w_cow.as_ref();
let mut out = vec![0.0f32; m * n];
unsafe {
matrixmultiply::sgemm(
m,
k,
n,
1.0,
x_data.as_ptr(),
k as isize,
1,
w_data.as_ptr(),
1,
k as isize,
0.0,
out.as_mut_ptr(),
n as isize,
1,
);
}
Tensor::from_f32(&out, Shape::new([m, n]))
}
fn matmul_nt_q4_0(x: &Tensor, w: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
if k % QUANT_BLOCK_SIZE != 0 {
return Err(SapientError::internal(
"Q4_0 matmul_nt: k must be a multiple of the block size (32)",
));
}
let x_cow = x.to_f32_cow();
let x_data: &[f32] = x_cow.as_ref();
let w_blocks: &[u8] = w.as_quant_blocks();
let row_bytes = k / QUANT_BLOCK_SIZE * Q4_0_BLOCK_BYTES;
let mut out = vec![0.0f32; m * n];
for i in 0..m {
let x_row = &x_data[i * k..(i + 1) * k];
out[i * n..(i + 1) * n]
.par_iter_mut()
.enumerate()
.for_each(|(j, slot)| {
*slot =
quant::dot_q4_0_row_f32(&w_blocks[j * row_bytes..(j + 1) * row_bytes], x_row);
});
}
Tensor::from_f32(&out, Shape::new([m, n]))
}
fn matmul_nt_q8_0(x: &Tensor, w: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
if k % QUANT_BLOCK_SIZE != 0 {
return Err(SapientError::internal(
"Q8_0 matmul_nt: k must be a multiple of the block size (32)",
));
}
let x_cow = x.to_f32_cow();
let x_data: &[f32] = x_cow.as_ref();
let w_blocks: &[u8] = w.as_quant_blocks();
let row_bytes = k / QUANT_BLOCK_SIZE * Q8_0_BLOCK_BYTES;
let mut out = vec![0.0f32; m * n];
for i in 0..m {
let x_row = &x_data[i * k..(i + 1) * k];
out[i * n..(i + 1) * n]
.par_iter_mut()
.enumerate()
.for_each(|(j, slot)| {
*slot =
quant::dot_q8_0_row_f32(&w_blocks[j * row_bytes..(j + 1) * row_bytes], x_row);
});
}
Tensor::from_f32(&out, Shape::new([m, n]))
}
pub fn gemm(
a: &Tensor,
b: &Tensor,
bias: Option<&Tensor>,
alpha: f32,
beta: f32,
trans_a: bool,
trans_b: bool,
) -> Result<Tensor> {
let a2 = if trans_a { a.t()? } else { a.clone() };
let b2 = if trans_b { b.t()? } else { b.clone() };
let a_shape = a2.shape();
let b_shape = b2.shape();
let m = a_shape.dims()[0];
let k = a_shape.dims()[1];
let k2 = b_shape.dims()[0];
let n = b_shape.dims()[1];
if k != k2 {
return Err(SapientError::ShapeMismatch {
expected: vec![m, k],
got: vec![k2, n],
});
}
let a_cow = a2.to_f32_cow();
let a_data = a_cow.as_ref();
let b_cow = b2.to_f32_cow();
let b_data = b_cow.as_ref();
let mut out = vec![0.0f32; m * n];
let a_strides = a2.strides();
let b_strides = b2.strides();
unsafe {
matrixmultiply::sgemm(
m,
k,
n,
alpha,
a_data.as_ptr(),
a_strides[0] as isize,
a_strides[1] as isize,
b_data.as_ptr(),
b_strides[0] as isize,
b_strides[1] as isize,
0.0,
out.as_mut_ptr(),
n as isize,
1,
);
}
if let Some(bias_t) = bias {
let bias_data = bias_t.as_f32_slice();
let b_len = bias_data.len();
if b_len != n && b_len != 1 {
return Err(SapientError::ShapeMismatch {
expected: vec![n],
got: vec![b_len],
});
}
for i in 0..m {
for j in 0..n {
let bv = if b_len == 1 {
bias_data[0]
} else {
bias_data[j]
};
out[i * n + j] += beta * bv;
}
}
}
Tensor::from_f32(&out, Shape::new([m, n]))
}
#[cfg(test)]
mod tests {
use super::*;
use sapient_core::DType;
#[test]
fn matmul_2x2() {
let a = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_f32(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
let c = matmul(&a, &b).unwrap();
let data = c.as_f32_slice();
assert!((data[0] - 19.0).abs() < 1e-5);
assert!((data[1] - 22.0).abs() < 1e-5);
assert!((data[2] - 43.0).abs() < 1e-5);
assert!((data[3] - 50.0).abs() < 1e-5);
}
#[test]
fn matmul_nt_q4_0_matches_float() {
let n_out = 4;
let k = 64;
let w_f32: Vec<f32> = (0..n_out * k)
.map(|i| (i as f32 % 16.0 - 8.0) * 0.05)
.collect();
let x_f32: Vec<f32> = (0..k).map(|i| (i as f32) * 0.01 - 0.3).collect();
let w_t = Tensor::from_f32(&w_f32, vec![n_out, k]).unwrap();
let x_t = Tensor::from_f32(&x_f32, vec![1, k]).unwrap();
let ref_out = matmul_nt(&x_t, &w_t).unwrap();
let ref_data = ref_out.as_f32_slice();
let w_blocks: Vec<u8> = w_f32
.chunks_exact(k)
.flat_map(super::quant::quantize_q4_0_row)
.collect();
let w_q = Tensor::from_quant_bytes(&w_blocks, vec![n_out, k], DType::Q4_0).unwrap();
let quant_out = matmul_nt(&x_t, &w_q).unwrap();
let quant_data = quant_out.as_f32_slice();
assert_eq!(ref_data.len(), quant_data.len());
for (i, (r, q)) in ref_data.iter().zip(quant_data).enumerate() {
assert!((r - q).abs() < 5e-3, "row {i}: ref={r} quant={q}");
}
}
#[test]
fn matmul_nt_linear() {
let x = Tensor::from_f32(&[1.0, 2.0], vec![1, 2]).unwrap();
let w = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
let y = matmul_nt(&x, &w).unwrap();
let d = y.as_f32_slice();
assert_eq!(y.shape().dims(), &[1, 3]);
assert!((d[0] - 5.0).abs() < 1e-5, "got {d:?}");
assert!((d[1] - 11.0).abs() < 1e-5, "got {d:?}");
assert!((d[2] - 17.0).abs() < 1e-5, "got {d:?}");
}
#[test]
fn matmul_nt_linear_f16_weight() {
let x = Tensor::from_f32(&[1.0, 2.0], vec![1, 2]).unwrap();
let bytes: Vec<u8> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.flat_map(|&v| half::f16::from_f32(v).to_le_bytes())
.collect();
let w = Tensor::from_f16_bytes(&bytes, vec![3, 2]).unwrap();
let y = matmul_nt(&x, &w).unwrap();
let d = y.as_f32_slice();
assert!((d[0] - 5.0).abs() < 1e-2, "got {d:?}");
assert!((d[1] - 11.0).abs() < 1e-2, "got {d:?}");
assert!((d[2] - 17.0).abs() < 1e-2, "got {d:?}");
}
#[test]
fn matmul_rank_mismatch() {
let a = Tensor::zeros(vec![4], DType::F32).unwrap();
let b = Tensor::zeros(vec![4], DType::F32).unwrap();
assert!(matmul(&a, &b).is_err());
}
#[test]
fn gemm_with_bias() {
let a = Tensor::from_f32(&[1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
let b = Tensor::from_f32(&[2.0, 3.0, 4.0, 5.0], vec![2, 2]).unwrap();
let bias = Tensor::from_f32(&[1.0, 1.0], vec![2]).unwrap();
let c = gemm(&a, &b, Some(&bias), 1.0, 1.0, false, false).unwrap();
let d = c.as_f32_slice();
assert!((d[0] - 3.0).abs() < 1e-5, "got {}", d[0]);
assert!((d[1] - 4.0).abs() < 1e-5, "got {}", d[1]);
}
}