redstone_ml/ndarray/
reshape.rs1use crate::dtype::RawDataType;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::slice::update_flags_with_contiguity;
4use crate::{NdArray, Reshape, StridedMemory};
5use crate::common::constructors::Constructors;
6
7impl<'a, T: RawDataType> NdArray<'a, T> {
8 pub fn flatten(&self) -> NdArray<'static, T> {
22 unsafe {
23 NdArray::from_contiguous_owned_buffer(vec![self.size()], self.clone_data())
24 }
25 }
26
27 pub(crate) unsafe fn reshaped_view_with_offset(&'a self,
35 offset: usize,
36 shape: Vec<usize>,
37 stride: Vec<usize>) -> NdArray<'a, T> {
38 let mut flags = update_flags_with_contiguity(self.flags, &shape, &stride);
39 flags -= NdArrayFlags::UserCreated;
40 flags -= NdArrayFlags::Owned;
41
42 NdArray {
43 ptr: self.ptr.add(offset),
44 len: shape.iter().product(),
45 capacity: 0,
46
47 shape,
48 stride,
49 flags,
50
51 _marker: self._marker,
52 }
53 }
54}
55
56impl<T: RawDataType> Reshape<T> for NdArray<'_, T> {
57 type Output = NdArray<'static, T>;
58
59 unsafe fn reshaped_view(mut self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output {
60 let flags = update_flags_with_contiguity(self.flags, &shape, &stride);
61
62 self.flags -= NdArrayFlags::Owned;
64
65 NdArray {
66 ptr: self.ptr,
67 len: self.len,
68 capacity: 0,
69
70 shape,
71 stride,
72 flags,
73
74 _marker: Default::default(),
75 }
76 }
77}
78
79impl<'a, T: RawDataType> Reshape<T> for &'a NdArray<'a, T> {
80 type Output = NdArray<'a, T>;
81
82 unsafe fn reshaped_view(self, shape: Vec<usize>, stride: Vec<usize>) -> Self::Output {
83 let mut flags = update_flags_with_contiguity(self.flags, &shape, &stride);
84 flags -= NdArrayFlags::UserCreated;
85 flags -= NdArrayFlags::Owned;
86
87 NdArray {
88 ptr: self.ptr,
89 len: shape.iter().product(),
90 capacity: self.capacity,
91
92 shape,
93 stride,
94 flags,
95
96 _marker: Default::default(),
97 }
98 }
99}