1use crate::traits::Scalar;
6use ndarray::{
7 ArrayBase, Axis, DataMut, DataOwned, Dimension, OwnedRepr, RawData, RemoveAxis, ShapeBuilder,
8};
9use num::Signed;
10use num_traits::{FromPrimitive, One, Pow, Zero};
11
12pub trait Tensor<S, D>
13where
14 S: RawData<Elem = Self::Scalar>,
15 D: Dimension,
16{
17 type Scalar;
18 type Container<U: RawData, V: Dimension>;
19 fn from_shape_with_fn<Sh, F>(shape: Sh, f: F) -> Self::Container<S, D>
21 where
22 Sh: ShapeBuilder<Dim = D>,
23 F: FnMut(D::Pattern) -> Self::Scalar,
24 Self: Sized;
25 fn from_shape_with_value<Sh>(shape: Sh, value: Self::Scalar) -> Self::Container<S, D>
27 where
28 Sh: ShapeBuilder<Dim = D>,
29 Self: Sized;
30 fn default<Sh>(shape: Sh) -> Self::Container<S, D>
32 where
33 Sh: ShapeBuilder<Dim = D>,
34 Self: Sized,
35 Self::Scalar: Default,
36 {
37 Self::from_shape_with_value(shape, Self::Scalar::default())
38 }
39 fn ones<Sh>(shape: Sh) -> Self::Container<S, D>
41 where
42 Sh: ShapeBuilder<Dim = D>,
43 Self: Sized,
44 Self::Scalar: Clone + One,
45 {
46 Self::from_shape_with_value(shape, Self::Scalar::one())
47 }
48 fn zeros<Sh>(shape: Sh) -> Self::Container<S, D>
50 where
51 Sh: ShapeBuilder<Dim = D>,
52 Self: Sized,
53 Self::Scalar: Clone + Zero,
54 {
55 Self::from_shape_with_value(shape, <Self as Tensor<S, D>>::Scalar::zero())
56 }
57 fn data(&self) -> &Self::Container<S, D>;
59 fn data_mut(&mut self) -> &mut Self::Container<S, D>;
61 fn dim(&self) -> D::Pattern;
63 fn raw_dim(&self) -> D;
65 fn shape(&self) -> &[usize];
67 fn set_data(&mut self, data: Self::Container<S, D>) -> &mut Self {
69 *self.data_mut() = data;
70 self
71 }
72 fn apply<F, B>(&self, f: F) -> Self::Container<OwnedRepr<B>, D>
75 where
76 F: FnMut(Self::Scalar) -> B;
77 fn apply_mut<F>(&mut self, f: F)
79 where
80 S: DataMut,
81 F: FnMut(Self::Scalar) -> Self::Scalar;
82
83 fn axis_iter(&self, axis: usize) -> ndarray::iter::AxisIter<'_, Self::Scalar, D::Smaller>
84 where
85 D: RemoveAxis;
86
87 fn iter(&self) -> ndarray::iter::Iter<'_, Self::Scalar, D>;
88
89 fn iter_mut(&mut self) -> ndarray::iter::IterMut<'_, Self::Scalar, D>
90 where
91 S: DataMut;
92
93 fn mean(&self) -> Self::Scalar
94 where
95 Self::Scalar: Scalar,
96 {
97 let sum = self.sum();
98 let count = self.iter().count();
99 sum / Self::Scalar::from_usize(count).unwrap()
100 }
101
102 fn sum(&self) -> Self::Scalar
103 where
104 Self::Scalar: Clone + core::iter::Sum,
105 {
106 self.iter().cloned().sum()
107 }
108
109 fn pow2(&self) -> Self::Container<OwnedRepr<Self::Scalar>, D>
110 where
111 Self::Scalar: Scalar,
112 {
113 let two = Self::Scalar::from_usize(2).unwrap();
114 self.apply(|x| x.pow(two))
115 }
116
117 fn abs(&self) -> Self::Container<OwnedRepr<Self::Scalar>, D>
118 where
119 Self::Scalar: Signed,
120 {
121 self.apply(|x| x.abs())
122 }
123
124 fn neg(&self) -> Self::Container<OwnedRepr<Self::Scalar>, D>
125 where
126 Self::Scalar: core::ops::Neg<Output = Self::Scalar>,
127 {
128 self.apply(|x| -x)
129 }
130}
131
132impl<A, S, D> Tensor<S, D> for ArrayBase<S, D>
133where
134 S: DataOwned<Elem = A>,
135 A: Scalar,
136 D: Dimension,
137{
138 type Scalar = A;
139 type Container<U: RawData, V: Dimension> = ArrayBase<U, V>;
140
141 fn from_shape_with_value<Sh>(shape: Sh, value: Self::Scalar) -> Self::Container<S, D>
142 where
143 Self: Sized,
144 Sh: ndarray::ShapeBuilder<Dim = D>,
145 {
146 Self::Container::<S, D>::from_elem(shape, value)
147 }
148
149 fn from_shape_with_fn<Sh, F>(shape: Sh, f: F) -> Self::Container<S, D>
150 where
151 Self: Sized,
152 Sh: ShapeBuilder<Dim = D>,
153 F: FnMut(D::Pattern) -> Self::Scalar,
154 {
155 Self::Container::<S, D>::from_shape_fn(shape, f)
156 }
157
158 fn data(&self) -> &Self::Container<S, D> {
159 self
160 }
161
162 fn data_mut(&mut self) -> &mut Self::Container<S, D> {
163 self
164 }
165
166 fn dim(&self) -> D::Pattern {
167 self.dim()
168 }
169
170 fn raw_dim(&self) -> D {
171 self.raw_dim()
172 }
173
174 fn shape(&self) -> &[usize] {
175 self.shape()
176 }
177
178 fn apply<F, B>(&self, f: F) -> Self::Container<OwnedRepr<B>, D>
179 where
180 F: FnMut(Self::Scalar) -> B,
181 {
182 self.mapv(f)
183 }
184
185 fn apply_mut<F>(&mut self, f: F)
186 where
187 F: FnMut(Self::Scalar) -> Self::Scalar,
188 S: DataMut,
189 {
190 self.mapv_inplace(f)
191 }
192
193 fn iter(&self) -> ndarray::iter::Iter<'_, Self::Scalar, D> {
194 self.iter()
195 }
196 fn iter_mut(&mut self) -> ndarray::iter::IterMut<'_, Self::Scalar, D>
197 where
198 S: DataMut,
199 {
200 self.iter_mut()
201 }
202 fn axis_iter(&self, axis: usize) -> ndarray::iter::AxisIter<'_, Self::Scalar, D::Smaller>
203 where
204 D: RemoveAxis,
205 {
206 self.axis_iter(Axis(axis))
207 }
208}