use image::RgbImage;
use kmeans_colors::{get_kmeans_hamerly, Kmeans};
use palette::cast::from_component_slice;
use palette::{FromColor, IntoColor, Lab, Srgb};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Centroid {
pub x: f32,
pub y: f32,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Cluster {
pub color: [u8; 3],
pub centroid: Centroid,
pub weight: f32,
}
#[derive(Debug)]
pub enum ClusterError {
EmptyImage,
KZero,
}
impl std::fmt::Display for ClusterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EmptyImage => write!(f, "input image is empty (0x0)"),
Self::KZero => write!(f, "k must be at least 1"),
}
}
}
impl std::error::Error for ClusterError {}
const KMEANS_RUNS: usize = 3;
const KMEANS_MAX_ITER: usize = 20;
const KMEANS_CONVERGE: f32 = 5.0;
const KMEANS_SEED: u64 = 0;
pub fn extract_clusters(img: &RgbImage, k: usize) -> Result<Vec<Cluster>, ClusterError> {
if k == 0 {
return Err(ClusterError::KZero);
}
let (width, height) = img.dimensions();
if width == 0 || height == 0 {
return Err(ClusterError::EmptyImage);
}
let raw: &[u8] = img.as_raw();
let lab: Vec<Lab> = from_component_slice::<Srgb<u8>>(raw)
.iter()
.map(|x| x.into_linear::<f32>().into_color())
.collect();
let effective_k = k.min(lab.len());
let mut best = Kmeans::new();
for i in 0..KMEANS_RUNS {
let run = get_kmeans_hamerly(
effective_k,
KMEANS_MAX_ITER,
KMEANS_CONVERGE,
false,
&lab,
KMEANS_SEED + i as u64,
);
if run.score < best.score {
best = run;
}
}
let cluster_count = best.centroids.len();
let mut counts = vec![0u64; cluster_count];
let mut sum_x = vec![0f64; cluster_count];
let mut sum_y = vec![0f64; cluster_count];
let total_pixels = best.indices.len();
for (i, &idx) in best.indices.iter().enumerate() {
let idx = idx as usize;
if idx >= cluster_count {
continue;
}
let x = (i as u32) % width;
let y = (i as u32) / width;
counts[idx] += 1;
sum_x[idx] += x as f64;
sum_y[idx] += y as f64;
}
let denom_x = (width.saturating_sub(1)).max(1) as f64;
let denom_y = (height.saturating_sub(1)).max(1) as f64;
let mut clusters: Vec<Cluster> = best
.centroids
.iter()
.enumerate()
.filter_map(|(idx, lab_centroid)| {
let count = counts[idx];
if count == 0 {
return None;
}
let mean_x = sum_x[idx] / count as f64;
let mean_y = sum_y[idx] / count as f64;
let cx = (mean_x / denom_x).clamp(0.0, 1.0) as f32;
let cy = (mean_y / denom_y).clamp(0.0, 1.0) as f32;
let srgb: Srgb = Srgb::from_color(*lab_centroid);
let rgb_u8: Srgb<u8> = srgb.into_format();
Some(Cluster {
color: [rgb_u8.red, rgb_u8.green, rgb_u8.blue],
centroid: Centroid { x: cx, y: cy },
weight: (count as f64 / total_pixels as f64) as f32,
})
})
.collect();
clusters.sort_by(|a, b| b.weight.total_cmp(&a.weight));
Ok(clusters)
}
pub fn derive_background_rgba(clusters: &[Cluster]) -> [u8; 4] {
clusters
.iter()
.max_by(|a, b| a.weight.total_cmp(&b.weight))
.map(|c| {
let [r, g, b] = c.color;
[r, g, b, 255]
})
.unwrap_or([0, 0, 0, 255])
}
pub fn drop_dominant(clusters: &[Cluster]) -> Vec<Cluster> {
let dominant_idx = clusters
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.weight.total_cmp(&b.weight))
.map(|(i, _)| i);
clusters
.iter()
.enumerate()
.filter(|(i, _)| Some(*i) != dominant_idx)
.map(|(_, c)| *c)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use image::{ImageBuffer, Rgb};
fn approx(a: f32, b: f32, eps: f32, label: &str) {
assert!(
(a - b).abs() < eps,
"{}: expected ~{}, got {} (eps={})",
label,
b,
a,
eps
);
}
#[test]
fn single_color_k1() {
let img: RgbImage = ImageBuffer::from_fn(256, 256, |_, _| Rgb([255u8, 0, 0]));
let clusters = extract_clusters(&img, 1).expect("clusters");
assert_eq!(clusters.len(), 1);
let c = clusters[0];
approx(c.color[0] as f32, 255.0, 2.0, "color.r");
approx(c.color[1] as f32, 0.0, 2.0, "color.g");
approx(c.color[2] as f32, 0.0, 2.0, "color.b");
approx(c.centroid.x, 0.5, 0.05, "centroid.x");
approx(c.centroid.y, 0.5, 0.05, "centroid.y");
approx(c.weight, 1.0, 1e-4, "weight");
}
#[test]
fn top_red_bottom_blue_k2() {
let img: RgbImage = ImageBuffer::from_fn(100, 100, |_, y| {
if y < 50 {
Rgb([255u8, 0, 0])
} else {
Rgb([0u8, 0, 255])
}
});
let clusters = extract_clusters(&img, 2).expect("clusters");
assert_eq!(clusters.len(), 2);
let red = clusters
.iter()
.find(|c| c.color[0] > c.color[2])
.expect("red cluster");
let blue = clusters
.iter()
.find(|c| c.color[2] > c.color[0])
.expect("blue cluster");
approx(red.color[0] as f32, 255.0, 2.0, "red.color.r");
approx(red.color[2] as f32, 0.0, 2.0, "red.color.b");
approx(blue.color[0] as f32, 0.0, 2.0, "blue.color.r");
approx(blue.color[2] as f32, 255.0, 2.0, "blue.color.b");
approx(red.centroid.x, 0.5, 0.05, "red.centroid.x");
approx(blue.centroid.x, 0.5, 0.05, "blue.centroid.x");
approx(red.centroid.y, 0.247, 0.05, "red.centroid.y");
approx(blue.centroid.y, 0.752, 0.05, "blue.centroid.y");
approx(red.weight, 0.5, 0.02, "red.weight");
approx(blue.weight, 0.5, 0.02, "blue.weight");
}
#[test]
fn single_pixel_image_k1() {
let img: RgbImage = ImageBuffer::from_fn(1, 1, |_, _| Rgb([128u8, 64, 200]));
let clusters = extract_clusters(&img, 1).expect("clusters");
assert_eq!(clusters.len(), 1);
let c = clusters[0];
approx(c.centroid.x, 0.0, 1e-6, "centroid.x");
approx(c.centroid.y, 0.0, 1e-6, "centroid.y");
approx(c.weight, 1.0, 1e-4, "weight");
approx(c.color[0] as f32, 128.0, 2.0, "color.r");
approx(c.color[1] as f32, 64.0, 2.0, "color.g");
approx(c.color[2] as f32, 200.0, 2.0, "color.b");
}
#[test]
fn empty_image_returns_error() {
let img: RgbImage = ImageBuffer::new(0, 0);
match extract_clusters(&img, 3) {
Err(ClusterError::EmptyImage) => {}
other => panic!("expected EmptyImage, got {:?}", other),
}
}
#[test]
fn k_zero_returns_error() {
let img: RgbImage = ImageBuffer::from_fn(8, 8, |_, _| Rgb([10u8, 20, 30]));
match extract_clusters(&img, 0) {
Err(ClusterError::KZero) => {}
other => panic!("expected KZero, got {:?}", other),
}
}
#[test]
fn derive_background_picks_max_weight_cluster() {
let clusters = vec![
Cluster {
color: [10, 20, 30],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.2,
},
Cluster {
color: [200, 100, 50],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.7,
},
Cluster {
color: [0, 0, 0],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.1,
},
];
assert_eq!(derive_background_rgba(&clusters), [200, 100, 50, 255]);
}
#[test]
fn derive_background_unsorted_input() {
let clusters = vec![
Cluster {
color: [200, 100, 50],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.7,
},
Cluster {
color: [10, 20, 30],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.2,
},
];
assert_eq!(derive_background_rgba(&clusters), [200, 100, 50, 255]);
}
#[test]
fn derive_background_unsorted_three_clusters() {
let clusters = vec![
Cluster {
color: [10, 10, 10],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.2,
},
Cluster {
color: [20, 20, 20],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.5, },
Cluster {
color: [30, 30, 30],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.3,
},
];
assert_eq!(derive_background_rgba(&clusters), [20, 20, 20, 255]);
}
#[test]
fn derive_background_empty_clusters_yields_black() {
let clusters: Vec<Cluster> = vec![];
assert_eq!(derive_background_rgba(&clusters), [0, 0, 0, 255]);
}
#[test]
fn drop_dominant_removes_max_weight() {
let clusters = vec![
Cluster {
color: [10, 20, 30],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.2,
},
Cluster {
color: [200, 100, 50],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.7,
},
Cluster {
color: [0, 0, 0],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.1,
},
];
let rest = drop_dominant(&clusters);
assert_eq!(rest.len(), 2);
assert!(rest.iter().all(|c| c.color != [200, 100, 50]));
assert_eq!(rest[0].color, [10, 20, 30]);
assert_eq!(rest[1].color, [0, 0, 0]);
}
#[test]
fn drop_dominant_on_sorted_input_returns_tail() {
let clusters = vec![
Cluster {
color: [200, 100, 50],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.6,
},
Cluster {
color: [10, 20, 30],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.3,
},
Cluster {
color: [0, 0, 0],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 0.1,
},
];
let rest = drop_dominant(&clusters);
assert_eq!(rest.len(), 2);
assert_eq!(rest[0].color, [10, 20, 30]);
assert_eq!(rest[1].color, [0, 0, 0]);
}
#[test]
fn drop_dominant_empty_input() {
let clusters: Vec<Cluster> = vec![];
assert!(drop_dominant(&clusters).is_empty());
}
#[test]
fn drop_dominant_single_cluster_yields_empty() {
let clusters = vec![Cluster {
color: [200, 100, 50],
centroid: Centroid { x: 0.5, y: 0.5 },
weight: 1.0,
}];
assert!(drop_dominant(&clusters).is_empty());
}
}