burn-vision 0.21.0

Vision processing operations for burn tensors
Documentation
use std::{cmp::Ordering, marker::PhantomData};

use alloc::vec::Vec;
use burn_tensor::{
    Bool, DType, Element, ElementConversion, ElementLimits, ElementOrdered, Int, Shape, Tensor,
    TensorData,
    backend::Backend,
    ops::{BoolTensor, IntTensor},
};
use ndarray::Array2;

use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity};

mod spaghetti;
mod spaghetti_4c;

/// Dispatches connected components based on `B::IntElem::dtype()`, binding a concrete
/// integer type to enable generic instantiations without extra trait bounds (after removing
/// `ElementComparison` from `Element`).
macro_rules! dispatch_int_dtype {
    (|$ty:ident| $body:expr) => {
        match B::IntElem::dtype() {
            DType::I64 => {
                type $ty = i64;
                $body
            }
            DType::I32 => {
                type $ty = i32;
                $body
            }
            DType::I16 => {
                type $ty = i16;
                $body
            }
            DType::I8 => {
                type $ty = i8;
                $body
            }
            DType::U64 => {
                type $ty = u64;
                $body
            }
            DType::U32 => {
                type $ty = u32;
                $body
            }
            DType::U16 => {
                type $ty = u16;
                $body
            }
            DType::U8 => {
                type $ty = u8;
                $body
            }
            _ => unreachable!("Unsupported dtype"),
        }
    };
}

pub fn connected_components<B: Backend>(
    img: BoolTensor<B>,
    connectivity: Connectivity,
) -> IntTensor<B> {
    dispatch_int_dtype!(|I| run::<B, I, NoOp<_>>(img, connectivity, NoOp::default).0)
}

pub fn connected_components_with_stats<B: Backend>(
    img: BoolTensor<B>,
    connectivity: Connectivity,
    _options: ConnectedStatsOptions,
) -> (IntTensor<B>, ConnectedStatsPrimitive<B>) {
    let device = B::bool_device(&img);

    dispatch_int_dtype!(|I| {
        let (labels, stats) =
            run::<B, I, ConnectedStatsOp<I>>(img, connectivity, ConnectedStatsOp::default);
        let stats = finalize_stats(&device, stats);
        (labels, stats)
    })
}

fn run<B: Backend, I: ElementOrdered, Stats: StatsOp<Label = I>>(
    img: BoolTensor<B>,
    connectivity: Connectivity,
    stats: impl Fn() -> Stats,
) -> (IntTensor<B>, Stats) {
    let device = B::bool_device(&img);
    let img = Tensor::<B, 2, Bool>::from_primitive(img);
    let [height, width] = img.shape().dims();
    let img = img.into_data();
    let img = img.into_vec::<B::BoolElem>().unwrap();

    let mut stats = stats();

    let out = match connectivity {
        Connectivity::Four => {
            spaghetti_4c::process::<B::BoolElem, UnionFind<_>>(img, height, width, &mut stats)
        }
        Connectivity::Eight => {
            // SAFETY: This is validated by `TensorData`
            let img = unsafe { Array2::from_shape_vec_unchecked((height, width), img) };
            spaghetti::process::<B::BoolElem, UnionFind<_>>(img, &mut stats)
        }
    };

    let (data, _) = out.into_raw_vec_and_offset();
    let data = TensorData::new(data, Shape::new([height, width]));
    let labels = Tensor::<B, 2, Int>::from_data(data, &device).into_primitive();
    (labels, stats)
}

pub trait Solver {
    type Label: ElementOrdered;

    fn init(max_labels: usize) -> Self;
    /// Hack to get around mutable borrow limitations on methods
    fn merge(label_1: Self::Label, label_2: Self::Label, solver: &mut Self) -> Self::Label;
    fn new_label(&mut self) -> Self::Label;
    fn flatten(&mut self) -> Self::Label;
    fn get_label(&self, i_label: Self::Label) -> Self::Label;
}

pub(crate) struct UnionFind<I: Element> {
    labels: Vec<I>,
}

impl<I: ElementOrdered> Solver for UnionFind<I> {
    type Label = I;

    fn init(max_labels: usize) -> Self {
        let mut labels = Vec::with_capacity(max_labels);
        labels.push(0.elem());
        Self { labels }
    }

