concision_core/traits/
tensor.rs

1/*
2    Appellation: tensor <module>
3    Contrib: @FL03
4*/
5use crate::traits::Scalar;
6use ndarray::{
7    ArrayBase, Axis, DataMut, DataOwned, Dimension, OwnedRepr, RawData, RemoveAxis, ShapeBuilder,
8};
9use num::Signed;
10use num_traits::{One, Zero};
11
12/// The [`RawTensor`] trait defines the base interface for all tensors,
13pub trait RawTensor<A, D> {
14    type Repr: RawData<Elem = A>;
15    type Container<U: RawData, V: Dimension>;
16
17    private!();
18}
19/// The [`Tensor`] trait extends the [`RawTensor`] trait to provide additional functionality
20/// for tensors, such as creating tensors from shapes, applying functions, and iterating over
21/// elements. It is generic over the element type `A` and the dimension type `D
22pub trait Tensor<A, D>: RawTensor<A, D>
23where
24    D: Dimension,
25{
26    /// Create a new tensor with the given shape and a function to fill it
27    fn from_shape_with_fn<Sh, F>(shape: Sh, f: F) -> Self::Container<Self::Repr, D>
28    where
29        Sh: ShapeBuilder<Dim = D>,
30        F: FnMut(D::Pattern) -> A,
31        Self: Sized;
32    /// Create a new tensor with the given shape and value
33    fn from_shape_with_value<Sh>(shape: Sh, value: A) -> Self::Container<Self::Repr, D>
34    where
35        Sh: ShapeBuilder<Dim = D>,
36        Self: Sized;
37    /// Create a new tensor with the given shape and all values set to their default
38    fn default<Sh>(shape: Sh) -> Self::Container<Self::Repr, D>
39    where
40        Sh: ShapeBuilder<Dim = D>,
41        Self: Sized,
42        A: Default,
43    {
44        Self::from_shape_with_value(shape, A::default())
45    }
46    /// create a new tensor with the given shape and all values set to one
47    fn ones<Sh>(shape: Sh) -> Self::Container<Self::Repr, D>
48    where
49        Sh: ShapeBuilder<Dim = D>,
50        Self: Sized,
51        A: Clone + One,
52    {
53        Self::from_shape_with_value(shape, A::one())
54    }
55    /// create a new tensor with the given shape and all values set to zero
56    fn zeros<Sh>(shape: Sh) -> Self::Container<Self::Repr, D>
57    where
58        Sh: ShapeBuilder<Dim = D>,
59        Self: Sized,
60        A: Clone + Zero,
61    {
62        Self::from_shape_with_value(shape, <A>::zero())
63    }
64    /// returns a reference to the data of the object
65    fn data(&self) -> &Self::Container<Self::Repr, D>;
66    /// returns a mutable reference to the data of the object
67    fn data_mut(&mut self) -> &mut Self::Container<Self::Repr, D>;
68    /// returns the number of dimensions of the object
69    fn dim(&self) -> D::Pattern;
70    /// returns the shape of the object
71    fn raw_dim(&self) -> D;
72    /// returns the shape of the object
73    fn shape(&self) -> &[usize];
74    /// returns a new tensor with the same shape as the object and the given function applied
75    /// to each element
76    fn apply<F, B>(&self, f: F) -> Self::Container<OwnedRepr<B>, D>
77    where
78        F: FnMut(A) -> B;
79    /// returns a new tensor with the same shape as the object and the given function applied
80    fn apply_mut<F>(&mut self, f: F)
81    where
82        Self::Repr: DataMut,
83        F: FnMut(A) -> A;
84
85    fn axis_iter(&self, axis: usize) -> ndarray::iter::AxisIter<'_, A, D::Smaller>
86    where
87        D: RemoveAxis;
88
89    fn iter(&self) -> ndarray::iter::Iter<'_, A, D>;
90
91    fn iter_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
92    where
93        Self::Repr: DataMut;
94
95    fn mean(&self) -> A
96    where
97        A: Scalar,
98    {
99        let sum = self.sum();
100        let count = self.iter().count();
101        sum / A::from_usize(count).unwrap()
102    }
103    /// sets the data of the object and returns a mutable reference to the object
104    fn set_data(&mut self, data: Self::Container<Self::Repr, D>) -> &mut Self {
105        *self.data_mut() = data;
106        self
107    }
108
109    fn sum(&self) -> A
110    where
111        A: Clone + core::iter::Sum,
112    {
113        self.iter().cloned().sum()
114    }
115
116    fn pow2(&self) -> Self::Container<OwnedRepr<A>, D>
117    where
118        A: Scalar,
119    {
120        let two = A::from_usize(2).unwrap();
121        self.apply(|x| x.pow(two))
122    }
123
124    fn abs(&self) -> Self::Container<OwnedRepr<A>, D>
125    where
126        A: Signed,
127    {
128        self.apply(|x| x.abs())
129    }
130
131    fn neg(&self) -> Self::Container<OwnedRepr<A>, D>
132    where
133        A: core::ops::Neg<Output = A>,
134    {
135        self.apply(|x| -x)
136    }
137}
138
139/*
140 ************* Implementations *************
141*/
142
143impl<A, S, D> RawTensor<A, D> for ArrayBase<S, D>
144where
145    S: RawData<Elem = A>,
146    A: Scalar,
147    D: Dimension,
148{
149    type Repr = S;
150    type Container<U: RawData, V: Dimension> = ArrayBase<U, V>;
151
152    seal!();
153}
154
155impl<A, S, D> Tensor<A, D> for ArrayBase<S, D>
156where
157    S: DataOwned<Elem = A>,
158    A: Scalar,
159    D: Dimension,
160{
161    fn from_shape_with_value<Sh>(shape: Sh, value: A) -> Self::Container<Self::Repr, D>
162    where
163        Self: Sized,
164        Sh: ndarray::ShapeBuilder<Dim = D>,
165    {
166        Self::Container::<S, D>::from_elem(shape, value)
167    }
168
169    fn from_shape_with_fn<Sh, F>(shape: Sh, f: F) -> Self::Container<Self::Repr, D>
170    where
171        Self: Sized,
172        Sh: ShapeBuilder<Dim = D>,
173        F: FnMut(D::Pattern) -> A,
174    {
175        Self::Container::<S, D>::from_shape_fn(shape, f)
176    }
177
178    fn data(&self) -> &Self::Container<Self::Repr, D> {
179        self
180    }
181
182    fn data_mut(&mut self) -> &mut Self::Container<Self::Repr, D> {
183        self
184    }
185
186    fn dim(&self) -> D::Pattern {
187        self.dim()
188    }
189
190    fn raw_dim(&self) -> D {
191        self.raw_dim()
192    }
193
194    fn shape(&self) -> &[usize] {
195        self.shape()
196    }
197
198    fn apply<F, B>(&self, f: F) -> Self::Container<OwnedRepr<B>, D>
199    where
200        F: FnMut(A) -> B,
201    {
202        self.mapv(f)
203    }
204
205    fn apply_mut<F>(&mut self, f: F)
206    where
207        F: FnMut(A) -> A,
208        S: DataMut,
209    {
210        self.mapv_inplace(f)
211    }
212
213    fn iter(&self) -> ndarray::iter::Iter<'_, A, D> {
214        self.iter()
215    }
216    fn iter_mut(&mut self) -> ndarray::iter::IterMut<'_, A, D>
217    where
218        S: DataMut,
219    {
220        self.iter_mut()
221    }
222    fn axis_iter(&self, axis: usize) -> ndarray::iter::AxisIter<'_, A, D::Smaller>
223    where
224        D: RemoveAxis,
225    {
226        self.axis_iter(Axis(axis))
227    }
228}