redstone-ml 0.0.0

High-performance Machine Learning, Auto-Differentiation and Tensor Algebra crate for Rust
Documentation
use crate::iterator::multi_flat_index_generator::MultiFlatIndexGenerator;
use crate::linalg::sum_of_products::*;
use crate::ndarray::{MAX_ARGS, MAX_DIMS};
use crate::{first_n_elements, NdArray};
use std::hint::assert_unchecked;
use std::ptr::null_mut;

pub(super) unsafe fn unspecialized_einsum_loop<T: SumOfProductsType>(operands: &[&NdArray<T>],
                                                                     strides: &[[usize; MAX_ARGS]; MAX_DIMS],
                                                                     iter_ndims: usize,
                                                                     iter_shape: &[usize],
                                                                     dst: *mut T) {
    let n_operands = operands.len();

    let strides = &strides[0..iter_ndims];
    let inner_stride = &strides[0][..n_operands + 1];
    let mut indices_iter = MultiFlatIndexGenerator::from(n_operands + 1, &iter_shape[1..], &strides[1..]);

    let sum_of_products = get_sum_of_products_function_generic_nops(inner_stride);

    let mut base_ptrs = [null_mut(); MAX_ARGS];
    let mut ptrs = base_ptrs;
    let ptrs = &mut ptrs[0..n_operands + 1];

    base_ptrs[n_operands] = dst;

    unsafe {
        for (i, &operand) in operands.iter().enumerate() {
            base_ptrs[i] = operand.mut_ptr();
        }

        for _ in 0..iter_shape[1..].iter().product() {
            let indices = indices_iter.cur_indices();

            for (i, &index) in indices[..n_operands + 1].iter().enumerate() {
                ptrs[i] = base_ptrs[i].add(index);
            }

            sum_of_products(ptrs, inner_stride, iter_shape[0]);
            indices_iter.increment_flat_indices();
        }
    }
}

pub(super) unsafe fn try_specialized_einsum_loop<T: SumOfProductsType>(operands: &[&NdArray<T>],
                                                                       strides: &[[usize; MAX_ARGS]; MAX_DIMS],
                                                                       iter_ndims: usize,
                                                                       iter_shape: &[usize],
                                                                       dst: *mut T) -> bool {
    let n_operands = operands.len();

    if n_operands == 1 {
        if iter_ndims == 2 {
            einsum_1operand_2labels(operands[0],
                                    first_n_elements!(strides[0], 2),
                                    first_n_elements!(strides[1], 2),
                                    first_n_elements!(iter_shape, 2),
                                    dst);
        } else if iter_ndims == 3 {
            einsum_1operand_3labels(operands[0],
                                    first_n_elements!(strides[0], 2),
                                    first_n_elements!(strides[1], 2),
                                    first_n_elements!(strides[2], 2),
                                    first_n_elements!(iter_shape, 3),
                                    dst);
        }
    } else if n_operands == 2 {
        if iter_ndims == 2 {
            einsum_2operands_2labels(operands[0], operands[1],
                                     first_n_elements!(strides[0], 3),
                                     first_n_elements!(strides[1], 3),
                                     first_n_elements!(iter_shape, 2),
                                     dst);
        } else if iter_ndims == 3 {
            einsum_2operands_3labels(operands[0], operands[1],
                                     first_n_elements!(strides[0], 3),
                                     first_n_elements!(strides[1], 3),
                                     first_n_elements!(strides[2], 3),
                                     first_n_elements!(iter_shape, 3),
                                     dst);
        }
    }

    (n_operands == 1 || n_operands == 2) && (iter_ndims == 2 || iter_ndims == 3)
}


pub(super) unsafe fn einsum_1operand_2labels<T: SumOfProductsType>(operand: &NdArray<T>,
                                                                   strides_dim0: &[usize; 2],
                                                                   strides_dim1: &[usize; 2],
                                                                   iter_shape: &[usize; 2],
                                                                   dst: *mut T)
{
    let sum_of_products = get_sum_of_products_function(strides_dim0);

    unsafe {
        let op = operand.mut_ptr();

        assert_unchecked(iter_shape[0] > 0);
        assert_unchecked(iter_shape[1] > 0);

        for i in 0..iter_shape[1] {
            let src = op.add(i * strides_dim1[0]);
            let dst = dst.add(i * strides_dim1[1]);

            sum_of_products(&[src, dst], strides_dim0, iter_shape[0]);
        }
    }
}

pub(super) unsafe fn einsum_1operand_3labels<T: SumOfProductsType>(operand: &NdArray<T>,
                                                                   strides_dim0: &[usize; 2],
                                                                   strides_dim1: &[usize; 2],
                                                                   strides_dim2: &[usize; 2],
                                                                   iter_shape: &[usize; 3],
                                                                   dst: *mut T)
{
    let sum_of_products = get_sum_of_products_function(strides_dim0);

    unsafe {
        let op = operand.mut_ptr();

        assert_unchecked(iter_shape[0] > 0);
        assert_unchecked(iter_shape[1] > 0);
        assert_unchecked(iter_shape[2] > 0);

        for i in 0..iter_shape[2] {
            for j in 0..iter_shape[1] {
                let src = op.add(i * strides_dim2[0] + j * strides_dim1[0]);
                let dst = dst.add(i * strides_dim2[1] + j * strides_dim1[1]);

                sum_of_products(&[src, dst], strides_dim0, iter_shape[0]);
            }
        }
    }
}

pub(super) unsafe fn einsum_2operands_2labels<T: SumOfProductsType>(operand1: &NdArray<T>,
                                                                    operand2: &NdArray<T>,
                                                                    strides_dim0: &[usize; 3],
                                                                    strides_dim1: &[usize; 3],
                                                                    iter_shape: &[usize; 2],
                                                                    dst: *mut T)
{
    let sum_of_products = get_sum_of_products_function(strides_dim0);

    unsafe {
        let op1 = operand1.mut_ptr();
        let op2 = operand2.mut_ptr();

        assert_unchecked(iter_shape[0] > 0);
        assert_unchecked(iter_shape[1] > 0);

        for i in 0..iter_shape[1] {
            let ptr1 = op1.add(i * strides_dim1[0]);
            let ptr2 = op2.add(i * strides_dim1[1]);
            let dst = dst.add(i * strides_dim1[2]);

            sum_of_products(&[ptr1, ptr2, dst], strides_dim0, iter_shape[0]);
        }
    }
}

pub(super) unsafe fn einsum_2operands_3labels<T: SumOfProductsType>(operand1: &NdArray<T>,
                                                                    operand2: &NdArray<T>,
                                                                    strides_dim0: &[usize; 3],
                                                                    strides_dim1: &[usize; 3],
                                                                    strides_dim2: &[usize; 3],
                                                                    iter_shape: &[usize; 3],
                                                                    dst: *mut T)
{
    let sum_of_products = get_sum_of_products_function(strides_dim0);

    unsafe {
        let op1 = operand1.mut_ptr();
        let op2 = operand2.mut_ptr();

        assert_unchecked(iter_shape[0] > 0);
        assert_unchecked(iter_shape[1] > 0);
        assert_unchecked(iter_shape[2] > 0);

        for i in 0..iter_shape[2] {
            for j in 0..iter_shape[1] {
                let ptr1 = op1.add(i * strides_dim2[0] + j * strides_dim1[0]);
                let ptr2 = op2.add(i * strides_dim2[1] + j * strides_dim1[1]);
                let dst = dst.add(i * strides_dim2[2] + j * strides_dim1[2]);

                sum_of_products(&[ptr1, ptr2, dst], strides_dim0, iter_shape[0]);
            }
        }
    }
}