pub mod backend_selection;
pub mod compute;
pub mod elementwise;
pub mod gemv;
pub mod jidoka;
pub mod microkernels;
pub mod norms;
pub mod packing;
pub mod parallel;
pub mod prepacked;
pub mod profiler;
pub mod reference;
pub mod softmax;
pub mod transpose;
pub use jidoka::{JidokaError, JidokaGuard};
pub use profiler::{BlisLevelStats, BlisProfileLevel, BlisProfiler, KaizenMetrics};
#[cfg(target_arch = "aarch64")]
pub use microkernels::microkernel_8x8_neon;
pub use microkernels::microkernel_scalar;
#[cfg(target_arch = "x86_64")]
pub use microkernels::{microkernel_8x6_avx2, microkernel_8x6_avx2_asm, microkernel_8x6_true_asm};
pub use backend_selection::{
gemm_auto, BackendCostModel, BrickLevel, ComputeBackend, PtxMicrokernelSpec, RooflineResult,
UnifiedBrickProfiler, WgslMicrokernelSpec,
};
pub use reference::{gemm_reference, gemm_reference_with_jidoka};
pub use packing::{pack_a, pack_b, packed_a_size, packed_b_size};
pub use compute::{gemm_blis, gemm_blis_with_prepacked_b};
pub use parallel::{gemm_blis_parallel, gemm_blis_parallel_with_prepacked_b, HeijunkaScheduler};
pub use prepacked::PrepackedB;
pub use transpose::transpose;
use crate::error::TruenoError;
pub const MR: usize = 8;
pub const NR: usize = 6;
pub const KC: usize = 256;
pub const MC: usize = 72;
pub const NC: usize = 4096;
pub fn gemm(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
contract_pre_matmul!(a);
let result = {
#[cfg(feature = "parallel")]
{
gemm_blis_parallel(m, n, k, a, b, c)
}
#[cfg(not(feature = "parallel"))]
{
gemm_blis(m, n, k, a, b, c, None)
}
};
if result.is_ok() {
contract_post_matmul!(c);
}
result
}
pub fn gemm_profiled(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
profiler: &mut BlisProfiler,
) -> Result<(), TruenoError> {
gemm_blis(m, n, k, a, b, c, Some(profiler))
}
#[cfg(test)]
mod tests;