concision_core/traits/
tensor.rs

1/*
2    Appellation: tensor <module>
3    Contrib: @FL03
4*/
5use crate::traits::Scalar;
6use ndarray::{
7    ArrayBase, Axis, DataMut, DataOwned, Dimension, OwnedRepr, RawData, RemoveAxis, ShapeBuilder,
8};
9use num::Signed;
10use num_traits::{FromPrimitive, One, Pow, Zero};
11
12pub trait Tensor<S, D>
13where
14    S: RawData<Elem = Self::Scalar>,
15    D: Dimension,
16{
17    type Scalar;
18    type Container<U: RawData, V: Dimension>;
19    /// Create a new tensor with the given shape and a function to fill it
20    fn from_shape_with_fn<Sh, F>(shape: Sh, f: F) -> Self::Container<S, D>
21    where
22        Sh: ShapeBuilder<Dim = D>,
23        F: FnMut(D::Pattern) -> Self::Scalar,
24        Self: Sized;
25    /// Create a new tensor with the given shape and value
26    fn from_shape_with_value<Sh>(shape: Sh, value: Self::Scalar) -> Self::Container<S, D>
27    where
28        Sh: ShapeBuilder<Dim = D>,
29        Self: Sized;
30    /// Create a new tensor with the given shape and all values set to their default
31    fn default<Sh>(shape: Sh) -> Self::Container<S, D>
32    where
33        Sh: ShapeBuilder<Dim = D>,
34        Self: Sized,
35        Self::Scalar: Default,
36    {
37        Self::from_shape_with_value(shape, Self::Scalar::default())
38    }
39    /// create a new tensor with the given shape and all values set to one
40    fn ones<Sh>(shape: Sh) -> Self::Container<S, D>
41    where
42        Sh: ShapeBuilder<Dim = D>,
43        Self: Sized,
44        Self::Scalar: Clone + One,
45    {
46        Self::from_shape_with_value(shape, Self::Scalar::one())
47    }
48    /// create a new tensor with the given shape and all values set to zero
49    fn zeros<Sh>(shape: Sh) -> Self::Container<S, D>
50    where
51        Sh: ShapeBuilder<Dim = D>,
52        Self: Sized,
53        Self::Scalar: Clone + Zero,
54    {
55        Self::from_shape_with_value(shape, <Self as Tensor<S, D>>::Scalar::zero())
56    }
57    /// returns a reference to the data of the object
58    fn data(&self) -> &Self::Container<S, D>;
59    /// returns a mutable reference to the data of the object
60    fn data_mut(&mut self) -> &mut Self::Container<S, D>;
61    /// returns the number of dimensions of the object
62    fn dim(&self) -> D::Pattern;
63    /// returns the shape of the object
64    fn raw_dim(&self) -> D;
65    /// returns the shape of the object
66    fn shape(&self) -> &[usize];
67    /// sets the data of the object
68    fn set_data(&mut self, data: Self::Container<S, D>) -> &mut Self {
69        *self.data_mut() = data;
70        self
71    }
72    /// returns a new tensor with the same shape as the object and the given function applied
73    /// to each element
74    fn apply<F, B>(&self, f: F) -> Self::Container<OwnedRepr<B>, D>
75    where
76        F: FnMut(Self::Scalar) -> B;
77    /// returns a new tensor with the same shape as the object and the given function applied
78    fn apply_mut<F>(&mut self, f: F)
79    where
80        S: DataMut,
81        F: FnMut(Self::Scalar) -> Self::Scalar;
82
83    fn axis_iter(&self, axis: usize) -> ndarray::iter::AxisIter<'_, Self::Scalar, D::Smaller>
84    where
85        D: RemoveAxis;
86
87    fn iter(&self) -> ndarray::iter::Iter<'_, Self::Scalar, D>;
88
89    fn iter_mut(&mut self) -> ndarray::iter::IterMut<'_, Self::Scalar, D>
90    where
91        S: DataMut;
92
93    fn mean(&self) -> Self::Scalar
94    where
95        Self::Scalar: Scalar,
96    {
97        let sum = self.sum();
98        let count = self.iter().count();
99        sum / Self::Scalar::from_usize(count).unwrap()
100    }
101
102    fn sum(&self) -> Self::Scalar
103    where
104        Self::Scalar: Clone + core::iter::Sum,
105    {
106        self.iter().cloned().sum()
107    }
108
109    fn pow2(&self) -> Self::Container<OwnedRepr<Self::Scalar>, D>
110    where
111        Self::Scalar: Scalar,
112    {
113        let two = Self::Scalar::from_usize(2).unwrap();
114        self.apply(|x| x.pow(two))
115    }
116
117    fn abs(&self) -> Self::Container<OwnedRepr<Self::Scalar>, D>
118    where
119        Self::Scalar: Signed,
120    {
121        self.apply(|x| x.abs())
122    }
123
124    fn neg(&self) -> Self::Container<OwnedRepr<Self::Scalar>, D>
125    where
126        Self::Scalar: core::ops::Neg<Output = Self::Scalar>,
127    {
128        self.apply(|x| -x)
129    }
130}
131
132impl<A, S, D> Tensor<S, D> for ArrayBase<S, D>
133where
134    S: DataOwned<Elem = A>,
135    A: Scalar,
136    D: Dimension,
137{
138    type Scalar = A;
139    type Container<U: RawData, V: Dimension> = ArrayBase<U, V>;
140
141    fn from_shape_with_value<Sh>(shape: Sh, value: Self::Scalar) -> Self::Container<S, D>
142    where
143        Self: Sized,
144        Sh: ndarray::ShapeBuilder<Dim = D>,
145    {
146        Self::Container::<S, D>::from_elem(shape, value)
147    }
148
149    fn from_shape_with_fn<Sh, F>(shape: Sh, f: F) -> Self::Container<S, D>
150    where
151        Self: Sized,
152        Sh: ShapeBuilder<Dim = D>,
153        F: FnMut(D::Pattern) -> Self::Scalar,
154    {
155        Self::Container::<S, D>::from_shape_fn(shape, f)
156    }
157
158    fn data(&self) -> &Self::Container<S, D> {
159        self
160    }
161
162    fn data_mut(&mut self) -> &mut Self::Container<S, D> {
163        self
164    }
165
166    fn dim(&self) -> D::Pattern {
167        self.dim()
168    }
169
170    fn raw_dim(&self) -> D {
171        self.raw_dim()
172    }
173
174    fn shape(&self) -> &[usize] {
175        self.shape()
176    }
177
178    fn apply<F, B>(&self, f: F) -> Self::Container<OwnedRepr<B>, D>
179    where
180        F: FnMut(Self::Scalar) -> B,
181    {
182        self.mapv(f)
183    }
184
185    fn apply_mut<F>(&mut self, f: F)
186    where
187        F: FnMut(Self::Scalar) -> Self::Scalar,
188        S: DataMut,
189    {
190        self.mapv_inplace(f)
191    }
192
193    fn iter(&self) -> ndarray::iter::Iter<'_, Self::Scalar, D> {
194        self.iter()
195    }
196    fn iter_mut(&mut self) -> ndarray::iter::IterMut<'_, Self::Scalar, D>
197    where
198        S: DataMut,
199    {
200        self.iter_mut()
201    }
202    fn axis_iter(&self, axis: usize) -> ndarray::iter::AxisIter<'_, Self::Scalar, D::Smaller>
203    where
204        D: RemoveAxis,
205    {
206        self.axis_iter(Axis(axis))
207    }
208}