redstone_ml/ndarray/
methods.rs

1use crate::dtype::RawDataType;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::common::methods::StridedMemory;
4use crate::NdArray;
5
6impl<'a, T: RawDataType> NdArray<'a, T> {
7    /// Retrieves the single value contained within an ndarray with a singular element.
8    ///
9    /// # Panics
10    /// If the ndarray contains more than one element (i.e., it is not a scalar or an ndarray with a
11    /// single element)
12    ///
13    /// # Example
14    /// ```
15    /// # use redstone_ml::*;
16    ///
17    /// let ndarray = NdArray::scalar(50f32);
18    /// let value = ndarray.value();
19    /// assert_eq!(value, 50.0);
20    /// ```
21    ///
22    /// # Notes
23    /// This function is only meant for arrays that are guaranteed to have
24    /// exactly one element. For arrays with multiple elements, consider using
25    /// appropriate methods to access individual elements or slices safely.
26    pub fn value(&self) -> T {
27        assert_eq!(self.size(), 1, "cannot get value of an ndarray with more than one element");
28        unsafe { self.ptr.read() }
29    }
30
31    /// Returns a slice of the ndarray's (flattened) data buffer
32    ///
33    /// # Example
34    /// ```
35    /// # use redstone_ml::*;
36    ///
37    /// let ndarray = NdArray::new([[50, 60], [-5, -10]]);
38    /// let data = ndarray.data_slice();
39    /// assert_eq!(data, &[50, 60, -5, -10]);
40    /// ```
41    pub fn data_slice(&self) -> &'a [T] {
42        assert!(self.is_contiguous(), "cannot get data slice of non-contiguous tensor");
43        unsafe { std::slice::from_raw_parts(self.ptr(), self.len) }
44    }
45
46    /// Converts an `NdArray` into its underlying data vector by flattening its dimensions.
47    ///
48    /// # Panics
49    /// - If the ndarray does not own its data (it is a NdArray view).
50    ///
51    /// # Example
52    /// ```
53    /// # use redstone_ml::*;
54    ///
55    /// let ndarray = NdArray::new([[50, 60], [-5, -10]]);
56    /// let data = ndarray.into_data_vector();
57    /// assert_eq!(data, vec![50, 60, -5, -10]);
58    /// ```
59    pub fn into_data_vector(mut self) -> Vec<T> {
60        if !self.flags.contains(NdArrayFlags::Owned) {
61            panic!("cannot return data vector of non-owned tensor");
62        }
63        assert!(self.is_contiguous(), "cannot get data vector of non-contiguous tensor");
64
65        // ensure the vector's data is not dropped when self goes out of scope and is destroyed
66        self.flags -= NdArrayFlags::Owned;
67
68        unsafe { Vec::from_raw_parts(self.mut_ptr(), self.len, self.capacity) }
69    }
70}
71
72impl<T: RawDataType> StridedMemory for NdArray<'_, T> {
73    #[inline]
74    fn shape(&self) -> &[usize] {
75        &self.shape
76    }
77
78    #[inline]
79    fn stride(&self) -> &[usize] {
80        &self.stride
81    }
82
83    #[inline]
84    fn flags(&self) -> NdArrayFlags {
85        self.flags
86    }
87}
88
89impl<T: RawDataType> StridedMemory for &NdArray<'_, T> {
90    #[inline]
91    fn shape(&self) -> &[usize] {
92        &self.shape
93    }
94
95    #[inline]
96    fn stride(&self) -> &[usize] {
97        &self.stride
98    }
99
100    #[inline]
101    fn flags(&self) -> NdArrayFlags {
102        self.flags
103    }
104}
105
106impl<'a, T: RawDataType> NdArray<'a, T> {
107    pub(crate) unsafe fn mut_ptr(&self) -> *mut T {
108        self.ptr.as_ptr()
109    }
110
111    pub(crate) unsafe fn ptr(&self) -> *const T {
112        self.ptr.as_ptr()
113    }
114}