use std::{
fmt::{Display, Formatter},
str::FromStr,
};
use crate::{
error::Error,
math::{
clustering::{Cluster, ClusteringAlgorithm, DBSCANPlusPlus, KMeans, DBSCAN, SLIC, SNIC},
DistanceMetric,
FloatNumber,
Point,
},
};
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Default, Clone, PartialEq)]
pub enum Algorithm {
KMeans,
#[default]
DBSCAN,
DBSCANpp,
}
impl Algorithm {
pub(crate) fn cluster<T>(
&self,
width: u32,
height: u32,
pixels: &[Point<T, 5>],
) -> Result<Vec<Cluster<T, 5>>, Error>
where
T: FloatNumber,
{
match self {
Self::KMeans => kmeans(width as usize, height as usize, pixels),
Self::DBSCAN => dbscan(pixels),
Self::DBSCANpp => cluster_with_dbscanpp(pixels),
}
}
}
impl FromStr for Algorithm {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"kmeans" => Ok(Self::KMeans),
"dbscan" => Ok(Self::DBSCAN),
"dbscan++" => Ok(Self::DBSCANpp),
_ => Err(Error::UnsupportedAlgorithm {
name: s.to_string(),
}),
}
}
}
impl Display for Algorithm {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::KMeans => write!(f, "kmeans"),
Self::DBSCAN => write!(f, "dbscan"),
Self::DBSCANpp => write!(f, "dbscan++"),
}
}
}
const KMEANS_CLUSTER_COUNT: usize = 128;
const KMEANS_MAX_ITER: usize = 10;
const KMEANS_TOLERANCE: f64 = 1e-3;
fn kmeans<T>(
width: usize,
height: usize,
pixels: &[Point<T, 5>],
) -> Result<Vec<Cluster<T, 5>>, Error>
where
T: FloatNumber,
{
let clustering = KMeans::new(
(width, height),
KMEANS_CLUSTER_COUNT,
KMEANS_MAX_ITER,
T::from_f64(KMEANS_TOLERANCE),
DistanceMetric::SquaredEuclidean,
)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
clustering
.fit(pixels)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})
}
const DBSCAN_MIN_POINTS: usize = 16;
const DBSCAN_EPSILON: f64 = 16e-4;
fn dbscan<T>(pixels: &[Point<T, 5>]) -> Result<Vec<Cluster<T, 5>>, Error>
where
T: FloatNumber,
{
let clustering = DBSCAN::new(
DBSCAN_MIN_POINTS,
T::from_f64(DBSCAN_EPSILON),
DistanceMetric::SquaredEuclidean,
)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
clustering
.fit(pixels)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})
}
const DBSCANPP_PROBABILITY: f64 = 0.1;
const DBSCANPP_MIN_POINTS: usize = 16;
const DBSCANPP_EPSILON: f64 = 16e-4;
fn cluster_with_dbscanpp<T>(pixels: &[Point<T, 5>]) -> Result<Vec<Cluster<T, 5>>, Error>
where
T: FloatNumber,
{
let clustering = DBSCANPlusPlus::new(
T::from_f64(DBSCANPP_PROBABILITY),
DBSCANPP_MIN_POINTS,
T::from_f64(DBSCANPP_EPSILON),
DistanceMetric::SquaredEuclidean,
)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
clustering
.fit(pixels)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})
}
const SLIC_SEGMENTS: usize = 128;
const SLIC_COMPACTNESS: f64 = 0.0225; const SLIC_MAX_ITER: usize = 10;
const SLIC_TOLERANCE: f64 = 1e-3;
#[allow(dead_code)]
fn slic<T>(width: usize, height: usize, pixels: &[Point<T, 5>]) -> Result<Vec<Cluster<T, 5>>, Error>
where
T: FloatNumber,
{
let clustering = SLIC::new(
(width, height),
SLIC_SEGMENTS,
T::from_f64(SLIC_COMPACTNESS),
SLIC_MAX_ITER,
T::from_f64(SLIC_TOLERANCE),
DistanceMetric::SquaredEuclidean,
)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
clustering
.fit(pixels)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})
}
const SNIC_SEGMENTS: usize = 128;
#[allow(dead_code)]
fn snic<T>(width: usize, height: usize, pixels: &[Point<T, 5>]) -> Result<Vec<Cluster<T, 5>>, Error>
where
T: FloatNumber,
{
let clustering = SNIC::new(
(width, height),
SNIC_SEGMENTS,
DistanceMetric::SquaredEuclidean,
)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
clustering
.fit(pixels)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[must_use]
fn empty_points() -> Vec<Point<f64, 5>> {
Vec::new()
}
#[rstest]
#[case::kmeans("kmeans", Algorithm::KMeans)]
#[case::dbscan("dbscan", Algorithm::DBSCAN)]
#[case::dbscanpp("dbscan++", Algorithm::DBSCANpp)]
#[case::kmeans_upper("KMEANS", Algorithm::KMeans)]
#[case::dbscan_upper("DBSCAN", Algorithm::DBSCAN)]
#[case::dbscanpp_upper("DBSCAN++", Algorithm::DBSCANpp)]
#[case::kmeans_capitalized("Kmeans", Algorithm::KMeans)]
#[case::dbscan_capitalized("Dbscan", Algorithm::DBSCAN)]
#[case::dbscanpp_capitalized("Dbscan++", Algorithm::DBSCANpp)]
fn test_from_str(#[case] input: &str, #[case] expected: Algorithm) {
let actual = Algorithm::from_str(input).unwrap();
assert_eq!(actual, expected);
}
#[rstest]
#[case::empty("")]
#[case::invalid("unknown")]
fn test_from_str_error(#[case] input: &str) {
let actual = Algorithm::from_str(input);
assert!(actual.is_err());
assert_eq!(
actual.unwrap_err().to_string(),
format!("Unsupported algorithm specified: '{}'", input)
);
}
#[rstest]
#[case::kmeans(Algorithm::KMeans, "kmeans")]
#[case::dbscan(Algorithm::DBSCAN, "dbscan")]
#[case::dbscanpp(Algorithm::DBSCANpp, "dbscan++")]
fn test_fmt(#[case] algorithm: Algorithm, #[case] expected: &str) {
let actual = format!("{}", algorithm);
assert_eq!(actual, expected);
}
#[test]
fn test_dbscan_empty() {
let pixels = empty_points();
let actual = dbscan(&pixels);
assert!(actual.is_err());
}
#[test]
fn test_dbscanpp_empty() {
let pixels = empty_points();
let actual = cluster_with_dbscanpp(&pixels);
assert!(actual.is_err());
}
#[test]
fn test_kmeans_empty() {
let pixels = empty_points();
let actual = kmeans(192, 128, &pixels);
assert!(actual.is_err());
}
}