mdarray/expr/
expression.rs

1#[cfg(feature = "nightly")]
2use alloc::alloc::Allocator;
3#[cfg(not(feature = "std"))]
4use alloc::vec::Vec;
5
6#[cfg(not(feature = "nightly"))]
7use crate::allocator::Allocator;
8use crate::expr::adapters::{Cloned, Copied, Enumerate, Map, Zip};
9use crate::expr::iter::Iter;
10use crate::shape::Shape;
11use crate::tensor::Tensor;
12use crate::traits::IntoCloned;
13
14/// Trait for applying a closure and returning an existing array or an expression.
15pub trait Apply<T>: IntoExpression {
16    /// The resulting type after applying a closure.
17    type Output<F: FnMut(Self::Item) -> T>: IntoExpression<Item = T, Shape = Self::Shape>;
18
19    /// The resulting type after zipping elements and applying a closure.
20    type ZippedWith<I: IntoExpression, F>: IntoExpression<Item = T>
21    where
22        F: FnMut((Self::Item, I::Item)) -> T;
23
24    /// Returns the array or an expression with the given closure applied to each element.
25    fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output<F>;
26
27    /// Returns the array or an expression with the given closure applied to zipped element pairs.
28    fn zip_with<I: IntoExpression, F>(self, expr: I, f: F) -> Self::ZippedWith<I, F>
29    where
30        F: FnMut((Self::Item, I::Item)) -> T;
31}
32
33/// Expression trait, for multidimensional iteration.
34pub trait Expression: IntoIterator {
35    /// Array shape type.
36    type Shape: Shape;
37
38    /// True if the expression can be restarted from the beginning after the last element.
39    const IS_REPEATABLE: bool;
40
41    /// Returns the array shape.
42    fn shape(&self) -> &Self::Shape;
43
44    /// Creates an expression which clones all of its elements.
45    #[inline]
46    fn cloned<'a, T: 'a + Clone>(self) -> Cloned<Self>
47    where
48        Self: Expression<Item = &'a T> + Sized,
49    {
50        Cloned::new(self)
51    }
52
53    /// Creates an expression which copies all of its elements.
54    #[inline]
55    fn copied<'a, T: 'a + Copy>(self) -> Copied<Self>
56    where
57        Self: Expression<Item = &'a T> + Sized,
58    {
59        Copied::new(self)
60    }
61
62    /// Returns the number of elements in the specified dimension.
63    ///
64    /// # Panics
65    ///
66    /// Panics if the dimension is out of bounds.
67    #[inline]
68    fn dim(&self, index: usize) -> usize {
69        self.shape().dim(index)
70    }
71
72    /// Creates an expression which gives tuples of the current count and the element.
73    #[inline]
74    fn enumerate(self) -> Enumerate<Self>
75    where
76        Self: Sized,
77    {
78        Enumerate::new(self)
79    }
80
81    /// Determines if the elements of the expression are equal to those of another.
82    #[inline]
83    fn eq<I: IntoExpression>(self, other: I) -> bool
84    where
85        Self: Expression<Item: PartialEq<I::Item>> + Sized,
86    {
87        self.eq_by(other, |x, y| x == y)
88    }
89
90    /// Determines if the elements of the expression are equal to those of another
91    /// with respect to the specified equality function.
92    #[inline]
93    fn eq_by<I: IntoExpression, F>(self, other: I, mut eq: F) -> bool
94    where
95        Self: Sized,
96        F: FnMut(Self::Item, I::Item) -> bool,
97    {
98        let other = other.into_expr();
99
100        self.shape().with_dims(|dims| other.shape().with_dims(|other| dims == other))
101            && self.zip(other).into_iter().all(|(x, y)| eq(x, y))
102    }
103
104    /// Evaluates the expression into a new array.
105    ///
106    /// The resulting type is `Array` if the shape has constant-sized dimensions, or
107    /// otherwise `Tensor`. If the shape type is generic, `FromExpression::from_expr`
108    /// can be used to evaluate the expression into a specific array type.
109    #[inline]
110    fn eval(self) -> <Self::Shape as Shape>::Owned<Self::Item>
111    where
112        Self: Sized,
113    {
114        FromExpression::from_expr(self)
115    }
116
117    /// Evaluates the expression with broadcasting and appends to the given array
118    /// along the first dimension.
119    ///
120    /// If the array is empty, it is reshaped to match the shape of the expression.
121    ///
122    /// # Panics
123    ///
124    /// Panics if the inner dimensions do not match, if the rank is not the same and
125    /// at least 1, or if the first dimension is not dynamically-sized.
126    #[inline]
127    fn eval_into<S: Shape, A: Allocator>(
128        self,
129        tensor: &mut Tensor<Self::Item, S, A>,
130    ) -> &mut Tensor<Self::Item, S, A>
131    where
132        Self: Sized,
133    {
134        tensor.expand(self);
135        tensor
136    }
137
138    /// Folds all elements into an accumulator by applying an operation, and returns the result.
139    #[inline]
140    fn fold<T, F: FnMut(T, Self::Item) -> T>(self, init: T, f: F) -> T
141    where
142        Self: Sized,
143    {
144        Iter::new(self).fold(init, f)
145    }
146
147    /// Calls a closure on each element of the expression.
148    #[inline]
149    fn for_each<F: FnMut(Self::Item)>(self, mut f: F)
150    where
151        Self: Sized,
152    {
153        self.fold((), |(), x| f(x));
154    }
155
156    /// Returns `true` if the array contains no elements.
157    #[inline]
158    fn is_empty(&self) -> bool {
159        self.shape().is_empty()
160    }
161
162    /// Returns the number of elements in the array.
163    #[inline]
164    fn len(&self) -> usize {
165        self.shape().len()
166    }
167
168    /// Creates an expression that calls a closure on each element.
169    #[inline]
170    fn map<T, F: FnMut(Self::Item) -> T>(self, f: F) -> Map<Self, F>
171    where
172        Self: Sized,
173    {
174        Map::new(self, f)
175    }
176
177    /// Determines if the elements of the expression are not equal to those of another.
178    #[inline]
179    fn ne<I: IntoExpression>(self, other: I) -> bool
180    where
181        Self: Expression<Item: PartialEq<I::Item>> + Sized,
182    {
183        !self.eq(other)
184    }
185
186    /// Returns the array rank, i.e. the number of dimensions.
187    #[inline]
188    fn rank(&self) -> usize {
189        self.shape().rank()
190    }
191
192    /// Creates an expression that gives tuples `(x, y)` of the elements from each expression.
193    ///
194    /// # Panics
195    ///
196    /// Panics if the expressions cannot be broadcast to a common shape.
197    #[inline]
198    fn zip<I: IntoExpression>(self, other: I) -> Zip<Self, I::IntoExpr>
199    where
200        Self: Sized,
201    {
202        Zip::new(self, other.into_expr())
203    }
204
205    #[doc(hidden)]
206    unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item;
207
208    #[doc(hidden)]
209    fn inner_rank(&self) -> usize;
210
211    #[doc(hidden)]
212    unsafe fn reset_dim(&mut self, index: usize, count: usize);
213
214    #[doc(hidden)]
215    unsafe fn step_dim(&mut self, index: usize);
216
217    #[cfg(not(feature = "nightly"))]
218    #[doc(hidden)]
219    #[inline]
220    fn clone_into_vec<T>(self, vec: &mut Vec<T>)
221    where
222        Self: Expression<Item: IntoCloned<T>> + Sized,
223    {
224        assert!(self.len() <= vec.capacity() - vec.len(), "length exceeds capacity");
225
226        self.for_each(|x| unsafe {
227            vec.as_mut_ptr().add(vec.len()).write(x.into_cloned());
228            vec.set_len(vec.len() + 1);
229        });
230    }
231
232    #[cfg(feature = "nightly")]
233    #[doc(hidden)]
234    #[inline]
235    fn clone_into_vec<T, A: Allocator>(self, vec: &mut Vec<T, A>)
236    where
237        Self: Expression<Item: IntoCloned<T>> + Sized,
238    {
239        assert!(self.len() <= vec.capacity() - vec.len(), "length exceeds capacity");
240
241        self.for_each(|x| unsafe {
242            vec.as_mut_ptr().add(vec.len()).write(x.into_cloned());
243            vec.set_len(vec.len() + 1);
244        });
245    }
246}
247
248/// Conversion trait from an expression.
249pub trait FromExpression<T, S: Shape>: Sized {
250    /// Creates an array from an expression.
251    fn from_expr<I: IntoExpression<Item = T, Shape = S>>(expr: I) -> Self;
252}
253
254/// Conversion trait into an expression.
255pub trait IntoExpression: IntoIterator {
256    /// Array shape type.
257    type Shape: Shape;
258
259    /// Which kind of expression are we turning this into?
260    type IntoExpr: Expression<Item = Self::Item, Shape = Self::Shape>;
261
262    /// Creates an expression from a value.
263    fn into_expr(self) -> Self::IntoExpr;
264}
265
266impl<T, E: Expression> Apply<T> for E {
267    type Output<F: FnMut(Self::Item) -> T> = Map<E, F>;
268    type ZippedWith<I: IntoExpression, F: FnMut((Self::Item, I::Item)) -> T> =
269        Map<Zip<Self, I::IntoExpr>, F>;
270
271    #[inline]
272    fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output<F> {
273        self.map(f)
274    }
275
276    #[inline]
277    fn zip_with<I: IntoExpression, F>(self, expr: I, f: F) -> Self::ZippedWith<I, F>
278    where
279        F: FnMut((Self::Item, I::Item)) -> T,
280    {
281        self.zip(expr).map(f)
282    }
283}
284
285impl<E: Expression> IntoExpression for E {
286    type Shape = E::Shape;
287    type IntoExpr = E;
288
289    #[inline]
290    fn into_expr(self) -> Self {
291        self
292    }
293}