tensorism 0.3.0

A library for easy tensor manipulation on top of ndarray.
use criterion::{Criterion, criterion_group, criterion_main};
use ndarray::{Array2, Array4, Axis, LinalgScalar};
use num_traits::Zero;
use std::{
    f64::consts::PI,
    hint::black_box,
    iter::Sum,
    ops::{Add, Mul},
};
use tensorism::new_ndarray;
use utils::{C64, are_close};

struct SumAccu<T> {
    value: T,
}

impl<T> SumAccu<T>
where
    T: Copy + Add<Output = T> + Zero,
{
    fn new() -> Self {
        Self { value: T::zero() }
    }
    fn accumulate(&mut self, x: T) {
        self.value = self.value + x;
    }
    fn get(&self) -> T {
        self.value
    }
}

fn ndarray<F>(t_a: &Array4<F>, t_b: &Array4<F>) -> Array4<F>
where
    F: Clone + LinalgScalar,
{
    let di = t_a.len_of(Axis(2));
    let dj = t_a.len_of(Axis(1));
    let dk = t_a.len_of(Axis(0));
    let dl = t_a.len_of(Axis(3));
    let dn = t_b.len_of(Axis(1));
    let dm = t_b.len_of(Axis(3));

    // --- 1. Permutations et Reshapes pour aligner la contraction ---

    // A[k, j, i, l]
    // Réorganisation de A pour regrouper les axes I=(i, j) et J=(k, l)
    // Nous le remodelons en (I, J) -> (i*j, k*l)
    let a_permuted = t_a.view().permuted_axes([2, 1, 0, 3]);
    let a_reshaped = a_permuted
        .as_standard_layout()
        .into_shape_with_order((di * dj, dk * dl))
        .unwrap();

    // B[l, n, k, m]
    // Réorganisation de B pour regrouper les axes J=(k, l) et K=(n, m)
    // Nous le remodelons en (J, K) -> (k*l, n*m)
    let b_permuted = t_b.view().permuted_axes([2, 0, 1, 3]);
    let b_reshaped = b_permuted
        .as_standard_layout()
        .into_shape_with_order((dk * dl, dn * dm))
        .unwrap();

    // --- 2. Multiplication Matricielle Optimale (Produit Scalaire) ---

    // C' = A' . B'  -> C'[I, K] = A'[I, J] . B'[J, K]
    let c_prime: Array2<F> = a_reshaped.dot(&b_reshaped);

    // --- 3. Dépermutation et Reshape pour le tenseur de sortie result[i, n, m, j] ---

    // La matrice C' est dans l'ordre (I, K), où I=(i, j) et K=(n, m)
    // C' est donc (i * j) x (n * m).

    // 3a. Remodeler C' en rang 4 dans l'ordre (i, j, n, m)
    let c_temp_view = c_prime.into_shape_with_order((di, dj, dn, dm)).unwrap();

    // 3b. Permuter result_temp[i, j, n, m] pour obtenir l'ordre de sortie désiré result[i, n, m, j]
    c_temp_view.permuted_axes([0, 2, 3, 1]).to_owned()
}

fn tensorism<F>(t_a: &Array4<F>, t_b: &Array4<F>) -> Array4<F>
where
    F: Copy + Mul<Output = F> + Sum,
{
    new_ndarray!(for i n m j => <F as Sum>::sum(for k l => t_a[k, j, i, l] * t_b[l, n, k, m]))
}

