acme_tensor/
utils.rs

1/*
2    Appellation: utils <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::prelude::{Scalar, TensorExpr, TensorResult};
6use crate::shape::{ShapeError, Stride};
7use crate::tensor::{from_vec_with_op, TensorBase};
8
9pub(crate) fn coordinates_to_index<Idx>(coords: Idx, strides: &Stride) -> usize
10where
11    Idx: AsRef<[usize]>,
12{
13    coords
14        .as_ref()
15        .iter()
16        .zip(strides.iter())
17        .fold(0, |acc, (&i, &s)| acc + i * s)
18}
19
20pub fn matmul<T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>) -> TensorResult<TensorBase<T>>
21where
22    T: Scalar,
23{
24    if lhs.shape().rank() != rhs.shape().rank() {
25        return Err(ShapeError::DimensionMismatch.into());
26    }
27
28    let shape = lhs.shape().matmul_shape(rhs.shape()).unwrap();
29    let mut result = vec![T::zero(); shape.size()];
30
31    for i in 0..lhs.shape().nrows() {
32        for j in 0..rhs.shape().ncols() {
33            for k in 0..lhs.shape().ncols() {
34                let pos = i * rhs.shape().ncols() + j;
35                let left = i * lhs.shape().ncols() + k;
36                let right = k * rhs.shape().ncols() + j;
37                result[pos] += lhs.data[left] * rhs.data[right];
38            }
39        }
40    }
41    let op = TensorExpr::matmul(lhs.clone(), rhs.clone());
42    let tensor = from_vec_with_op(false, op, shape, result);
43    Ok(tensor)
44}
45
46macro_rules! i {
47    ($($x:expr),*) => {
48        vec![$($x),*]
49    };
50
51}
52
53macro_rules! impl_partial_eq {
54    ($s:ident -> $cmp:tt: [$($t:ty),*]) => {
55        $(
56            impl_partial_eq!($s -> $cmp, $t);
57        )*
58    };
59    ($s:ident -> $cmp:tt, $t:ty) => {
60        impl PartialEq<$t> for $s {
61            fn eq(&self, other: &$t) -> bool {
62                self.$cmp == *other
63            }
64        }
65
66        impl PartialEq<$s> for $t {
67            fn eq(&self, other: &$s) -> bool {
68                *self == other.$cmp
69            }
70        }
71    };
72}
73
74macro_rules! izip {
75    // @closure creates a tuple-flattening closure for .map() call. usage:
76    // @closure partial_pattern => partial_tuple , rest , of , iterators
77    // eg. izip!( @closure ((a, b), c) => (a, b, c) , dd , ee )
78    ( @closure $p:pat => $tup:expr ) => {
79        |$p| $tup
80    };
81
82    // The "b" identifier is a different identifier on each recursion level thanks to hygiene.
83    ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => {
84        izip!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*)
85    };
86
87    // unary
88    ($first:expr $(,)*) => {
89        IntoIterator::into_iter($first)
90    };
91
92    // binary
93    ($first:expr, $second:expr $(,)*) => {
94        izip!($first)
95            .zip($second)
96    };
97
98    // n-ary where n > 2
99    ( $first:expr $( , $rest:expr )* $(,)* ) => {
100        izip!($first)
101            $(
102                .zip($rest)
103            )*
104            .map(
105                izip!(@closure a => (a) $( , $rest )*)
106            )
107    };
108}