1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/*
    Appellation: utils <mod>
    Contrib: FL03 <jo3mccain@icloud.com>
*/
//! # Utilities
//!
//!
use crate::prelude::{Scalar, TensorOp, TensorResult};
use crate::shape::ShapeError;
use crate::tensor::{from_vec_with_op, TensorBase};

pub fn matmul<T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>) -> TensorResult<TensorBase<T>>
where
    T: Scalar,
{
    if lhs.shape().rank() != rhs.shape().rank() {
        return Err(ShapeError::IncompatibleShapes.into());
    }

    let shape = lhs.shape().matmul_shape(&rhs.shape()).unwrap();
    let mut result = vec![T::zero(); shape.size()];

    for i in 0..lhs.shape().nrows() {
        for j in 0..rhs.shape().ncols() {
            for k in 0..lhs.shape().ncols() {
                let pos = i * rhs.shape().ncols() + j;
                let left = i * lhs.shape().ncols() + k;
                let right = k * rhs.shape().ncols() + j;
                result[pos] += lhs.store[left] * rhs.store[right];
            }
        }
    }
    let op = TensorOp::matmul(lhs.clone(), rhs.clone());
    let tensor = from_vec_with_op(false, op, shape, result);
    Ok(tensor)
}

pub fn dot_product<T>(lhs: &TensorBase<T>, rhs: &TensorBase<T>) -> TensorResult<TensorBase<T>>
where
    T: Scalar,
{
    if lhs.shape().rank() != rhs.shape().rank() {
        return Err(ShapeError::IncompatibleShapes.into());
    }

    let shape = lhs.shape().matmul_shape(&rhs.shape()).unwrap();
    let mut result = vec![T::zero(); shape.size()];

    for i in 0..lhs.shape().nrows() {
        for j in 0..rhs.shape().ncols() {
            for k in 0..lhs.shape().ncols() {
                let pos = i * rhs.shape().ncols() + j;
                let left = i * lhs.shape().ncols() + k;
                let right = k * rhs.shape().ncols() + j;
                result[pos] += lhs.store[left] * rhs.store[right];
            }
        }
    }
    let op = TensorOp::matmul(lhs.clone(), rhs.clone());
    let tensor = from_vec_with_op(false, op, shape, result);
    Ok(tensor)
}