use crate::{
Point,
backends::cpu::{self, MorphOp, morph},
};
use bon::Builder;
use burn_tensor::{
Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
backend::Backend,
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Connectivity {
Four,
Eight,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ConnectedStatsOptions {
pub bounds_enabled: bool,
pub max_label_enabled: bool,
pub compact_labels: bool,
}
#[derive(Clone, Debug, Builder)]
pub struct MorphOptions<B: Backend, K: TensorKind<B>> {
pub anchor: Option<Point>,
#[builder(default = 1)]
pub iterations: usize,
#[builder(default)]
pub border_type: BorderType,
pub border_value: Option<Tensor<B, 1, K>>,
}
impl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {
fn default() -> Self {
Self {
anchor: Default::default(),
iterations: 1,
border_type: Default::default(),
border_value: Default::default(),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
pub enum BorderType {
#[default]
Constant,
Replicate,
Reflect,
Reflect101,
Wrap,
}
#[derive(Clone, Debug)]
pub struct ConnectedStats<B: Backend> {
pub area: Tensor<B, 1, Int>,
pub top: Tensor<B, 1, Int>,
pub left: Tensor<B, 1, Int>,
pub right: Tensor<B, 1, Int>,
pub bottom: Tensor<B, 1, Int>,
pub max_label: Tensor<B, 1, Int>,
}
pub struct ConnectedStatsPrimitive<B: Backend> {
pub area: IntTensor<B>,
pub left: IntTensor<B>,
pub top: IntTensor<B>,
pub right: IntTensor<B>,
pub bottom: IntTensor<B>,
pub max_label: IntTensor<B>,
}
impl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {
fn from(value: ConnectedStatsPrimitive<B>) -> Self {
ConnectedStats {
area: Tensor::from_primitive(value.area),
top: Tensor::from_primitive(value.top),
left: Tensor::from_primitive(value.left),
right: Tensor::from_primitive(value.right),
bottom: Tensor::from_primitive(value.bottom),
max_label: Tensor::from_primitive(value.max_label),
}
}
}
impl<B: Backend> ConnectedStats<B> {
pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {
ConnectedStatsPrimitive {
area: self.area.into_primitive(),
top: self.top.into_primitive(),
left: self.left.into_primitive(),
right: self.right.into_primitive(),
bottom: self.bottom.into_primitive(),
max_label: self.max_label.into_primitive(),
}
}
}
impl Default for ConnectedStatsOptions {
fn default() -> Self {
Self::all()
}
}
impl ConnectedStatsOptions {
pub fn none() -> Self {
Self {
bounds_enabled: false,
max_label_enabled: false,
compact_labels: false,
}
}
pub fn all() -> Self {
Self {
bounds_enabled: true,
max_label_enabled: true,
compact_labels: true,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct NmsOptions {
pub iou_threshold: f32,
pub score_threshold: f32,
pub max_output_boxes: usize,
}
impl Default for NmsOptions {
fn default() -> Self {
Self {
iou_threshold: 0.5,
score_threshold: 0.0,
max_output_boxes: 0,
}
}
}
pub trait VisionBackend:
BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
{
}
pub trait BoolVisionOps: Backend {
fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {
cpu::connected_components::<Self>(img, connectivity)
}
fn connected_components_with_stats(
img: BoolTensor<Self>,
connectivity: Connectivity,
opts: ConnectedStatsOptions,
) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {
cpu::connected_components_with_stats(img, connectivity, opts)
}
fn bool_erode(
input: BoolTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Bool>,
) -> BoolTensor<Self> {
let input = Tensor::<Self, 3, Bool>::from_primitive(input);
morph(input, kernel, MorphOp::Erode, opts).into_primitive()
}
fn bool_dilate(
input: BoolTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Bool>,
) -> BoolTensor<Self> {
let input = Tensor::<Self, 3, Bool>::from_primitive(input);
morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
}
}
pub trait IntVisionOps: Backend {
fn int_erode(
input: IntTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Int>,
) -> IntTensor<Self> {
let input = Tensor::<Self, 3, Int>::from_primitive(input);
morph(input, kernel, MorphOp::Erode, opts).into_primitive()
}
fn int_dilate(
input: IntTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Int>,
) -> IntTensor<Self> {
let input = Tensor::<Self, 3, Int>::from_primitive(input);
morph(input, kernel, MorphOp::Dilate, opts).into_primitive()
}
}
pub trait FloatVisionOps: Backend {
fn float_erode(
input: FloatTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Float>,
) -> FloatTensor<Self> {
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
morph(input, kernel, MorphOp::Erode, opts)
.into_primitive()
.tensor()
}
fn float_dilate(
input: FloatTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Float>,
) -> FloatTensor<Self> {
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));
morph(input, kernel, MorphOp::Dilate, opts)
.into_primitive()
.tensor()
}
fn nms(
boxes: FloatTensor<Self>,
scores: FloatTensor<Self>,
options: NmsOptions,
) -> IntTensor<Self> {
let boxes = Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(boxes));
let scores = Tensor::<Self, 1>::from_primitive(TensorPrimitive::Float(scores));
cpu::nms::<Self>(boxes, scores, options).into_primitive()
}
}
pub trait QVisionOps: Backend {
fn q_erode(
input: QuantizedTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Float>,
) -> QuantizedTensor<Self> {
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {
TensorPrimitive::QFloat(tensor) => tensor,
_ => unreachable!(),
}
}
fn q_dilate(
input: QuantizedTensor<Self>,
kernel: BoolTensor<Self>,
opts: MorphOptions<Self, Float>,
) -> QuantizedTensor<Self> {
let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));
match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {
TensorPrimitive::QFloat(tensor) => tensor,
_ => unreachable!(),
}
}
}