fn flatten<F>(t_a: &Array4<F>, t_b: &Array4<F>) -> Array4<F>
where
    F: Copy + Mul<Output = F> + Sum,
{
    let dim_number_4 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).0;
    if dim_number_4 != ::ndarray::ArrayBase::<_, _>::dim(&t_b).2 {
        panic!("Dimensions are not matching between t_a[k, _, _, _] and t_b[_, _, k, _]");
    }
    let dim_number_5 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).3;
    if dim_number_5 != ::ndarray::ArrayBase::<_, _>::dim(&t_b).0 {
        panic!("Dimensions are not matching between t_a[_, _, _, l] and t_b[l, _, _, _]");
    }
    let dim_number_0 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).2;
    let dim_number_1 = ::ndarray::ArrayBase::<_, _>::dim(&t_b).1;
    let dim_number_2 = ::ndarray::ArrayBase::<_, _>::dim(&t_b).3;
    let dim_number_3 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).1;

    let p_a = t_a.as_ptr();
    let p_b = t_b.as_ptr();
    let mut result = ::ndarray::Array::<F, ::ndarray::Dim<[::ndarray::Ix; 4usize]>>::uninit((
        dim_number_0,
        dim_number_1,
        dim_number_2,
        dim_number_3,
    ));

    let strides_a = t_a.strides();
    let strides_b = t_b.strides();
    let stride_a0 = strides_a[0];
    let stride_a1 = strides_a[1];
    let stride_a2 = strides_a[2];
    let stride_a3 = strides_a[3];
    let stride_b0 = strides_b[0];
    let stride_b1 = strides_b[1];
    let stride_b2 = strides_b[2];
    let stride_b3 = strides_b[3];

    let mut ptr_on_result = result.as_mut_ptr() as *mut F;

    for i in 0..dim_number_0 {
        for n in 0..dim_number_1 {
            for m in 0..dim_number_2 {
                for j in 0..dim_number_3 {
                    let value: F = ((0usize..dim_number_5)
                        .flat_map(|l| (0usize..dim_number_4).map(move |k| (k, l)))
                        .map(|(k, l)| {
                            let i_a: isize = (k as isize) * stride_a0
                                + (j as isize) * stride_a1
                                + (i as isize) * stride_a2
                                + (l as isize) * stride_a3;
                            let i_b: isize = (l as isize) * stride_b0
                                + (n as isize) * stride_b1
                                + (k as isize) * stride_b2
                                + (m as isize) * stride_b3;
                            (unsafe { *p_a.offset(i_a) }) * (unsafe { *p_b.offset(i_b) })
                        }))
                    .sum::<F>();
                    unsafe {
                        ptr_on_result.write(value);
                        ptr_on_result = ptr_on_result.add(1);
                    }
                }
            }
        }
    }
    unsafe { result.assume_init() }
}

fn flatten_accumulator<F>(t_a: &Array4<F>, t_b: &Array4<F>) -> Array4<F>
where
    F: Copy + Mul<Output = F> + Zero + Add<Output = F>,
{
    let dim_number_4 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).0;
    if dim_number_4 != ::ndarray::ArrayBase::<_, _>::dim(&t_b).2 {
        panic!("Dimensions are not matching between t_a[k, _, _, _] and t_b[_, _, k, _]");
    }
    let dim_number_5 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).3;
    if dim_number_5 != ::ndarray::ArrayBase::<_, _>::dim(&t_b).0 {
        panic!("Dimensions are not matching between t_a[_, _, _, l] and t_b[l, _, _, _]");
    }
    let dim_number_0 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).2;
    let dim_number_1 = ::ndarray::ArrayBase::<_, _>::dim(&t_b).1;
    let dim_number_2 = ::ndarray::ArrayBase::<_, _>::dim(&t_b).3;
    let dim_number_3 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).1;

    let p_a = t_a.as_ptr();
    let p_b = t_b.as_ptr();
    let mut result = ::ndarray::Array::<F, ::ndarray::Dim<[::ndarray::Ix; 4usize]>>::uninit((
        dim_number_0,
        dim_number_1,
        dim_number_2,
        dim_number_3,
    ));

    let strides_a = t_a.strides();
    let strides_b = t_b.strides();
    let stride_a0 = strides_a[0];
    let stride_a1 = strides_a[1];
    let stride_a2 = strides_a[2];
    let stride_a3 = strides_a[3];
    let stride_b0 = strides_b[0];
    let stride_b1 = strides_b[1];
    let stride_b2 = strides_b[2];
    let stride_b3 = strides_b[3];

    let mut ptr_on_result = result.as_mut_ptr() as *mut F;

    for i in 0..dim_number_0 {
        for n in 0..dim_number_1 {
            for m in 0..dim_number_2 {
                for j in 0..dim_number_3 {
                    let mut accu = SumAccu::<F>::new();
                    for k in 0..dim_number_4 {
                        for l in 0..dim_number_5 {
                            let i_a: isize = (k as isize) * stride_a0
                                + (j as isize) * stride_a1
                                + (i as isize) * stride_a2
                                + (l as isize) * stride_a3;
                            let i_b: isize = (l as isize) * stride_b0
                                + (n as isize) * stride_b1
                                + (k as isize) * stride_b2
                                + (m as isize) * stride_b3;
                            let value =
                                (unsafe { *p_a.offset(i_a) }) * (unsafe { *p_b.offset(i_b) });
                            accu.accumulate(value);
                        }
                    }
                    unsafe {
                        ptr_on_result.write(accu.get());
                        ptr_on_result = ptr_on_result.add(1);
                    }
                }
            }
        }
    }
    unsafe { result.assume_init() }
}

