use crate::{
image::{
segmentation::{
kmeans::KmeansError,
label::{Builder as SegmentBuilder, LabelImage},
seed::SeedGenerator,
segment::SegmentMetadata,
Segmentation,
},
Pixel,
},
math::{
neighbors::{kdtree::KdTreeSearch, NeighborSearch},
DistanceMetric,
FloatNumber,
},
};
#[derive(Debug, PartialEq)]
pub struct KmeansSegmentation<T>
where
T: FloatNumber,
{
segments: usize,
max_iter: usize,
tolerance: T,
generator: SeedGenerator,
metric: DistanceMetric,
}
impl<T> KmeansSegmentation<T>
where
T: FloatNumber,
{
const DEFAULT_SEGMENTS: usize = 64;
const DEFAULT_MAX_ITER: usize = 100;
const DEFAULT_TOLERANCE: f64 = 1e-4;
#[must_use]
pub fn builder() -> Builder<T> {
Builder::default()
}
#[must_use]
fn iterate(
&self,
pixels: &[Pixel<T>],
mask: &[bool],
centers: &mut [Pixel<T>],
builder: &mut SegmentBuilder<T>,
) -> bool {
builder.iter_mut().for_each(SegmentMetadata::clear);
let center_search = KdTreeSearch::build(centers, self.metric, 16);
for (index, pixel) in pixels.iter().enumerate() {
if !mask[index] {
continue;
}
if let Some(nearest) = center_search.search_nearest(pixel) {
builder.get_mut(&nearest.index).insert(index, pixel);
}
}
let mut converged = true;
for segment in builder.iter() {
let Some(old_center) = centers.get_mut(segment.label()) else {
continue;
};
let new_center = segment.center();
let diff = self.metric.measure(old_center, new_center);
if diff > self.tolerance {
converged = false;
}
*old_center = *new_center;
}
converged
}
}
impl<T> Segmentation<T> for KmeansSegmentation<T>
where
T: FloatNumber,
{
type Err = KmeansError<T>;
fn segment_with_mask(
&self,
width: usize,
height: usize,
pixels: &[Pixel<T>],
mask: &[bool],
) -> Result<LabelImage<T>, Self::Err> {
if width * height != pixels.len() {
return Err(KmeansError::UnexpectedLength {
actual: pixels.len(),
expected: width * height,
});
}
let mut centers: Vec<_> = self
.generator
.generate(width, height, pixels, mask, self.segments)
.iter()
.map(|&seed| pixels[seed])
.collect();
let mut builder = LabelImage::builder(width, height);
for _ in 0..self.max_iter {
if self.iterate(pixels, mask, &mut centers, &mut builder) {
break;
}
}
Ok(builder.build())
}
}
#[derive(Debug, PartialEq)]
pub struct Builder<T>
where
T: FloatNumber,
{
segments: usize,
max_iter: usize,
tolerance: T,
generator: SeedGenerator,
metric: DistanceMetric,
}
impl<T> Builder<T>
where
T: FloatNumber,
{
#[must_use]
pub fn segments(mut self, segments: usize) -> Self {
self.segments = segments;
self
}
#[must_use]
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn tolerance(mut self, tolerance: T) -> Self {
self.tolerance = tolerance;
self
}
#[allow(dead_code)]
#[must_use]
pub fn generator(mut self, generator: SeedGenerator) -> Self {
self.generator = generator;
self
}
#[must_use]
pub fn metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
pub fn build(self) -> Result<KmeansSegmentation<T>, KmeansError<T>> {
if self.segments == 0 {
return Err(KmeansError::InvalidSegments);
}
if self.max_iter == 0 {
return Err(KmeansError::InvalidIterations);
}
if self.tolerance <= T::zero() || self.tolerance.is_nan() {
return Err(KmeansError::InvalidTolerance(self.tolerance));
}
Ok(KmeansSegmentation {
segments: self.segments,
max_iter: self.max_iter,
tolerance: self.tolerance,
generator: self.generator,
metric: self.metric,
})
}
}
impl<T> Default for Builder<T>
where
T: FloatNumber,
{
fn default() -> Self {
Self {
segments: KmeansSegmentation::<T>::DEFAULT_SEGMENTS,
max_iter: KmeansSegmentation::<T>::DEFAULT_MAX_ITER,
tolerance: T::from_f64(KmeansSegmentation::<T>::DEFAULT_TOLERANCE),
generator: SeedGenerator::default(),
metric: DistanceMetric::default(),
}
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::ImageData;
#[test]
fn test_builder() {
let actual = KmeansSegmentation::<f64>::builder();
assert_eq!(actual, Builder::default());
}
#[test]
fn test_builder_build() {
let actual = KmeansSegmentation::<f64>::builder()
.segments(10)
.max_iter(100)
.tolerance(1e-4)
.generator(SeedGenerator::RegularGrid)
.metric(DistanceMetric::SquaredEuclidean)
.build();
assert!(actual.is_ok());
let segmentation = actual.unwrap();
assert_eq!(
segmentation,
KmeansSegmentation {
segments: 10,
max_iter: 100,
tolerance: 1e-4,
generator: SeedGenerator::RegularGrid,
metric: DistanceMetric::SquaredEuclidean,
}
);
}
#[rstest]
#[case(0, 25, 1e-4, KmeansError::InvalidSegments)]
#[case(48, 0, 1e-4, KmeansError::InvalidIterations)]
#[case(48, 25, -1e-4, KmeansError::InvalidTolerance(-1e-4))]
fn test_builder_build_invalid_parameters(
#[case] segments: usize,
#[case] max_iter: usize,
#[case] tolerance: f64,
#[case] expected: KmeansError<f64>,
) {
let actual = KmeansSegmentation::builder()
.segments(segments)
.max_iter(max_iter)
.tolerance(tolerance)
.build();
assert!(actual.is_err());
let error = actual.unwrap_err();
assert_eq!(error, expected);
}
#[test]
fn test_builder_build_invalid_tolerance_nan() {
let actual = KmeansSegmentation::<f64>::builder()
.tolerance(f64::NAN)
.build();
assert!(actual.is_err());
let error = actual.unwrap_err();
assert_eq!(
error.to_string(),
"Tolerance must be greater than zero and not NaN: NaN"
);
}
#[test]
#[cfg(feature = "image")]
fn test_segment() {
let image_data = ImageData::load("../../gfx/flags/za.png").unwrap();
let segmentation = KmeansSegmentation::builder()
.segments(24)
.max_iter(5)
.tolerance(1e-4)
.build()
.unwrap();
let width = image_data.width() as usize;
let height = image_data.height() as usize;
let pixels: Vec<_> = image_data.pixels().collect();
let actual = segmentation.segment(width, height, &pixels);
assert!(actual.is_ok());
let label_image = actual.unwrap();
let segments: Vec<_> = label_image.segments().collect();
assert_eq!(segments.len(), 24);
}
}