lak-kernels 0.1.1

BLAS-like linear algebra kernels in fully-safe Rust.
Documentation
// nn_direct.rs 

use crate::l1::scal;
use crate::types::{MatRef, MatMut}; 
use crate::l3::gemm_impl::nn_direct_microkernel::{
    sgemm_nn_micro, 
    dgemm_nn_micro, 
}; 

pub(crate) const NC_F32: usize = 128; 
pub(crate) const NC_F64: usize = 64; 
pub(crate) const KC_F32: usize = 128; 
pub(crate) const KC_F64: usize = 64; 

#[inline(always)]
pub(crate) fn sgemm_nn( 
    alpha: f32, 
    beta: f32, 
    a: MatRef<'_, f32>, 
    b: MatRef<'_, f32>, 
    mut c: MatMut<'_, f32>, 
) { 
    let (_, k_a) = a.dimension(); 
    let (k_b, n) = b.dimension(); 
    assert_eq!(k_a, k_b, "A and B inner dimension must match"); 

    let k = k_a; 

    let n_nc = n.div_ceil(NC_F32); 
    let n_kc = k.div_ceil(KC_F32); 

    for nc_idx in 0..n_nc { 
        let j0 = nc_idx * NC_F32; 
        let j1 = (j0 + NC_F32).min(n); 

        let mut c_panel = c.col_panel_mut(j0..j1);
        let b_panel = b.col_panel(j0..j1); 

        if beta != 0.0 { 
            for j in 0..c_panel.n_cols() { 
                scal(beta, c_panel.col_mut(j));
            }
        }

        for kc_idx in 0..n_kc { 
            let kc_beg = kc_idx * KC_F32; 
            let kc_end = (kc_beg + KC_F32).min(k); 

            let a_panel = a.col_panel(kc_beg..kc_end); 

            sgemm_nn_micro(
                alpha, 
                a_panel,
                b_panel,
                c_panel.reborrow(),
                kc_beg
            );
        }
    }
}

#[inline(always)]
pub(crate) fn dgemm_nn( 
    alpha: f64, 
    beta: f64, 
    a: MatRef<'_, f64>, 
    b: MatRef<'_, f64>, 
    mut c: MatMut<'_, f64>, 
) { 
    let (_, k_a) = a.dimension(); 
    let (k_b, n) = b.dimension(); 
    assert_eq!(k_a, k_b, "A and B inner dimension must match"); 

    let k = k_a; 

    let n_nc = n.div_ceil(NC_F64); 
    let n_kc = k.div_ceil(KC_F64); 

    for nc_idx in 0..n_nc { 
        let j0 = nc_idx * NC_F64; 
        let j1 = (j0 + NC_F64).min(n); 

        let mut c_panel = c.col_panel_mut(j0..j1);
        let b_panel = b.col_panel(j0..j1); 

        for j in 0..c_panel.n_cols() { 
            scal(beta, c_panel.col_mut(j));
        }

        for kc_idx in 0..n_kc { 
            let kc_beg = kc_idx * KC_F64; 
            let kc_end = (kc_beg + KC_F64).min(k); 

            let a_panel = a.col_panel(kc_beg..kc_end); 

            dgemm_nn_micro(
                alpha, 
                a_panel,
                b_panel,
                c_panel.reborrow(),
                kc_beg
            );
        }
    }
}