use burn_tensor::{
BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend,
ops::BoolTensor,
};
use crate::{
BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions,
VisionBackend,
};
pub trait ConnectedComponents<B: Backend> {
fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int>;
fn connected_components_with_stats(
self,
connectivity: Connectivity,
options: ConnectedStatsOptions,
) -> (Tensor<B, 2, Int>, ConnectedStats<B>);
}
pub trait Morphology<B: Backend, K: TensorKind<B>> {
fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;
}
pub trait MorphologyKind<B: Backend>: BasicOps<B> {
fn erode(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive;
fn dilate(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive;
}
pub trait Nms<B: Backend> {
fn nms(self, scores: Tensor<B, 1, Float>, opts: NmsOptions) -> Tensor<B, 1, Int>;
}
impl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {
fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {
Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))
}
fn connected_components_with_stats(
self,
connectivity: Connectivity,
options: ConnectedStatsOptions,
) -> (Tensor<B, 2, Int>, ConnectedStats<B>) {
let (labels, stats) =
B::connected_components_with_stats(self.into_primitive(), connectivity, options);
(Tensor::from_primitive(labels), stats.into())
}
}
impl<B: VisionBackend, K: MorphologyKind<B>> Morphology<B, K> for Tensor<B, 3, K> {
fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
Tensor::new(K::erode(
self.into_primitive(),
kernel.into_primitive(),
opts,
))
}
fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {
Tensor::new(K::dilate(
self.into_primitive(),
kernel.into_primitive(),
opts,
))
}
}
impl<B: VisionBackend> MorphologyKind<B> for Float {
fn erode(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_erode(tensor, kernel, opts))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts))
}
}
}
fn dilate(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => {
TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts))
}
TensorPrimitive::QFloat(tensor) => {
TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts))
}
}
}
}
impl<B: VisionBackend> MorphologyKind<B> for Int {
fn erode(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
B::int_erode(tensor, kernel, opts)
}
fn dilate(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
B::int_dilate(tensor, kernel, opts)
}
}
impl<B: VisionBackend> MorphologyKind<B> for Bool {
fn erode(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
B::bool_erode(tensor, kernel, opts)
}
fn dilate(
tensor: Self::Primitive,
kernel: BoolTensor<B>,
opts: MorphOptions<B, Self>,
) -> Self::Primitive {
B::bool_dilate(tensor, kernel, opts)
}
}
impl<B: VisionBackend> Nms<B> for Tensor<B, 2> {
fn nms(self, scores: Tensor<B, 1>, options: NmsOptions) -> Tensor<B, 1, Int> {
match (self.into_primitive(), scores.into_primitive()) {
(TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => {
Tensor::<B, 1, Int>::from_primitive(B::nms(boxes, scores, options))
}
_ => todo!("Quantized inputs are not yet supported"),
}
}
}