redstone_ml/ndarray/
reshape.rs

1use crate::dtype::RawDataType;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::slice::update_flags_with_contiguity;
4use crate::{NdArray, Reshape, StridedMemory};
5use crate::common::constructors::Constructors;
6
7impl<'a, T: RawDataType> NdArray<'a, T> {
8    /// Returns a 1D copy of a flattened multidimensional ndarray.
9    ///
10    /// If copying the data is not desirable, it may be possible to return a view.
11    /// See `NdArray::ravel()`.
12    ///
13    /// # Examples
14    /// ```
15    /// # use redstone_ml::*;
16    ///
17    /// let ndarray = NdArray::new([[1, 2, 3], [4, 5, 6]]);
18    /// let flat_array = ndarray.flatten();
19    /// assert_eq!(flat_array, NdArray::new([1, 2, 3, 4, 5, 6]));
20    /// ```
21    pub fn flatten(&self) -> NdArray<'static, T> {
22        unsafe {
23            NdArray::from_contiguous_owned_buffer(vec![self.size()], self.clone_data())
24        }
25    }
26
27    /// Provides a non-owning view of the ndarray with the specified shape and stride.
28    /// The data pointed to by the view is shared with the original ndarray
29    /// but offset by the specified amount.
30    ///
31    /// # Safety
32    /// - Ensure the memory referenced by `offset`, `shape`, and `stride` is valid and owned
33    ///   by the original ndarray.
34    pub(crate) unsafe fn reshaped_view_with_offset(&'a self,
35                                                   offset: usize,
36                                                   shape: Vec<usize>,
37                                                   stride: Vec<usize>) -> NdArray<'a, T> {
38        let mut flags = update_flags_with_contiguity(self.flags, &shape, &stride);
39        flags -= NdArrayFlags::UserCreated;
40        flags -= NdArrayFlags::Owned;
41
42        NdArray {
43            ptr: self.ptr.add(offset),
44            len: shape.iter().product(),
45            capacity: 0,
46
47            shape,
48            stride,
49            flags,
50
51            _marker: self._marker,
52        }
53    }
54}
55
56impl<T: RawDataType> Reshape<T> for NdArray<'_, T> {
57    type Output = NdArray<'static, T>;
58
59    unsafe fn reshaped_view(mut self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output {
60        let flags = update_flags_with_contiguity(self.flags, &shape, &stride);
61
62        // prevent ndarray's data from being deallocated once this method ends
63        self.flags -= NdArrayFlags::Owned;
64
65        NdArray {
66            ptr: self.ptr,
67            len: self.len,
68            capacity: 0,
69
70            shape,
71            stride,
72            flags,
73
74            _marker: Default::default(),
75        }
76    }
77}
78
79impl<'a, T: RawDataType> Reshape<T> for &'a NdArray<'a, T> {
80    type Output = NdArray<'a, T>;
81
82    unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output {
83        let mut flags = update_flags_with_contiguity(self.flags, &shape, &stride);
84        flags -= NdArrayFlags::UserCreated;
85        flags -= NdArrayFlags::Owned;
86
87        NdArray {
88            ptr: self.ptr,
89            len: shape.iter().product(),
90            capacity: self.capacity,
91
92            shape,
93            stride,
94            flags,
95
96            _marker: Default::default(),
97        }
98    }
99}