use std::{
fmt::{Display, Formatter},
str::FromStr,
};
use crate::{
error::Error,
image::{
segmentation::{
DbscanSegmentation,
FastDbscanSegmentation,
KmeansSegmentation,
LabelImage,
Segmentation,
SlicSegmentation,
SnicSegmentation,
},
Pixel,
},
math::{DistanceMetric, FloatNumber},
Filter,
ImageData,
};
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Default, Clone, PartialEq)]
pub enum Algorithm {
KMeans,
#[default]
DBSCAN,
DBSCANpp,
SLIC,
SNIC,
}
impl Algorithm {
const SEGMENTS: usize = 128;
const KMEANS_MAX_ITER: usize = 50;
const KMEANS_TOLERANCE: f64 = 1e-3;
const DBSCAN_MIN_POINTS: usize = 10;
const DBSCAN_EPSILON: f64 = 0.03;
const FASTDBSCAN_PROBABILITY: f64 = 0.1;
const FASTDBSCAN_MIN_POINTS: usize = 10;
const FASTDBSCAN_EPSILON: f64 = 0.04;
const SLIC_COMPACTNESS: f64 = 0.0225;
const SLIC_MAX_ITER: usize = 10;
const SLIC_TOLERANCE: f64 = 1e-3;
pub(crate) fn segment<T, F>(
&self,
image_data: &ImageData,
filter: &F,
) -> Result<LabelImage<T>, Error>
where
T: FloatNumber,
F: Filter,
{
match self {
Self::KMeans => segment_internal(image_data, filter, || {
KmeansSegmentation::builder()
.segments(Self::SEGMENTS)
.max_iter(Self::KMEANS_MAX_ITER)
.tolerance(T::from_f64(Self::KMEANS_TOLERANCE))
.metric(DistanceMetric::SquaredEuclidean)
.build()
}),
Self::DBSCAN => segment_internal(image_data, filter, || {
DbscanSegmentation::builder()
.segments(Self::SEGMENTS)
.min_pixels(Self::DBSCAN_MIN_POINTS)
.epsilon(T::from_f64(Self::DBSCAN_EPSILON.powi(2))) .metric(DistanceMetric::SquaredEuclidean)
.build()
}),
Self::DBSCANpp => segment_internal(image_data, filter, || {
FastDbscanSegmentation::builder()
.min_pixels(Self::FASTDBSCAN_MIN_POINTS)
.probability(T::from_f64(Self::FASTDBSCAN_PROBABILITY))
.epsilon(T::from_f64(Self::FASTDBSCAN_EPSILON).powi(2))
.metric(DistanceMetric::SquaredEuclidean)
.build()
}),
Self::SLIC => segment_internal(image_data, filter, || {
SlicSegmentation::builder()
.segments(Self::SEGMENTS)
.max_iter(Self::SLIC_MAX_ITER)
.compactness(T::from_f64(Self::SLIC_COMPACTNESS))
.tolerance(T::from_f64(Self::SLIC_TOLERANCE))
.metric(DistanceMetric::SquaredEuclidean)
.build()
}),
Self::SNIC => segment_internal(image_data, filter, || {
SnicSegmentation::<T>::builder()
.segments(Self::SEGMENTS)
.metric(DistanceMetric::Euclidean)
.build()
}),
}
}
}
impl FromStr for Algorithm {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"kmeans" => Ok(Self::KMeans),
"dbscan" => Ok(Self::DBSCAN),
"dbscan++" => Ok(Self::DBSCANpp),
"slic" => Ok(Self::SLIC),
"snic" => Ok(Self::SNIC),
_ => Err(Error::UnsupportedAlgorithm {
name: s.to_string(),
}),
}
}
}
impl Display for Algorithm {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::KMeans => write!(f, "kmeans"),
Self::DBSCAN => write!(f, "dbscan"),
Self::DBSCANpp => write!(f, "dbscan++"),
Self::SLIC => write!(f, "slic"),
Self::SNIC => write!(f, "snic"),
}
}
}
fn segment_internal<T, F, B, S, E>(
image_data: &ImageData,
filter: &F,
builder: B,
) -> Result<LabelImage<T>, Error>
where
T: FloatNumber,
F: Filter,
B: FnOnce() -> Result<S, E>,
S: Segmentation<T, Err = E>,
E: Display,
{
let segmentation = builder().map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
let width = image_data.width() as usize;
let height = image_data.height() as usize;
let (pixels, mask) = collect_pixels_and_mask(image_data, filter);
segmentation
.segment_with_mask(width, height, &pixels, &mask)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})
}
#[must_use]
fn collect_pixels_and_mask<T, F>(image_data: &ImageData, filter: &F) -> (Vec<Pixel<T>>, Vec<bool>)
where
T: FloatNumber,
F: Filter,
{
let width = image_data.width() as usize;
let height = image_data.height() as usize;
let (pixels, mask) = image_data.pixels_with_filter(filter).fold(
(
Vec::with_capacity(width * height),
Vec::with_capacity(width * height),
),
|(mut pixels, mut mask), (p, m)| {
pixels.push(p);
mask.push(m);
(pixels, mask)
},
);
(pixels, mask)
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::Rgba;
#[rstest]
#[case::kmeans("kmeans", Algorithm::KMeans)]
#[case::dbscan("dbscan", Algorithm::DBSCAN)]
#[case::dbscanpp("dbscan++", Algorithm::DBSCANpp)]
#[case::kmeans_upper("KMEANS", Algorithm::KMeans)]
#[case::dbscan_upper("DBSCAN", Algorithm::DBSCAN)]
#[case::dbscanpp_upper("DBSCAN++", Algorithm::DBSCANpp)]
#[case::kmeans_capitalized("Kmeans", Algorithm::KMeans)]
#[case::dbscan_capitalized("Dbscan", Algorithm::DBSCAN)]
#[case::dbscanpp_capitalized("Dbscan++", Algorithm::DBSCANpp)]
fn test_from_str(#[case] input: &str, #[case] expected: Algorithm) {
let actual = Algorithm::from_str(input).unwrap();
assert_eq!(actual, expected);
}
#[rstest]
#[case::empty("")]
#[case::invalid("unknown")]
fn test_from_str_error(#[case] input: &str) {
let actual = Algorithm::from_str(input);
assert!(actual.is_err());
assert_eq!(
actual.unwrap_err().to_string(),
format!("Unsupported algorithm specified: '{}'", input)
);
}
#[rstest]
#[case::kmeans(Algorithm::KMeans, "kmeans")]
#[case::dbscan(Algorithm::DBSCAN, "dbscan")]
#[case::dbscanpp(Algorithm::DBSCANpp, "dbscan++")]
fn test_fmt(#[case] algorithm: Algorithm, #[case] expected: &str) {
let actual = format!("{}", algorithm);
assert_eq!(actual, expected);
}
#[rstest]
#[case::kmeans(Algorithm::KMeans)]
#[case::dbscan(Algorithm::DBSCAN)]
#[case::dbscanpp(Algorithm::DBSCANpp)]
#[case::slic(Algorithm::SLIC)]
#[case::snic(Algorithm::SNIC)]
fn test_segment_empty(#[case] algorithm: Algorithm) {
let pixels: Vec<_> = Vec::new();
let image_data = ImageData::new(0, 0, &pixels).expect("Failed to create empty image data");
let actual = algorithm.segment(&image_data, &|rgba: &Rgba| rgba[0] != 0);
assert!(actual.is_ok());
let label_image: LabelImage<f64> = actual.unwrap();
assert_eq!(label_image.width(), 0);
assert_eq!(label_image.height(), 0);
}
}