1use crate::prelude::{Scalar, TensorExpr, TensorResult};
6use crate::shape::{ShapeError, Stride};
7use crate::tensor::{from_vec_with_op, TensorBase};
8
9pub(crate) fn coordinates_to_index<Idx>(coords: Idx, strides: &Stride) -> usize
10where
11 Idx: AsRef<[usize]>,
12{
13 coords
14 .as_ref()
15 .iter()
16 .zip(strides.iter())
17 .fold(0, |acc, (&i, &s)| acc + i * s)
18}
19
20pub fn matmul<T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>) -> TensorResult<TensorBase<T>>
21where
22 T: Scalar,
23{
24 if lhs.shape().rank() != rhs.shape().rank() {
25 return Err(ShapeError::DimensionMismatch.into());
26 }
27
28 let shape = lhs.shape().matmul_shape(rhs.shape()).unwrap();
29 let mut result = vec![T::zero(); shape.size()];
30
31 for i in 0..lhs.shape().nrows() {
32 for j in 0..rhs.shape().ncols() {
33 for k in 0..lhs.shape().ncols() {
34 let pos = i * rhs.shape().ncols() + j;
35 let left = i * lhs.shape().ncols() + k;
36 let right = k * rhs.shape().ncols() + j;
37 result[pos] += lhs.data[left] * rhs.data[right];
38 }
39 }
40 }
41 let op = TensorExpr::matmul(lhs.clone(), rhs.clone());
42 let tensor = from_vec_with_op(false, op, shape, result);
43 Ok(tensor)
44}
45
46macro_rules! i {
47 ($($x:expr),*) => {
48 vec![$($x),*]
49 };
50
51}
52
53macro_rules! impl_partial_eq {
54 ($s:ident -> $cmp:tt: [$($t:ty),*]) => {
55 $(
56 impl_partial_eq!($s -> $cmp, $t);
57 )*
58 };
59 ($s:ident -> $cmp:tt, $t:ty) => {
60 impl PartialEq<$t> for $s {
61 fn eq(&self, other: &$t) -> bool {
62 self.$cmp == *other
63 }
64 }
65
66 impl PartialEq<$s> for $t {
67 fn eq(&self, other: &$s) -> bool {
68 *self == other.$cmp
69 }
70 }
71 };
72}
73
74macro_rules! izip {
75 ( @closure $p:pat => $tup:expr ) => {
79 |$p| $tup
80 };
81
82 ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => {
84 izip!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*)
85 };
86
87 ($first:expr $(,)*) => {
89 IntoIterator::into_iter($first)
90 };
91
92 ($first:expr, $second:expr $(,)*) => {
94 izip!($first)
95 .zip($second)
96 };
97
98 ( $first:expr $( , $rest:expr )* $(,)* ) => {
100 izip!($first)
101 $(
102 .zip($rest)
103 )*
104 .map(
105 izip!(@closure a => (a) $( , $rest )*)
106 )
107 };
108}