redstone_ml/tensor/
reshape.rs

1use crate::none_backwards::NoneBackwards;
2use crate::reshape_backwards::ReshapeBackwards;
3use crate::transpose_backwards::TransposeBackwards;
4use crate::{AxisType, Reshape, StridedMemory, Tensor, TensorDataType};
5use crate::identity_backwards::IdentityBackwards;
6
7
8impl<'a, T: TensorDataType> Reshape<T> for &'a Tensor<'a, T> {
9    type Output = Tensor<'a, T>;
10
11    /// Provides a non-owning view of the tensor with the specified shape and stride.
12    /// The data pointed to by the view is shared with the original tensor.
13    ///
14    /// # Safety
15    /// - Ensure the memory layout referenced by `shape`, and `stride` is valid and owned
16    ///   by the original tensor.
17    unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output {
18        let requires_grad = self.requires_grad();
19        let grad_fn = if requires_grad { ReshapeBackwards::new(self, self.shape()) } else { NoneBackwards::new() };
20
21        let result = self.array.as_ref().reshaped_view(shape, stride);
22
23        // NdArray<'static, T> needed to create a shared pointer to the result
24        // this function outputs a Tensor<'a, T> where ('a: 'static) so it should be safe.
25        let result = result.lifetime_cast();
26
27        Tensor::from_raw_parts(result, requires_grad, grad_fn)
28    }
29
30    /// Provides a non-owning view of the tensor that shares its data with the original tensor.
31    ///
32    /// # Example
33    /// ```
34    /// # use redstone_ml::*;
35    ///
36    /// let tensor = Tensor::new([1.0, 2.0, 3.0, 4.0]);
37    /// let view = (&tensor).view();
38    /// assert!(view.is_view())
39    /// ```
40    fn view(self) -> Self::Output {
41        let requires_grad = self.requires_grad();
42        let grad_fn = if requires_grad { IdentityBackwards::new(self) } else { NoneBackwards::new() };
43
44        let result = self.array.as_ref().view();
45
46        unsafe {
47            // NdArray<'static, T> needed to create a shared pointer to the result
48            // this function outputs a Tensor<'a, T> where ('a: 'static) so it should be safe.
49            let result = result.lifetime_cast();
50
51            Tensor::from_raw_parts(result, requires_grad, grad_fn)
52        }
53    }
54
55    /// Returns a transposed version of the tensor, swapping the specified axes.
56    ///
57    /// # Panics
58    /// - If `axis1` or `axis2` are out of bounds
59    ///
60    /// # Examples
61    /// ```
62    /// # use redstone_ml::*;
63    ///
64    /// let array = Tensor::new([[2.0, 3.0, 4.0], [10.0, 20.0, 30.0]]);
65    ///
66    /// let transposed = array.transpose(0, 1);
67    /// assert_eq!(transposed, Tensor::new([[2.0, 10.0], [3.0, 20.0], [4.0, 30.0]]));
68    /// ```
69    fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output {
70        let requires_grad = self.requires_grad();
71        let grad_fn =
72            if requires_grad {
73                TransposeBackwards::new(self, axis1.isize(), axis2.isize())
74            } else {
75                NoneBackwards::new()
76            };
77
78        let result = self.array.as_ref().transpose(axis1, axis2);
79
80        unsafe {
81            // NdArray<'static, T> needed to create a shared pointer to the result
82            // this function outputs a Tensor<'a, T> where ('a: 'static) so it should be safe.
83            let result = result.lifetime_cast();
84
85            Tensor::from_raw_parts(result, requires_grad, grad_fn)
86        }
87    }
88}
89
90impl<T: TensorDataType> Reshape<T> for Tensor<'_, T> {
91    type Output = Tensor<'static, T>;
92
93    /// Provides a non-owning view of the tensor with the specified shape and stride.
94    /// The data pointed to by the view is shared with the original tensor.
95    ///
96    /// # Safety
97    /// - Ensure the memory layout referenced by `shape`, and `stride` is valid and owned
98    ///   by the original tensor.
99    unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output {
100        let requires_grad = self.requires_grad();
101        let grad_fn = if requires_grad { ReshapeBackwards::new(&self, self.shape()) } else { NoneBackwards::new() };
102
103        let result = self.into_ndarray().reshaped_view(shape, stride);
104        Tensor::from_raw_parts(result, requires_grad, grad_fn)
105    }
106
107    /// Provides a non-owning view of the tensor that shares its data with the original tensor.
108    ///
109    /// # Example
110    /// ```
111    /// # use redstone_ml::*;
112    ///
113    /// let tensor = Tensor::new([1.0, 2.0, 3.0, 4.0]);
114    /// let view = (&tensor).view();
115    /// assert!(view.is_view())
116    /// ```
117    fn view(self) -> Self::Output {
118        let requires_grad = self.requires_grad();
119        let grad_fn = if requires_grad { IdentityBackwards::new(&self) } else { NoneBackwards::new() };
120
121        let result = self.into_ndarray().view();
122        unsafe { Tensor::from_raw_parts(result, requires_grad, grad_fn) }
123    }
124
125    /// Returns a transposed version of the tensor, swapping the specified axes.
126    ///
127    /// # Panics
128    /// - If `axis1` or `axis2` are out of bounds
129    ///
130    /// # Examples
131    /// ```
132    /// # use redstone_ml::*;
133    ///
134    /// let array = Tensor::new([[2.0, 3.0, 4.0], [10.0, 20.0, 30.0]]);
135    ///
136    /// let transposed = array.transpose(0, 1);
137    /// assert_eq!(transposed, Tensor::new([[2.0, 10.0], [3.0, 20.0], [4.0, 30.0]]));
138    /// ```
139    fn transpose(self, axis1: impl AxisType, axis2: impl AxisType) -> Self::Output {
140        let requires_grad = self.requires_grad();
141        let grad_fn =
142            if requires_grad {
143                TransposeBackwards::new(&self, axis1.isize(), axis2.isize())
144            } else {
145                NoneBackwards::new()
146            };
147
148        let result = self.into_ndarray().transpose(axis1, axis2);
149        unsafe { Tensor::from_raw_parts(result, requires_grad, grad_fn) }
150    }
151}