redstone_ml/ndarray/
clone.rs

1use crate::dtype::RawDataType;
2use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
3use crate::iterator::flat_index_generator::FlatIndexGenerator;
4use crate::{Constructors, NdArray, StridedMemory};
5use std::ptr::copy_nonoverlapping;
6
7impl<'a, T: RawDataType> NdArray<'a, T> {
8    #[allow(clippy::should_implement_trait)]
9    pub fn clone<'r>(&self) -> NdArray<'r, T> {
10        unsafe { NdArray::from_contiguous_owned_buffer(self.shape.clone(), self.clone_data()) }
11    }
12
13    pub(super) fn clone_data(&self) -> Vec<T> {
14        if self.is_contiguous() {
15            return unsafe { self.clone_data_contiguous() };
16        }
17        unsafe { self.clone_data_non_contiguous() }
18    }
19
20    /// Safety: expects ndarray buffer is contiguously stored
21    unsafe fn clone_data_contiguous(&self) -> Vec<T> {
22        let mut data = Vec::with_capacity(self.len);
23
24        copy_nonoverlapping(self.ptr(), data.as_mut_ptr(), self.len);
25        data.set_len(self.len);
26        data
27    }
28
29    /// Safety: expects ndarray buffer is not contiguously stored
30    unsafe fn clone_data_non_contiguous(&self) -> Vec<T> {
31        let size = self.size();
32        let mut data = Vec::with_capacity(size);
33
34        let (mut shape, mut stride) = collapse_to_uniform_stride(&self.shape, &self.stride);
35
36        // safe to unwrap because if stride has no elements, this would be a scalar ndarray
37        // however, scalar arrays are contiguously stored so this method wouldn't be called
38        let &mut mut contiguous_stride = stride.last_mut().unwrap();
39
40        // if elements along the last axis are located contiguously,
41        // we can collapse the last dimension and copy contiguous_stride elements at once
42        if contiguous_stride == 1 {
43            contiguous_stride = shape.pop().unwrap();
44            stride.pop();
45        }
46
47        // if elements along the last axis aren't located contiguously,
48        // they must correspond to an NdArray view with a step-size along the last axis of > 1
49        // this is equivalent to 1 contiguous element along the last axis
50        else {
51            contiguous_stride = 1;
52        }
53
54        let src = self.ptr();
55        let mut dst = data.as_mut_ptr();
56
57        for i in FlatIndexGenerator::from(&shape, &stride) {
58            copy_nonoverlapping(src.add(i), dst, contiguous_stride);
59            dst = dst.add(contiguous_stride);
60        }
61
62        data.set_len(size);
63        data
64    }
65}