1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use crate::bmm_backwards::BMMBackwards;
use crate::dot_backwards::DotBackwards;
use crate::matrix_product_backwards::MatrixProductBackwards;
use crate::matrix_vec_backwards::MatrixVecBackwards;
use crate::none_backwards::NoneBackwards;
use crate::{StridedMemory, Tensor, TensorDataType};
impl<'a, T: TensorDataType> Tensor<'a, T> {
/// Calculates the dot product of two 1D tensors.
///
/// # Panics
/// - Panics if either tensor is not 1D
/// - Panics if the lengths of the two tensors are not equal
///
/// # Examples
/// ```
/// # use redstone_ml::*;
/// let tensor1 = Tensor::new([1.0, 2.0, 3.0]);
/// let tensor2 = Tensor::new([4.0, 5.0, 6.0]);
/// let result = tensor1.dot(tensor2);
/// assert_eq!(result.value(), 32.0); // 1*4 + 2*5 + 3*6 = 32
/// ```
pub fn dot<'b, 'r>(&self, other: impl AsRef<Tensor<'b, T>>) -> Tensor<'r, T> {
let other = other.as_ref();
let requires_grad = self.requires_grad() || other.requires_grad();
let grad_fn = if requires_grad { DotBackwards::new(self, other) } else { NoneBackwards::new() };
unsafe { Tensor::from_raw_parts(self.array.dot(&other.array), requires_grad, grad_fn) }
}
/// Calculates the matrix product of two tensors.
///
/// - If both tensors are 1D, then their dot product is returned.
/// - If both tensors are 2D, then their matrix product is returned.
/// - If the first tensor is 2D and the second tensor is 1D, then the matrix-vector product is returned.
///
/// # Panics
/// - If the dimensions/shape of the tensors are incompatible
///
/// # Example
/// ```
/// # use redstone_ml::*;
///
/// let a = Tensor::new(vec![
/// [1.0, 2.0, 3.0],
/// [4.0, 5.0, 6.0],
/// ]);
///
/// let b = Tensor::new(vec![
/// [7.0, 8.0],
/// [9.0, 10.0],
/// [11.0, 12.0],
/// ]);
///
/// let result = a.matmul(&b);
/// assert_eq!(result, Tensor::new([
/// [58.0, 64.0],
/// [139.0, 154.0],
/// ]));
/// ```
pub fn matmul<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T> {
let other = other.as_ref();
if self.ndims() == 1 && other.ndims() == 1 {
return self.dot(other);
}
let requires_grad = self.requires_grad() || other.requires_grad();
let result = self.array.matmul(&other.array);
let grad_fn = if requires_grad {
if self.ndims() == 2 && other.ndims() == 1 {
MatrixVecBackwards::new(self, other)
} else if self.ndims() == 2 && other.ndims() == 2 {
MatrixProductBackwards::new(self, other)
} else {
panic!("this should never happen")
}
} else { NoneBackwards::new() };
unsafe { Tensor::from_raw_parts(result, requires_grad, grad_fn) }
}
/// Performs batch matrix multiplication on 3D tensors.
///
/// The shape of the resulting ndarray will be `[batch_size, self.shape()[1], other.shape()[2]]`,
/// where `batch_size` is the shared first dimension of both input tensors.
///
/// # Panics
/// - If either tensor is not 3D
/// - If the tensors do not have dimensions compatible for batch matrix multiplication.
///
/// # Example
/// ```
/// # use redstone_ml::*;
///
/// let arr1 = Tensor::<f32>::rand([3, 2, 4]); // 3 batches of 2x4 matrices
/// let arr2 = Tensor::<f32>::rand([3, 4, 5]); // 3 batches of 4x5 matrices
/// let result = arr1.bmm(&arr2);
/// assert_eq!(result.shape(), [3, 2, 5]); // result is 3 batches of 2x5 matrices
/// ```
pub fn bmm<'r>(&self, other: impl AsRef<Tensor<'a, T>>) -> Tensor<'r, T> {
let other = other.as_ref();
let requires_grad = self.requires_grad() || other.requires_grad();
let grad_fn = if requires_grad { BMMBackwards::new(self, other) } else { NoneBackwards::new() };
unsafe { Tensor::from_raw_parts(self.array.bmm(&other.array), requires_grad, grad_fn) }
}
}