Skip to main content

mdarray/expr/
sources.rs

1use core::fmt::{Debug, Formatter, Result};
2
3use crate::dim::Dims;
4use crate::expr::expression::Expression;
5use crate::expr::iter::Iter;
6use crate::index::{Axis, Keep, Split};
7use crate::layout::Layout;
8use crate::mapping::Mapping;
9use crate::shape::{IntoShape, Shape};
10use crate::slice::Slice;
11use crate::view::{View, ViewMut};
12
13/// Array axis expression.
14pub struct AxisExpr<'a, T, S: Shape, L: Layout, A: Axis> {
15    slice: &'a Slice<T, S, L>,
16    axis: A,
17    mapping: <Keep<A, S, L> as Layout>::Mapping<(A::Dim<S>,)>,
18    offset: isize,
19}
20
21/// Mutable array axis expression.
22pub struct AxisExprMut<'a, T, S: Shape, L: Layout, A: Axis> {
23    slice: &'a mut Slice<T, S, L>,
24    axis: A,
25    mapping: <Keep<A, S, L> as Layout>::Mapping<(A::Dim<S>,)>,
26    offset: isize,
27}
28
29/// Expression that repeats an element by cloning.
30#[derive(Clone)]
31pub struct Fill<T> {
32    value: T,
33}
34
35/// Expression that gives elements by calling a closure repeatedly.
36#[derive(Clone)]
37pub struct FillWith<F> {
38    f: F,
39}
40
41/// Expression with a defined shape that repeats an element by cloning.
42#[derive(Clone)]
43pub struct FromElem<T, S> {
44    shape: S,
45    elem: T,
46}
47
48/// Expression with a defined shape and elements from the given function.
49#[derive(Clone)]
50pub struct FromFn<S: Shape, F> {
51    shape: S,
52    f: F,
53    index: S::Dims<usize>,
54}
55
56/// Array lanes expression.
57pub struct Lanes<'a, T, S: Shape, L: Layout, A: Axis> {
58    slice: &'a Slice<T, S, L>,
59    axis: A,
60    mapping: <Split<A, S, L> as Layout>::Mapping<A::Remove<S>>,
61    offset: isize,
62}
63
64/// Mutable array lanes expression.
65pub struct LanesMut<'a, T, S: Shape, L: Layout, A: Axis> {
66    slice: &'a mut Slice<T, S, L>,
67    axis: A,
68    mapping: <Split<A, S, L> as Layout>::Mapping<A::Remove<S>>,
69    offset: isize,
70}
71
72/// Creates an expression with elements by cloning `value`.
73///
74/// # Examples
75///
76/// ```
77/// use mdarray::{array, expr, view};
78///
79/// let mut a = array![0; 3];
80///
81/// a.assign(expr::fill(1));
82///
83/// assert_eq!(a, view![1; 3]);
84/// ```
85#[inline]
86pub fn fill<T: Clone>(value: T) -> Fill<T> {
87    Fill::new(value)
88}
89
90/// Creates an expression with elements returned by calling a closure repeatedly.
91///
92/// # Examples
93///
94/// ```
95/// use mdarray::{array, expr, view};
96///
97/// let mut a = array![0; 3];
98///
99/// a.assign(expr::fill_with(|| 1));
100///
101/// assert_eq!(a, view![1; 3]);
102/// ```
103#[inline]
104pub fn fill_with<T, F: FnMut() -> T>(f: F) -> FillWith<F> {
105    FillWith::new(f)
106}
107
108/// Creates an expression with the given shape and elements by cloning `value`.
109///
110/// # Examples
111///
112/// ```
113/// use mdarray::{expr, expr::Expression, view};
114///
115/// assert_eq!(expr::from_elem([2, 3], 1).eval(), view![[1; 3]; 2]);
116/// ```
117#[inline]
118pub fn from_elem<T: Clone, I: IntoShape>(shape: I, elem: T) -> FromElem<T, I::IntoShape> {
119    FromElem::new(shape.into_shape(), elem)
120}
121
122/// Creates an expression with the given shape and elements from the given function.
123///
124/// # Examples
125///
126/// ```
127/// use mdarray::{expr, expr::Expression, view};
128///
129/// assert_eq!(expr::from_fn([2, 3], |i| 3 * i[0] + i[1] + 1).eval(), view![[1, 2, 3], [4, 5, 6]]);
130/// ```
131#[inline]
132pub fn from_fn<T, I: IntoShape, F>(shape: I, f: F) -> FromFn<I::IntoShape, F>
133where
134    F: FnMut(&[usize]) -> T,
135{
136    FromFn::new(shape.into_shape(), f)
137}
138
139macro_rules! impl_axis_expr {
140    ($name:tt, $expr:tt, $as_ptr:tt, {$($mut:tt)?}, $repeatable:tt) => {
141        impl<'a, T, S: Shape, L: Layout, A: Axis> $name<'a, T, S, L, A> {
142            #[inline]
143            pub(crate) fn new(
144                slice: &'a $($mut)? Slice<T, S, L>,
145                axis: A,
146            ) -> Self {
147                let mapping = axis.get(slice.mapping());
148
149                Self { slice, axis, mapping, offset: 0 }
150            }
151        }
152
153        impl<'a, T: Debug, S: Shape, L: Layout, A: Axis> Debug for $name<'a, T, S, L, A> {
154            fn fmt(&self, f: &mut Formatter<'_>) -> Result {
155                let index = self.axis.index(self.slice.rank());
156
157                f.debug_tuple(stringify!($name)).field(&index).field(&self.slice).finish()
158            }
159        }
160
161        impl<'a, T, S: Shape, L: Layout, A: Axis> Expression for $name<'a, T, S, L, A> {
162            type Shape = (A::Dim<S>,);
163
164            const IS_REPEATABLE: bool = $repeatable;
165
166            #[inline]
167            fn shape(&self) -> &Self::Shape {
168                self.mapping.shape()
169            }
170
171            #[inline]
172            unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item {
173                let offset = self.offset + self.mapping.inner_stride() * index as isize;
174
175                let mapping = self.axis.remove(self.slice.mapping());
176                let len = mapping.shape().checked_len().expect("invalid length");
177
178                // If the view is empty, we must not offset the pointer.
179                let count = if len == 0 { 0 } else { offset };
180
181                unsafe { $expr::new_unchecked(self.slice.$as_ptr().offset(count), mapping) }
182            }
183
184            #[inline]
185            fn inner_rank(&self) -> usize {
186                1
187            }
188
189            #[inline]
190            unsafe fn reset_dim(&mut self, _: usize, _: usize) {
191                self.offset = 0;
192            }
193
194            #[inline]
195            unsafe fn step_dim(&mut self, _: usize) {
196                self.offset += self.mapping.inner_stride();
197            }
198        }
199
200        impl<'a, T, S: Shape, L: Layout, A: Axis> IntoIterator for $name<'a, T, S, L, A> {
201            type Item = $expr<'a, T, A::Remove<S>, Split<A, S, L>>;
202            type IntoIter = Iter<Self>;
203
204            #[inline]
205            fn into_iter(self) -> Iter<Self> {
206                Iter::new(self)
207            }
208        }
209    };
210}
211
212impl_axis_expr!(AxisExpr, View, as_ptr, {}, true);
213impl_axis_expr!(AxisExprMut, ViewMut, as_mut_ptr, {mut}, false);
214
215impl<T, S: Shape, L: Layout, A: Axis> Clone for AxisExpr<'_, T, S, L, A> {
216    #[inline]
217    fn clone(&self) -> Self {
218        Self {
219            slice: self.slice,
220            axis: self.axis,
221            mapping: self.mapping.clone(),
222            offset: self.offset,
223        }
224    }
225
226    #[inline]
227    fn clone_from(&mut self, source: &Self) {
228        self.slice = source.slice;
229        self.axis = source.axis;
230        self.mapping.clone_from(&source.mapping);
231        self.offset = source.offset;
232    }
233}
234
235impl<T> Fill<T> {
236    #[inline]
237    pub(crate) fn new(value: T) -> Self {
238        Self { value }
239    }
240}
241
242impl<T: Debug> Debug for Fill<T> {
243    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
244        f.debug_tuple("Fill").field(&self.value).finish()
245    }
246}
247
248impl<T: Clone> Expression for Fill<T> {
249    type Shape = ();
250
251    const IS_REPEATABLE: bool = true;
252
253    #[inline]
254    fn shape(&self) -> &() {
255        &()
256    }
257
258    #[inline]
259    unsafe fn get_unchecked(&mut self, _: usize) -> T {
260        self.value.clone()
261    }
262
263    #[inline]
264    fn inner_rank(&self) -> usize {
265        usize::MAX
266    }
267
268    #[inline]
269    unsafe fn reset_dim(&mut self, _: usize, _: usize) {}
270
271    #[inline]
272    unsafe fn step_dim(&mut self, _: usize) {}
273}
274
275impl<T: Clone> IntoIterator for Fill<T> {
276    type Item = T;
277    type IntoIter = Iter<Self>;
278
279    #[inline]
280    fn into_iter(self) -> Iter<Self> {
281        Iter::new(self)
282    }
283}
284
285impl<F> FillWith<F> {
286    #[inline]
287    pub(crate) fn new(f: F) -> Self {
288        Self { f }
289    }
290}
291
292impl<T: Debug, F: FnMut() -> T> Debug for FillWith<F> {
293    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
294        f.debug_tuple("FillWith").finish()
295    }
296}
297
298impl<T, F: FnMut() -> T> Expression for FillWith<F> {
299    type Shape = ();
300
301    const IS_REPEATABLE: bool = true;
302
303    #[inline]
304    fn shape(&self) -> &() {
305        &()
306    }
307
308    #[inline]
309    unsafe fn get_unchecked(&mut self, _: usize) -> T {
310        (self.f)()
311    }
312
313    #[inline]
314    fn inner_rank(&self) -> usize {
315        usize::MAX
316    }
317
318    #[inline]
319    unsafe fn reset_dim(&mut self, _: usize, _: usize) {}
320
321    #[inline]
322    unsafe fn step_dim(&mut self, _: usize) {}
323}
324
325impl<T, F: FnMut() -> T> IntoIterator for FillWith<F> {
326    type Item = T;
327    type IntoIter = Iter<Self>;
328
329    #[inline]
330    fn into_iter(self) -> Iter<Self> {
331        Iter::new(self)
332    }
333}
334
335impl<T, S: Shape> FromElem<T, S> {
336    #[inline]
337    pub(crate) fn new(shape: S, elem: T) -> Self {
338        _ = shape.checked_len().expect("invalid length");
339
340        Self { shape, elem }
341    }
342}
343
344impl<T: Debug, S: Shape> Debug for FromElem<T, S> {
345    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
346        f.debug_tuple("FromElem").field(&self.shape).field(&self.elem).finish()
347    }
348}
349
350impl<T: Clone, S: Shape> Expression for FromElem<T, S> {
351    type Shape = S;
352
353    const IS_REPEATABLE: bool = true;
354
355    #[inline]
356    fn shape(&self) -> &S {
357        &self.shape
358    }
359
360    #[inline]
361    unsafe fn get_unchecked(&mut self, _: usize) -> T {
362        self.elem.clone()
363    }
364
365    #[inline]
366    fn inner_rank(&self) -> usize {
367        usize::MAX
368    }
369
370    #[inline]
371    unsafe fn reset_dim(&mut self, _: usize, _: usize) {}
372
373    #[inline]
374    unsafe fn step_dim(&mut self, _: usize) {}
375}
376
377impl<T: Clone, S: Shape> IntoIterator for FromElem<T, S> {
378    type Item = T;
379    type IntoIter = Iter<Self>;
380
381    #[inline]
382    fn into_iter(self) -> Iter<Self> {
383        Iter::new(self)
384    }
385}
386
387impl<S: Shape, F> FromFn<S, F> {
388    #[inline]
389    pub(crate) fn new(shape: S, f: F) -> Self {
390        _ = shape.checked_len().expect("invalid length");
391
392        // Initialize the index buffer with the correct rank for this shape.
393        //
394        // For static-rank shapes, use the compile-time rank if available.
395        // For dynamic-rank shapes like `DynRank`, fall back to the runtime
396        // dimension length obtained via `with_dims`.
397        let rank = S::RANK.unwrap_or_else(|| shape.with_dims(|dims| dims.len()));
398        let index = Dims::new(rank);
399
400        Self { shape, f, index }
401    }
402}
403
404impl<S: Shape, F> Debug for FromFn<S, F> {
405    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
406        f.debug_tuple("FromFn").field(&self.shape).finish()
407    }
408}
409
410impl<T, S: Shape, F: FnMut(&[usize]) -> T> Expression for FromFn<S, F> {
411    type Shape = S;
412
413    const IS_REPEATABLE: bool = true;
414
415    #[inline]
416    fn shape(&self) -> &S {
417        &self.shape
418    }
419
420    #[inline]
421    unsafe fn get_unchecked(&mut self, _: usize) -> T {
422        let value = (self.f)(self.index.as_ref());
423
424        // Increment the last dimension, which will be reset by reset_dim().
425        if self.rank() > 0 {
426            self.index.as_mut()[self.shape.rank() - 1] += 1;
427        }
428
429        value
430    }
431
432    #[inline]
433    fn inner_rank(&self) -> usize {
434        if self.shape.rank() > 0 { 1 } else { usize::MAX }
435    }
436
437    #[inline]
438    unsafe fn reset_dim(&mut self, index: usize, _: usize) {
439        self.index.as_mut()[index] = 0;
440    }
441
442    #[inline]
443    unsafe fn step_dim(&mut self, index: usize) {
444        // Don't increment the last dimension, since it is done in get_unchecked().
445        if index + 1 < self.rank() {
446            self.index.as_mut()[index] += 1;
447        }
448    }
449}
450
451impl<T, S: Shape, F: FnMut(&[usize]) -> T> IntoIterator for FromFn<S, F> {
452    type Item = T;
453    type IntoIter = Iter<Self>;
454
455    #[inline]
456    fn into_iter(self) -> Iter<Self> {
457        Iter::new(self)
458    }
459}
460
461macro_rules! impl_lanes {
462    ($name:tt, $expr:tt, $as_ptr:tt, {$($mut:tt)?}, $repeatable:tt) => {
463        impl<'a, T, S: Shape, L: Layout, A: Axis> $name<'a, T, S, L, A> {
464            #[inline]
465            pub(crate) fn new(
466                slice: &'a $($mut)? Slice<T, S, L>,
467                axis: A,
468            ) -> Self {
469                let mapping = axis.remove(slice.mapping());
470
471                // Ensure that the subarray is valid.
472                _ = mapping.shape().checked_len().expect("invalid length");
473
474                Self { slice, axis, mapping, offset: 0 }
475            }
476        }
477
478        impl<'a, T: Debug, S: Shape, L: Layout, A: Axis> Debug for $name<'a, T, S, L, A> {
479            fn fmt(&self, f: &mut Formatter<'_>) -> Result {
480                let index = self.axis.index(self.slice.rank());
481
482                f.debug_tuple(stringify!($name)).field(&index).field(&self.slice).finish()
483            }
484        }
485
486        impl<'a, T, S: Shape, L: Layout, A: Axis> Expression for $name<'a, T, S, L, A> {
487            type Shape = A::Remove<S>;
488
489            const IS_REPEATABLE: bool = $repeatable;
490
491            #[inline]
492            fn shape(&self) -> &Self::Shape {
493                self.mapping.shape()
494            }
495
496            #[inline]
497            unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item {
498                let offset = self.mapping.inner_stride() * index as isize;
499                let mapping = self.axis.get(self.slice.mapping());
500
501                // If the view is empty, we must not offset the pointer.
502                let count = if mapping.is_empty() { 0 } else { offset };
503
504                unsafe { $expr::new_unchecked(self.slice.$as_ptr().offset(count), mapping) }
505            }
506
507            #[inline]
508            fn inner_rank(&self) -> usize {
509                if Split::<A, S, L>::IS_DENSE {
510                    // For static rank 0, the inner stride is 0 so we allow inner rank >0.
511                    if A::Remove::<S>::RANK == Some(0) { usize::MAX } else { self.mapping.rank() }
512                } else {
513                    // For rank 0, the inner stride is always 0 so we can allow inner rank >0.
514                    if self.mapping.rank() > 0 { 1 } else { usize::MAX }
515                }
516            }
517
518            #[inline]
519            unsafe fn reset_dim(&mut self, index: usize, count: usize) {
520                self.offset -= self.mapping.stride(index) * count as isize;
521            }
522
523            #[inline]
524            unsafe fn step_dim(&mut self, index: usize) {
525                self.offset += self.mapping.stride(index);
526            }
527        }
528
529        impl<'a, T, S: Shape, L: Layout, A: Axis> IntoIterator for $name<'a, T, S, L, A> {
530            type Item = $expr<'a, T, (A::Dim<S>,), Keep<A, S, L>>;
531            type IntoIter = Iter<Self>;
532
533            #[inline]
534            fn into_iter(self) -> Iter<Self> {
535                Iter::new(self)
536            }
537        }
538    };
539}
540
541impl_lanes!(Lanes, View, as_ptr, {}, true);
542impl_lanes!(LanesMut, ViewMut, as_mut_ptr, {mut}, false);
543
544impl<T, S: Shape, L: Layout, A: Axis> Clone for Lanes<'_, T, S, L, A> {
545    #[inline]
546    fn clone(&self) -> Self {
547        Self {
548            slice: self.slice,
549            axis: self.axis,
550            mapping: self.mapping.clone(),
551            offset: self.offset,
552        }
553    }
554
555    #[inline]
556    fn clone_from(&mut self, source: &Self) {
557        self.slice = source.slice;
558        self.axis = source.axis;
559        self.mapping.clone_from(&source.mapping);
560        self.offset = source.offset;
561    }
562}