#![allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ColorCluster {
pub r: f64,
pub g: f64,
pub b: f64,
}
impl ColorCluster {
#[must_use]
pub fn new(r: f64, g: f64, b: f64) -> Self {
Self { r, g, b }
}
#[must_use]
pub fn distance_sq(&self, other: &Self) -> f64 {
let dr = self.r - other.r;
let dg = self.g - other.g;
let db = self.b - other.b;
dr * dr + dg * dg + db * db
}
#[must_use]
pub fn distance(&self, other: &Self) -> f64 {
self.distance_sq(other).sqrt()
}
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn to_rgb_u8(&self) -> [u8; 3] {
[
self.r.clamp(0.0, 255.0).round() as u8,
self.g.clamp(0.0, 255.0).round() as u8,
self.b.clamp(0.0, 255.0).round() as u8,
]
}
#[must_use]
pub fn luminance(&self) -> f64 {
0.299 * self.r + 0.587 * self.g + 0.114 * self.b
}
}
#[derive(Debug, Clone)]
pub struct ClusterResult {
pub centroids: Vec<ColorCluster>,
pub counts: Vec<usize>,
pub inertia: f64,
pub iterations: usize,
}
impl ClusterResult {
#[must_use]
pub fn dominant_color(&self) -> Option<ColorCluster> {
self.counts
.iter()
.enumerate()
.max_by_key(|(_, &c)| c)
.map(|(i, _)| self.centroids[i])
}
#[must_use]
pub fn sorted_by_count(&self) -> Vec<(ColorCluster, usize)> {
let mut pairs: Vec<(ColorCluster, usize)> = self
.centroids
.iter()
.copied()
.zip(self.counts.iter().copied())
.collect();
pairs.sort_by(|a, b| b.1.cmp(&a.1));
pairs
}
}
#[derive(Debug)]
pub struct KMeansColorCluster {
pub k: usize,
pub max_iter: usize,
pub tolerance: f64,
}
impl KMeansColorCluster {
#[must_use]
pub fn new(k: usize, max_iter: usize, tolerance: f64) -> Self {
Self {
k: k.max(1),
max_iter: max_iter.max(1),
tolerance,
}
}
#[must_use]
pub fn k_means(&self, pixels: &[u8]) -> Option<ClusterResult> {
if pixels.is_empty() || pixels.len() % 3 != 0 {
return None;
}
let n = pixels.len() / 3;
let k = self.k.min(n);
let mut centroids: Vec<ColorCluster> = (0..k)
.map(|i| {
let idx = i * n / k;
let off = idx * 3;
ColorCluster::new(
f64::from(pixels[off]),
f64::from(pixels[off + 1]),
f64::from(pixels[off + 2]),
)
})
.collect();
let mut assignments = vec![0usize; n];
let mut iterations = 0;
for _ in 0..self.max_iter {
iterations += 1;
for i in 0..n {
let off = i * 3;
let px = ColorCluster::new(
f64::from(pixels[off]),
f64::from(pixels[off + 1]),
f64::from(pixels[off + 2]),
);
let mut best = 0;
let mut best_dist = f64::MAX;
for (ci, c) in centroids.iter().enumerate() {
let d = px.distance_sq(c);
if d < best_dist {
best_dist = d;
best = ci;
}
}
assignments[i] = best;
}
let mut sums_r = vec![0.0f64; k];
let mut sums_g = vec![0.0f64; k];
let mut sums_b = vec![0.0f64; k];
let mut counts = vec![0usize; k];
for i in 0..n {
let off = i * 3;
let ci = assignments[i];
sums_r[ci] += f64::from(pixels[off]);
sums_g[ci] += f64::from(pixels[off + 1]);
sums_b[ci] += f64::from(pixels[off + 2]);
counts[ci] += 1;
}
let mut max_shift = 0.0f64;
for ci in 0..k {
if counts[ci] == 0 {
continue;
}
let new_c = ColorCluster::new(
sums_r[ci] / counts[ci] as f64,
sums_g[ci] / counts[ci] as f64,
sums_b[ci] / counts[ci] as f64,
);
let shift = centroids[ci].distance(&new_c);
if shift > max_shift {
max_shift = shift;
}
centroids[ci] = new_c;
}
if max_shift < self.tolerance {
break;
}
}
let mut inertia = 0.0f64;
let mut counts = vec![0usize; k];
for i in 0..n {
let off = i * 3;
let px = ColorCluster::new(
f64::from(pixels[off]),
f64::from(pixels[off + 1]),
f64::from(pixels[off + 2]),
);
let ci = assignments[i];
inertia += px.distance_sq(¢roids[ci]);
counts[ci] += 1;
}
Some(ClusterResult {
centroids,
counts,
inertia,
iterations,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_color_cluster_new() {
let c = ColorCluster::new(100.0, 150.0, 200.0);
assert!((c.r - 100.0).abs() < f64::EPSILON);
assert!((c.g - 150.0).abs() < f64::EPSILON);
assert!((c.b - 200.0).abs() < f64::EPSILON);
}
#[test]
fn test_distance_sq_zero() {
let c = ColorCluster::new(1.0, 2.0, 3.0);
assert!((c.distance_sq(&c)).abs() < f64::EPSILON);
}
#[test]
fn test_distance_sq_known() {
let a = ColorCluster::new(0.0, 0.0, 0.0);
let b = ColorCluster::new(3.0, 4.0, 0.0);
assert!((a.distance_sq(&b) - 25.0).abs() < 1e-9);
}
#[test]
fn test_distance_known() {
let a = ColorCluster::new(0.0, 0.0, 0.0);
let b = ColorCluster::new(3.0, 4.0, 0.0);
assert!((a.distance(&b) - 5.0).abs() < 1e-9);
}
#[test]
fn test_to_rgb_u8() {
let c = ColorCluster::new(128.4, 0.0, 255.9);
let rgb = c.to_rgb_u8();
assert_eq!(rgb, [128, 0, 255]);
}
#[test]
fn test_to_rgb_u8_clamp() {
let c = ColorCluster::new(-10.0, 300.0, 127.5);
let rgb = c.to_rgb_u8();
assert_eq!(rgb, [0, 255, 128]);
}
#[test]
fn test_luminance() {
let white = ColorCluster::new(255.0, 255.0, 255.0);
assert!((white.luminance() - 255.0).abs() < 1e-9);
let black = ColorCluster::new(0.0, 0.0, 0.0);
assert!((black.luminance()).abs() < f64::EPSILON);
}
#[test]
fn test_kmeans_empty() {
let km = KMeansColorCluster::new(3, 10, 0.1);
assert!(km.k_means(&[]).is_none());
}
#[test]
fn test_kmeans_bad_len() {
let km = KMeansColorCluster::new(2, 10, 0.1);
assert!(km.k_means(&[1, 2]).is_none()); }
#[test]
fn test_kmeans_single_pixel() {
let km = KMeansColorCluster::new(1, 10, 0.1);
let result = km
.k_means(&[100, 150, 200])
.expect("k_means should succeed");
assert_eq!(result.centroids.len(), 1);
assert!((result.centroids[0].r - 100.0).abs() < 1e-9);
assert_eq!(result.counts[0], 1);
}
#[test]
fn test_kmeans_two_clusters() {
let mut pixels = Vec::new();
for _ in 0..4 {
pixels.extend_from_slice(&[255, 0, 0]);
}
for _ in 0..4 {
pixels.extend_from_slice(&[0, 0, 255]);
}
let km = KMeansColorCluster::new(2, 50, 0.01);
let result = km.k_means(&pixels).expect("k_means should succeed");
assert_eq!(result.centroids.len(), 2);
assert!(result.counts.iter().all(|&c| c == 4));
}
#[test]
fn test_dominant_color() {
let mut pixels = Vec::new();
for _ in 0..6 {
pixels.extend_from_slice(&[255, 0, 0]);
}
for _ in 0..2 {
pixels.extend_from_slice(&[0, 255, 0]);
}
let km = KMeansColorCluster::new(2, 50, 0.01);
let result = km.k_means(&pixels).expect("k_means should succeed");
let dom = result
.dominant_color()
.expect("dominant_color should succeed");
assert!(dom.r > dom.g);
}
#[test]
fn test_sorted_by_count() {
let result = ClusterResult {
centroids: vec![
ColorCluster::new(0.0, 0.0, 0.0),
ColorCluster::new(255.0, 255.0, 255.0),
],
counts: vec![10, 50],
inertia: 0.0,
iterations: 1,
};
let sorted = result.sorted_by_count();
assert_eq!(sorted[0].1, 50);
assert_eq!(sorted[1].1, 10);
}
#[test]
fn test_kmeans_k_exceeds_n() {
let km = KMeansColorCluster::new(5, 10, 0.1);
let result = km
.k_means(&[10, 20, 30, 40, 50, 60])
.expect("k_means should succeed");
assert_eq!(result.centroids.len(), 2);
}
#[test]
fn test_kmeans_converges() {
let mut pixels = Vec::new();
for _ in 0..20 {
pixels.extend_from_slice(&[100, 100, 100]);
}
let km = KMeansColorCluster::new(1, 100, 0.001);
let result = km.k_means(&pixels).expect("k_means should succeed");
assert!(result.iterations <= 2);
assert!(result.inertia < 1e-6);
}
#[test]
fn test_cluster_result_dominant_empty() {
let result = ClusterResult {
centroids: Vec::new(),
counts: Vec::new(),
inertia: 0.0,
iterations: 0,
};
assert!(result.dominant_color().is_none());
}
}