redstone_ml/tensor/
matrix_ops.rs

1use crate::bmm_backwards::BMMBackwards;
2use crate::dot_backwards::DotBackwards;
3use crate::matrix_product_backwards::MatrixProductBackwards;
4use crate::matrix_vec_backwards::MatrixVecBackwards;
5use crate::none_backwards::NoneBackwards;
6use crate::{StridedMemory, Tensor, TensorDataType};
7
8impl<'a, T: TensorDataType> Tensor<'a, T> {
9    /// Calculates the dot product of two 1D tensors.
10    ///
11    /// # Panics
12    /// - Panics if either tensor is not 1D
13    /// - Panics if the lengths of the two tensors are not equal
14    ///
15    /// # Examples
16    /// ```
17    /// # use redstone_ml::*;
18    /// let tensor1 = Tensor::new([1.0, 2.0, 3.0]);
19    /// let tensor2 = Tensor::new([4.0, 5.0, 6.0]);
20    /// let result = tensor1.dot(tensor2);
21    /// assert_eq!(result.value(), 32.0); // 1*4 + 2*5 + 3*6 = 32
22    /// ```
23    pub fn dot<'b, 'r>(&self, other: impl AsRef<Tensor<'b, T>>) -> Tensor<'r, T> {
24        let other = other.as_ref();
25
26        let requires_grad = self.requires_grad() || other.requires_grad();
27        let grad_fn = if requires_grad { DotBackwards::new(self, other) } else { NoneBackwards::new() };
28
29        unsafe { Tensor::from_raw_parts(self.array.dot(&other.array), requires_grad, grad_fn) }
30    }
31
32    /// Calculates the matrix product of two tensors.
33    ///
34    /// - If both tensors are 1D, then their dot product is returned.
35    /// - If both tensors are 2D, then their matrix product is returned.
36    /// - If the first tensor is 2D and the second tensor is 1D, then the matrix-vector product is returned.
37    ///
38    /// # Panics
39    /// - If the dimensions/shape of the tensors are incompatible
40    ///
41    /// # Example
42    /// ```
43    /// # use redstone_ml::*;
44    ///
45    /// let a = Tensor::new(vec![
46    ///     [1.0, 2.0, 3.0],
47    ///     [4.0, 5.0, 6.0],
48    /// ]);
49    ///
50    /// let b = Tensor::new(vec![
51    ///     [7.0, 8.0],
52    ///     [9.0, 10.0],
53    ///     [11.0, 12.0],
54    /// ]);
55    ///
56    /// let result = a.matmul(&b);
57    /// assert_eq!(result, Tensor::new([
58    ///     [58.0, 64.0],
59    ///     [139.0, 154.0],
60    /// ]));
61    /// ```
62    pub fn matmul<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T> {
63        let other = other.as_ref();
64
65        if self.ndims() == 1 && other.ndims() == 1 {
66            return self.dot(other);
67        }
68
69        let requires_grad = self.requires_grad() || other.requires_grad();
70        let result = self.array.matmul(&other.array);
71
72        let grad_fn = if requires_grad {
73            if self.ndims() == 2 && other.ndims() == 1 {
74                MatrixVecBackwards::new(self, other)
75            } else if self.ndims() == 2 && other.ndims() == 2 {
76                MatrixProductBackwards::new(self, other)
77            } else {
78                panic!("this should never happen")
79            }
80        } else { NoneBackwards::new() };
81
82        unsafe { Tensor::from_raw_parts(result, requires_grad, grad_fn) }
83    }
84
85    /// Performs batch matrix multiplication on 3D tensors.
86    ///
87    /// The shape of the resulting ndarray will be `[batch_size, self.shape()[1], other.shape()[2]]`,
88    /// where `batch_size` is the shared first dimension of both input tensors.
89    ///
90    /// # Panics
91    /// - If either tensor is not 3D
92    /// - If the tensors do not have dimensions compatible for batch matrix multiplication.
93    ///
94    /// # Example
95    /// ```
96    /// # use redstone_ml::*;
97    ///
98    /// let arr1 = Tensor::<f32>::rand([3, 2, 4]); // 3 batches of 2x4 matrices
99    /// let arr2 = Tensor::<f32>::rand([3, 4, 5]); // 3 batches of 4x5 matrices
100    /// let result = arr1.bmm(&arr2);
101    /// assert_eq!(result.shape(), [3, 2, 5]); // result is 3 batches of 2x5 matrices
102    /// ```
103    pub fn bmm<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T> {
104        let other = other.as_ref();
105        
106        let requires_grad = self.requires_grad() || other.requires_grad();
107        let grad_fn = if requires_grad { BMMBackwards::new(self, other) } else { NoneBackwards::new() };
108
109        unsafe { Tensor::from_raw_parts(self.array.bmm(&other.array), requires_grad, grad_fn) }
110    }
111}