#[inline]
pub(crate) fn gemm_block(
coef: &[f64],
state_block: &[f64],
k_rows: usize,
d: usize,
m_len: usize,
v_block: &mut [f64],
) {
if k_rows == 0 || d == 0 || m_len == 0 {
return;
}
debug_assert_eq!(
coef.len(),
k_rows * d,
"gemm_block: coef slice must be exactly k_rows*d"
);
debug_assert_eq!(
state_block.len(),
m_len * d,
"gemm_block: state_block must be exactly m_len*d"
);
debug_assert_eq!(
v_block.len(),
k_rows * m_len,
"gemm_block: v_block must be exactly k_rows*m_len"
);
#[allow(clippy::cast_possible_wrap)]
unsafe {
matrixmultiply::dgemm(
k_rows,
d,
m_len,
1.0,
coef.as_ptr(),
d as isize,
1,
state_block.as_ptr(),
1,
d as isize,
0.0,
v_block.as_mut_ptr(),
m_len as isize,
1,
);
}
}
#[cfg(test)]
#[allow(clippy::cast_precision_loss)]
mod tests {
use super::gemm_block;
#[test]
fn gemm_block_matches_naive_reference_small() {
const K: usize = 5;
const D: usize = 3;
const M_LEN: usize = 4;
let coef: Vec<f64> = (0..K * D).map(|i| (i as f64) * 0.1).collect();
let state: Vec<f64> = (0..M_LEN * D).map(|i| (i as f64) * 0.01 - 0.5).collect();
let mut v = [0.0_f64; K * M_LEN];
gemm_block(&coef, &state, K, D, M_LEN, &mut v);
let mut expected = [0.0_f64; K * M_LEN];
for k in 0..K {
for m in 0..M_LEN {
let mut acc = 0.0_f64;
for d in 0..D {
acc += coef[k * D + d] * state[m * D + d];
}
expected[k * M_LEN + m] = acc;
}
}
for i in 0..(K * M_LEN) {
assert!(
(v[i] - expected[i]).abs() < 1e-12,
"gemm_block[{i}] = {} but expected {}",
v[i],
expected[i],
);
}
}
#[test]
fn gemm_block_zero_dimensions_no_op() {
let mut v: Vec<f64> = Vec::new();
gemm_block(&[], &[], 0, 5, 3, &mut v);
let mut v2 = vec![1.0_f64; 15];
gemm_block(&[], &[], 5, 0, 3, &mut v2);
assert!(v2.iter().all(|&x| x == 1.0));
let mut v3 = vec![2.0_f64; 0];
gemm_block(&[], &[], 5, 3, 0, &mut v3);
}
}