#![allow(
unreachable_code,
clippy::collapsible_else_if,
clippy::if_same_then_else
)]
use std::cmp::Ordering;
use burn_tensor::{Element, ElementConversion};
use ndarray::{Array2, Axis, s};
#[allow(non_snake_case)]
mod Spaghetti_forest_labels;
pub(crate) use Spaghetti_forest_labels::*;
use crate::Connectivity;
use super::{Solver, StatsOp, max_labels};
pub fn process<I: Element, B: Element, LabelsSolver: Solver<I>>(
img_arr: Array2<B>,
stats: &mut impl StatsOp<I>,
) -> Array2<I> {
let (h, w) = img_arr.dim();
let mut img_labels_arr = Array2::<I>::default(img_arr.raw_dim());
let img = img_arr.as_ptr();
let e_rows = h as u32 & 0xfffffffe;
let o_rows = h % 2 == 1;
let e_cols = w as u32 & 0xfffffffe;
let o_cols = w % 2 == 1;
let img_labels = img_labels_arr.as_mut_ptr();
let mut solver = LabelsSolver::init(max_labels(h, w, Connectivity::Eight));
let solver = &mut solver;
let w = w as i32;
unsafe {
if h == 1 {
let r = 0;
let img_row00 = img.add(r * w as usize);
let img_labels_row00 = img_labels.add(r * w as usize);
let mut c = -2i32;
let entry = singleLabels::sl_tree_0;
include!("Spaghetti_single_line_forest_code.rs");
} else {
{
let r = 0;
let img_row00 = img.add(r * w as usize);
let img_row01 = img.add((r + 1) * w as usize);
let img_labels_row00 = img_labels.add(r * w as usize);
let mut c = -2i32;
let entry = firstLabels::fl_tree_0;
include!("Spaghetti_first_line_forest_code.rs");
}
for r in (2..e_rows as usize).step_by(2) {
let img_row00 = img.add(r * w as usize);
let img_row12 = img.add((r - 2) * w as usize);
let img_row11 = img.add((r - 1) * w as usize);
let img_row01 = img.add((r + 1) * w as usize);
let img_labels_row00 = img_labels.add(r * w as usize);
let img_labels_row12 = img_labels.add((r - 2) * w as usize);
let mut c = -2;
let entry = centerLabels::cl_tree_0;
include!("Spaghetti_center_line_forest_code.rs");
}
if o_rows {
let r = h - 1;
let img_row00 = img.add(r * w as usize);
let img_row12 = img.add((r - 2) * w as usize);
let img_row11 = img.add((r - 1) * w as usize);
let img_labels_row00 = img_labels.add(r * w as usize);
let img_labels_row12 = img_labels.add((r - 2) * w as usize);
let mut c = -2;
let entry = lastLabels::ll_tree_0;
include!("Spaghetti_last_line_forest_code.rs");
}
}
}
let n_labels = solver.flatten();
stats.init(n_labels.to_usize());
let img = img_arr;
let mut img_labels = img_labels_arr;
for r in (0..e_rows as usize).step_by(2) {
let img_row00 = img.index_axis(Axis(0), r);
let img_row01 = img.index_axis(Axis(0), r + 1);
let (mut img_labels_row00, mut img_labels_row01) =
img_labels.multi_slice_mut((s![r, ..], s![r + 1, ..]));
for c in (0..e_cols as usize).step_by(2) {
let mut i_label = img_labels_row00[c];
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
i_label = solver.get_label(i_label);
if img_row00[c].to_u8() > 0 {
img_labels_row00[c] = i_label;
stats.update(r, c, i_label);
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
}
if img_row00[c + 1].to_u8() > 0 {
img_labels_row00[c + 1] = i_label;
stats.update(r, c + 1, i_label);
} else {
img_labels_row00[c + 1] = 0.elem();
stats.update(r, c + 1, 0.elem());
}
if img_row01[c].to_u8() > 0 {
img_labels_row01[c] = i_label;
stats.update(r + 1, c, i_label);
} else {
img_labels_row01[c] = 0.elem();
stats.update(r + 1, c, 0.elem());
}
if img_row01[c + 1].to_u8() > 0 {
img_labels_row01[c + 1] = i_label;
stats.update(r + 1, c + 1, i_label);
} else {
img_labels_row01[c + 1] = 0.elem();
stats.update(r + 1, c + 1, 0.elem());
}
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
img_labels_row00[c + 1] = 0.elem();
stats.update(r, c + 1, 0.elem());
img_labels_row01[c] = 0.elem();
stats.update(r + 1, c, 0.elem());
img_labels_row01[c + 1] = 0.elem();
stats.update(r + 1, c + 1, 0.elem());
}
}
if o_cols {
let c = e_cols as usize;
let mut i_label = img_labels_row00[c];
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
i_label = solver.get_label(i_label);
if img_row00[c].to_u8() > 0 {
img_labels_row00[c] = i_label;
stats.update(r, c, i_label);
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
}
if img_row01[c].to_u8() > 0 {
img_labels_row01[c] = i_label;
stats.update(r + 1, c, i_label);
} else {
img_labels_row01[c] = 0.elem();
stats.update(r + 1, c, 0.elem());
}
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
img_labels_row01[c] = 0.elem();
stats.update(r + 1, c, 0.elem());
}
}
}
if o_rows {
let r = e_rows as usize;
let img_row00 = img.index_axis(Axis(0), r);
let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]);
for c in (0..e_cols as usize).step_by(2) {
let mut i_label = img_labels_row00[c];
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
i_label = solver.get_label(i_label);
if img_row00[c].to_u8() > 0 {
img_labels_row00[c] = i_label;
stats.update(r, c, i_label);
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
}
if img_row00[c + 1].to_u8() > 0 {
img_labels_row00[c + 1] = i_label;
stats.update(r, c + 1, i_label);
} else {
img_labels_row00[c + 1] = 0.elem();
stats.update(r, c + 1, 0.elem());
}
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
img_labels_row00[c + 1] = 0.elem();
stats.update(r, c + 1, 0.elem());
}
}
if o_cols {
let c = e_cols as usize;
let mut i_label = img_labels_row00[c];
if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {
i_label = solver.get_label(i_label);
if img_row00[c].to_u8() > 0 {
img_labels_row00[c] = i_label;
stats.update(r, c, i_label);
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, 0.elem());
}
} else {
img_labels_row00[c] = 0.elem();
stats.update(r, c, i_label);
}
}
}
stats.finish();
img_labels
}