redstone-ml 0.0.0

High-performance Machine Learning, Auto-Differentiation and Tensor Algebra crate for Rust
Documentation
use std::fmt::Display;
use paste::paste;
use crate::acceleration::simd::Simd;

macro_rules! simd_elementwise_operations {
    ($name:ident, $simd_op:ident, $operator:tt) => {
        paste! {
            #[cfg(neon_simd)]
            unsafe fn [<simd_ $name _stride_0_1>](lhs: *const Self, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
                let a = Self::simd_from_constant(*lhs);

                while count >= 4 * Self::LANES {
                    let b0 = Self::simd_load(rhs.add(0 * Self::LANES));
                    let b1 = Self::simd_load(rhs.add(1 * Self::LANES));
                    let b2 = Self::simd_load(rhs.add(2 * Self::LANES));
                    let b3 = Self::simd_load(rhs.add(3 * Self::LANES));

                    let ab0 = Self::$simd_op(a, b0);
                    let ab1 = Self::$simd_op(a, b1);
                    let ab2 = Self::$simd_op(a, b2);
                    let ab3 = Self::$simd_op(a, b3);

                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
                    Self::simd_store(dst.add(3 * Self::LANES), ab3);

                    count -= 4 * Self::LANES;
                    rhs = rhs.add(4 * Self::LANES);
                    dst = dst.add(4 * Self::LANES);
                }

                while count != 0 {
                    *dst = *lhs $operator *rhs;

                    count -= 1;
                    rhs = rhs.add(1);
                    dst = dst.add(1);
                }
            }

            #[cfg(neon_simd)]
            unsafe fn [<simd_ $name _stride_1_0>](mut lhs: *const Self, rhs: *const Self, mut dst: *mut Self, mut count: usize) {
                let b = Self::simd_from_constant(*rhs);

                while count >= 4 * Self::LANES {
                    let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
                    let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
                    let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
                    let a3 = Self::simd_load(lhs.add(3 * Self::LANES));

                    let ab0 = Self::$simd_op(a0, b);
                    let ab1 = Self::$simd_op(a1, b);
                    let ab2 = Self::$simd_op(a2, b);
                    let ab3 = Self::$simd_op(a3, b);

                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
                    Self::simd_store(dst.add(3 * Self::LANES), ab3);

                    count -= 4 * Self::LANES;
                    lhs = lhs.add(4 * Self::LANES);
                    dst = dst.add(4 * Self::LANES);
                }

                while count != 0 {
                    *dst = *lhs $operator *rhs;

                    count -= 1;
                    lhs = lhs.add(1);
                    dst = dst.add(1);
                }
            }

            #[cfg(neon_simd)]
            unsafe fn [<simd_ $name _stride_1_1>](mut lhs: *const Self, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
                while count >= 4 * Self::LANES {
                    let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
                    let b0 = Self::simd_load(rhs.add(0 * Self::LANES));

                    let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
                    let b1 = Self::simd_load(rhs.add(1 * Self::LANES));

                    let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
                    let b2 = Self::simd_load(rhs.add(2 * Self::LANES));

                    let a3 = Self::simd_load(lhs.add(3 * Self::LANES));
                    let b3 = Self::simd_load(rhs.add(3 * Self::LANES));

                    let ab0 = Self::$simd_op(a0, b0);
                    let ab1 = Self::$simd_op(a1, b1);
                    let ab2 = Self::$simd_op(a2, b2);
                    let ab3 = Self::$simd_op(a3, b3);

                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
                    Self::simd_store(dst.add(3 * Self::LANES), ab3);

                    count -= 4 * Self::LANES;
                    lhs = lhs.add(4 * Self::LANES);
                    rhs = rhs.add(4 * Self::LANES);
                    dst = dst.add(4 * Self::LANES);
                }

                while count != 0 {
                    *dst = *lhs $operator *rhs;

                    count -= 1;
                    lhs = lhs.add(1);
                    rhs = rhs.add(1);
                    dst = dst.add(1);
                }
            }

            #[cfg(neon_simd)]
            unsafe fn [<simd_ $name _stride_n_0>](mut lhs: *const Self, lhs_stride: usize, rhs: *const Self, mut dst: *mut Self, mut count: usize) {
                 let b = Self::simd_from_constant(*rhs);

                 while count >= 4 * Self::LANES {
                    let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
                    let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
                    let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
                    let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);

                    let ab0 = Self::$simd_op(a0, b);
                    let ab1 = Self::$simd_op(a1, b);
                    let ab2 = Self::$simd_op(a2, b);
                    let ab3 = Self::$simd_op(a3, b);

                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
                    Self::simd_store(dst.add(3 * Self::LANES), ab3);

                    count -= 4 * Self::LANES;
                    lhs = lhs.add(4 * lhs_stride * Self::LANES);
                    dst = dst.add(4 * Self::LANES);
                }

                while count != 0 {
                    *dst = *lhs $operator *rhs;

                    count -= 1;
                    lhs = lhs.add(lhs_stride);
                    dst = dst.add(1);
                }
            }

            #[cfg(neon_simd)]
            unsafe fn [<simd_ $name _stride_0_n>](lhs: *const Self, mut rhs: *const Self, rhs_stride: usize, mut dst: *mut Self, mut count: usize) {
                 let a = Self::simd_from_constant(*lhs);

                 while count >= 4 * Self::LANES {
                    let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);
                    let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);
                    let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);
                    let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);

                    let ab0 = Self::$simd_op(a, b0);
                    let ab1 = Self::$simd_op(a, b1);
                    let ab2 = Self::$simd_op(a, b2);
                    let ab3 = Self::$simd_op(a, b3);

                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
                    Self::simd_store(dst.add(3 * Self::LANES), ab3);

                    count -= 4 * Self::LANES;
                    rhs = rhs.add(4 * rhs_stride * Self::LANES);
                    dst = dst.add(4 * Self::LANES);
                }

                while count != 0 {
                    *dst = *lhs $operator *rhs;

                    count -= 1;
                    rhs = rhs.add(rhs_stride);
                    dst = dst.add(1);
                }
            }

            #[cfg(neon_simd)]
            unsafe fn [<simd_ $name _stride_n_1>](mut lhs: *const Self, lhs_stride: usize, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
                while count >= 4 * Self::LANES {
                    let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
                    let b0 = Self::simd_load(rhs.add(0 * Self::LANES));

                    let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
                    let b1 = Self::simd_load(rhs.add(1 * Self::LANES));

                    let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
                    let b2 = Self::simd_load(rhs.add(2 * Self::LANES));

                    let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);
                    let b3 = Self::simd_load(rhs.add(3 * Self::LANES));

                    let ab0 = Self::$simd_op(a0, b0);
                    let ab1 = Self::$simd_op(a1, b1);
                    let ab2 = Self::$simd_op(a2, b2);
                    let ab3 = Self::$simd_op(a3, b3);

                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
                    Self::simd_store(dst.add(3 * Self::LANES), ab3);

                    count -= 4 * Self::LANES;
                    lhs = lhs.add(4 * lhs_stride * Self::LANES);
                    rhs = rhs.add(4 * Self::LANES);
                    dst = dst.add(4 * Self::LANES);
                }

                while count != 0 {
                    *dst = *lhs $operator *rhs;

                    count -= 1;
                    lhs = lhs.add(lhs_stride);
                    rhs = rhs.add(1);
                    dst = dst.add(1);
                }
            }

            #[cfg(neon_simd)]
            unsafe fn [<simd_ $name _stride_1_n>](mut lhs: *const Self, mut rhs: *const Self, rhs_stride: usize, mut dst: *mut Self, mut count: usize) {
                while count >= 4 * Self::LANES {
                    let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
                    let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);

                    let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
                    let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);

                    let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
                    let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);

                    let a3 = Self::simd_load(lhs.add(3 * Self::LANES));
                    let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);

                    let ab0 = Self::$simd_op(a0, b0);
                    let ab1 = Self::$simd_op(a1, b1);
                    let ab2 = Self::$simd_op(a2, b2);
                    let ab3 = Self::$simd_op(a3, b3);

                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
                    Self::simd_store(dst.add(3 * Self::LANES), ab3);

                    count -= 4 * Self::LANES;
                    lhs = lhs.add(4 * Self::LANES);
                    rhs = rhs.add(4 * rhs_stride * Self::LANES);
                    dst = dst.add(4 * Self::LANES);
                }

                while count != 0 {
                    *dst = *lhs $operator *rhs;

                    count -= 1;
                    lhs = lhs.add(1);
                    rhs = rhs.add(rhs_stride);
                    dst = dst.add(1);
                }
            }

            #[cfg(neon_simd)]
            unsafe fn [<simd_ $name _stride_n_n>](mut lhs: *const Self, lhs_stride: usize,
                                                  mut rhs: *const Self, rhs_stride: usize,
                                                  mut dst: *mut Self, mut count: usize) {
                while count >= 4 * Self::LANES {
                    let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
                    let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);

                    let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
                    let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);

                    let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
                    let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);

                    let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);
                    let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);

                    let ab0 = Self::$simd_op(a0, b0);
                    let ab1 = Self::$simd_op(a1, b1);
                    let ab2 = Self::$simd_op(a2, b2);
                    let ab3 = Self::$simd_op(a3, b3);

                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
                    Self::simd_store(dst.add(3 * Self::LANES), ab3);

                    count -= 4 * Self::LANES;
                    lhs = lhs.add(4 * lhs_stride * Self::LANES);
                    rhs = rhs.add(4 * rhs_stride * Self::LANES);
                    dst = dst.add(4 * Self::LANES);
                }

                while count != 0 {
                    *dst = *lhs $operator *rhs;

                    count -= 1;
                    lhs = lhs.add(lhs_stride);
                    rhs = rhs.add(rhs_stride);
                    dst = dst.add(1);
                }
            }
        }
    };
}

