redstone_ml/tensor/reshape.rs
1use crate::none_backwards::NoneBackwards;
2use crate::reshape_backwards::ReshapeBackwards;
3use crate::transpose_backwards::TransposeBackwards;
4use crate::{AxisType, Reshape, StridedMemory, Tensor, TensorDataType};
5use crate::identity_backwards::IdentityBackwards;
6
7
8impl<'a, T: TensorDataType> Reshape<T> for &'a Tensor<'a, T> {
9 type Output = Tensor<'a, T>;
10
11 /// Provides a non-owning view of the tensor with the specified shape and stride.
12 /// The data pointed to by the view is shared with the original tensor.
13 ///
14 /// # Safety
15 /// - Ensure the memory layout referenced by `shape`, and `stride` is valid and owned
16 /// by the original tensor.
17 unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output {
18 let requires_grad = self.requires_grad();
19 let grad_fn = if requires_grad { ReshapeBackwards::new(self, self.shape()) } else { NoneBackwards::new() };
20
21 let result = self.array.as_ref().reshaped_view(shape, stride);
22
23 // NdArray<'static, T> needed to create a shared pointer to the result
24 // this function outputs a Tensor<'a, T> where ('a: 'static) so it should be safe.
25 let result = result.lifetime_cast();
26
27 Tensor::from_raw_parts(result, requires_grad, grad_fn)
28 }
29
30 /// Provides a non-owning view of the tensor that shares its data with the original tensor.
31 ///
32 /// # Example
33 /// ```
34 /// # use redstone_ml::*;
35 ///
36 /// let tensor = Tensor::new([1.0, 2.0, 3.0, 4.0]);
37 /// let view = (&tensor).view();
38 /// assert!(view.is_view())
39 /// ```
40 fn view(self) -> Self::Output {
41 let requires_grad = self.requires_grad();
42 let grad_fn = if requires_grad { IdentityBackwards::new(self) } else { NoneBackwards::new() };
43
44 let result = self.array.as_ref().view();
45
46 unsafe {
47 // NdArray<'static, T> needed to create a shared pointer to the result
48 // this function outputs a Tensor<'a, T> where ('a: 'static) so it should be safe.
49 let result = result.lifetime_cast();
50
51 Tensor::from_raw_parts(result, requires_grad, grad_fn)
52 }
53 }
54
55 /// Returns a transposed version of the tensor, swapping the specified axes.
56 ///
57 /// # Panics
58 /// - If `axis1` or `axis2` are out of bounds
59 ///
60 /// # Examples
61 /// ```
62 /// # use redstone_ml::*;
63 ///
64 /// let array = Tensor::new([[2.0, 3.0, 4.0], [10.0, 20.0, 30.0]]);
65 ///
66 /// let transposed = array.transpose(0, 1);
67 /// assert_eq!(transposed, Tensor::new([[2.0, 10.0], [3.0, 20.0], [4.0, 30.0]]));
68 /// ```
69 fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output {
70 let requires_grad = self.requires_grad();
71 let grad_fn =
72 if requires_grad {
73 TransposeBackwards::new(self, axis1.isize(), axis2.isize())
74 } else {
75 NoneBackwards::new()
76 };
77
78 let result = self.array.as_ref().transpose(axis1, axis2);
79
80 unsafe {
81 // NdArray<'static, T> needed to create a shared pointer to the result
82 // this function outputs a Tensor<'a, T> where ('a: 'static) so it should be safe.
83 let result = result.lifetime_cast();
84
85 Tensor::from_raw_parts(result, requires_grad, grad_fn)
86 }
87 }
88}
89
90impl<T: TensorDataType> Reshape<T> for Tensor<'_, T> {
91 type Output = Tensor<'static, T>;
92
93 /// Provides a non-owning view of the tensor with the specified shape and stride.
94 /// The data pointed to by the view is shared with the original tensor.
95 ///
96 /// # Safety
97 /// - Ensure the memory layout referenced by `shape`, and `stride` is valid and owned
98 /// by the original tensor.
99 unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output {
100 let requires_grad = self.requires_grad();
101 let grad_fn = if requires_grad { ReshapeBackwards::new(&self, self.shape()) } else { NoneBackwards::new() };
102
103 let result = self.into_ndarray().reshaped_view(shape, stride);
104 Tensor::from_raw_parts(result, requires_grad, grad_fn)
105 }
106
107 /// Provides a non-owning view of the tensor that shares its data with the original tensor.
108 ///
109 /// # Example
110 /// ```
111 /// # use redstone_ml::*;
112 ///
113 /// let tensor = Tensor::new([1.0, 2.0, 3.0, 4.0]);
114 /// let view = (&tensor).view();
115 /// assert!(view.is_view())
116 /// ```
117 fn view(self) -> Self::Output {
118 let requires_grad = self.requires_grad();
119 let grad_fn = if requires_grad { IdentityBackwards::new(&self) } else { NoneBackwards::new() };
120
121 let result = self.into_ndarray().view();
122 unsafe { Tensor::from_raw_parts(result, requires_grad, grad_fn) }
123 }
124
125 /// Returns a transposed version of the tensor, swapping the specified axes.
126 ///
127 /// # Panics
128 /// - If `axis1` or `axis2` are out of bounds
129 ///
130 /// # Examples
131 /// ```
132 /// # use redstone_ml::*;
133 ///
134 /// let array = Tensor::new([[2.0, 3.0, 4.0], [10.0, 20.0, 30.0]]);
135 ///
136 /// let transposed = array.transpose(0, 1);
137 /// assert_eq!(transposed, Tensor::new([[2.0, 10.0], [3.0, 20.0], [4.0, 30.0]]));
138 /// ```
139 fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output {
140 let requires_grad = self.requires_grad();
141 let grad_fn =
142 if requires_grad {
143 TransposeBackwards::new(&self, axis1.isize(), axis2.isize())
144 } else {
145 NoneBackwards::new()
146 };
147
148 let result = self.into_ndarray().transpose(axis1, axis2);
149 unsafe { Tensor::from_raw_parts(result, requires_grad, grad_fn) }
150 }
151}