lak-kernels 0.1.0

BLAS-like linear algebra kernels in fully-safe Rust.
// nn_direct_microkernel.rs 

use std::simd::Simd; 
use crate::fused::faxpy::faxpy;
use crate::l1::axpy;
use crate::traits::Fma; 
use crate::types::{MatMut, MatRef, VecMut, VecRef}; 

// tunable 
pub(crate) const MR_F32: usize = 16; 
pub(crate) const MR_F64: usize = 8; 

// not tunable 
pub(crate) const NR_F32: usize = 4; 
pub(crate) const NR_F64: usize = 4; 

#[inline(always)]
pub(crate) fn sgemm_nn_micro( 
    alpha: f32, 
    a_panel: MatRef<'_, f32>, 
    b_panel: MatRef<'_, f32>, 
    mut c_panel: MatMut<'_, f32>, 
    kc_beg: usize, 
) { 
    let (m, kc) = a_panel.dimension(); 
    let (k, nc) = b_panel.dimension(); 
    debug_assert_eq!(m,  c_panel.n_rows()); 
    debug_assert_eq!(nc, c_panel.n_cols()); 
    debug_assert!(kc_beg + kc <= k);

    for j in (0..nc).step_by(NR_F32) {
        let nr = (nc - j).min(NR_F32); 

        // full NR wide panel 
        if nr == NR_F32 { 

            skernel_mrnr(
                alpha, 
                a_panel, 
                b_panel, 
                c_panel.reborrow(), 
                kc_beg, 
                kc, 
                j,
                m, 
                k
            );

        } else {
            // leftover cols 
            let j0 = j; 
            let j1 = nc; 
            let b_slice = b_panel.as_slice(); 

            for jj in j0..j1 { 
                let ccol = c_panel.col_mut(jj); 
                let bbeg = jj * k + kc_beg; 
                let bcol = VecRef::new(&b_slice[bbeg..bbeg + kc]); 

                faxpy(
                    alpha, 
                    a_panel, 
                    bcol, 
                    ccol
                );      
            }
        }
    }
} 



#[inline(always)]
pub(crate) fn dgemm_nn_micro( 
    alpha: f64, 
    a_panel: MatRef<'_, f64>, 
    b_panel: MatRef<'_, f64>, 
    mut c_panel: MatMut<'_, f64>, 
    kc_beg: usize, 
) { 
    let (m, kc) = a_panel.dimension(); 
    let (k, nc) = b_panel.dimension(); 
    debug_assert_eq!(m,  c_panel.n_rows()); 
    debug_assert_eq!(nc, c_panel.n_cols()); 
    debug_assert!(kc_beg + kc <= k);

    for j in (0..nc).step_by(NR_F64) {
        let nr = (nc - j).min(NR_F64); 

        // full NR wide panel 
        if nr == NR_F64 { 

            dkernel_mrnr(alpha, 
                a_panel, 
                b_panel, 
                c_panel.reborrow(), 
                kc_beg,
                kc, 
                j,
                m,
                k
            );

        } else {
            // leftover cols 
            let j0 = j; 
            let j1 = nc; 
            let b_slice = b_panel.as_slice(); 

            for jj in j0..j1 { 
                let ccol = c_panel.col_mut(jj); 
                let bbeg = jj * k + kc_beg; 
                let bcol = VecRef::new(&b_slice[bbeg..bbeg + kc]); 

                faxpy(
                    alpha, 
                    a_panel, 
                    bcol, 
                    ccol
                );
            }
        }
    }
}


