redstone_ml/tensor/
methods.rs

1use std::rc::Rc;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::common::methods::StridedMemory;
4use crate::{NdArray, Tensor, TensorDataType};
5
6impl<'a, T: TensorDataType> Tensor<'a, T> {
7    /// Retrieves the single value contained within a tensor with a singular element.
8    ///
9    /// # Panics
10    /// If the tensor contains more than one element (i.e., it is not a scalar or a tensor with a
11    /// single element)
12    ///
13    /// # Example
14    /// ```
15    /// # use redstone_ml::*;
16    ///
17    /// let tensor = Tensor::scalar(50.0);
18    /// let value = tensor.value();
19    /// assert_eq!(value, 50.0);
20    /// ```
21    ///
22    /// # Notes
23    /// This function is only meant for arrays that are guaranteed to have
24    /// exactly one element. For arrays with multiple elements, consider using
25    /// appropriate methods to access individual elements or slices safely.
26    pub fn value(&self) -> T {
27        self.array.value()
28    }
29
30    /// Returns a reference to the underlying `NdArray` of the tensor
31    pub fn ndarray(&self) -> &NdArray<'a, T> {
32        self.array.as_ref()
33    }
34
35    /// Returns a reference-counted pointer to the underlying `NdArray` of the tensor
36    pub fn get_ndarray(&self) -> Rc<NdArray<'static, T>> {
37        self.array.clone()
38    }
39
40    /// Converts the tensor to an `NdArray`
41    pub fn into_ndarray(self) -> NdArray<'static, T> {
42        match Rc::try_unwrap(self.array) {
43            Ok(result) => { result }
44            Err(rc) => { (*rc).clone() }
45        }
46    }
47}
48
49#[allow(clippy::len_without_is_empty)]
50impl<'a, T: TensorDataType> StridedMemory for Tensor<'a, T> {
51    /// Returns the dimensions of the tensor along each axis.
52    ///
53    /// ```
54    /// # use redstone_ml::*;
55    ///
56    /// let a = Tensor::new([3.0, 4.0, 5.0]);
57    /// assert_eq!(a.shape(), &[3]);
58    ///
59    /// let b = Tensor::new([[3.0], [5.0]]);
60    /// assert_eq!(b.shape(), &[2, 1]);
61    ///
62    /// let c = Tensor::scalar(0.0);
63    /// assert_eq!(c.shape(), &[]);
64    /// ```
65    #[inline]
66    fn shape(&self) -> &[usize] {
67        self.array.shape()
68    }
69
70    /// Returns the stride of the tensor.
71    ///
72    /// The stride represents the distance in memory between elements in a tensor along each axis.
73    ///
74    /// ```
75    /// # use redstone_ml::*;
76    ///
77    /// let a = Tensor::new([[3.0, 4.0], [5.0, 6.0]]);
78    /// assert_eq!(a.stride(), &[2, 1]);
79    /// ```
80    #[inline]
81    fn stride(&self) -> &[usize] {
82        self.array.stride()
83    }
84
85    /// Returns flags containing information about various tensor metadata.
86    #[inline]
87    fn flags(&self) -> NdArrayFlags {
88        self.array.flags()
89    }
90}
91
92#[allow(clippy::len_without_is_empty)]
93impl<T: TensorDataType> StridedMemory for &Tensor<'_, T> {
94    /// Returns the dimensions of the tensor along each axis.
95    ///
96    /// ```
97    /// # use redstone_ml::*;
98    ///
99    /// let a = Tensor::new([3.0, 4.0, 5.0]);
100    /// assert_eq!(a.shape(), &[3]);
101    ///
102    /// let b = Tensor::new([[3.0], [5.0]]);
103    /// assert_eq!(b.shape(), &[2, 1]);
104    ///
105    /// let c = Tensor::scalar(0.0);
106    /// assert_eq!(c.shape(), &[]);
107    /// ```
108    #[inline]
109    fn shape(&self) -> &[usize] {
110        self.array.shape()
111    }
112
113    /// Returns the stride of the tensor.
114    ///
115    /// The stride represents the distance in memory between elements in a tensor along each axis.
116    ///
117    /// ```
118    /// # use redstone_ml::*;
119    ///
120    /// let a = Tensor::new([[3.0, 4.0], [5.0, 6.0]]);
121    /// assert_eq!(a.stride(), &[2, 1]);
122    /// ```
123    #[inline]
124    fn stride(&self) -> &[usize] {
125        self.array.stride()
126    }
127
128    /// Returns flags containing information about various tensor metadata.
129    #[inline]
130    fn flags(&self) -> NdArrayFlags {
131        self.array.flags()
132    }
133}