use super::traits::NumericOps;
pub(super) unsafe fn reinterpret_vec<T, U>(v: Vec<T>) -> Vec<U> {
debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
debug_assert_eq!(std::mem::align_of::<T>(), std::mem::align_of::<U>());
let mut v = std::mem::ManuallyDrop::new(v);
Vec::from_raw_parts(v.as_mut_ptr() as *mut U, v.len(), v.capacity())
}
pub(super) fn gemm_cpu<T: NumericOps>(a: &[T], b: &[T], m: usize, k: usize, n: usize) -> Vec<T> {
let mut c = vec![T::zero(); m * n];
for i in 0..m {
for p in 0..k {
let a_val = a[i * k + p];
for j in 0..n {
let b_val = b[p * n + j];
c[i * n + j] = c[i * n + j].add(a_val.mul(b_val));
}
}
}
c
}
pub(super) fn gemm_transpose_b_cpu<T: NumericOps>(
a: &[T],
b: &[T],
m: usize,
k: usize,
n: usize,
) -> Vec<T> {
let mut c = vec![T::zero(); m * n];
for i in 0..m {
for j in 0..n {
let mut sum = T::zero();
for p in 0..k {
sum = sum.add(a[i * k + p].mul(b[j * k + p]));
}
c[i * n + j] = sum;
}
}
c
}
pub(super) fn gemm_transpose_a_cpu<T: NumericOps>(
a: &[T],
b: &[T],
m: usize,
k: usize,
n: usize,
) -> Vec<T> {
let mut c = vec![T::zero(); m * n];
for i in 0..m {
for p in 0..k {
let a_val = a[p * m + i]; for j in 0..n {
c[i * n + j] = c[i * n + j].add(a_val.mul(b[p * n + j]));
}
}
}
c
}