fn skernel_mrnr( 
    alpha: f32, 
    a_panel: MatRef<'_, f32>, 
    b_panel: MatRef<'_, f32>, 
    mut c_panel: MatMut<'_, f32>, 
    kc_beg: usize, 
    kc: usize, 
    j: usize, 
    m: usize, 
    k: usize, 
) { 
    debug_assert!(NR_F32 == 4); 

    let c_slice  = c_panel.as_slice_mut();
    let c_base   = j * m; 
    let c_block  = &mut c_slice[c_base..c_base + NR_F32 * m]; 
    let (cl, cr) = c_block.split_at_mut(m * 2); 
    let (c0, c1) = cl.split_at_mut(m); 
    let (c2, c3) = cr.split_at_mut(m);  

    let (c0_chunks, c0_tail) = c0.as_chunks_mut::<MR_F32>();
    let (c1_chunks, c1_tail) = c1.as_chunks_mut::<MR_F32>(); 
    let (c2_chunks, c2_tail) = c2.as_chunks_mut::<MR_F32>();
    let (c3_chunks, c3_tail) = c3.as_chunks_mut::<MR_F32>(); 
    let n_chunks = c0_chunks.len(); 

    let a_slice = a_panel.as_slice(); 
    let b_slice = b_panel.as_slice(); 

    for chunk_idx in 0..n_chunks { 
        let mut c0vec = Simd::from_array(c0_chunks[chunk_idx]);
        let mut c1vec = Simd::from_array(c1_chunks[chunk_idx]); 
        let mut c2vec = Simd::from_array(c2_chunks[chunk_idx]); 
        let mut c3vec = Simd::from_array(c3_chunks[chunk_idx]); 

        for kk in 0..kc { 
            let row_beg = chunk_idx * MR_F32; 
            let col_beg = kk * m; 

            let acol = &a_slice[col_beg + row_beg..col_beg + row_beg + MR_F32]; 
            let avec = Simd::<f32, MR_F32>::from_slice(acol);

            let keff = kc_beg + kk; 
            let b0 = b_slice[j * k + keff]; 
            let b1 = b_slice[(j + 1) * k + keff]; 
            let b2 = b_slice[(j + 2) * k + keff]; 
            let b3 = b_slice[(j + 3) * k + keff]; 
            let b0vec = Simd::<f32, MR_F32>::splat(alpha * b0); 
            let b1vec = Simd::<f32, MR_F32>::splat(alpha * b1); 
            let b2vec = Simd::<f32, MR_F32>::splat(alpha * b2); 
            let b3vec = Simd::<f32, MR_F32>::splat(alpha * b3); 

            c0vec = b0vec.fma(avec, c0vec); 
            c1vec = b1vec.fma(avec, c1vec); 
            c2vec = b2vec.fma(avec, c2vec); 
            c3vec = b3vec.fma(avec, c3vec); 
        }

        c0vec.copy_to_slice(&mut c0_chunks[chunk_idx]);
        c1vec.copy_to_slice(&mut c1_chunks[chunk_idx]); 
        c2vec.copy_to_slice(&mut c2_chunks[chunk_idx]); 
        c3vec.copy_to_slice(&mut c3_chunks[chunk_idx]); 
    }

    // leftover tail
    let tail_len = c0_tail.len(); 
    if tail_len != 0 { 
        let mr_tail_idx = n_chunks * MR_F32; 

        let mut c0 = VecMut::new(c0_tail); 
        let mut c1 = VecMut::new(c1_tail); 
        let mut c2 = VecMut::new(c2_tail); 
        let mut c3 = VecMut::new(c3_tail); 

        for kk in 0..kc { 
            let row_beg = mr_tail_idx; 
            let col_beg = kk * m; 

            let acol = VecRef::new(
                &a_slice[col_beg + row_beg..col_beg + row_beg + tail_len]
            ); 

            let keff = kc_beg + kk;
            let b0 = b_slice[j * k + keff]; 
            let b1 = b_slice[(j + 1) * k + keff]; 
            let b2 = b_slice[(j + 2) * k + keff]; 
            let b3 = b_slice[(j + 3) * k + keff]; 

            axpy(alpha * b0, acol, c0.reborrow());
            axpy(alpha * b1, acol, c1.reborrow()); 
            axpy(alpha * b2, acol, c2.reborrow()); 
            axpy(alpha * b3, acol, c3.reborrow()); 
        }   
    }
}