    fn merge(mut label_1: I, mut label_2: I, solver: &mut Self) -> I {
        use Ordering::Less;

        while matches!(solver.labels[label_1.to_usize()].cmp(&label_1), Less) {
            label_1 = solver.labels[label_1.to_usize()];
        }

        while matches!(solver.labels[label_2.to_usize()].cmp(&label_2), Less) {
            label_2 = solver.labels[label_2.to_usize()];
        }

        if matches!(label_1.cmp(&label_2), Less) {
            solver.labels[label_2.to_usize()] = label_1;
            label_1
        } else {
            solver.labels[label_1.to_usize()] = label_2;
            label_2
        }
    }

    fn new_label(&mut self) -> I {
        let len = I::from_elem(self.labels.len());
        self.labels.push(len);
        len
    }

    fn flatten(&mut self) -> I {
        let mut k = 1;
        for i in 1..self.labels.len() {
            if matches!(self.labels[i].cmp(&I::from_elem(i)), Ordering::Less) {
                self.labels[i] = self.labels[self.labels[i].to_usize()];
            } else {
                self.labels[i] = k.elem();
                k += 1;
            }
        }
        k.elem()
    }

    fn get_label(&self, i_label: I) -> I {
        self.labels[i_label.to_usize()]
    }
}

pub trait StatsOp {
    type Label;

    fn init(&mut self, num_labels: usize);
    fn update(&mut self, row: usize, column: usize, label: Self::Label);
    fn finish(&mut self);
}

#[derive(Default)]
struct NoOp<I: Element> {
    _i: PhantomData<I>,
}

impl<I: Element> StatsOp for NoOp<I> {
    type Label = I; // placeholder still required

    fn init(&mut self, _num_labels: usize) {}

    fn update(&mut self, _row: usize, _column: usize, _label: Self::Label) {}

    fn finish(&mut self) {}
}

#[derive(Default, Debug)]
struct ConnectedStatsOp<I: Element> {
    pub area: Vec<I>,
    pub left: Vec<I>,
    pub top: Vec<I>,
    pub right: Vec<I>,
    pub bottom: Vec<I>,
}

impl<I: Element + ElementLimits> StatsOp for ConnectedStatsOp<I> {
    type Label = I;

    fn init(&mut self, num_labels: usize) {
        self.area = vec![0.elem(); num_labels];
        self.left = vec![I::MAX; num_labels];
        self.top = vec![I::MAX; num_labels];
        self.right = vec![0.elem(); num_labels];
        self.bottom = vec![0.elem(); num_labels];
    }

    fn update(&mut self, row: usize, column: usize, label: I) {
        let l = label.to_usize();
        unsafe {
            *self.area.get_unchecked_mut(l) =
                I::from_elem((*self.area.get_unchecked(l)).to_usize() + 1);
            *self.left.get_unchecked_mut(l) =
                I::from_elem((*self.left.get_unchecked(l)).to_usize().min(column));
            *self.top.get_unchecked_mut(l) =
                I::from_elem((*self.top.get_unchecked(l)).to_usize().min(row));
            *self.right.get_unchecked_mut(l) =
                I::from_elem((*self.right.get_unchecked(l)).to_usize().max(column));
            *self.bottom.get_unchecked_mut(l) =
                I::from_elem((*self.bottom.get_unchecked(l)).to_usize().max(row));
        }
    }

    fn finish(&mut self) {
        // Background shouldn't have stats
        self.area[0] = 0.elem();
        self.left[0] = 0.elem();
        self.right[0] = 0.elem();
        self.top[0] = 0.elem();
        self.bottom[0] = 0.elem();
    }
}

fn finalize_stats<B: Backend, I: Element>(
    device: &B::Device,
    stats: ConnectedStatsOp<I>,
) -> ConnectedStatsPrimitive<B> {
    let labels = stats.area.len();

    let into_prim = |data: Vec<I>| {
        let data = TensorData::new(data, Shape::new([labels]));
        Tensor::<B, 1, Int>::from_data(data, device).into_primitive()
    };

    let max_label = {
        let data = TensorData::new(vec![I::from_elem(labels - 1)], Shape::new([1]));
        Tensor::<B, 1, Int>::from_data(data, device).into_primitive()
    };

    ConnectedStatsPrimitive {
        area: into_prim(stats.area),
        left: into_prim(stats.left),
        top: into_prim(stats.top),
        right: into_prim(stats.right),
        bottom: into_prim(stats.bottom),
        max_label,
    }
}

pub fn max_labels(h: usize, w: usize, conn: Connectivity) -> usize {
    match conn {
        Connectivity::Four => (h * w).div_ceil(2) + 1,
        Connectivity::Eight => h.div_ceil(2) * w.div_ceil(2) + 1,
    }
}