redstone_ml/ndarray/
clone.rs1use crate::dtype::RawDataType;
2use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
3use crate::iterator::flat_index_generator::FlatIndexGenerator;
4use crate::{Constructors, NdArray, StridedMemory};
5use std::ptr::copy_nonoverlapping;
6
7impl<'a, T: RawDataType> NdArray<'a, T> {
8 #[allow(clippy::should_implement_trait)]
9 pub fn clone<'r>(&self) -> NdArray<'r, T> {
10 unsafe { NdArray::from_contiguous_owned_buffer(self.shape.clone(), self.clone_data()) }
11 }
12
13 pub(super) fn clone_data(&self) -> Vec<T> {
14 if self.is_contiguous() {
15 return unsafe { self.clone_data_contiguous() };
16 }
17 unsafe { self.clone_data_non_contiguous() }
18 }
19
20 unsafe fn clone_data_contiguous(&self) -> Vec<T> {
22 let mut data = Vec::with_capacity(self.len);
23
24 copy_nonoverlapping(self.ptr(), data.as_mut_ptr(), self.len);
25 data.set_len(self.len);
26 data
27 }
28
29 unsafe fn clone_data_non_contiguous(&self) -> Vec<T> {
31 let size = self.size();
32 let mut data = Vec::with_capacity(size);
33
34 let (mut shape, mut stride) = collapse_to_uniform_stride(&self.shape, &self.stride);
35
36 let &mut mut contiguous_stride = stride.last_mut().unwrap();
39
40 if contiguous_stride == 1 {
43 contiguous_stride = shape.pop().unwrap();
44 stride.pop();
45 }
46
47 else {
51 contiguous_stride = 1;
52 }
53
54 let src = self.ptr();
55 let mut dst = data.as_mut_ptr();
56
57 for i in FlatIndexGenerator::from(&shape, &stride) {
58 copy_nonoverlapping(src.add(i), dst, contiguous_stride);
59 dst = dst.add(contiguous_stride);
60 }
61
62 data.set_len(size);
63 data
64 }
65}