mdarray/expr/
sources.rs

1use core::fmt::{Debug, Formatter, Result};
2
3use crate::expr::expression::Expression;
4use crate::expr::iter::Iter;
5use crate::index::{Axis, Keep, Split};
6use crate::layout::Layout;
7use crate::mapping::Mapping;
8use crate::shape::{IntoShape, Shape};
9use crate::slice::Slice;
10use crate::view::{View, ViewMut};
11
12/// Array axis expression.
13pub struct AxisExpr<'a, T, S: Shape, L: Layout, A: Axis> {
14    slice: &'a Slice<T, S, L>,
15    axis: A,
16    mapping: <Keep<A, S, L> as Layout>::Mapping<(A::Dim<S>,)>,
17    offset: isize,
18}
19
20/// Mutable array axis expression.
21pub struct AxisExprMut<'a, T, S: Shape, L: Layout, A: Axis> {
22    slice: &'a mut Slice<T, S, L>,
23    axis: A,
24    mapping: <Keep<A, S, L> as Layout>::Mapping<(A::Dim<S>,)>,
25    offset: isize,
26}
27
28/// Expression that repeats an element by cloning.
29#[derive(Clone)]
30pub struct Fill<T> {
31    value: T,
32}
33
34/// Expression that gives elements by calling a closure repeatedly.
35#[derive(Clone)]
36pub struct FillWith<F> {
37    f: F,
38}
39
40/// Expression with a defined shape that repeats an element by cloning.
41#[derive(Clone)]
42pub struct FromElem<T, S> {
43    shape: S,
44    elem: T,
45}
46
47/// Expression with a defined shape and elements from the given function.
48#[derive(Clone)]
49pub struct FromFn<S: Shape, F> {
50    shape: S,
51    f: F,
52    index: S::Dims<usize>,
53}
54
55/// Array lanes expression.
56pub struct Lanes<'a, T, S: Shape, L: Layout, A: Axis> {
57    slice: &'a Slice<T, S, L>,
58    axis: A,
59    mapping: <Split<A, S, L> as Layout>::Mapping<A::Remove<S>>,
60    offset: isize,
61}
62
63/// Mutable array lanes expression.
64pub struct LanesMut<'a, T, S: Shape, L: Layout, A: Axis> {
65    slice: &'a mut Slice<T, S, L>,
66    axis: A,
67    mapping: <Split<A, S, L> as Layout>::Mapping<A::Remove<S>>,
68    offset: isize,
69}
70
71/// Creates an expression with elements by cloning `value`.
72///
73/// # Examples
74///
75/// ```
76/// use mdarray::{expr, tensor, view};
77///
78/// let mut t = tensor![0; 3];
79///
80/// t.assign(expr::fill(1));
81///
82/// assert_eq!(t, view![1; 3]);
83/// ```
84pub fn fill<T: Clone>(value: T) -> Fill<T> {
85    Fill::new(value)
86}
87
88/// Creates an expression with elements returned by calling a closure repeatedly.
89///
90/// # Examples
91///
92/// ```
93/// use mdarray::{expr, tensor, view};
94///
95/// let mut t = tensor![0; 3];
96///
97/// t.assign(expr::fill_with(|| 1));
98///
99/// assert_eq!(t, view![1; 3]);
100/// ```
101pub fn fill_with<T, F: FnMut() -> T>(f: F) -> FillWith<F> {
102    FillWith::new(f)
103}
104
105/// Creates an expression with the given shape and elements by cloning `value`.
106///
107/// # Examples
108///
109/// ```
110/// use mdarray::{view, expr, expr::Expression};
111///
112/// assert_eq!(expr::from_elem([2, 3], 1).eval(), view![[1; 3]; 2]);
113/// ```
114pub fn from_elem<T: Clone, I: IntoShape>(shape: I, elem: T) -> FromElem<T, I::IntoShape> {
115    FromElem::new(shape.into_shape(), elem)
116}
117
118/// Creates an expression with the given shape and elements from the given function.
119///
120/// # Examples
121///
122/// ```
123/// use mdarray::{expr, expr::Expression, view};
124///
125/// assert_eq!(expr::from_fn([2, 3], |i| 3 * i[0] + i[1] + 1).eval(), view![[1, 2, 3], [4, 5, 6]]);
126/// ```
127pub fn from_fn<T, I: IntoShape, F>(shape: I, f: F) -> FromFn<I::IntoShape, F>
128where
129    F: FnMut(&[usize]) -> T,
130{
131    FromFn::new(shape.into_shape(), f)
132}
133
134macro_rules! impl_axis_expr {
135    ($name:tt, $expr:tt, $as_ptr:tt, {$($mut:tt)?}, $repeatable:tt) => {
136        impl<'a, T, S: Shape, L: Layout, A: Axis> $name<'a, T, S, L, A> {
137            pub(crate) fn new(
138                slice: &'a $($mut)? Slice<T, S, L>,
139                axis: A,
140            ) -> Self {
141                let mapping = axis.get(slice.mapping());
142
143                Self { slice, axis, mapping, offset: 0 }
144            }
145        }
146
147        impl<'a, T: Debug, S: Shape, L: Layout, A: Axis> Debug for $name<'a, T, S, L, A> {
148            fn fmt(&self, f: &mut Formatter<'_>) -> Result {
149                let index = self.axis.index(self.slice.rank());
150
151                f.debug_tuple(stringify!($name)).field(&index).field(&self.slice).finish()
152            }
153        }
154
155        impl<'a, T, S: Shape, L: Layout, A: Axis> Expression for $name<'a, T, S, L, A> {
156            type Shape = (A::Dim<S>,);
157
158            const IS_REPEATABLE: bool = $repeatable;
159
160            fn shape(&self) -> &Self::Shape {
161                self.mapping.shape()
162            }
163
164            unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item {
165                let offset = self.offset + self.mapping.inner_stride() * index as isize;
166
167                let mapping = self.axis.remove(self.slice.mapping());
168                let len = mapping.shape().checked_len().expect("invalid length");
169
170                // If the view is empty, we must not offset the pointer.
171                let count = if len == 0 { 0 } else { offset };
172
173                unsafe { $expr::new_unchecked(self.slice.$as_ptr().offset(count), mapping) }
174            }
175
176            fn inner_rank(&self) -> usize {
177                1
178            }
179
180            unsafe fn reset_dim(&mut self, _: usize, _: usize) {
181                self.offset = 0;
182            }
183
184            unsafe fn step_dim(&mut self, _: usize) {
185                self.offset += self.mapping.inner_stride();
186            }
187        }
188
189        impl<'a, T, S: Shape, L: Layout, A: Axis> IntoIterator for $name<'a, T, S, L, A> {
190            type Item = $expr<'a, T, A::Remove<S>, Split<A, S, L>>;
191            type IntoIter = Iter<Self>;
192
193            fn into_iter(self) -> Iter<Self> {
194                Iter::new(self)
195            }
196        }
197    };
198}
199
200impl_axis_expr!(AxisExpr, View, as_ptr, {}, true);
201impl_axis_expr!(AxisExprMut, ViewMut, as_mut_ptr, {mut}, false);
202
203impl<T, S: Shape, L: Layout, A: Axis> Clone for AxisExpr<'_, T, S, L, A> {
204    fn clone(&self) -> Self {
205        Self {
206            slice: self.slice,
207            axis: self.axis,
208            mapping: self.mapping.clone(),
209            offset: self.offset,
210        }
211    }
212
213    fn clone_from(&mut self, source: &Self) {
214        self.slice = source.slice;
215        self.axis = source.axis;
216        self.mapping.clone_from(&source.mapping);
217        self.offset = source.offset;
218    }
219}
220
221impl<T> Fill<T> {
222    pub(crate) fn new(value: T) -> Self {
223        Self { value }
224    }
225}
226
227impl<T: Debug> Debug for Fill<T> {
228    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
229        f.debug_tuple("Fill").field(&self.value).finish()
230    }
231}
232
233impl<T: Clone> Expression for Fill<T> {
234    type Shape = ();
235
236    const IS_REPEATABLE: bool = true;
237
238    fn shape(&self) -> &() {
239        &()
240    }
241
242    unsafe fn get_unchecked(&mut self, _: usize) -> T {
243        self.value.clone()
244    }
245
246    fn inner_rank(&self) -> usize {
247        usize::MAX
248    }
249
250    unsafe fn reset_dim(&mut self, _: usize, _: usize) {}
251    unsafe fn step_dim(&mut self, _: usize) {}
252}
253
254impl<T: Clone> IntoIterator for Fill<T> {
255    type Item = T;
256    type IntoIter = Iter<Self>;
257
258    fn into_iter(self) -> Iter<Self> {
259        Iter::new(self)
260    }
261}
262
263impl<F> FillWith<F> {
264    pub(crate) fn new(f: F) -> Self {
265        Self { f }
266    }
267}
268
269impl<T: Debug, F: FnMut() -> T> Debug for FillWith<F> {
270    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
271        f.debug_tuple("FillWith").finish()
272    }
273}
274
275impl<T, F: FnMut() -> T> Expression for FillWith<F> {
276    type Shape = ();
277
278    const IS_REPEATABLE: bool = true;
279
280    fn shape(&self) -> &() {
281        &()
282    }
283
284    unsafe fn get_unchecked(&mut self, _: usize) -> T {
285        (self.f)()
286    }
287
288    fn inner_rank(&self) -> usize {
289        usize::MAX
290    }
291
292    unsafe fn reset_dim(&mut self, _: usize, _: usize) {}
293    unsafe fn step_dim(&mut self, _: usize) {}
294}
295
296impl<T, F: FnMut() -> T> IntoIterator for FillWith<F> {
297    type Item = T;
298    type IntoIter = Iter<Self>;
299
300    fn into_iter(self) -> Iter<Self> {
301        Iter::new(self)
302    }
303}
304
305impl<T, S: Shape> FromElem<T, S> {
306    pub(crate) fn new(shape: S, elem: T) -> Self {
307        _ = shape.checked_len().expect("invalid length");
308
309        Self { shape, elem }
310    }
311}
312
313impl<T: Debug, S: Shape> Debug for FromElem<T, S> {
314    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
315        f.debug_tuple("FromElem").field(&self.shape).field(&self.elem).finish()
316    }
317}
318
319impl<T: Clone, S: Shape> Expression for FromElem<T, S> {
320    type Shape = S;
321
322    const IS_REPEATABLE: bool = true;
323
324    fn shape(&self) -> &S {
325        &self.shape
326    }
327
328    unsafe fn get_unchecked(&mut self, _: usize) -> T {
329        self.elem.clone()
330    }
331
332    fn inner_rank(&self) -> usize {
333        usize::MAX
334    }
335
336    unsafe fn reset_dim(&mut self, _: usize, _: usize) {}
337    unsafe fn step_dim(&mut self, _: usize) {}
338}
339
340impl<T: Clone, S: Shape> IntoIterator for FromElem<T, S> {
341    type Item = T;
342    type IntoIter = Iter<Self>;
343
344    fn into_iter(self) -> Iter<Self> {
345        Iter::new(self)
346    }
347}
348
349impl<S: Shape, F> FromFn<S, F> {
350    pub(crate) fn new(shape: S, f: F) -> Self {
351        _ = shape.checked_len().expect("invalid length");
352
353        Self { shape, f, index: S::Dims::default() }
354    }
355}
356
357impl<S: Shape, F> Debug for FromFn<S, F> {
358    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
359        f.debug_tuple("FromFn").field(&self.shape).finish()
360    }
361}
362
363impl<T, S: Shape, F: FnMut(&[usize]) -> T> Expression for FromFn<S, F> {
364    type Shape = S;
365
366    const IS_REPEATABLE: bool = true;
367
368    fn shape(&self) -> &S {
369        &self.shape
370    }
371
372    unsafe fn get_unchecked(&mut self, _: usize) -> T {
373        let value = (self.f)(self.index.as_ref());
374
375        // Increment the last dimension, which will be reset by reset_dim().
376        if self.rank() > 0 {
377            self.index.as_mut()[self.shape.rank() - 1] += 1;
378        }
379
380        value
381    }
382
383    fn inner_rank(&self) -> usize {
384        if self.shape.rank() > 0 { 1 } else { usize::MAX }
385    }
386
387    unsafe fn reset_dim(&mut self, index: usize, _: usize) {
388        self.index.as_mut()[index] = 0;
389    }
390
391    unsafe fn step_dim(&mut self, index: usize) {
392        // Don't increment the last dimension, since it is done in get_unchecked().
393        if index + 1 < self.rank() {
394            self.index.as_mut()[index] += 1;
395        }
396    }
397}
398
399impl<T, S: Shape, F: FnMut(&[usize]) -> T> IntoIterator for FromFn<S, F> {
400    type Item = T;
401    type IntoIter = Iter<Self>;
402
403    fn into_iter(self) -> Iter<Self> {
404        Iter::new(self)
405    }
406}
407
408macro_rules! impl_lanes {
409    ($name:tt, $expr:tt, $as_ptr:tt, {$($mut:tt)?}, $repeatable:tt) => {
410        impl<'a, T, S: Shape, L: Layout, A: Axis> $name<'a, T, S, L, A> {
411            pub(crate) fn new(
412                slice: &'a $($mut)? Slice<T, S, L>,
413                axis: A,
414            ) -> Self {
415                let mapping = axis.remove(slice.mapping());
416
417                // Ensure that the subarray is valid.
418                _ = mapping.shape().checked_len().expect("invalid length");
419
420                Self { slice, axis, mapping, offset: 0 }
421            }
422        }
423
424        impl<'a, T: Debug, S: Shape, L: Layout, A: Axis> Debug for $name<'a, T, S, L, A> {
425            fn fmt(&self, f: &mut Formatter<'_>) -> Result {
426                let index = self.axis.index(self.slice.rank());
427
428                f.debug_tuple(stringify!($name)).field(&index).field(&self.slice).finish()
429            }
430        }
431
432        impl<'a, T, S: Shape, L: Layout, A: Axis> Expression for $name<'a, T, S, L, A> {
433            type Shape = A::Remove<S>;
434
435            const IS_REPEATABLE: bool = $repeatable;
436
437            fn shape(&self) -> &Self::Shape {
438                self.mapping.shape()
439            }
440
441            unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item {
442                let offset = self.mapping.inner_stride() * index as isize;
443                let mapping = self.axis.get(self.slice.mapping());
444
445                // If the view is empty, we must not offset the pointer.
446                let count = if mapping.is_empty() { 0 } else { offset };
447
448                unsafe { $expr::new_unchecked(self.slice.$as_ptr().offset(count), mapping) }
449            }
450
451            fn inner_rank(&self) -> usize {
452                if Split::<A, S, L>::IS_DENSE {
453                    // For static rank 0, the inner stride is 0 so we allow inner rank >0.
454                    if A::Remove::<S>::RANK == Some(0) { usize::MAX } else { self.mapping.rank() }
455                } else {
456                    // For rank 0, the inner stride is always 0 so we can allow inner rank >0.
457                    if self.mapping.rank() > 0 { 1 } else { usize::MAX }
458                }
459            }
460
461            unsafe fn reset_dim(&mut self, index: usize, count: usize) {
462                self.offset -= self.mapping.stride(index) * count as isize;
463            }
464
465            unsafe fn step_dim(&mut self, index: usize) {
466                self.offset += self.mapping.stride(index);
467            }
468        }
469
470        impl<'a, T, S: Shape, L: Layout, A: Axis> IntoIterator for $name<'a, T, S, L, A> {
471            type Item = $expr<'a, T, (A::Dim<S>,), Keep<A, S, L>>;
472            type IntoIter = Iter<Self>;
473
474            fn into_iter(self) -> Iter<Self> {
475                Iter::new(self)
476            }
477        }
478    };
479}
480
481impl_lanes!(Lanes, View, as_ptr, {}, true);
482impl_lanes!(LanesMut, ViewMut, as_mut_ptr, {mut}, false);
483
484impl<T, S: Shape, L: Layout, A: Axis> Clone for Lanes<'_, T, S, L, A> {
485    fn clone(&self) -> Self {
486        Self {
487            slice: self.slice,
488            axis: self.axis,
489            mapping: self.mapping.clone(),
490            offset: self.offset,
491        }
492    }
493
494    fn clone_from(&mut self, source: &Self) {
495        self.slice = source.slice;
496        self.axis = source.axis;
497        self.mapping.clone_from(&source.mapping);
498        self.offset = source.offset;
499    }
500}