concision_core/tensor/
tensor.rs1use ndarray::{ArrayBase, Data, DataMut, DataOwned, Dimension, NdIndex, RawData, ShapeBuilder};
6use num_traits::{One, Zero};
7
8#[doc(hidden)]
9pub struct TensorBase<S, D>
11where
12    D: Dimension,
13    S: RawData,
14{
15    pub(crate) store: ArrayBase<S, D>,
16}
17
18impl<A, S, D> TensorBase<S, D>
19where
20    D: Dimension,
21    S: RawData<Elem = A>,
22{
23    pub const fn from_ndarray(store: ArrayBase<S, D>) -> Self {
25        Self { store }
26    }
27    pub fn from_shape_fn<Sh, F>(shape: Sh, f: F) -> Self
29    where
30        S: DataOwned,
31        Sh: ShapeBuilder<Dim = D>,
32        F: FnMut(D::Pattern) -> A,
33    {
34        Self {
35            store: ArrayBase::from_shape_fn(shape, f),
36        }
37    }
38    pub fn from_fn_with_shape<Sh, F>(shape: Sh, f: F) -> Self
40    where
41        S: DataOwned,
42        Sh: ShapeBuilder<Dim = D>,
43        F: Fn() -> A,
44    {
45        Self::from_shape_fn(shape, |_| f())
46    }
47    pub fn ones<Sh>(shape: Sh) -> Self
50    where
51        A: Clone + One,
52        S: DataOwned,
53        Sh: ShapeBuilder<Dim = D>,
54    {
55        Self::from_fn_with_shape(shape, A::one)
56    }
57    pub fn zeros<Sh>(shape: Sh) -> Self
60    where
61        A: Clone + Zero,
62        S: DataOwned,
63        Sh: ShapeBuilder<Dim = D>,
64    {
65        Self::from_fn_with_shape(shape, A::zero)
66    }
67    pub fn get<Ix>(&self, index: Ix) -> Option<&A>
69    where
70        S: Data,
71        Ix: NdIndex<D>,
72    {
73        self.store().get(index)
74    }
75    pub fn get_mut<Ix>(&mut self, index: Ix) -> Option<&mut A>
77    where
78        S: DataMut,
79        Ix: NdIndex<D>,
80    {
81        self.store_mut().get_mut(index)
82    }
83    pub fn map<F, B>(&self, f: F) -> super::Tensor<B, D>
85    where
86        S: DataOwned,
87        A: Clone,
88        F: FnMut(A) -> B,
89    {
90        TensorBase {
91            store: self.store().mapv(f),
92        }
93    }
94}
95
96#[doc(hidden)]
97#[allow(dead_code)]
98impl<A, S, D> TensorBase<S, D>
99where
100    D: Dimension,
101    S: RawData<Elem = A>,
102{
103    pub(crate) const fn store(&self) -> &ArrayBase<S, D> {
105        &self.store
106    }
107    pub(crate) const fn store_mut(&mut self) -> &mut ArrayBase<S, D> {
109        &mut self.store
110    }
111    pub(crate) fn set_store(&mut self, store: ArrayBase<S, D>) -> &mut Self {
113        self.store = store;
114        self
115    }
116}