summed-area 1.0.0

Implementation of a summed-area table for fast sums or averages of subsections of a 2d array or an image
Documentation
#![doc = include_str!("../README.md")]

use std::ops::{Add, Sub, RangeBounds};
pub use imgref::ImgRef;

/// AKA Integral Image. Precomputed sums of subsections of a 2d array.
///
/// The implementation is generic and works with anything that can be added and subtracted (`f32`, `u64`, etc.).
///
/// The `SumType` should be a larger type for storing the sum, and `InputType` is a type of a slice.
pub struct SummedArea<SumType> {
    stride: usize,
    sums: Vec<SumType>,
}

impl<SumType> SummedArea<SumType> where SumType: Default + Copy + Add<Output=SumType> + Sub<Output=SumType> {
    /// Sum area in the slice. Height is implied from slice's length (`len/width`).
    ///
    /// Call it like `SummedArea::<u64>::new_slice::<u8>(byte_slice, width)` to specify a larger type to sum into (e.g. sum `u8`s into `u64`, so that they don't overflow).
    #[inline]
    #[must_use]
    pub fn new_slice<InputType>(input: &[InputType], width: usize) -> Self where SumType: From<InputType>, InputType: Copy {
        let height = input.len()/width;
        Self::new(ImgRef::new(&input, width, height))
    }

    /// Sum the 2d rect. See [`ImgRef::new`].
    #[must_use]
    pub fn new<InputType>(input: ImgRef<InputType>) -> Self where SumType: From<InputType>, InputType: Copy {
        let out_width = input.width()+1;
        let out_height = input.height()+1;
        let area = out_width.checked_mul(out_height).unwrap();

        let mut sums: Vec<SumType> = Vec::with_capacity(area);
        sums.resize_with(out_width, SumType::default); // first row of 0s to avoid check on y-1

        let out = &mut sums.spare_capacity_mut()[..area-out_width];

        let mut rows = input.rows().zip(out.chunks_exact_mut(out_width));
        if let Some((in_row, out_row)) = rows.next() {
            let (first, out_row) = out_row.split_first_mut().unwrap();
            debug_assert_eq!(out_row.len(), input.width());
            first.write(SumType::default()); // avoid check on x-1
            let mut row_sum = SumType::default();
            in_row.iter().copied().zip(out_row.iter_mut()).for_each(|(curr, out_col)| {
                let curr = SumType::from(curr);
                row_sum = row_sum + curr;
                out_col.write(row_sum);
            });
            let mut prev_out_row = out_row;
            rows.for_each(|(in_row, out_row)| {
                let (first, out_row) = out_row.split_first_mut().unwrap();
                debug_assert_eq!(out_row.len(), input.width());
                first.write(SumType::default()); // avoid check on x-1
                let mut up_left = SumType::default();
                let mut left = SumType::default();
                let cols = in_row.iter().copied().zip(out_row.iter_mut().zip(prev_out_row.iter_mut()));
                cols.for_each(|(curr, (out_col, prev_out_px))| {
                    // safety: previous row has been written to entirely. Use slice_assume_init_ref when stable.
                    let up = unsafe { prev_out_px.assume_init_read() };
                    let curr = SumType::from(curr);
                    let curr_out = curr + up + left - up_left;
                    out_col.write(curr_out);
                    left = curr_out;
                    up_left = up;
                });
                prev_out_row = out_row;
            });
        }
        // safety: rows from chunks exact cover this entire area
        unsafe {
            sums.set_len(area);
        }
        debug_assert_eq!(sums.len()/out_width, out_height);
        Self {
            stride: out_width,
            sums,
        }
    }

    /// Width of the rect that has been summed up
    #[inline]
    #[must_use]
    pub fn width(&self) -> usize {
        self.stride-1
    }

    /// Height of the rect that has been summed up
    #[inline]
    #[must_use]
    pub fn height(&self) -> usize {
        self.sums.len() / self.stride - 1
    }

    /// Sum of all values in a rect at (x,y) being (width,height) elements large.
    ///
    /// If x+width or y+height are out of bounds it may return numerically incorrect result or panic.
    #[inline(always)]
    #[track_caller]
    #[must_use]
    pub fn sum_rect(&self, x: usize, y: usize, width: usize, height: usize) -> SumType {
        self.try_sum_rect(x, y, width, height).expect("oob")
    }

    /// Sum of all values in a rect at (x,y) being (width,height) elements large.
    ///
    /// If x+width or y+height are out of bounds it may return numerically incorrect result or `None`.
    #[inline]
    #[must_use]
    pub fn try_sum_rect(&self, x1: usize, y1: usize, width: usize, height: usize) -> Option<SumType> {
        self.try_sum_bounds(x1, y1, x1+width, y1+height)
    }

    /// Sum of all values in a rect spanning `horizontal` colums and `vertical` columns (indexed from 0).
    ///
    /// If ranges are out of bounds it may return numerically incorrect result or panic.
    #[inline(always)]
    #[track_caller]
    #[must_use]
    pub fn sum_range(&self, horizontal: impl RangeBounds<usize>, vertical: impl RangeBounds<usize>) -> SumType {
        self.try_sum_range(horizontal, vertical).expect("oob")
    }

