1use crate::traits::Scalar;
6use ndarray::{
7 ArrayBase, Axis, DataMut, DataOwned, Dimension, OwnedRepr, RawData, RemoveAxis, ShapeBuilder,
8};
9use num::Signed;
10use num_traits::{One, Zero};
11
12pub trait RawTensor<A, D> {
14 type Repr: RawData<Elem = A>;
15 type Container<U: RawData, V: Dimension>;
16
17 private!();
18}
19pub trait Tensor<A, D>: RawTensor<A, D>
23where
24 D: Dimension,
25{
26 fn from_shape_with_fn<Sh, F>(shape: Sh, f: F) -> Self::Container<Self::Repr, D>
28 where
29 Sh: ShapeBuilder<Dim = D>,
30 F: FnMut(D::Pattern) -> A,
31 Self: Sized;
32 fn from_shape_with_value<Sh>(shape: Sh, value: A) -> Self::Container<Self::Repr, D>
34 where
35 Sh: ShapeBuilder<Dim = D>,
36 Self: Sized;
37 fn default<Sh>(shape: Sh) -> Self::Container<Self::Repr, D>
39 where
40 Sh: ShapeBuilder<Dim = D>,
41 Self: Sized,
42 A: Default,
43 {
44 Self::from_shape_with_value(shape, A::default())
45 }
46 fn ones<Sh>(shape: Sh) -> Self::Container<Self::Repr, D>
48 where
49 Sh: ShapeBuilder<Dim = D>,
50 Self: Sized,
51 A: Clone + One,
52 {
53 Self::from_shape_with_value(shape, A::one())
54 }
55 fn zeros<Sh>(shape: Sh) -> Self::Container<Self::Repr, D>
57 where
58 Sh: ShapeBuilder<Dim = D>,
59 Self: Sized,
60 A: Clone + Zero,
61 {
62 Self::from_shape_with_value(shape, <A>::zero())
63 }
64 fn data(&self) -> &Self::Container<Self::Repr, D>;
66 fn data_mut(&mut self) -> &mut Self::Container<Self::Repr, D>;
68 fn dim(&self) -> D::Pattern;
70 fn raw_dim(&self) -> D;
72 fn shape(&self) -> &[usize];
74 fn apply<F, B>(&self, f: F) -> Self::Container<OwnedRepr<B>, D>
77 where
78 F: FnMut(A) -> B;
79 fn apply_mut<F>(&mut self, f: F)
81 where
82 Self::Repr: DataMut,
83 F: FnMut(A) -> A;
84
85 fn axis_iter(&self, axis: usize) -> ndarray::iter::AxisIter<'_, A, D::Smaller>
86 where
87 D: RemoveAxis;
88
89 fn iter(&self) -> ndarray::iter::Iter<'_, A, D>;
90
91 fn iter_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
92 where
93 Self::Repr: DataMut;
94
95 fn mean(&self) -> A
96 where
97 A: Scalar,
98 {
99 let sum = self.sum();
100 let count = self.iter().count();
101 sum / A::from_usize(count).unwrap()
102 }
103 fn set_data(&mut self, data: Self::Container<Self::Repr, D>) -> &mut Self {
105 *self.data_mut() = data;
106 self
107 }
108
109 fn sum(&self) -> A
110 where
111 A: Clone + core::iter::Sum,
112 {
113 self.iter().cloned().sum()
114 }
115
116 fn pow2(&self) -> Self::Container<OwnedRepr<A>, D>
117 where
118 A: Scalar,
119 {
120 let two = A::from_usize(2).unwrap();
121 self.apply(|x| x.pow(two))
122 }
123
124 fn abs(&self) -> Self::Container<OwnedRepr<A>, D>
125 where
126 A: Signed,
127 {
128 self.apply(|x| x.abs())
129 }
130
131 fn neg(&self) -> Self::Container<OwnedRepr<A>, D>
132 where
133 A: core::ops::Neg<Output = A>,
134 {
135 self.apply(|x| -x)
136 }
137}
138
139impl<A, S, D> RawTensor<A, D> for ArrayBase<S, D>
144where
145 S: RawData<Elem = A>,
146 A: Scalar,
147 D: Dimension,
148{
149 type Repr = S;
150 type Container<U: RawData, V: Dimension> = ArrayBase<U, V>;
151
152 seal!();
153}
154
155impl<A, S, D> Tensor<A, D> for ArrayBase<S, D>
156where
157 S: DataOwned<Elem = A>,
158 A: Scalar,
159 D: Dimension,
160{
161 fn from_shape_with_value<Sh>(shape: Sh, value: A) -> Self::Container<Self::Repr, D>
162 where
163 Self: Sized,
164 Sh: ndarray::ShapeBuilder<Dim = D>,
165 {
166 Self::Container::<S, D>::from_elem(shape, value)
167 }
168
169 fn from_shape_with_fn<Sh, F>(shape: Sh, f: F) -> Self::Container<Self::Repr, D>
170 where
171 Self: Sized,
172 Sh: ShapeBuilder<Dim = D>,
173 F: FnMut(D::Pattern) -> A,
174 {
175 Self::Container::<S, D>::from_shape_fn(shape, f)
176 }
177
178 fn data(&self) -> &Self::Container<Self::Repr, D> {
179 self
180 }
181
182 fn data_mut(&mut self) -> &mut Self::Container<Self::Repr, D> {
183 self
184 }
185
186 fn dim(&self) -> D::Pattern {
187 self.dim()
188 }
189
190 fn raw_dim(&self) -> D {
191 self.raw_dim()
192 }
193
194 fn shape(&self) -> &[usize] {
195 self.shape()
196 }
197
198 fn apply<F, B>(&self, f: F) -> Self::Container<OwnedRepr<B>, D>
199 where
200 F: FnMut(A) -> B,
201 {
202 self.mapv(f)
203 }
204
205 fn apply_mut<F>(&mut self, f: F)
206 where
207 F: FnMut(A) -> A,
208 S: DataMut,
209 {
210 self.mapv_inplace(f)
211 }
212
213 fn iter(&self) -> ndarray::iter::Iter<'_, A, D> {
214 self.iter()
215 }
216 fn iter_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
217 where
218 S: DataMut,
219 {
220 self.iter_mut()
221 }
222 fn axis_iter(&self, axis: usize) -> ndarray::iter::AxisIter<'_, A, D::Smaller>
223 where
224 D: RemoveAxis,
225 {
226 self.axis_iter(Axis(axis))
227 }
228}