concision_core/ops/
fill.rs

1/*
2   Appellation: convert <mod>
3   Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use ndarray::prelude::*;
6use ndarray::{DataMut, RawData};
7
8/// This trait is used to fill an array with a value based on a mask.
9/// The mask is a boolean array of the same shape as the array.
10pub trait MaskFill<A, D>
11where
12    D: Dimension,
13{
14    type Output;
15
16    fn masked_fill(&self, mask: &Array<bool, D>, value: A) -> Self::Output;
17}
18
19/// [`IsSquare`] is a trait for checking if the layout, or dimensionality, of a tensor is
20/// square.
21pub trait IsSquare {
22    fn is_square(&self) -> bool;
23}
24
25/*
26 ******** implementations ********
27*/
28
29impl<A, S, D> MaskFill<A, D> for ArrayBase<S, D>
30where
31    A: Clone,
32    D: Dimension,
33    S: DataMut<Elem = A>,
34    Self: Clone,
35{
36    type Output = ArrayBase<S, D>;
37
38    fn masked_fill(&self, mask: &Array<bool, D>, value: A) -> Self::Output {
39        let mut arr = self.clone();
40        arr.zip_mut_with(mask, |x, &m| {
41            if m {
42                *x = value.clone();
43            }
44        });
45        arr
46    }
47}
48
49impl<S, D> IsSquare for ArrayBase<S, D>
50where
51    D: Dimension,
52    S: RawData,
53{
54    fn is_square(&self) -> bool {
55        let first = self.shape().first().unwrap();
56        self.shape().iter().all(|x| x == first)
57    }
58}