redstone_ml/ndarray/methods.rs
1use crate::dtype::RawDataType;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::common::methods::StridedMemory;
4use crate::NdArray;
5
6impl<'a, T: RawDataType> NdArray<'a, T> {
7 /// Retrieves the single value contained within an ndarray with a singular element.
8 ///
9 /// # Panics
10 /// If the ndarray contains more than one element (i.e., it is not a scalar or an ndarray with a
11 /// single element)
12 ///
13 /// # Example
14 /// ```
15 /// # use redstone_ml::*;
16 ///
17 /// let ndarray = NdArray::scalar(50f32);
18 /// let value = ndarray.value();
19 /// assert_eq!(value, 50.0);
20 /// ```
21 ///
22 /// # Notes
23 /// This function is only meant for arrays that are guaranteed to have
24 /// exactly one element. For arrays with multiple elements, consider using
25 /// appropriate methods to access individual elements or slices safely.
26 pub fn value(&self) -> T {
27 assert_eq!(self.size(), 1, "cannot get value of an ndarray with more than one element");
28 unsafe { self.ptr.read() }
29 }
30
31 /// Returns a slice of the ndarray's (flattened) data buffer
32 ///
33 /// # Example
34 /// ```
35 /// # use redstone_ml::*;
36 ///
37 /// let ndarray = NdArray::new([[50, 60], [-5, -10]]);
38 /// let data = ndarray.data_slice();
39 /// assert_eq!(data, &[50, 60, -5, -10]);
40 /// ```
41 pub fn data_slice(&self) -> &'a [T] {
42 assert!(self.is_contiguous(), "cannot get data slice of non-contiguous tensor");
43 unsafe { std::slice::from_raw_parts(self.ptr(), self.len) }
44 }
45
46 /// Converts an `NdArray` into its underlying data vector by flattening its dimensions.
47 ///
48 /// # Panics
49 /// - If the ndarray does not own its data (it is a NdArray view).
50 ///
51 /// # Example
52 /// ```
53 /// # use redstone_ml::*;
54 ///
55 /// let ndarray = NdArray::new([[50, 60], [-5, -10]]);
56 /// let data = ndarray.into_data_vector();
57 /// assert_eq!(data, vec![50, 60, -5, -10]);
58 /// ```
59 pub fn into_data_vector(mut self) -> Vec<T> {
60 if !self.flags.contains(NdArrayFlags::Owned) {
61 panic!("cannot return data vector of non-owned tensor");
62 }
63 assert!(self.is_contiguous(), "cannot get data vector of non-contiguous tensor");
64
65 // ensure the vector's data is not dropped when self goes out of scope and is destroyed
66 self.flags -= NdArrayFlags::Owned;
67
68 unsafe { Vec::from_raw_parts(self.mut_ptr(), self.len, self.capacity) }
69 }
70}
71
72impl<T: RawDataType> StridedMemory for NdArray<'_, T> {
73 #[inline]
74 fn shape(&self) -> &[usize] {
75 &self.shape
76 }
77
78 #[inline]
79 fn stride(&self) -> &[usize] {
80 &self.stride
81 }
82
83 #[inline]
84 fn flags(&self) -> NdArrayFlags {
85 self.flags
86 }
87}
88
89impl<T: RawDataType> StridedMemory for &NdArray<'_, T> {
90 #[inline]
91 fn shape(&self) -> &[usize] {
92 &self.shape
93 }
94
95 #[inline]
96 fn stride(&self) -> &[usize] {
97 &self.stride
98 }
99
100 #[inline]
101 fn flags(&self) -> NdArrayFlags {
102 self.flags
103 }
104}
105
106impl<'a, T: RawDataType> NdArray<'a, T> {
107 pub(crate) unsafe fn mut_ptr(&self) -> *mut T {
108 self.ptr.as_ptr()
109 }
110
111 pub(crate) unsafe fn ptr(&self) -> *const T {
112 self.ptr.as_ptr()
113 }
114}