use crate::l3::gemm_impl::nt_blocked::{dgemm_nt_blocked, sgemm_nt_blocked};
use crate::l3::gemm_impl::nt_direct::{dgemm_nt, sgemm_nt};
use crate::l3::gemm_impl::tn_blocked::{dgemm_tn_blocked, sgemm_tn_blocked};
use crate::l3::gemm_impl::tt_blocked::{dgemm_tt_blocked, sgemm_tt_blocked};
use crate::traits::GemmDispatch;
use crate::types::{MatMut, MatRef, Transpose};
use crate::l3::gemm_impl::{
nn_direct::{sgemm_nn, dgemm_nn},
nn_blocked::{sgemm_nn_blocked, dgemm_nn_blocked},
};
const SGEMM_NN_BLOCKED_THRESHOLD: usize = 96;
const DGEMM_NN_BLOCKED_THRESHOLD: usize = 48;
const SGEMM_NT_BLOCKED_THRESHOLD: usize = 256;
const DGEMM_NT_BLOCKED_THRESHOLD: usize = 128;
#[inline(always)]
pub fn sgemm(
atrans: Transpose,
btrans: Transpose,
alpha: f32,
beta: f32,
a: MatRef<'_, f32>,
b: MatRef<'_, f32>,
c: MatMut<'_, f32>,
) {
match (atrans, btrans) {
(Transpose::NoTranspose, Transpose::NoTranspose) => {
if c.n_rows() > SGEMM_NT_BLOCKED_THRESHOLD {
sgemm_nn_blocked(alpha, beta, a, b, c);
} else {
sgemm_nn(alpha, beta, a, b, c);
}
},
(Transpose::NoTranspose, Transpose::Transpose) => {
if c.n_rows() > SGEMM_NN_BLOCKED_THRESHOLD {
sgemm_nt_blocked(alpha, beta, a, b, c);
} else {
sgemm_nt(alpha, beta, a, b, c);
}
},
(Transpose::Transpose, Transpose::NoTranspose) => {
sgemm_tn_blocked(alpha, beta, a, b, c);
},
(Transpose::Transpose, Transpose::Transpose) => {
sgemm_tt_blocked(alpha, beta, a, b, c);
}
}
}
#[inline(always)]
pub fn dgemm(
atrans: Transpose,
btrans: Transpose,
alpha: f64,
beta: f64,
a: MatRef<'_, f64>,
b: MatRef<'_, f64>,
c: MatMut<'_, f64>,
) {
match (atrans, btrans) {
(Transpose::NoTranspose, Transpose::NoTranspose) => {
if c.n_rows() > DGEMM_NT_BLOCKED_THRESHOLD {
dgemm_nn_blocked(alpha, beta, a, b, c);
} else {
dgemm_nn(alpha, beta, a, b, c);
}
},
(Transpose::NoTranspose, Transpose::Transpose) => {
if c.n_rows() > DGEMM_NN_BLOCKED_THRESHOLD {
dgemm_nt_blocked(alpha, beta, a, b, c);
} else {
dgemm_nt(alpha, beta, a, b, c);
}
},
(Transpose::Transpose, Transpose::NoTranspose) => {
dgemm_tn_blocked(alpha, beta, a, b, c);
},
(Transpose::Transpose, Transpose::Transpose) => {
dgemm_tt_blocked(alpha, beta, a, b, c);
}
}
}
pub fn gemm<T>(
atrans: Transpose,
btrans: Transpose,
alpha: T,
beta: T,
a: MatRef<'_, T>,
b: MatRef<'_, T>,
c: MatMut<'_, T>,
)
where T: GemmDispatch
{
T::gemm(atrans, btrans, alpha, beta, a, b, c);
}