pub(crate) trait SimdBinaryOps: Simd + Display {
    simd_elementwise_operations!(add, simd_add, +);
    simd_elementwise_operations!(sub, simd_sub, -);
    simd_elementwise_operations!(mul, simd_mul, *);
    simd_elementwise_operations!(div, simd_div, /);
}

impl<T: Simd + Display> SimdBinaryOps for T {}

#[macro_export]
macro_rules! simd_binary_op_specializations {
    ($name: ident) => {
        paste! {
            #[cfg(neon_simd)]
            unsafe fn [<$name _stride_0_1>](lhs: *const Self, rhs: *const Self,
                                            dst: *mut Self, count: usize) {
                use $crate::ops::simd_binary_ops::SimdBinaryOps;
                Self::[<simd_ $name _stride_0_1>](lhs, rhs, dst, count);
            }

            #[cfg(neon_simd)]
            unsafe fn [<$name _stride_1_0>](lhs: *const Self, rhs: *const Self,
                                            dst: *mut Self, count: usize) {
                use $crate::ops::simd_binary_ops::SimdBinaryOps;
                Self::[<simd_ $name _stride_1_0>](lhs, rhs, dst, count);
            }

            #[cfg(neon_simd)]
            unsafe fn [<$name _stride_0_n>](lhs: *const Self,
                                            rhs: *const Self, rhs_stride: usize,
                                            dst: *mut Self, count: usize) {
                use $crate::ops::simd_binary_ops::SimdBinaryOps;
                Self::[<simd_ $name _stride_0_n>](lhs, rhs, rhs_stride, dst, count);
            }

            #[cfg(neon_simd)]
            unsafe fn [<$name _stride_n_0>](lhs: *const Self, lhs_stride: usize,
                                            rhs: *const Self,
                                            dst: *mut Self, count: usize) {
                use $crate::ops::simd_binary_ops::SimdBinaryOps;
                Self::[<simd_ $name _stride_n_0>](lhs, lhs_stride, rhs, dst, count);
            }

            #[cfg(neon_simd)]
            unsafe fn [<$name _stride_1_1>](lhs: *const Self, rhs: *const Self,
                                            dst: *mut Self, count: usize) {
                use $crate::ops::simd_binary_ops::SimdBinaryOps;
                Self::[<simd_ $name _stride_1_1>](lhs, rhs, dst, count);
            }

            #[cfg(neon_simd)]
            unsafe fn [<$name _stride_1_n>](lhs: *const Self,
                                            rhs: *const Self, rhs_stride: usize,
                                            dst: *mut Self, count: usize) {
                use $crate::ops::simd_binary_ops::SimdBinaryOps;
                Self::[<simd_ $name _stride_1_n>](lhs, rhs, rhs_stride, dst, count);
            }

            #[cfg(neon_simd)]
            unsafe fn [<$name _stride_n_1>](lhs: *const Self, lhs_stride: usize,
                                            rhs: *const Self,
                                            dst: *mut Self, count: usize) {
                use $crate::ops::simd_binary_ops::SimdBinaryOps;
                Self::[<simd_ $name _stride_n_1>](lhs, lhs_stride, rhs, dst, count);
            }

            #[cfg(neon_simd)]
            unsafe fn [<$name _stride_n_n>](lhs: *const Self, lhs_stride: usize,
                                            rhs: *const Self, rhs_stride: usize,
                                            dst: *mut Self, count: usize) {
                use $crate::ops::simd_binary_ops::SimdBinaryOps;
                Self::[<simd_ $name _stride_n_n>](lhs, lhs_stride, rhs, rhs_stride, dst, count);
            }
        }
    };
}