fn dkernel_mrnr( 
    alpha: f64, 
    a_panel: MatRef<'_, f64>, 
    b_panel: MatRef<'_, f64>, 
    mut c_panel: MatMut<'_, f64>, 
    kc_beg: usize, 
    kc: usize, 
    j: usize, 
    m: usize, 
    k: usize, 
) { 
    debug_assert!(NR_F64 == 4); 

    let c_slice  = c_panel.as_slice_mut();
    let c_base   = j * m; 
    let c_block  = &mut c_slice[c_base..c_base + NR_F64 * m]; 
    let (cl, cr) = c_block.split_at_mut(m * 2); 
    let (c0, c1) = cl.split_at_mut(m); 
    let (c2, c3) = cr.split_at_mut(m);  

    let (c0_chunks, c0_tail) = c0.as_chunks_mut::<MR_F64>();
    let (c1_chunks, c1_tail) = c1.as_chunks_mut::<MR_F64>(); 
    let (c2_chunks, c2_tail) = c2.as_chunks_mut::<MR_F64>();
    let (c3_chunks, c3_tail) = c3.as_chunks_mut::<MR_F64>(); 
    let n_chunks = c0_chunks.len(); 

    let a_slice = a_panel.as_slice(); 
    let b_slice = b_panel.as_slice(); 

    for chunk_idx in 0..n_chunks { 
        let mut c0vec = Simd::from_array(c0_chunks[chunk_idx]);
        let mut c1vec = Simd::from_array(c1_chunks[chunk_idx]); 
        let mut c2vec = Simd::from_array(c2_chunks[chunk_idx]); 
        let mut c3vec = Simd::from_array(c3_chunks[chunk_idx]); 

        for kk in 0..kc { 
            let row_beg = chunk_idx * MR_F64; 
            let col_beg = kk * m; 

            let acol = &a_slice[col_beg + row_beg..col_beg + row_beg + MR_F64]; 
            let avec = Simd::<f64, MR_F64>::from_slice(acol);

            let keff = kc_beg + kk; 
            let b0 = b_slice[j * k + keff]; 
            let b1 = b_slice[(j + 1) * k + keff]; 
            let b2 = b_slice[(j + 2) * k + keff]; 
            let b3 = b_slice[(j + 3) * k + keff]; 
            let b0vec = Simd::<f64, MR_F64>::splat(alpha * b0); 
            let b1vec = Simd::<f64, MR_F64>::splat(alpha * b1); 
            let b2vec = Simd::<f64, MR_F64>::splat(alpha * b2); 
            let b3vec = Simd::<f64, MR_F64>::splat(alpha * b3); 

            c0vec = b0vec.fma(avec, c0vec); 
            c1vec = b1vec.fma(avec, c1vec); 
            c2vec = b2vec.fma(avec, c2vec); 
            c3vec = b3vec.fma(avec, c3vec); 
        }

        c0vec.copy_to_slice(&mut c0_chunks[chunk_idx]);
        c1vec.copy_to_slice(&mut c1_chunks[chunk_idx]); 
        c2vec.copy_to_slice(&mut c2_chunks[chunk_idx]); 
        c3vec.copy_to_slice(&mut c3_chunks[chunk_idx]); 
    }

    // leftover tail
    let tail_len = c0_tail.len(); 
    if tail_len != 0 { 
        let mr_tail_idx = n_chunks * MR_F64; 

        let mut c0 = VecMut::new(c0_tail); 
        let mut c1 = VecMut::new(c1_tail); 
        let mut c2 = VecMut::new(c2_tail); 
        let mut c3 = VecMut::new(c3_tail); 

        for kk in 0..kc { 
            let row_beg = mr_tail_idx; 
            let col_beg = kk * m; 

            let acol = VecRef::new(
                &a_slice[col_beg + row_beg..col_beg + row_beg + tail_len]
            ); 

            let keff = kc_beg + kk;
            let b0 = b_slice[j * k + keff]; 
            let b1 = b_slice[(j + 1) * k + keff]; 
            let b2 = b_slice[(j + 2) * k + keff]; 
            let b3 = b_slice[(j + 3) * k + keff]; 

            axpy(alpha * b0, acol, c0.reborrow());
            axpy(alpha * b1, acol, c1.reborrow()); 
            axpy(alpha * b2, acol, c2.reborrow()); 
            axpy(alpha * b3, acol, c3.reborrow()); 
        }   
    }
}