concision_core/tensor/
tensor.rs

1/*
2    appellation: tensor <module>
3    authors: @FL03
4*/
5use ndarray::{ArrayBase, Data, DataMut, DataOwned, Dimension, NdIndex, RawData, ShapeBuilder};
6use num_traits::{One, Zero};
7
8/// the [`TensorBase`] struct is the base type for all tensors in the library.
9#[repr(transparent)]
10pub struct TensorBase<S, D>
11where
12    D: Dimension,
13    S: RawData,
14{
15    pub(crate) store: ArrayBase<S, D>,
16}
17
18impl<A, S, D> TensorBase<S, D>
19where
20    D: Dimension,
21    S: RawData<Elem = A>,
22{
23    /// create a new [`TensorBase`] from the given store.
24    pub const fn from_ndarray(store: ArrayBase<S, D>) -> Self {
25        Self { store }
26    }
27    /// create a new [`TensorBase`] from the given shape and a function to fill it.
28    pub fn from_shape_fn<Sh, F>(shape: Sh, f: F) -> Self
29    where
30        S: DataOwned,
31        Sh: ShapeBuilder<Dim = D>,
32        F: FnMut(D::Pattern) -> A,
33    {
34        Self {
35            store: ArrayBase::from_shape_fn(shape, f),
36        }
37    }
38    /// create a new [`TensorBase`] from the given shape and a function to fill it.
39    pub fn from_fn_with_shape<Sh, F>(shape: Sh, f: F) -> Self
40    where
41        S: DataOwned,
42        Sh: ShapeBuilder<Dim = D>,
43        F: Fn() -> A,
44    {
45        Self::from_shape_fn(shape, |_| f())
46    }
47    /// returns a new instance of the [`TensorBase`] with the given shape and values initialized
48    /// to zero.
49    pub fn ones<Sh>(shape: Sh) -> Self
50    where
51        A: Clone + One,
52        S: DataOwned,
53        Sh: ShapeBuilder<Dim = D>,
54    {
55        Self::from_fn_with_shape(shape, A::one)
56    }
57    /// returns a new instance of the [`TensorBase`] with the given shape and values initialized
58    /// to zero.
59    pub fn zeros<Sh>(shape: Sh) -> Self
60    where
61        A: Clone + Zero,
62        S: DataOwned,
63        Sh: ShapeBuilder<Dim = D>,
64    {
65        Self::from_fn_with_shape(shape, A::zero)
66    }
67    /// returns a reference to the element at the given index, if any
68    pub fn get<Ix>(&self, index: Ix) -> Option<&A>
69    where
70        S: Data,
71        Ix: NdIndex<D>,
72    {
73        self.store().get(index)
74    }
75    /// returns a mutable reference to the element at the given index, if any
76    pub fn get_mut<Ix>(&mut self, index: Ix) -> Option<&mut A>
77    where
78        S: DataMut,
79        Ix: NdIndex<D>,
80    {
81        self.store_mut().get_mut(index)
82    }
83    /// applies the function to every element within the tensor
84    pub fn map<F, B>(&self, f: F) -> super::Tensor<B, D>
85    where
86        A: Clone,
87        S: Data,
88        F: FnMut(A) -> B,
89    {
90        TensorBase {
91            store: self.store().mapv(f),
92        }
93    }
94    /// this method applies the function to the store, capturing the result in a new tensor.
95    pub(crate) fn mapd<F, U, S2, D2>(&self, f: F) -> TensorBase<S2, D2>
96    where
97        D2: Dimension,
98        S2: RawData<Elem = U>,
99        F: FnOnce(&ArrayBase<S, D>) -> ArrayBase<S2, D2>,
100    {
101        TensorBase {
102            store: f(self.store()),
103        }
104    }
105}
106
107#[doc(hidden)]
108#[allow(dead_code)]
109impl<A, S, D> TensorBase<S, D>
110where
111    D: Dimension,
112    S: RawData<Elem = A>,
113{
114    /// returns an immutable reference to the store of the tensor
115    pub(crate) const fn store(&self) -> &ArrayBase<S, D> {
116        &self.store
117    }
118    /// returns a mutable reference to the store of the tensor
119    pub(crate) const fn store_mut(&mut self) -> &mut ArrayBase<S, D> {
120        &mut self.store
121    }
122    /// update the current store and return a mutable reference to self
123    pub(crate) fn set_store(&mut self, store: ArrayBase<S, D>) -> &mut Self {
124        self.store = store;
125        self
126    }
127}