redstone_ml/tensor/matrix_ops.rs
1use crate::bmm_backwards::BMMBackwards;
2use crate::dot_backwards::DotBackwards;
3use crate::matrix_product_backwards::MatrixProductBackwards;
4use crate::matrix_vec_backwards::MatrixVecBackwards;
5use crate::none_backwards::NoneBackwards;
6use crate::{StridedMemory, Tensor, TensorDataType};
7
8impl<'a, T: TensorDataType> Tensor<'a, T> {
9 /// Calculates the dot product of two 1D tensors.
10 ///
11 /// # Panics
12 /// - Panics if either tensor is not 1D
13 /// - Panics if the lengths of the two tensors are not equal
14 ///
15 /// # Examples
16 /// ```
17 /// # use redstone_ml::*;
18 /// let tensor1 = Tensor::new([1.0, 2.0, 3.0]);
19 /// let tensor2 = Tensor::new([4.0, 5.0, 6.0]);
20 /// let result = tensor1.dot(tensor2);
21 /// assert_eq!(result.value(), 32.0); // 1*4 + 2*5 + 3*6 = 32
22 /// ```
23 pub fn dot<'b, 'r>(&self, other: impl AsRef<Tensor<'b, T>>) -> Tensor<'r, T> {
24 let other = other.as_ref();
25
26 let requires_grad = self.requires_grad() || other.requires_grad();
27 let grad_fn = if requires_grad { DotBackwards::new(self, other) } else { NoneBackwards::new() };
28
29 unsafe { Tensor::from_raw_parts(self.array.dot(&other.array), requires_grad, grad_fn) }
30 }
31
32 /// Calculates the matrix product of two tensors.
33 ///
34 /// - If both tensors are 1D, then their dot product is returned.
35 /// - If both tensors are 2D, then their matrix product is returned.
36 /// - If the first tensor is 2D and the second tensor is 1D, then the matrix-vector product is returned.
37 ///
38 /// # Panics
39 /// - If the dimensions/shape of the tensors are incompatible
40 ///
41 /// # Example
42 /// ```
43 /// # use redstone_ml::*;
44 ///
45 /// let a = Tensor::new(vec![
46 /// [1.0, 2.0, 3.0],
47 /// [4.0, 5.0, 6.0],
48 /// ]);
49 ///
50 /// let b = Tensor::new(vec![
51 /// [7.0, 8.0],
52 /// [9.0, 10.0],
53 /// [11.0, 12.0],
54 /// ]);
55 ///
56 /// let result = a.matmul(&b);
57 /// assert_eq!(result, Tensor::new([
58 /// [58.0, 64.0],
59 /// [139.0, 154.0],
60 /// ]));
61 /// ```
62 pub fn matmul<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T> {
63 let other = other.as_ref();
64
65 if self.ndims() == 1 && other.ndims() == 1 {
66 return self.dot(other);
67 }
68
69 let requires_grad = self.requires_grad() || other.requires_grad();
70 let result = self.array.matmul(&other.array);
71
72 let grad_fn = if requires_grad {
73 if self.ndims() == 2 && other.ndims() == 1 {
74 MatrixVecBackwards::new(self, other)
75 } else if self.ndims() == 2 && other.ndims() == 2 {
76 MatrixProductBackwards::new(self, other)
77 } else {
78 panic!("this should never happen")
79 }
80 } else { NoneBackwards::new() };
81
82 unsafe { Tensor::from_raw_parts(result, requires_grad, grad_fn) }
83 }
84
85 /// Performs batch matrix multiplication on 3D tensors.
86 ///
87 /// The shape of the resulting ndarray will be `[batch_size, self.shape()[1], other.shape()[2]]`,
88 /// where `batch_size` is the shared first dimension of both input tensors.
89 ///
90 /// # Panics
91 /// - If either tensor is not 3D
92 /// - If the tensors do not have dimensions compatible for batch matrix multiplication.
93 ///
94 /// # Example
95 /// ```
96 /// # use redstone_ml::*;
97 ///
98 /// let arr1 = Tensor::<f32>::rand([3, 2, 4]); // 3 batches of 2x4 matrices
99 /// let arr2 = Tensor::<f32>::rand([3, 4, 5]); // 3 batches of 4x5 matrices
100 /// let result = arr1.bmm(&arr2);
101 /// assert_eq!(result.shape(), [3, 2, 5]); // result is 3 batches of 2x5 matrices
102 /// ```
103 pub fn bmm<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T> {
104 let other = other.as_ref();
105
106 let requires_grad = self.requires_grad() || other.requires_grad();
107 let grad_fn = if requires_grad { BMMBackwards::new(self, other) } else { NoneBackwards::new() };
108
109 unsafe { Tensor::from_raw_parts(self.array.bmm(&other.array), requires_grad, grad_fn) }
110 }
111}