lak-kernels 0.1.0

BLAS-like linear algebra kernels in fully-safe Rust.
// dot.rs

use std::ops::{Add, AddAssign, Mul}; 
use std::simd::{Simd, SimdElement};
use std::simd::num::SimdFloat; 

use crate::traits::Fma; 
use crate::types::VecRef; 
use crate::assert_length_eq; 

pub(crate) const LANES: usize = 32;

/// the dot product 
///
/// x dot y 
///
/// args: 
/// * x: [VecRef] 
/// * y: [VecRef] 
///
/// returns: 
/// * T - dot product 
pub fn dot<T>( 
    x: VecRef<'_, T>, 
    y: VecRef<'_, T>, 
) -> T 
where
    T: SimdElement 
        + Copy
        + Default
        + AddAssign
        + Mul<Output=T> 
        + Add<Output=T>
        + Fma, 

    Simd<T, LANES>: SimdFloat<Scalar=T> 
        + AddAssign
        + Fma, 
{ 
    assert_length_eq!(x, y); 

    let x_slice = x.as_slice(); 
    let y_slice = y.as_slice(); 

    let (x_chunks, x_tail) = x_slice.as_chunks::<LANES>();  
    let (y_chunks, y_tail) = y_slice.as_chunks::<LANES>(); 

    let mut accumulator = Simd::<T, LANES>::splat(T::default()); 
    for (&x_chunk, &y_chunk) in x_chunks.iter().zip(y_chunks.iter()) { 
        let x_vec = Simd::from_array(x_chunk); 
        let y_vec = Simd::from_array(y_chunk); 

        accumulator = x_vec.fma(y_vec, accumulator); 
    }

    let mut sum = T::default(); 
    for (&xt, &yt) in x_tail.iter().zip(y_tail.iter()) { 
        sum = xt.fma(yt, sum); 
    }

    accumulator.reduce_sum() + sum 
}