erased_cells/masked/
masked_buffer.rs

1use std::fmt::{Debug, Formatter};
2
3pub use self::ops::*;
4use crate::masked::nodata::IsNodata;
5use crate::{BufferOps, CellBuffer, CellEncoding, CellType, CellValue, Mask, NoData};
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9/// A [`CellBuffer`] with a companion [`Mask`].
10///
11/// The `Mask` tracks which cells are valid across operations, and which should be
12/// treated as "no-data" values.
13///
14/// # Example
15///
16/// ```rust
17/// use erased_cells::{BufferOps, Mask, MaskedCellBuffer};
18/// // Fill a buffer with the `u16` numbers `0..=3` and mask [true, false, true, false].
19/// let buf = MaskedCellBuffer::fill_with_mask_via(4, |i| (i as f64, i % 2 == 0));
20/// assert_eq!(buf.mask(), &Mask::new(vec![true, false, true, false]));
21/// // We can count the data/no-data values
22/// assert_eq!(buf.counts(), (2, 2));
23///
24/// // Mask values are propagated across math operations.
25/// let ones = MaskedCellBuffer::from_vec(vec![1.0; 4]);
26/// let r = (buf + ones) * 2.0;
27///
28/// let expected = MaskedCellBuffer::new(
29///     vec![
30///         (0.0 + 1.0) * 2.0,
31///         (1.0 + 1.0) * 2.0,
32///         (2.0 + 1.0) * 2.0,
33///         (3.0 + 1.0) * 2.0,
34///     ]
35///     .into(),
36///     Mask::new(vec![true, false, true, false]),
37/// );
38/// assert_eq!(r, expected);
39/// ```
40#[derive(Clone, PartialEq, PartialOrd)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub struct MaskedCellBuffer(CellBuffer, Mask);
43
44impl MaskedCellBuffer {
45    /// Create a new combined [`CellBuffer`] and [`Mask`].
46    ///
47    /// # Panics
48    /// Will panics if `buffer` and `mask` are not the same length.
49    pub fn new(buffer: CellBuffer, mask: Mask) -> Self {
50        assert_eq!(
51            buffer.len(),
52            mask.len(),
53            "Mask and buffer must have the same length."
54        );
55        Self(buffer, mask)
56    }
57
58    /// Constructs a `MaskedCellBuffer` from a `Vec<CellEncoding>`, specifying a `NoData<T>` value.
59    ///
60    /// Mask value will be `false` when associated cell matches `nodata`.
61    ///
62    /// Use [`Self::from_vec`]
63    pub fn from_vec_with_nodata<T: CellEncoding>(data: Vec<T>, nodata: NoData<T>) -> Self {
64        let mut mask = Mask::fill(data.len(), true);
65        let buf = CellBuffer::from_vec(data);
66
67        buf.into_iter().zip(mask.iter_mut()).for_each(|(v, m)| {
68            *m = !v.is(nodata);
69        });
70
71        Self::new(buf, mask)
72    }
73
74    pub fn fill_with_mask_via<T, F>(len: usize, mv: F) -> Self
75    where
76        T: CellEncoding,
77        F: Fn(usize) -> (T, bool),
78    {
79        (0..len).map(mv).collect()
80    }
81
82    pub fn buffer(&self) -> &CellBuffer {
83        &self.0
84    }
85
86    pub fn buffer_mut(&mut self) -> &mut CellBuffer {
87        &mut self.0
88    }
89
90    pub fn mask(&self) -> &Mask {
91        &self.1
92    }
93
94    pub fn mask_mut(&mut self) -> &mut Mask {
95        &mut self.1
96    }
97
98    /// Get a buffer value at position `index` with mask evaluated.
99    ///
100    /// Returns `Some(CellValue)` if mask at `index` is `true`, `None` otherwise.
101    pub fn get_masked(&self, index: usize) -> Option<CellValue> {
102        if self.mask().get(index) {
103            Some(self.buffer().get(index))
104        } else {
105            None
106        }
107    }
108
109    /// Get the cell value and mask value at position `index`.
110    ///
111    /// Returns `(CellValue, bool)`. If `bool` is `false`, associated
112    /// `CellValue` should be considered invalid.
113    pub fn get_with_mask(&self, index: usize) -> (CellValue, bool) {
114        (self.buffer().get(index), self.mask().get(index))
115    }
116
117    /// Set the `value` and `mask` at position `index`.
118    ///
119    /// Returns `Err(NarrowingError)` if `value` cannot be converted to
120    /// `self.cell_type()` without data loss (e.g. overflow).
121    pub fn put_with_mask(
122        &mut self,
123        index: usize,
124        value: CellValue,
125        mask: bool,
126    ) -> crate::error::Result<()> {
127        self.put(index, value)?;
128        self.mask_mut().put(index, mask);
129        Ok(())
130    }
131
132    /// Returns a tuple of representing counts of `(data, nodata)`.
133    pub fn counts(&self) -> (usize, usize) {
134        self.mask().counts()
135    }
136
137    /// Convert `self` into a `Vec<T>`, replacing values where the mask is `0` to `no_data.value()`
138    pub fn to_vec_with_nodata<T: CellEncoding>(
139        self,
140        no_data: NoData<T>,
141    ) -> crate::error::Result<Vec<T>> {
142        let Self(buf, mask) = self;
143        let out = buf.to_vec::<T>()?;
144        if let Some(no_data) = no_data.value() {
145            Ok(out
146                .into_iter()
147                .zip(mask)
148                .map(|(v, m)| if m { v } else { no_data })
149                .collect())
150        } else {
151            Ok(out)
152        }
153    }
154}
155
156impl BufferOps for MaskedCellBuffer {
157    fn from_vec<T: CellEncoding>(data: Vec<T>) -> Self {
158        let buffer = CellBuffer::from_vec(data);
159        let mask = Mask::fill(buffer.len(), true);
160        Self::new(buffer, mask)
161    }
162
163    fn with_defaults(len: usize, ct: CellType) -> Self {
164        let buffer = CellBuffer::with_defaults(len, ct);
165        let mask = Mask::fill(len, true);
166        Self::new(buffer, mask)
167    }
168
169    fn fill(len: usize, value: CellValue) -> Self {
170        let buffer = CellBuffer::fill(len, value);
171        let mask = Mask::fill(len, true);
172        Self::new(buffer, mask)
173    }
174
175    fn fill_via<T, F>(len: usize, f: F) -> Self
176    where
177        T: CellEncoding,
178        F: Fn(usize) -> T,
179    {
180        let buffer = CellBuffer::fill_via(len, f);
181        let mask = Mask::fill(len, true);
182        Self::new(buffer, mask)
183    }
184
185    fn len(&self) -> usize {
186        self.buffer().len()
187    }
188
189    fn cell_type(&self) -> CellType {
190        self.buffer().cell_type()
191    }
192
193    fn get(&self, index: usize) -> CellValue {
194        self.buffer().get(index)
195    }
196
197    fn put(&mut self, idx: usize, value: CellValue) -> crate::error::Result<()> {
198        self.buffer_mut().put(idx, value)
199    }
200
201    fn convert(&self, cell_type: CellType) -> crate::error::Result<Self>
202    where
203        Self: Sized,
204    {
205        let converted = self.buffer().convert(cell_type)?;
206        Ok(Self::new(converted, self.mask().to_owned()))
207    }
208
209    fn min_max(&self) -> (CellValue, CellValue) {
210        let init = (self.cell_type().max_value(), self.cell_type().min_value());
211        self.into_iter().fold(init, |(amin, amax), (v, m)| {
212            if m {
213                (amin.min(v), amax.max(v))
214            } else {
215                (amin, amax)
216            }
217        })
218    }
219
220    /// Converts `self` to `Vec<T>`, ignoring the `mask` values.
221    ///
222    /// See also: [`Self::to_vec_with_nodata`] and [`NoData`].
223    fn to_vec<T: CellEncoding>(self) -> crate::error::Result<Vec<T>> {
224        self.0.to_vec()
225    }
226}
227
228impl Debug for MaskedCellBuffer {
229    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
230        let basename = self.cell_type().to_string();
231        f.debug_tuple(&format!("{basename}MaskedCellBuffer"))
232            .field(self.buffer())
233            .field(self.mask())
234            .finish()
235    }
236}
237
238impl From<MaskedCellBuffer> for (CellBuffer, Mask) {
239    fn from(value: MaskedCellBuffer) -> Self {
240        (value.0, value.1)
241    }
242}
243
244impl<'a> From<&'a MaskedCellBuffer> for (&'a CellBuffer, &'a Mask) {
245    fn from(value: &'a MaskedCellBuffer) -> Self {
246        (&value.0, &value.1)
247    }
248}
249
250/// Converts a [`CellBuffer`] into a [`MaskedCellBuffer`] with an all-true mask.
251impl From<CellBuffer> for MaskedCellBuffer {
252    fn from(value: CellBuffer) -> Self {
253        let len = value.len();
254        Self::new(value, Mask::fill(len, true))
255    }
256}
257
258impl<C: CellEncoding> FromIterator<C> for MaskedCellBuffer {
259    fn from_iter<T: IntoIterator<Item = C>>(iter: T) -> Self {
260        Self::from_vec(iter.into_iter().collect())
261    }
262}
263
264impl<C: CellEncoding> FromIterator<(C, bool)> for MaskedCellBuffer {
265    fn from_iter<T: IntoIterator<Item = (C, bool)>>(iter: T) -> Self {
266        // This is basically a copy of `unzip` except that we control
267        // instantiation of an empty CellBuffer to ensure the right
268        // CellType. Possible because
269        // `impl Extend for (impl Extend, impl Extend)` exists.
270        let mut pair = (
271            CellBuffer::with_defaults(0, C::cell_type()),
272            Mask::default(),
273        );
274        pair.extend(iter);
275
276        let (data, mask) = pair;
277        Self::new(data, mask)
278    }
279}
280
281impl<C: CellEncoding> Extend<(C, bool)> for MaskedCellBuffer {
282    fn extend<T: IntoIterator<Item = (C, bool)>>(&mut self, iter: T) {
283        for (v, m) in iter {
284            self.buffer_mut().extend(Some(v));
285            self.mask_mut().extend(Some(m));
286        }
287    }
288}
289
290impl<'buf> IntoIterator for &'buf MaskedCellBuffer {
291    type Item = (CellValue, bool);
292    type IntoIter = MaskedCellBufferIterator<'buf>;
293
294    fn into_iter(self) -> Self::IntoIter {
295        MaskedCellBufferIterator { buf: self, idx: 0, len: self.len() }
296    }
297}
298
299/// Iterator over ([`CellValue`], `bool`) elements in a [`MaskedCellBuffer`].
300pub struct MaskedCellBufferIterator<'buf> {
301    buf: &'buf MaskedCellBuffer,
302    idx: usize,
303    len: usize,
304}
305
306impl Iterator for MaskedCellBufferIterator<'_> {
307    type Item = (CellValue, bool);
308
309    fn next(&mut self) -> Option<Self::Item> {
310        if self.idx >= self.len {
311            None
312        } else {
313            let r = self.buf.get_with_mask(self.idx);
314            self.idx += 1;
315            Some(r)
316        }
317    }
318}
319
320mod ops {
321    use crate::{CellValue, MaskedCellBuffer};
322    use std::ops::{Add, Div, Mul, Neg, Sub};
323
324    macro_rules! cb_bin_op {
325        ($trt:ident, $mth:ident, $op:tt) => {
326            // Both borrows.
327            impl $trt for &MaskedCellBuffer {
328                type Output = MaskedCellBuffer;
329                fn $mth(self, rhs: Self) -> Self::Output {
330                    let (lbuf, lmask) = self.into();
331                    let (rbuf, rmask) = rhs.into();
332                    let new_buf = lbuf.into_iter().zip(rbuf.into_iter()).map(|(l, r)| l $op r).collect();
333                    #[allow(clippy::suspicious_arithmetic_impl)]
334                    let new_mask = lmask & rmask;
335                    Self::Output::new(new_buf, new_mask)
336                }
337            }
338            // Both owned/consumed
339            impl $trt for MaskedCellBuffer {
340                type Output = MaskedCellBuffer;
341                fn $mth(self, rhs: Self) -> Self::Output {
342                    $trt::$mth(&self, &rhs)
343                }
344            }
345            // RHS borrow
346            impl $trt<&MaskedCellBuffer> for MaskedCellBuffer {
347                type Output = MaskedCellBuffer;
348                fn $mth(self, rhs: &MaskedCellBuffer) -> Self::Output {
349                    $trt::$mth(&self, &rhs)
350                }
351            }
352            // RHS scalar
353            // TODO: figure out how to implement LHS scalar, avoiding orphan rule.
354            impl<R> $trt<R> for MaskedCellBuffer
355            where
356                R: Into<CellValue>,
357            {
358                type Output = MaskedCellBuffer;
359                fn $mth(self, rhs: R) -> Self::Output {
360                    let r: CellValue = rhs.into();
361                    let (buf, mask) = self.into();
362                    let new_buf = buf.into_iter().map(|l | l $op r).collect();
363                    Self::new(new_buf, mask)
364                }
365            }
366        };
367    }
368    cb_bin_op!(Add, add, +);
369    cb_bin_op!(Sub, sub, -);
370    cb_bin_op!(Mul, mul, *);
371    cb_bin_op!(Div, div, /);
372
373    impl Neg for &MaskedCellBuffer {
374        type Output = MaskedCellBuffer;
375        fn neg(self) -> Self::Output {
376            Self::Output::new(self.buffer().neg(), self.mask().clone())
377        }
378    }
379    impl Neg for MaskedCellBuffer {
380        type Output = MaskedCellBuffer;
381        fn neg(self) -> Self::Output {
382            Self::Output::new(self.buffer().neg(), self.1)
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use crate::{BufferOps, CellBuffer, CellType, CellValue, Mask, MaskedCellBuffer, NoData};
390
391    fn filler(i: usize) -> u8 {
392        i as u8
393    }
394    fn masker(i: usize) -> bool {
395        i % 2 == 0
396    }
397    fn filler_masker(i: usize) -> (u8, bool) {
398        (filler(i), masker(i))
399    }
400
401    #[test]
402    fn ctor() {
403        let m = MaskedCellBuffer::fill_via(3, filler);
404        let r = MaskedCellBuffer::new(CellBuffer::fill_via(3, filler), Mask::fill(3, true));
405        assert_eq!(m, r);
406
407        let m = MaskedCellBuffer::from_vec(vec![0.0; 4]);
408        assert_eq!(m.mask().counts(), (4, 0));
409        let m = MaskedCellBuffer::with_defaults(4, CellType::Int8);
410        assert_eq!(m.mask().counts(), (4, 0));
411    }
412
413    #[test]
414    fn vec_with_nodata() {
415        let v = vec![1.0, f64::NAN, 3.0, f64::NAN];
416        let m = MaskedCellBuffer::from_vec_with_nodata(v.clone(), NoData::Default);
417        assert_eq!(
418            m,
419            MaskedCellBuffer::new(v.clone().into(), Mask::new(vec![true, false, true, false]))
420        );
421        let m = MaskedCellBuffer::from_vec_with_nodata(v.clone(), NoData::new(3.0));
422        assert_eq!(
423            m,
424            MaskedCellBuffer::new(v.into(), Mask::new(vec![true, true, false, true]))
425        );
426    }
427
428    #[test]
429    fn get_masked() {
430        let mut buf = MaskedCellBuffer::fill_with_mask_via(9, filler_masker);
431        assert_eq!(buf.get(4), 4.into());
432        assert_eq!(buf.get_masked(4), Some(4.into()));
433        assert_eq!(buf.get_masked(5), None);
434        buf.put(5, CellValue::new(4u8)).unwrap();
435        assert_eq!(buf.get_masked(5), None);
436
437        buf.mask_mut().put(5, true);
438        assert_eq!(buf.get_masked(5), Some(4.into()));
439        buf.put_with_mask(5, CellValue::new(99u8), false).unwrap();
440        assert_eq!(buf.get_masked(5), None);
441    }
442
443    #[test]
444    fn convert() {
445        let buf = MaskedCellBuffer::fill_with_mask_via(4, filler_masker);
446        let r = buf.convert(CellType::Float64).unwrap();
447        assert_eq!(r.to_vec::<f64>().unwrap(), [0.0, 1.0, 2.0, 3.0]);
448    }
449
450    #[test]
451    fn extend() {
452        let mut buf = MaskedCellBuffer::fill(3, 0.into());
453        buf.extend([(1, false)]);
454        assert_eq!(buf.get_masked(0), Some(0.into()));
455        assert_eq!(buf.get_masked(3), None);
456    }
457
458    #[test]
459    fn from_iter() {
460        let buf: MaskedCellBuffer = (0..5i16).collect();
461        assert!(buf.mask().all(true));
462        assert_eq!(buf.to_vec::<i16>().unwrap(), [0, 1, 2, 3, 4i16]);
463    }
464
465    #[test]
466    fn unary() {
467        let mbuf = MaskedCellBuffer::fill_with_mask_via(9, filler_masker);
468        let r = -&mbuf;
469        let v = r.to_vec_with_nodata::<i16>(NoData::Default).unwrap();
470
471        #[rustfmt::skip]
472        assert_eq!(
473            v,
474            vec![0, i16::MIN, -2, i16::MIN, -4, i16::MIN, -6, i16::MIN, -8]
475        );
476
477        let r = -mbuf;
478
479        assert_eq!(r.to_vec_with_nodata::<i16>(NoData::Default).unwrap(), v);
480    }
481
482    #[test]
483    fn min_max() {
484        let mbuf = MaskedCellBuffer::fill_with_mask_via(9, |i| (filler(i), i != 0 && i != 8));
485        assert_eq!(mbuf.min_max(), (1u8.into(), 7u8.into()));
486    }
487
488    #[test]
489    fn scalar() {
490        // All `true` case
491        let mbuf = MaskedCellBuffer::fill_with_mask_via(9, |i| (filler(i), true));
492        let r = mbuf * 2.0;
493        let expected = CellBuffer::fill_via(9, filler) * 2.0;
494        assert_eq!(r, expected.clone().into());
495
496        // Alternating mask case
497        let mbuf = MaskedCellBuffer::fill_with_mask_via(9, filler_masker);
498        let r = mbuf * 2.0;
499        assert_ne!(r, expected.into());
500
501        let v = r
502            .to_vec_with_nodata::<f64>(NoData::Value(f64::MIN))
503            .unwrap();
504
505        #[rustfmt::skip]
506        assert_eq!(
507            v,
508            vec![0.0, f64::MIN, 4.0, f64::MIN, 8.0, f64::MIN, 12.0, f64::MIN, 16.0]
509        );
510    }
511
512    #[test]
513    fn binary() {
514        let lhs =
515            MaskedCellBuffer::new(CellBuffer::fill(9, 1f64.into()), Mask::fill_via(9, masker));
516        let rhs = MaskedCellBuffer::new(CellBuffer::fill(9, 2f64.into()), Mask::fill(9, true));
517
518        macro_rules! test_ops {
519            ($($op:tt)*) => {$({
520                let r = &lhs $op &rhs;
521                assert_eq!(r.get_masked(0), Some((1f64 $op 2f64).into()));
522                assert_eq!(r.get_masked(1), None);
523                let r = lhs.clone() $op &rhs;
524                assert_eq!(r.get_masked(2), Some((1f64 $op 2f64).into()));
525                assert_eq!(r.get_masked(3), None);
526                let r = lhs.clone() $op rhs.clone();
527                assert_eq!(r.get_masked(4), Some((1f64 $op 2f64).into()));
528                assert_eq!(r.get_masked(5), None);
529            })*};
530        }
531        test_ops!(+ - * /);
532    }
533
534    #[test]
535    fn debug() {
536        let m: MaskedCellBuffer = (0..1).collect();
537        let dbg = format!("{m:?}");
538        assert!(dbg.starts_with("Int32MaskedCellBuffer"));
539        assert!(dbg.contains("CellBuffer(0)"));
540        assert!(dbg.contains("Mask(true)"));
541    }
542}