concision_core/tensor/
tensor.rs1use ndarray::{ArrayBase, Data, DataMut, DataOwned, Dimension, NdIndex, RawData, ShapeBuilder};
6use num_traits::{One, Zero};
7
8#[repr(transparent)]
10pub struct TensorBase<S, D>
11where
12 D: Dimension,
13 S: RawData,
14{
15 pub(crate) store: ArrayBase<S, D>,
16}
17
18impl<A, S, D> TensorBase<S, D>
19where
20 D: Dimension,
21 S: RawData<Elem = A>,
22{
23 pub const fn from_ndarray(store: ArrayBase<S, D>) -> Self {
25 Self { store }
26 }
27 pub fn from_shape_fn<Sh, F>(shape: Sh, f: F) -> Self
29 where
30 S: DataOwned,
31 Sh: ShapeBuilder<Dim = D>,
32 F: FnMut(D::Pattern) -> A,
33 {
34 Self {
35 store: ArrayBase::from_shape_fn(shape, f),
36 }
37 }
38 pub fn from_fn_with_shape<Sh, F>(shape: Sh, f: F) -> Self
40 where
41 S: DataOwned,
42 Sh: ShapeBuilder<Dim = D>,
43 F: Fn() -> A,
44 {
45 Self::from_shape_fn(shape, |_| f())
46 }
47 pub fn ones<Sh>(shape: Sh) -> Self
50 where
51 A: Clone + One,
52 S: DataOwned,
53 Sh: ShapeBuilder<Dim = D>,
54 {
55 Self::from_fn_with_shape(shape, A::one)
56 }
57 pub fn zeros<Sh>(shape: Sh) -> Self
60 where
61 A: Clone + Zero,
62 S: DataOwned,
63 Sh: ShapeBuilder<Dim = D>,
64 {
65 Self::from_fn_with_shape(shape, A::zero)
66 }
67 pub fn get<Ix>(&self, index: Ix) -> Option<&A>
69 where
70 S: Data,
71 Ix: NdIndex<D>,
72 {
73 self.store().get(index)
74 }
75 pub fn get_mut<Ix>(&mut self, index: Ix) -> Option<&mut A>
77 where
78 S: DataMut,
79 Ix: NdIndex<D>,
80 {
81 self.store_mut().get_mut(index)
82 }
83 pub fn map<F, B>(&self, f: F) -> super::Tensor<B, D>
85 where
86 A: Clone,
87 S: Data,
88 F: FnMut(A) -> B,
89 {
90 TensorBase {
91 store: self.store().mapv(f),
92 }
93 }
94 pub(crate) fn mapd<F, U, S2, D2>(&self, f: F) -> TensorBase<S2, D2>
96 where
97 D2: Dimension,
98 S2: RawData<Elem = U>,
99 F: FnOnce(&ArrayBase<S, D>) -> ArrayBase<S2, D2>,
100 {
101 TensorBase {
102 store: f(self.store()),
103 }
104 }
105}
106
107#[doc(hidden)]
108#[allow(dead_code)]
109impl<A, S, D> TensorBase<S, D>
110where
111 D: Dimension,
112 S: RawData<Elem = A>,
113{
114 pub(crate) const fn store(&self) -> &ArrayBase<S, D> {
116 &self.store
117 }
118 pub(crate) const fn store_mut(&mut self) -> &mut ArrayBase<S, D> {
120 &mut self.store
121 }
122 pub(crate) fn set_store(&mut self, store: ArrayBase<S, D>) -> &mut Self {
124 self.store = store;
125 self
126 }
127}