erased_cells/
buffer.rs

1use std::fmt::{Debug, Formatter};
2
3use num_traits::ToPrimitive;
4use paste::paste;
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8use crate::error::{Error, Result};
9use crate::{with_ct, BufferOps, CellEncoding, CellType, CellValue};
10
11pub use self::ops::*;
12
13/// CellBuffer enum constructor.
14macro_rules! cb_enum {
15    ( $(($id:ident, $p:ident)),*) => {
16        /// An enum over buffers of [`CellEncoding`] types.
17        ///
18        /// See [module documentation](crate#example) for usage example.
19        ///
20        /// # Example
21        ///
22        /// ```rust
23        /// use erased_cells::{CellBuffer, CellType, CellValue, BufferOps};
24        /// // Fill a buffer with the `u8` numbers `0..=8`.
25        /// let buf1 = CellBuffer::fill_via(9, |i| i as u8);
26        ///
27        /// // `u8` maps to `CellType::UInt8`
28        /// assert_eq!(buf1.cell_type(), CellType::UInt8);
29        ///
30        /// // A fetching values maintains its CellType through a CellValue.
31        /// let val: CellValue = buf1.get(3);
32        /// assert_eq!(val, CellValue::UInt8(3));
33        /// let (min, max): (CellValue, CellValue) = buf1.min_max();
34        /// assert_eq!((min, max), (CellValue::UInt8(0), CellValue::UInt8(8)));
35        ///
36        /// // Basic math ops work on CellValues. Primitives can be converted to CellValues with `into`.
37        /// // Math ops coerce to floating point values.
38        /// assert_eq!(((max - min + 1) / 2), 4.5.into());
39        ///
40        /// // Fill another buffer with the `f32` numbers `8..=0`.
41        /// let buf2 = CellBuffer::fill_via(9, |i| 8f32 - i as f32);
42        /// assert_eq!(buf2.cell_type(), CellType::Float32);
43        /// assert_eq!(
44        ///     buf2.min_max(),
45        ///     (CellValue::Float32(0.0), CellValue::Float32(8.0))
46        /// );
47        ///
48        /// // Basic math ops also work on CellBuffers, applied element-wise.
49        /// let diff = buf2 - buf1;
50        /// assert_eq!(diff.min_max(), ((-8).into(), 8.into()));
51        /// ```
52        #[derive(Clone)]
53        #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
54        pub enum CellBuffer { $($id(Vec<$p>)),* }
55    }
56}
57with_ct!(cb_enum);
58
59impl CellBuffer {}
60
61impl BufferOps for CellBuffer {
62    fn from_vec<T: CellEncoding>(data: Vec<T>) -> Self {
63        data.into()
64    }
65
66    fn with_defaults(len: usize, ct: CellType) -> Self {
67        macro_rules! empty {
68            ( $(($id:ident, $p:ident)),*) => {
69                match ct {
70                    $(CellType::$id => Self::from_vec(vec![$p::default(); len]),)*
71                }
72            };
73        }
74        with_ct!(empty)
75    }
76
77    fn fill(len: usize, value: CellValue) -> Self {
78        macro_rules! empty {
79            ( $(($id:ident, $p:ident)),*) => {
80                match value.cell_type() {
81                    $(CellType::$id => Self::from_vec::<$p>(vec![value.get().unwrap(); len]),)*
82                }
83            };
84        }
85        with_ct!(empty)
86    }
87
88    fn fill_via<T, F>(len: usize, f: F) -> Self
89    where
90        T: CellEncoding,
91        F: Fn(usize) -> T,
92    {
93        let v: Vec<T> = (0..len).map(f).collect();
94        Self::from_vec(v)
95    }
96
97    fn len(&self) -> usize {
98        macro_rules! len {
99            ( $(($id:ident, $_p:ident)),*) => {
100                match self {
101                    $(CellBuffer::$id(v) => v.len(),)*
102                }
103            };
104        }
105        with_ct!(len)
106    }
107
108    fn is_empty(&self) -> bool {
109        self.len() == 0
110    }
111
112    fn cell_type(&self) -> CellType {
113        macro_rules! ct {
114            ( $(($id:ident, $_p:ident)),*) => {
115                match self {
116                    $(CellBuffer::$id(_) => CellType::$id,)*
117                }
118            };
119        }
120        with_ct!(ct)
121    }
122
123    fn get(&self, index: usize) -> CellValue {
124        macro_rules! get {
125            ( $(($id:ident, $_p:ident)),*) => {
126                match self {
127                    $(CellBuffer::$id(b) => CellValue::$id(b[index]),)*
128                }
129            };
130        }
131        with_ct!(get)
132    }
133
134    fn put(&mut self, idx: usize, value: CellValue) -> Result<()> {
135        let value = value.convert(self.cell_type())?;
136        macro_rules! put {
137            ( $(($id:ident, $_p:ident)),*) => {
138                match (self, value) {
139                    $((CellBuffer::$id(b), CellValue::$id(v)) => b[idx] = v,)*
140                    _ => unreachable!(),
141                }
142            }
143        }
144        with_ct!(put);
145        Ok(())
146    }
147
148    fn convert(&self, cell_type: CellType) -> Result<Self> {
149        if cell_type == self.cell_type() {
150            return Ok(self.clone());
151        }
152
153        let err = || Error::NarrowingError { src: self.cell_type(), dst: cell_type };
154
155        if !self.cell_type().can_fit_into(cell_type) {
156            return Err(err());
157        }
158
159        let r: CellBuffer = self
160            .into_iter()
161            .map(|v| v.convert(cell_type).unwrap())
162            .collect();
163
164        Ok(r)
165    }
166
167    fn min_max(&self) -> (CellValue, CellValue) {
168        let init = (self.cell_type().max_value(), self.cell_type().min_value());
169        self.into_iter()
170            .fold(init, |(amin, amax), v| (amin.min(v), amax.max(v)))
171    }
172
173    fn to_vec<T: CellEncoding>(self) -> Result<Vec<T>> {
174        let r = self.convert(T::cell_type())?;
175        macro_rules! to_vec {
176            ( $(($id:ident, $_p:ident)),*) => {
177                match r {
178                    $(CellBuffer::$id(b) => Ok(danger::cast(b)),)*
179                }
180            }
181        }
182        with_ct!(to_vec)
183    }
184}
185
186impl Debug for CellBuffer {
187    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
188        use crate::Elided;
189        let basename = self.cell_type().to_string();
190        macro_rules! render {
191            ( $(($id:ident, $_p:ident)),*) => {{
192                f.write_fmt(format_args!("{basename}CellBuffer("))?;
193                match self {
194                    $(CellBuffer::$id(b) => f.write_fmt(format_args!("{:?}", Elided(b)))?,)*
195                };
196                f.write_str(")")
197            }}
198        }
199        with_ct!(render)
200    }
201}
202
203impl<C: CellEncoding> Extend<C> for CellBuffer {
204    fn extend<T: IntoIterator<Item = C>>(&mut self, iter: T) {
205        macro_rules! render {
206            ( $(($id:ident, $p:ident)),*) => { paste! {
207                match self {
208                    $(CellBuffer::$id(b) => {
209                        let conv_iter = iter.into_iter().map(|c| {
210                            c.into_cell_value().[<to_ $p>]().unwrap()
211                        });
212                        b.extend(conv_iter)
213                    },)*
214                }
215            }}
216        }
217        with_ct!(render);
218    }
219}
220
221impl<C: CellEncoding> FromIterator<C> for CellBuffer {
222    fn from_iter<T: IntoIterator<Item = C>>(iter: T) -> Self {
223        Self::from_vec(iter.into_iter().collect())
224    }
225}
226
227impl FromIterator<CellValue> for CellBuffer {
228    fn from_iter<T: IntoIterator<Item = CellValue>>(iterable: T) -> Self {
229        // TODO: is there a way to avoid this collect?
230        let values = iterable.into_iter().collect::<Vec<CellValue>>();
231        match values.as_slice() {
232            [] => CellBuffer::with_defaults(0, CellType::UInt8),
233            [x, ..] => {
234                let ct: CellType = x.cell_type();
235                macro_rules! conv {
236                    ( $(($id:ident, $_p:ident)),*) => {
237                        match ct {
238                            $(CellType::$id => {
239                                CellBuffer::$id(values.iter().map(|v| v.get().unwrap()).collect())
240                            })*
241                        }
242                    }
243                }
244                with_ct!(conv)
245            }
246        }
247    }
248}
249
250impl<T: CellEncoding> From<Vec<T>> for CellBuffer {
251    fn from(values: Vec<T>) -> Self {
252        macro_rules! from {
253            ( $(($id:ident, $_p:ident)),*) => {
254                match T::cell_type() {
255                    $(CellType::$id => Self::$id(danger::cast(values)),)*
256                }
257            }
258        }
259        with_ct!(from)
260    }
261}
262
263impl<T: CellEncoding> From<&[T]> for CellBuffer {
264    fn from(values: &[T]) -> Self {
265        macro_rules! from {
266            ( $(($id:ident, $_p:ident)),*) => {
267                match T::cell_type() {
268                    $(CellType::$id => Self::$id(danger::cast(values.to_vec())),)*
269                }
270            }
271        }
272        with_ct!(from)
273    }
274}
275
276impl<'buf> IntoIterator for &'buf CellBuffer {
277    type Item = CellValue;
278    type IntoIter = CellBufferIterator<'buf>;
279    fn into_iter(self) -> Self::IntoIter {
280        CellBufferIterator { buf: self, idx: 0, len: self.len() }
281    }
282}
283
284/// Iterator over [`CellValue`] elements in a [`CellBuffer`].
285pub struct CellBufferIterator<'buf> {
286    buf: &'buf CellBuffer,
287    idx: usize,
288    len: usize,
289}
290
291impl Iterator for CellBufferIterator<'_> {
292    type Item = CellValue;
293
294    fn next(&mut self) -> Option<Self::Item> {
295        if self.idx >= self.len {
296            None
297        } else {
298            let r = self.buf.get(self.idx);
299            self.idx += 1;
300            Some(r)
301        }
302    }
303}
304
305impl<C: CellEncoding> TryFrom<CellBuffer> for Vec<C> {
306    type Error = Error;
307
308    fn try_from(value: CellBuffer) -> Result<Self> {
309        value.to_vec()
310    }
311}
312
313mod ops {
314    use std::cmp::Ordering;
315    use std::ops::{Add, Div, Mul, Neg, Sub};
316
317    use crate::{BufferOps, CellBuffer, CellValue};
318
319    macro_rules! cb_bin_op {
320        ($trt:ident, $mth:ident, $op:tt) => {
321            // Both borrows.
322            impl $trt for &CellBuffer {
323                type Output = CellBuffer;
324                fn $mth(self, rhs: Self) -> Self::Output {
325                    self.into_iter().zip(rhs.into_iter()).map(|(l, r)| l $op r).collect()
326                }
327            }
328            // Both owned/consumed
329            impl $trt for CellBuffer {
330                type Output = CellBuffer;
331                fn $mth(self, rhs: Self) -> Self::Output {
332                    $trt::$mth(&self, &rhs)
333                }
334            }
335            // RHS borrow
336            impl $trt<&CellBuffer> for CellBuffer {
337                type Output = CellBuffer;
338                fn $mth(self, rhs: &CellBuffer) -> Self::Output {
339                    $trt::$mth(&self, &rhs)
340                }
341            }
342            // RHS scalar
343            // TODO: figure out how to implement LHS scalar, avoiding orphan rule.
344            impl <R> $trt<R> for CellBuffer where R: Into<CellValue> {
345                type Output = CellBuffer;
346                fn $mth(self, rhs: R) -> Self::Output {
347                    let r: CellValue = rhs.into();
348                    self.into_iter().map(|l | l $op r).collect()
349                }
350            }
351        }
352    }
353    cb_bin_op!(Add, add, +);
354    cb_bin_op!(Sub, sub, -);
355    cb_bin_op!(Mul, mul, *);
356    cb_bin_op!(Div, div, /);
357
358    impl Neg for &CellBuffer {
359        type Output = CellBuffer;
360        fn neg(self) -> Self::Output {
361            self.into_iter().map(|v| -v).collect()
362        }
363    }
364    impl Neg for CellBuffer {
365        type Output = CellBuffer;
366        fn neg(self) -> Self::Output {
367            Neg::neg(&self)
368        }
369    }
370
371    impl PartialEq<Self> for CellBuffer {
372        fn eq(&self, other: &Self) -> bool {
373            Ord::cmp(self, other) == Ordering::Equal
374        }
375    }
376
377    impl Eq for CellBuffer {}
378
379    impl PartialOrd for CellBuffer {
380        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
381            Some(self.cmp(other))
382        }
383    }
384
385    /// Computes ordering for [`CellBuffer`]. Unlike `Vec<CellEncoding>`, floating point
386    /// cell types are compared with `{f32|f64}::total_cmp`.  
387    impl Ord for CellBuffer {
388        fn cmp(&self, other: &Self) -> Ordering {
389            let lct = self.cell_type();
390            let rct = other.cell_type();
391
392            // If the cell types are different, then base comparison on that.
393            match lct.cmp(&rct) {
394                Ordering::Equal => (),
395                ne => return ne,
396            }
397
398            macro_rules! total_cmp {
399                ($l:ident, $r:ident) => {{
400                    // Implementation derived from core::slice::cmp::SliceOrd.
401                    let len = $l.len().min($r.len());
402
403                    // Slice to the loop iteration range to enable bound check
404                    // elimination in the compiler
405                    let lhs = &$l[..len];
406                    let rhs = &$r[..len];
407
408                    for i in 0..len {
409                        match lhs[i].total_cmp(&rhs[i]) {
410                            Ordering::Equal => (),
411                            non_eq => return non_eq,
412                        }
413                    }
414                    $l.len().cmp(&$r.len())
415                }};
416            }
417            // lhs & rhs have to be the same variant.
418            // For integral values, defer to `Vec`'s `Ord`.
419            // For floating values, we use `total_ord`.
420            match (self, other) {
421                (CellBuffer::UInt8(lhs), CellBuffer::UInt8(rhs)) => Ord::cmp(&lhs, &rhs),
422                (CellBuffer::UInt16(lhs), CellBuffer::UInt16(rhs)) => Ord::cmp(&lhs, &rhs),
423                (CellBuffer::UInt32(lhs), CellBuffer::UInt32(rhs)) => Ord::cmp(&lhs, &rhs),
424                (CellBuffer::UInt64(lhs), CellBuffer::UInt64(rhs)) => Ord::cmp(&lhs, &rhs),
425                (CellBuffer::Int8(lhs), CellBuffer::Int8(rhs)) => Ord::cmp(&lhs, &rhs),
426                (CellBuffer::Int16(lhs), CellBuffer::Int16(rhs)) => Ord::cmp(&lhs, &rhs),
427                (CellBuffer::Int32(lhs), CellBuffer::Int32(rhs)) => Ord::cmp(&lhs, &rhs),
428                (CellBuffer::Int64(lhs), CellBuffer::Int64(rhs)) => Ord::cmp(&lhs, &rhs),
429                (CellBuffer::Float32(l), CellBuffer::Float32(r)) => total_cmp!(l, r),
430                (CellBuffer::Float64(l), CellBuffer::Float64(r)) => total_cmp!(l, r),
431                _ => unreachable!("{self:?} <> {other:?}"),
432            }
433        }
434    }
435}
436
437mod danger {
438    use crate::CellEncoding;
439
440    #[inline]
441    pub(crate) fn cast<T: CellEncoding, P: CellEncoding>(buffer: Vec<T>) -> Vec<P> {
442        assert_eq!(T::cell_type(), P::cell_type());
443        // As suggested in https://doc.rust-lang.org/core/intrinsics/fn.transmute.html
444        unsafe {
445            let mut v = std::mem::ManuallyDrop::new(buffer);
446            Vec::from_raw_parts(v.as_mut_ptr() as *mut P, v.len(), v.capacity())
447        }
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use crate::{with_ct, BufferOps, CellBuffer, CellType, CellValue};
454
455    fn bigger(start: CellType) -> impl Iterator<Item = CellType> {
456        CellType::iter().filter(move |ct| start.can_fit_into(*ct))
457    }
458
459    #[test]
460    fn defaults() {
461        macro_rules! test {
462            ($( ($id:ident, $p:ident) ),*) => {
463                $({
464                    let cv = CellBuffer::with_defaults(3, CellType::$id);
465                    assert_eq!(cv.len(), 3);
466                    assert_eq!(cv.get(0), CellValue::new(<$p>::default()));
467                })*};
468        }
469        with_ct!(test);
470    }
471
472    #[test]
473    fn put_get() {
474        use num_traits::One;
475        macro_rules! test {
476            ($( ($id:ident, $p:ident) ),*) => {
477                $({
478                    let mut cv = CellBuffer::fill(3, <$p>::default().into());
479                    let one = CellValue::new(<$p>::one());
480                    cv.put(1, one).expect("Put one");
481                    assert_eq!(cv.get(1), one.convert(CellType::$id).unwrap());
482                })*};
483        }
484        with_ct!(test);
485    }
486
487    #[test]
488    fn extend() {
489        let mut buf = CellBuffer::fill(3, 0u8.into());
490        assert!(!buf.is_empty());
491        assert_eq!(buf.cell_type(), CellType::UInt8);
492        buf.extend([1]);
493        assert_eq!(buf.cell_type(), CellType::UInt8);
494        assert_eq!(buf.get(0), 0.into());
495        assert_eq!(buf.get(3), 1.into());
496    }
497
498    #[test]
499    fn to_vec() {
500        macro_rules! test {
501            ($( ($id:ident, $p:ident) ),*) => {
502                $({
503                    let v = vec![<$p>::default(); 3];
504                    let buf = CellBuffer::from_vec(v.clone());
505                    let r = buf.to_vec::<$p>().unwrap();
506                    assert_eq!(r, v);
507                })*
508            };
509        }
510        with_ct!(test);
511    }
512
513    #[test]
514    fn min_max() {
515        let buf = CellBuffer::from_vec(vec![-1.0, 3.0, 2000.0, -5555.5]);
516        let (min, max) = buf.min_max();
517        assert_eq!(min, CellValue::Float64(-5555.5));
518        assert_eq!(max, CellValue::Float64(2000.0));
519
520        let buf = CellBuffer::from_vec(vec![1u8, 3u8, 200u8, 0u8]);
521        let (min, max) = buf.min_max();
522        assert_eq!(min, CellValue::UInt8(0));
523        assert_eq!(max, CellValue::UInt8(200));
524    }
525
526    #[test]
527    fn from_others() {
528        let v = vec![
529            CellValue::UInt16(3),
530            CellValue::UInt16(4),
531            CellValue::UInt16(5),
532        ];
533
534        let b: CellBuffer = v.into_iter().collect();
535        assert_eq!(b.cell_type(), CellType::UInt16);
536        assert_eq!(b.len(), 3);
537        assert_eq!(b.get(2), CellValue::UInt16(5));
538
539        let v = vec![33.3f32, 44.4, 55.5];
540        let b: CellBuffer = v.clone().into_iter().collect();
541        assert_eq!(b.cell_type(), CellType::Float32);
542        assert_eq!(b.len(), 3);
543        assert_eq!(b.get(2), CellValue::Float32(55.5));
544
545        let b: CellBuffer = v.clone().into();
546        assert_eq!(b.cell_type(), CellType::Float32);
547        assert_eq!(b.len(), 3);
548        assert_eq!(b.get(2), CellValue::Float32(55.5));
549
550        let b: CellBuffer = v.clone().as_slice().into();
551        assert_eq!(b.cell_type(), CellType::Float32);
552        assert_eq!(b.len(), 3);
553        assert_eq!(b.get(2), CellValue::Float32(55.5));
554    }
555
556    #[test]
557    fn debug() {
558        let b = CellBuffer::fill(5, 37.into());
559        assert!(format!("{b:?}").starts_with("Int32CellBuffer"));
560        let b = CellBuffer::fill(15, 37.into());
561        assert!(format!("{b:?}").contains("..."));
562    }
563
564    #[test]
565    fn convert() {
566        for ct in CellType::iter() {
567            let buf = CellBuffer::with_defaults(3, ct);
568            for target in bigger(ct) {
569                let r = buf.convert(target);
570                assert!(r.is_ok(), "{ct} vs {target}");
571                let r = r.unwrap();
572
573                assert_eq!(r.cell_type(), target);
574            }
575        }
576    }
577
578    #[test]
579    fn unary() {
580        use num_traits::One;
581        macro_rules! test {
582            ($( ($id:ident, $p:ident) ),*) => {$({
583                let one: CellValue = <$p>::one().into();
584                let buf = -CellBuffer::fill(3, one);
585                assert_eq!(buf.get(0), -one);
586            })*};
587        }
588
589        with_ct!(test);
590    }
591
592    #[test]
593    fn binary() {
594        for lhs_ct in CellType::iter() {
595            let lhs_val = lhs_ct.one();
596            for rhs_ct in CellType::iter() {
597                let lhs = CellBuffer::fill(3, lhs_val);
598                let rhs_val = rhs_ct.one() + rhs_ct.one();
599                let rhs = CellBuffer::fill(3, rhs_val);
600                assert_eq!((&lhs + &rhs).get(0), lhs_val + rhs_val);
601                assert_eq!((&rhs + &lhs).get(1), rhs_val + lhs_val);
602                assert_eq!((&lhs - &rhs).get(2), lhs_val - rhs_val);
603                assert_eq!((&rhs - &lhs).get(0), rhs_val - lhs_val);
604                assert_eq!((&lhs * &rhs).get(1), lhs_val * rhs_val);
605                assert_eq!((&rhs * &lhs).get(2), rhs_val * lhs_val);
606                assert_eq!((&lhs / &rhs).get(0), lhs_val / rhs_val);
607                assert_eq!((&rhs / &lhs).get(1), rhs_val / lhs_val);
608                // Consuming (non-borrow) case
609                assert_eq!((rhs / lhs).get(2), rhs_val / lhs_val);
610            }
611        }
612    }
613
614    #[test]
615    fn scalar() {
616        let buf = CellBuffer::fill_via(9, |i| i as u8 + 1);
617        let r = buf * 2.0;
618        assert_eq!(r, CellBuffer::fill_via(9, |i| (i as f64 + 1.0) * 2.0));
619    }
620
621    #[test]
622    fn equal() {
623        let buf = CellBuffer::fill_via(9, |i| if i % 2 == 0 { f64::NAN } else { i as f64 });
624        assert_eq!(buf, buf);
625        assert_eq!(
626            CellBuffer::with_defaults(4, CellType::UInt8),
627            CellBuffer::with_defaults(4, CellType::UInt8)
628        );
629        assert_ne!(
630            CellBuffer::with_defaults(4, CellType::UInt8),
631            CellBuffer::with_defaults(5, CellType::UInt8)
632        );
633    }
634
635    #[test]
636    fn cmp() {
637        assert!(CellBuffer::from_vec(vec![1, 2, 3]) < CellBuffer::from_vec(vec![2, 3, 4]));
638        // Higher positioned elements determine ordering
639        assert!(CellBuffer::from_vec(vec![1, 2, 3]) < CellBuffer::from_vec(vec![2, 3]));
640
641        assert!(
642            CellBuffer::from_vec(vec![f64::NAN, 2.0, 3.0])
643                < CellBuffer::from_vec(vec![f64::NAN, 2.0, 4.0])
644        );
645
646        assert!(
647            CellBuffer::with_defaults(4, CellType::UInt8)
648                < CellBuffer::with_defaults(4, CellType::Float32)
649        );
650        assert!(
651            CellBuffer::with_defaults(4, CellType::Float32)
652                > CellBuffer::with_defaults(4, CellType::UInt8)
653        );
654        assert!(
655            CellBuffer::with_defaults(4, CellType::UInt8)
656                < CellBuffer::with_defaults(5, CellType::UInt8)
657        );
658        assert!(
659            CellBuffer::with_defaults(5, CellType::UInt8)
660                > CellBuffer::with_defaults(4, CellType::UInt8)
661        );
662        assert!(
663            CellBuffer::with_defaults(4, CellType::Float64)
664                < CellBuffer::with_defaults(5, CellType::Float64)
665        );
666        assert!(
667            CellBuffer::with_defaults(5, CellType::Float64)
668                > CellBuffer::with_defaults(4, CellType::Float64)
669        );
670    }
671}