redstone_ml/tensor/methods.rs
1use std::rc::Rc;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::common::methods::StridedMemory;
4use crate::{NdArray, Tensor, TensorDataType};
5
6impl<'a, T: TensorDataType> Tensor<'a, T> {
7 /// Retrieves the single value contained within a tensor with a singular element.
8 ///
9 /// # Panics
10 /// If the tensor contains more than one element (i.e., it is not a scalar or a tensor with a
11 /// single element)
12 ///
13 /// # Example
14 /// ```
15 /// # use redstone_ml::*;
16 ///
17 /// let tensor = Tensor::scalar(50.0);
18 /// let value = tensor.value();
19 /// assert_eq!(value, 50.0);
20 /// ```
21 ///
22 /// # Notes
23 /// This function is only meant for arrays that are guaranteed to have
24 /// exactly one element. For arrays with multiple elements, consider using
25 /// appropriate methods to access individual elements or slices safely.
26 pub fn value(&self) -> T {
27 self.array.value()
28 }
29
30 /// Returns a reference to the underlying `NdArray` of the tensor
31 pub fn ndarray(&self) -> &NdArray<'a, T> {
32 self.array.as_ref()
33 }
34
35 /// Returns a reference-counted pointer to the underlying `NdArray` of the tensor
36 pub fn get_ndarray(&self) -> Rc<NdArray<'static, T>> {
37 self.array.clone()
38 }
39
40 /// Converts the tensor to an `NdArray`
41 pub fn into_ndarray(self) -> NdArray<'static, T> {
42 match Rc::try_unwrap(self.array) {
43 Ok(result) => { result }
44 Err(rc) => { (*rc).clone() }
45 }
46 }
47}
48
49#[allow(clippy::len_without_is_empty)]
50impl<'a, T: TensorDataType> StridedMemory for Tensor<'a, T> {
51 /// Returns the dimensions of the tensor along each axis.
52 ///
53 /// ```
54 /// # use redstone_ml::*;
55 ///
56 /// let a = Tensor::new([3.0, 4.0, 5.0]);
57 /// assert_eq!(a.shape(), &[3]);
58 ///
59 /// let b = Tensor::new([[3.0], [5.0]]);
60 /// assert_eq!(b.shape(), &[2, 1]);
61 ///
62 /// let c = Tensor::scalar(0.0);
63 /// assert_eq!(c.shape(), &[]);
64 /// ```
65 #[inline]
66 fn shape(&self) -> &[usize] {
67 self.array.shape()
68 }
69
70 /// Returns the stride of the tensor.
71 ///
72 /// The stride represents the distance in memory between elements in a tensor along each axis.
73 ///
74 /// ```
75 /// # use redstone_ml::*;
76 ///
77 /// let a = Tensor::new([[3.0, 4.0], [5.0, 6.0]]);
78 /// assert_eq!(a.stride(), &[2, 1]);
79 /// ```
80 #[inline]
81 fn stride(&self) -> &[usize] {
82 self.array.stride()
83 }
84
85 /// Returns flags containing information about various tensor metadata.
86 #[inline]
87 fn flags(&self) -> NdArrayFlags {
88 self.array.flags()
89 }
90}
91
92#[allow(clippy::len_without_is_empty)]
93impl<T: TensorDataType> StridedMemory for &Tensor<'_, T> {
94 /// Returns the dimensions of the tensor along each axis.
95 ///
96 /// ```
97 /// # use redstone_ml::*;
98 ///
99 /// let a = Tensor::new([3.0, 4.0, 5.0]);
100 /// assert_eq!(a.shape(), &[3]);
101 ///
102 /// let b = Tensor::new([[3.0], [5.0]]);
103 /// assert_eq!(b.shape(), &[2, 1]);
104 ///
105 /// let c = Tensor::scalar(0.0);
106 /// assert_eq!(c.shape(), &[]);
107 /// ```
108 #[inline]
109 fn shape(&self) -> &[usize] {
110 self.array.shape()
111 }
112
113 /// Returns the stride of the tensor.
114 ///
115 /// The stride represents the distance in memory between elements in a tensor along each axis.
116 ///
117 /// ```
118 /// # use redstone_ml::*;
119 ///
120 /// let a = Tensor::new([[3.0, 4.0], [5.0, 6.0]]);
121 /// assert_eq!(a.stride(), &[2, 1]);
122 /// ```
123 #[inline]
124 fn stride(&self) -> &[usize] {
125 self.array.stride()
126 }
127
128 /// Returns flags containing information about various tensor metadata.
129 #[inline]
130 fn flags(&self) -> NdArrayFlags {
131 self.array.flags()
132 }
133}