fn flatten_factorized<F>(t_a: &Array4<F>, t_b: &Array4<F>) -> Array4<F>
where
    F: Copy + Mul<Output = F> + Sum,
{
    let dim_number_4 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).0;
    if dim_number_4 != ::ndarray::ArrayBase::<_, _>::dim(&t_b).2 {
        panic!("Dimensions are not matching between t_a[k, _, _, _] and t_b[_, _, k, _]");
    }
    let dim_number_5 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).3;
    if dim_number_5 != ::ndarray::ArrayBase::<_, _>::dim(&t_b).0 {
        panic!("Dimensions are not matching between t_a[_, _, _, l] and t_b[l, _, _, _]");
    }
    let dim_number_0 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).2;
    let dim_number_1 = ::ndarray::ArrayBase::<_, _>::dim(&t_b).1;
    let dim_number_2 = ::ndarray::ArrayBase::<_, _>::dim(&t_b).3;
    let dim_number_3 = ::ndarray::ArrayBase::<_, _>::dim(&t_a).1;

    let p_a = t_a.as_ptr();
    let p_b = t_b.as_ptr();
    let mut result = ::ndarray::Array::<F, ::ndarray::Dim<[::ndarray::Ix; 4usize]>>::uninit((
        dim_number_0,
        dim_number_1,
        dim_number_2,
        dim_number_3,
    ));
    let strides_a = t_a.strides();
    let strides_b = t_b.strides();
    let stride_a0 = strides_a[0];
    let stride_a1 = strides_a[1];
    let stride_a2 = strides_a[2];
    let stride_a3 = strides_a[3];
    let stride_b0 = strides_b[0];
    let stride_b1 = strides_b[1];
    let stride_b2 = strides_b[2];
    let stride_b3 = strides_b[3];
    let mut ptr_on_result = result.as_mut_ptr() as *mut F;

    for i in 0..dim_number_0 {
        for n in 0..dim_number_1 {
            for m in 0..dim_number_2 {
                for j in 0..dim_number_3 {
                    let k_a = (j as isize) * stride_a1 + (i as isize) * stride_a2;
                    let k_b = (n as isize) * stride_b1 + (m as isize) * stride_b3;
                    let value: F = ((0usize..dim_number_5)
                        .flat_map(|l| (0usize..dim_number_4).map(move |k| (k, l)))
                        .map(|(k, l)| {
                            let i_a: isize =
                                (k as isize) * stride_a0 + k_a + (l as isize) * stride_a3;
                            let i_b: isize =
                                (l as isize) * stride_b0 + k_b + (k as isize) * stride_b2;
                            (unsafe { *p_a.offset(i_a) }) * (unsafe { *p_b.offset(i_b) })
                        }))
                    .sum::<F>();
                    unsafe {
                        ptr_on_result.write(value);
                        ptr_on_result = ptr_on_result.add(1);
                    }
                }
            }
        }
    }
    unsafe { result.assume_init() }
}