    /// Sum of all values in a rect spanning `horizontal` colums and `vertical` columns (indexed from 0).
    ///
    /// If ranges are out of bounds it may return numerically incorrect result or `None`.
    #[inline]
    #[must_use]
    pub fn try_sum_range(&self, horizontal: impl RangeBounds<usize>, vertical: impl RangeBounds<usize>) -> Option<SumType> {
        let (x1, x2) = bounds(horizontal, self.stride);
        let (y1, y2) = bounds(vertical, self.height());
        self.try_sum_bounds(x1, y1, x2, y2)
    }

    #[inline]
    fn try_sum_bounds(&self, x1: usize, y1: usize, x2: usize, y2: usize) -> Option<SumType> {
        let x1y1 = x1 + y1 * self.stride;
        let x1y2 = x1 + y2 * self.stride;
        let x2y1 = x2 + y1 * self.stride;
        let x2y2 = x2 + y2 * self.stride;
        // there's no way to account for all possible numeric overflows in less than 4 branches
        let tl = *self.sums.get(x1y1)?;
        let tr = *self.sums.get(x2y1)?;
        let bl = *self.sums.get(x1y2)?;
        let br = *self.sums.get(x2y2)?;
        Some(tl + br - tr - bl)
    }
}

fn bounds(range: impl RangeBounds<usize>, max: usize) -> (usize, usize) {
    let start = match range.start_bound() {
        std::ops::Bound::Included(x) => *x,
        std::ops::Bound::Unbounded => 0,
        std::ops::Bound::Excluded(x) => *x+1,
    };
    let end = match range.end_bound() {
        std::ops::Bound::Included(x) => *x+1,
        std::ops::Bound::Excluded(x) => *x,
        std::ops::Bound::Unbounded => max,
    };
    (start, end)
}

#[test]
fn bounds_test() {
    assert_eq!((0, 99), bounds(0.., 99));
    assert_eq!((0, 0), bounds(0..0, 99));
    assert_eq!((0, 1), bounds(0..1, 99));
    assert_eq!((0, 2), bounds(0..=1, 99));
    assert_eq!((0, 1), bounds(0..=0, 99));
    assert_eq!((0, 50), bounds(..50, 99));
    assert_eq!((0, 50), bounds(..=49, 99));
}

#[test]
fn wiki() {
    let s = SummedArea::new_slice(&[
        31.,2.,4.,33.,5.,36.,
        12.,26.,9.,10.,29.,25.,
        13.,17.,21.,22.,20.,18.,
        24.,23.,15.,16.,14.,19.,
        30.,8.,28.,27.,11.,7.,
        1.,35.,34.,3.,32.,6.,
    ], 6);
    assert_eq!(s.sums.len(), 7*7);
    assert_eq!(s.width(), 6);
    assert_eq!(s.height(), 6);
    assert_eq!([
        0., 0.,0.,0.,0.,0.,0.,
        0., 31.,33.,37.,70.,75.,111.,
        0., 43.,71.,84.,127.,161.,222.,
        0., 56.,101.,135.,200.,254.,333.,
        0., 80.,148.,197.,278.,346.,444.,
        0., 110.,186.,263.,371.,450.,555.,
        0., 111.,222.,333.,444.,555.,666.,
    ], s.sums.as_slice());

    assert_eq!(111., s.sum_rect(2, 3, 3, 2));
    assert_eq!(111., s.try_sum_range(2..=4, 3..5).unwrap());
}

#[test]
fn pixels() {
    let s = SummedArea::<rgb::RGB<u16>>::new_slice(&[
        rgb::RGB::<u8>::new(1,2,3),
        rgb::RGB::<u8>::new(4,5,6),
    ], 1);
    assert_eq!(2, s.height());
    assert_eq!(rgb::RGB::<u16>::new(1,2,3), s.sum_range(0..1, 0..1));
    assert_eq!(rgb::RGB::new(4,5,6), s.sum_range(0..1, 1..2));
    assert_eq!(rgb::RGB::new(5,7,9), s.sum_range(0..1, 0..2));
}

#[test]
fn ones() {
    let _ = SummedArea::<i64>::new_slice(&[1i32], 1);

    let s = SummedArea::new_slice(&[
        1.,1.,1.,
        1.,1.,1.,
        1.,1.,1.,
    ], 3);
    assert_eq!(s.height(), 3);
    assert_eq!([
        0.0, 0.0, 0.0, 0.0,
        0.0, 1.0, 2.0, 3.0,
        0.0, 2.0, 4.0, 6.0,
        0.0, 3.0, 6.0, 9.0,
    ], s.sums.as_slice());

    assert_eq!(0., s.sum_rect(0,0,0,0));
    assert_eq!(0., s.sum_rect(2,2,0,0));
    assert_eq!(0., s.sum_rect(0,0,1,0));
    assert_eq!(0., s.sum_rect(0,0,0,1));
    assert_eq!(1., s.sum_rect(0,0,1,1));
    assert_eq!(1., s.sum_rect(1,1,1,1));
    assert_eq!(2., s.sum_rect(0,0,1,2));
    assert_eq!(2., s.sum_rect(0,0,2,1));
    assert_eq!(4., s.sum_rect(0,0,2,2));
    assert_eq!(4., s.sum_rect(1,1,2,2));
    assert_eq!(6., s.sum_rect(0,0,2,3));
    assert_eq!(6., s.sum_rect(0,0,3,2));
    assert_eq!(6., s.sum_rect(0,1,3,2));
    assert_eq!(9., s.sum_rect(0,0,3,3));
    assert_eq!(1., s.sum_rect(2,2,1,1));
}