eryon_surface/ops/
fill.rs1use ndarray::prelude::*;
6use ndarray::{DataMut, RawData};
7
8pub trait MaskFill<A, D>
11where
12 D: Dimension,
13{
14 type Output;
15
16 fn masked_fill(&self, mask: &Array<bool, D>, value: A) -> Self::Output;
17}
18
19pub trait IntoAxis {
20 fn into_axis(self) -> Axis;
21}
22
23pub trait IsSquare {
24 fn is_square(&self) -> bool;
25}
26
27impl<A, S, D> MaskFill<A, D> for ArrayBase<S, D>
32where
33 A: Clone,
34 D: Dimension,
35 S: DataMut<Elem = A>,
36 Self: Clone,
37{
38 type Output = ArrayBase<S, D>;
39
40 fn masked_fill(&self, mask: &Array<bool, D>, value: A) -> Self::Output {
41 let mut arr = self.clone();
42 arr.zip_mut_with(&mask, |x, &m| {
43 if m {
44 *x = value.clone();
45 }
46 });
47 arr
48 }
49}
50
51impl<S> IntoAxis for S
52where
53 S: AsRef<usize>,
54{
55 fn into_axis(self) -> Axis {
56 Axis(*self.as_ref())
57 }
58}
59
60impl<S, D> IsSquare for ArrayBase<S, D>
61where
62 D: Dimension,
63 S: RawData,
64{
65 fn is_square(&self) -> bool {
66 let first = self.shape().first().unwrap();
67 self.shape().iter().all(|x| x == first)
68 }
69}