// result[i, n, m, j] = SUM_{k, l} (A[k, j, i, l] * B[l, n, k, m])
fn benchmark(c: &mut Criterion) {
    let (di, dj, dk, dl, dn, dm) = (4, 6, 8, 3, 12, 11);

    let shape1 = (dk, dj, di, dl);
    let tensor1_f = Array4::<f64>::from_shape_fn(shape1, |(i, j, k, l)| {
        let i = (i as f64) * PI;
        let j = (j as f64) * PI;
        let k = (k as f64) * PI;
        let l = (l as f64) * PI;
        (i * k + 3.0).sin() * 0.9 - (j / 2.0 + 1.44).sin() * 1.5
            + (k / 3.0 + 4.9 - j).cos() * 2.7
            + (l + i / 2.0).sin() * 0.5
    });
    let tensor1_c = Array4::<C64>::from_shape_fn(shape1, |(i, j, k, l)| {
        let i = C64::from_usize(i) * PI;
        let j = C64::from_usize(j) * PI;
        let k = C64::from_usize(k) * PI + C64::I;
        let l = C64::from_usize(l) * PI;
        (i * k + 3.0).sin() * 0.9 - (j / 2.0 + 1.44).sin() * 1.5
            + (k / 3.0 + 4.9 - j).cos() * 2.7
            + (l + i / 2.0).sin() * 0.5
    });

    let shape2 = (dl, dn, dk, dm);
    let tensor2_f = Array4::<f64>::from_shape_fn(shape2, |(i, j, k, l)| {
        let i = (i as f64) * std::f64::consts::PI;
        let j = (j as f64) * PI;
        let k = (k as f64) * PI;
        let l = (l as f64) * PI;
        (j * i + 2.9).sin() * 1.1 - (k / 1.9 + 1.54).sin() * 1.6
            + (j / 3.1 + 4.8 - k).cos() * 2.6
            + (l - i / 3.0).cos() * 0.4
    });
    let tensor2_c = Array4::<C64>::from_shape_fn(shape2, |(i, j, k, l)| {
        let i = C64::from_usize(i) * PI;
        let j = C64::from_usize(j) * PI + C64::I;
        let k = C64::from_usize(k) * PI;
        let l = C64::from_usize(l) * PI;
        C64::from_f64(1.1) * (j * i + 2.9).sin() - (k / 1.9 + 1.54).sin() * 1.6
            + (j / 3.1 + 4.8 - k).cos() * 2.6
            + (l - i / 3.0).cos() * 0.4
    });

    let result_ndarray = ndarray(&tensor1_f, &tensor2_f);
    let result_tensorism = tensorism(&tensor1_f, &tensor2_f);
    let result_flatten = flatten(&tensor1_f, &tensor2_f);
    are_close(&result_ndarray, &result_tensorism, 0.001f64);
    are_close(&result_ndarray, &result_flatten, 0.001f64);

    let result_ndarray = ndarray(&tensor1_c, &tensor2_c);
    let result_tensorism = tensorism(&tensor1_c, &tensor2_c);
    let result_flatten = flatten(&tensor1_c, &tensor2_c);
    let result_flatten_factorized = flatten_factorized(&tensor1_c, &tensor2_c);
    let result_flatten_accu = flatten_accumulator(&tensor1_c, &tensor2_c);
    are_close(&result_ndarray, &result_tensorism, 0.001f64);
    are_close(&result_ndarray, &result_flatten, 0.001f64);
    are_close(&result_ndarray, &result_flatten_factorized, 0.001f64);
    are_close(&result_ndarray, &result_flatten_accu, 0.001f64);

    c.bench_function("Double contraction f64 - Ndarray", |b| {
        b.iter(|| ndarray(black_box(&tensor1_f), black_box(&tensor2_f)))
    });
    c.bench_function("Double contraction f64 - Tensorism", |b| {
        b.iter(|| tensorism(black_box(&tensor1_f), black_box(&tensor2_f)))
    });
    c.bench_function("Double contraction f64 - Flatten", |b| {
        b.iter(|| flatten(black_box(&tensor1_f), black_box(&tensor2_f)))
    });
    c.bench_function("Double contraction f64 - Flatten factorized", |b| {
        b.iter(|| flatten_factorized(black_box(&tensor1_f), black_box(&tensor2_f)))
    });
    c.bench_function("Double contraction f64 - Flatten accumulator", |b| {
        b.iter(|| flatten_accumulator(black_box(&tensor1_f), black_box(&tensor2_f)))
    });
    c.bench_function("Double contraction C64 - Ndarray", |b| {
        b.iter(|| ndarray(black_box(&tensor1_c), black_box(&tensor2_c)))
    });
    c.bench_function("Double contraction C64 - Tensorism", |b| {
        b.iter(|| tensorism(black_box(&tensor1_c), black_box(&tensor2_c)))
    });
    c.bench_function("Double contraction C64 - Flatten", |b| {
        b.iter(|| flatten(black_box(&tensor1_c), black_box(&tensor2_c)))
    });
    c.bench_function("Double contraction C64 - Flatten factorized", |b| {
        b.iter(|| flatten_factorized(black_box(&tensor1_c), black_box(&tensor2_c)))
    });
    c.bench_function("Double contraction C64 - Flatten accumulator", |b| {
        b.iter(|| flatten_accumulator(black_box(&tensor1_c), black_box(&tensor2_c)))
    });
}

criterion_group!(benches, benchmark);
criterion_main!(benches);