use crate::dtype::Element;
#[inline]
unsafe fn convert_to_f32<T: Element>(
src: *const T,
dst: *mut f32,
rows: usize,
cols: usize,
ld: usize,
) {
for i in 0..rows {
let src_row = src.add(i * ld);
let dst_row = dst.add(i * cols);
for j in 0..cols {
*dst_row.add(j) = (*src_row.add(j)).to_f32();
}
}
}
#[inline]
unsafe fn convert_from_f32<T: Element>(
src: *const f32,
dst: *mut T,
rows: usize,
cols: usize,
ld: usize,
) {
for i in 0..rows {
let src_row = src.add(i * cols);
let dst_row = dst.add(i * ld);
for j in 0..cols {
*dst_row.add(j) = T::from_f32(*src_row.add(j));
}
}
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_via_f32<T: Element>(
a: *const T,
b: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let mut a_f32 = vec![0.0f32; m * k];
let mut b_f32 = vec![0.0f32; k * n];
let mut c_f32 = vec![0.0f32; m * n];
convert_to_f32(a, a_f32.as_mut_ptr(), m, k, lda);
convert_to_f32(b, b_f32.as_mut_ptr(), k, n, ldb);
super::matmul_f32(
a_f32.as_ptr(),
b_f32.as_ptr(),
c_f32.as_mut_ptr(),
m,
n,
k,
k,
n,
n,
);
convert_from_f32(c_f32.as_ptr(), out, m, n, ldc);
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_bias_via_f32<T: Element>(
a: *const T,
b: *const T,
bias: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
let mut a_f32 = vec![0.0f32; m * k];
let mut b_f32 = vec![0.0f32; k * n];
let mut bias_f32 = vec![0.0f32; n];
let mut c_f32 = vec![0.0f32; m * n];
convert_to_f32(a, a_f32.as_mut_ptr(), m, k, lda);
convert_to_f32(b, b_f32.as_mut_ptr(), k, n, ldb);
convert_to_f32(bias, bias_f32.as_mut_ptr(), 1, n, n);
super::matmul_bias_f32(
a_f32.as_ptr(),
b_f32.as_ptr(),
bias_f32.as_ptr(),
c_f32.as_mut_ptr(),
m,
n,
k,
k,
n,
n,
);
convert_from_f32(c_f32.as_ptr(), out, m, n, ldc);
}