use super::super::dot::i8xi8_dot_i32;
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_i8_to_i32(
a: *const i8,
b: *const i8,
out: *mut i32,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let mut b_col = vec![0i8; k];
for j in 0..n {
for kk in 0..k {
*b_col.as_mut_ptr().add(kk) = *b.add(kk * ldb + j);
}
for i in 0..m {
let a_row = a.add(i * lda);
*out.add(i * ldc + j) = i8xi8_dot_i32(a_row, b_col.as_ptr(), k);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matmul_i8_to_i32_basic() {
let a: Vec<i8> = vec![1, 2, 3, 4];
let b: Vec<i8> = vec![5, 6, 7, 8];
let mut c = [0i32; 4];
unsafe {
matmul_i8_to_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 2, 2, 2, 2, 2);
}
assert_eq!(c, [19, 22, 43, 50]);
}
#[test]
fn test_matmul_i8_to_i32_negative() {
let a: Vec<i8> = vec![-1, 2, 3, -4];
let b: Vec<i8> = vec![5, -6, -7, 8];
let mut c = [0i32; 4];
unsafe {
matmul_i8_to_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 2, 2, 2, 2, 2);
}
assert_eq!(c, [-19, 22, 43, -50]);
}
#[test]
fn test_matmul_i8_to_i32_wide() {
let (m, n, k) = (2, 3, 64);
let a: Vec<i8> = (0..m * k)
.map(|i| ((i % 127) as i8).wrapping_sub(64))
.collect();
let b: Vec<i8> = (0..k * n)
.map(|i| ((i * 3 % 127) as i8).wrapping_sub(64))
.collect();
let mut c = vec![0i32; m * n];
unsafe {
matmul_i8_to_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n);
}
let mut expected = vec![0i32; m * n];
for i in 0..m {
for j in 0..n {
for kk in 0..k {
expected[i * n + j] += a[i * k + kk] as i32 * b[kk * n + j] as i32;
}
}
}
assert_eq!(c, expected);
}
}