redstone_ml/ndarray/
constructors.rs

1use crate::ndarray::flags::NdArrayFlags;
2use crate::ndarray::NdArray;
3use crate::common::constructors::Constructors;
4use crate::RawDataType;
5use std::mem::ManuallyDrop;
6use std::ptr::NonNull;
7
8/// Computes the stride of an ndarray from its given shape assuming a contiguous layout.
9///
10/// In the context of multidimensional arrays, the stride refers to the number of elements
11/// that need to be skipped in memory to move to the next element along each dimension.
12/// Strides are calculated by determining how many elements are spanned by the dimensions
13/// following a particular axis.
14///
15/// # Arguments
16///
17/// * `shape` - A slice representing the shape of the ndarray.
18///
19/// # Returns
20///
21/// A `Vec<usize>` containing the stride for each dimension of the ndarray, with the same
22/// length as the input `shape`. The result indicates how many elements need to be skipped
23/// in memory to traverse the ndarray along each dimension.
24///
25/// # Example
26///
27/// ```
28/// let shape = vec![5, 3, 2, 1];
29///
30/// // stride would be [10, 2, 1, 1]
31/// // Axis 0 (size 5): stride = 3 * 2 * 1 * 1 = 10
32/// // Axis 1 (size 3): stride = 2 * 1 * 1 = 2
33/// // Axis 2 (size 2): stride = 1 * 1
34/// // Axis 3 (size 1): stride is always 1
35/// ```
36pub(crate) fn stride_from_shape(shape: &[usize]) -> Vec<usize> {
37    let ndims = shape.len();
38    let mut stride = vec![0; ndims];
39
40    let mut p = 1;
41    for i in (0..ndims).rev() {
42        stride[i] = p;
43        p *= shape[i];
44    }
45
46    stride
47}
48
49
50impl<'a, T: RawDataType> Constructors<T> for NdArray<'a, T> {
51    unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self {
52        let flags = NdArrayFlags::Owned | NdArrayFlags::Contiguous | NdArrayFlags::UniformStride | NdArrayFlags::Writeable;
53
54        // take control of the data so that Rust doesn't drop it once the vector goes out of scope
55        let mut data = ManuallyDrop::new(data);
56        let stride = stride_from_shape(&shape);
57
58        Self {
59            ptr: NonNull::new_unchecked(data.as_mut_ptr()),
60            len: data.len(),
61            capacity: data.capacity(),
62
63            shape,
64            stride,
65            flags,
66
67            _marker: Default::default(),
68        }
69    }
70}
71
72impl<T: RawDataType> Drop for NdArray<'_, T> {
73    /// This method is implicitly invoked when the ndarray is deleted to clean up its memory if
74    /// the ndarray owns its data (i.e. it is not a view into another ndarray ).
75    ///
76    /// Resets `self.len` and `self.capacity` to 0.
77    fn drop(&mut self) {
78        if self.flags.contains(NdArrayFlags::Owned) {
79            // drops the data
80            unsafe { Vec::from_raw_parts(self.mut_ptr(), self.len, self.capacity) };
81        }
82
83        self.len = 0;
84        self.capacity = 0;
85    }
86}