yarnn 0.1.0

Yet Another rust Neural Network framework
Documentation

fn gemm_nn(m: usize, n: usize, k: usize, alpha: f32, 
           a: &[f32], lda: usize, 
           b: &[f32], ldb: usize,
           c: &mut [f32], ldc: usize)
{
    let a = &a[0..m * k];
    let b = &b[0..n * k];
    let c = &mut c[0..m * n];

    for i_m in 0..m {
        for i_k in 0..k {
            let a_part = alpha * a[i_m * lda + i_k];
            for i_n in 0..n {
                c[i_m * ldc + i_n] += a_part * b[i_k * ldb + i_n];
            }
        }
    }
}

fn gemm_nt(m: usize, n: usize, k: usize, alpha: f32, 
           a: &[f32], lda: usize, 
           b: &[f32], ldb: usize,
           c: &mut [f32], ldc: usize)
{
    let a = &a[0..m * k];
    let b = &b[0..n * k];
    let c = &mut c[0..m * n];

    for i_m in 0..m {
        for i_n in 0..n {
            let mut sum = 0.0;

            for i_k in 0..k {
                sum += alpha * a[i_m * lda + i_k] * b[i_n * ldb + i_k];
            }

            c[i_m * ldc + i_n] += sum;
        }
    }
}

fn gemm_tn(m: usize, n: usize, k: usize, alpha: f32, 
           a: &[f32], lda: usize, 
           b: &[f32], ldb: usize,
           c: &mut [f32], ldc: usize)
{    
    let a = &a[0..m * k];
    let b = &b[0..n * k];
    let c = &mut c[0..m * n];

    for i_m in 0..m {
        for i_k in 0..k {
            let a_part = alpha * a[i_k * lda + i_m];

            for i_n in 0..n {
                c[i_m * ldc + i_n] += a_part * b[i_k * ldb + i_n];
            }
        }
    }
}

fn gemm_tt(m: usize, n: usize, k: usize, alpha: f32, 
           a: &[f32], lda: usize, 
           b: &[f32], ldb: usize,
           c: &mut [f32], ldc: usize)
{
    let a = &a[0..m * k];
    let b = &b[0..n * k];
    let c = &mut c[0..m * n];
    
    for i_m in 0..m {
        for i_n in 0..n {
            let mut sum = 0.0;
            
            for i_k in 0..k {
                sum += alpha * a[i_k * lda + i_m] * b[i_n * ldb + i_k];
            }

            c[i_m * ldc + i_n] += sum;
        }
    }
}

pub fn gemm(ta: bool, tb: bool, m: usize, n: usize, k: usize, alpha: f32, 
        a: &[f32], lda: usize, 
        b: &[f32], ldb: usize, beta: f32,
        c: &mut [f32], ldc: usize)
{
    for i in 0..m {
        for j in 0..n {
            c[i * ldc + j] *= beta;
        }
    }

    if !ta && !tb {
        gemm_nn(m, n, k, alpha, a, lda, b, ldb, c, ldc);
    } else if ta && !tb {
        gemm_tn(m, n, k, alpha, a, lda, b, ldb, c, ldc);
    } else if !ta && tb {
        gemm_nt(m, n, k, alpha, a, lda, b, ldb, c, ldc);
    } else {
        gemm_tt(m, n, k, alpha, a, lda, b, ldb, c, ldc);
    }
}