use crate::error::ScError;
use crate::seed::{init_seeds, perturb};
use crate::{
calculate_grid_interval, distance_lab, distance_s, distance_xy, get_in_bounds,
get_mut_in_bounds, m_div_s,
};
use num_traits::ToPrimitive;
use palette::{cast, FromColor, Lab, Srgb};
#[derive(Debug, Clone)]
struct SlicInfo<T, U> {
pub distances: Vec<T>,
pub labels: Vec<U>,
}
impl<T, U> SlicInfo<T, U> {
#[must_use]
#[inline]
pub fn new() -> Self {
Self::default()
}
}
impl<T, U> Default for SlicInfo<T, U> {
#[inline]
fn default() -> Self {
Self {
distances: Vec::new(),
labels: Vec::new(),
}
}
}
#[derive(Debug, Clone, Copy)]
struct SlicUpdate<T> {
pub data: T,
pub x: f64,
pub y: f64,
pub count: f64,
}
impl<T: Default> SlicUpdate<T> {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl<T: Default> Default for SlicUpdate<T> {
#[inline]
fn default() -> Self {
Self {
data: Default::default(),
x: Default::default(),
y: Default::default(),
count: Default::default(),
}
}
}
pub fn slic_from_bytes(
k: u32,
m: u8,
width: u32,
height: u32,
iter: Option<u8>,
image: &[u8],
) -> Result<Vec<usize>, ScError> {
if usize::try_from(u64::from(width) * u64::from(height))
.or(Err("Invalid image dimensions in SLIC from bytes"))?
!= image.len() / 3
{
return Err(ScError::MismatchedSlicBuffer);
}
let input_buffer = cast::from_component_slice::<Srgb<u8>>(image);
let mut input_lab: Vec<Lab<_, f64>> = Vec::new();
input_lab.try_reserve_exact(input_buffer.len())?;
input_lab.extend(
input_buffer
.iter()
.map(|&c| Lab::from_color(c.into_format())),
);
slic(k, m, width, height, iter, &input_lab)
}
pub fn slic<Wp>(
k: u32,
m: u8,
width: u32,
height: u32,
iter: Option<u8>,
image: &[Lab<Wp, f64>],
) -> Result<Vec<usize>, ScError> {
let m = m.clamp(1, 20);
let iter = iter.unwrap_or(10);
if k == 0 {
return Err(ScError::ZeroSuperpixelCount);
}
if width == 0 || height == 0 {
return Err(ScError::InvalidImageDimension);
}
match u64::from(k).cmp(&(u64::from(width) * u64::from(height))) {
std::cmp::Ordering::Less => {}
std::cmp::Ordering::Equal | std::cmp::Ordering::Greater => {
return Err(ScError::InvalidSuperpixelCount);
}
}
let s = calculate_grid_interval(width, height, k)
.to_u32()
.ok_or(ScError::InvalidGridInterval)?;
if s == 0 {
return Err(ScError::ZeroGridInterval);
}
let m_s_term = m_div_s(f64::from(m), f64::from(s));
let mut clusters = Vec::new();
init_seeds(width, height, s, k, image, &mut clusters)?;
for seed in &mut clusters {
perturb(seed, i64::from(width), i64::from(height), image)?;
}
let mut info = SlicInfo::<f64, usize>::new();
info.distances.try_reserve_exact(image.len())?;
info.labels.try_reserve_exact(image.len())?;
info.distances
.extend((0..image.len()).map(|_| f64::INFINITY));
info.labels.extend((0..image.len()).map(|_| 0));
let mut updates: Vec<SlicUpdate<Lab<Wp, f64>>> = Vec::new();
updates.try_reserve_exact(clusters.len())?;
updates.extend((0..clusters.len()).map(|_| SlicUpdate::new()));
for _ in 0..iter {
for (center_index, center) in clusters.iter().enumerate() {
for y in center.y.saturating_sub(s)..center.y.saturating_add(s).min(height) {
let x_start = center.x.saturating_sub(s);
let x_end = center.x.saturating_add(s).min(width);
let row_start = u64::from(y).saturating_mul(u64::from(width));
for (x, idx) in (x_start..x_end).zip(
row_start.saturating_add(u64::from(center.x.saturating_sub(s)))
..row_start.saturating_add(u64::from(x_end)),
) {
let idx = usize::try_from(idx)
.or(Err("Index out of bounds for finding new neighbors"))?;
if idx < image.len() && idx < info.distances.len() && idx < info.labels.len() {
let color = image[idx];
let distance = distance_s(
m_s_term,
distance_lab(color, center.data),
distance_xy(
(f64::from(x), f64::from(y)),
(f64::from(center.x), f64::from(center.y)),
),
);
if distance < info.distances[idx] {
info.distances[idx] = distance;
info.labels[idx] = center_index;
}
}
}
}
}
let width_usize = usize::try_from(width).or(Err("Could not convert width to usize"))?;
for (y, (row, info_labels)) in image
.chunks_exact(width_usize)
.zip(info.labels.chunks_exact(width_usize))
.enumerate()
{
#[allow(clippy::cast_precision_loss)]
for (x, (&color, &info_label)) in row.iter().zip(info_labels).enumerate() {
if let Some(update) = updates.get_mut(info_label) {
update.data += color;
update.x += x as f64;
update.y += y as f64;
update.count += 1.0;
}
}
}
for (update, center) in updates.iter_mut().zip(&mut clusters) {
if update.count == 0.0 {
continue;
}
center.data = update.data / update.count;
center.x = (update.x / update.count)
.to_u32()
.ok_or("Update X out of bounds")?;
center.y = (update.y / update.count)
.to_u32()
.ok_or("Update Y out of bounds")?;
*update = SlicUpdate::new();
}
}
enforce_connectivity(width, height, s, &mut info.labels)?;
Ok(info.labels)
}
fn enforce_connectivity(
width: u32,
height: u32,
s: u32,
labels: &mut [usize],
) -> Result<(), ScError> {
let width_i = i64::from(width);
let height_i = i64::from(height);
let cluster_threshold =
usize::try_from(u64::from(s).pow(2) / 4).or(Err("Could not convert cluster threshold"))?;
let mut new_labels = Vec::new();
new_labels.try_reserve_exact(labels.len())?;
new_labels.extend((0..labels.len()).map(|_| usize::MAX));
let new_labels = new_labels.as_mut_slice();
let mut label_queue = Vec::new();
label_queue.try_reserve(
usize::try_from(u64::from(s).pow(2).saturating_mul(8))
.or(Err("Could not calculate label set size"))?,
)?;
let neighbors = [(-1, 0), (0, -1), (1, 0), (0, 1)];
let mut neighbor_label = 0;
let mut new_label = 0_usize;
let width_usize = usize::try_from(width).or(Err(
"Could not convert width to usize in enforce_connectivity",
))?;
for (y, label_row) in labels.chunks_exact(width_usize).enumerate() {
for (x, &old_label) in label_row.iter().enumerate() {
let idx_usize = y.saturating_mul(width_usize).saturating_add(x);
if new_labels.get(idx_usize) == Some(&usize::MAX) {
*new_labels
.get_mut(idx_usize)
.ok_or("Label index out of bounds")? = new_label;
for &neighbor in &neighbors {
let neighbor_x = (x as i64) + neighbor.0;
let neighbor_y = (y as i64) + neighbor.1;
if let Some(l) =
get_in_bounds(width_i, height_i, neighbor_x, neighbor_y, new_labels)
{
if *l != usize::MAX {
neighbor_label = *l;
}
}
}
label_queue.clear();
label_queue.push(((x as i64), (y as i64)));
let mut label_queue_idx = 0;
let mut label_count = 1_usize;
while label_queue_idx < label_count {
for &neighbor in &neighbors {
let entry = label_queue
.get(label_queue_idx)
.ok_or("Could not get label")?;
let new_vx = entry.0 + neighbor.0;
let new_vy = entry.1 + neighbor.1;
if let (Some(old_visit_label), Some(new_visit_label)) = (
get_in_bounds(width_i, height_i, new_vx, new_vy, labels),
get_mut_in_bounds(width_i, height_i, new_vx, new_vy, new_labels),
) {
if *old_visit_label == old_label && *new_visit_label == usize::MAX {
if label_queue.capacity() == label_queue.len() {
label_queue.try_reserve(1)?;
}
label_queue.push((new_vx, new_vy));
*new_visit_label = new_label;
label_count = label_count.saturating_add(1);
}
}
}
label_queue_idx = label_queue_idx.saturating_add(1);
}
if label_count <= cluster_threshold {
for &(l_x, l_y) in &label_queue {
*get_mut_in_bounds(width_i, height_i, l_x, l_y, new_labels)
.ok_or("New label index out of bounds")? = neighbor_label;
}
continue;
}
new_label = new_label.saturating_add(1);
}
}
}
labels.copy_from_slice(new_labels);
Ok(())
}