use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::color::{mean_color, Color, ColorPalette};
use crate::error::{DominantColorError, Result};
use crate::Config;
pub fn extract(pixels: &[[u8; 3]], config: &Config) -> Result<ColorPalette> {
let k = config.max_colors.min(pixels.len());
if k == 0 {
return Err(DominantColorError::EmptyImage);
}
let mut rng = SmallRng::seed_from_u64(config.kmeans_seed);
let mut centroids = kmeans_plus_plus_init(pixels, k, &mut rng);
for _ in 0..config.kmeans_max_iterations {
let assignments = assign_pixels(pixels, ¢roids);
let new_centroids = update_centroids(pixels, &assignments, k, &mut rng);
let max_shift = centroids
.iter()
.zip(new_centroids.iter())
.map(|(old, new)| Color::sq_distance_rgb(old, new).sqrt())
.fold(0.0_f64, f64::max);
centroids = new_centroids;
if max_shift < config.kmeans_convergence_threshold {
break; }
}
let assignments = assign_pixels(pixels, ¢roids);
let total = pixels.len() as f32;
let mut cluster_counts = vec![0usize; k];
for &idx in &assignments {
cluster_counts[idx] += 1;
}
let palette: ColorPalette = (0..k)
.filter(|&i| cluster_counts[i] > 0)
.filter_map(|i| {
let cluster_pixels: Vec<[u8; 3]> = pixels
.iter()
.zip(assignments.iter())
.filter(|(_, &a)| a == i)
.map(|(&p, _)| p)
.collect();
mean_color(&cluster_pixels, cluster_counts[i] as f32 / total)
})
.collect();
if palette.is_empty() {
return Err(DominantColorError::internal("收敛后所有簇均为空"));
}
Ok(palette)
}
fn kmeans_plus_plus_init(pixels: &[[u8; 3]], k: usize, rng: &mut SmallRng) -> Vec<[u8; 3]> {
let mut centroids: Vec<[u8; 3]> = Vec::with_capacity(k);
centroids.push(pixels[rng.gen_range(0..pixels.len())]);
for _ in 1..k {
let weights: Vec<f64> = pixels
.iter()
.map(|p| {
centroids
.iter()
.map(|c| Color::sq_distance_rgb(p, c))
.fold(f64::MAX, f64::min)
})
.collect();
let total: f64 = weights.iter().sum();
if total == 0.0 {
centroids.push(pixels[rng.gen_range(0..pixels.len())]);
continue;
}
let mut dart = rng.gen::<f64>() * total;
let mut chosen = pixels.len() - 1;
for (i, &w) in weights.iter().enumerate() {
dart -= w;
if dart <= 0.0 {
chosen = i;
break;
}
}
centroids.push(pixels[chosen]);
}
centroids
}
fn assign_pixels(pixels: &[[u8; 3]], centroids: &[[u8; 3]]) -> Vec<usize> {
pixels
.iter()
.map(|p| {
centroids
.iter()
.enumerate()
.map(|(i, c)| (i, Color::sq_distance_rgb(p, c)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0)
})
.collect()
}
fn update_centroids(
pixels: &[[u8; 3]],
assignments: &[usize],
k: usize,
rng: &mut SmallRng,
) -> Vec<[u8; 3]> {
let mut sums = vec![[0u64; 3]; k];
let mut counts = vec![0usize; k];
for (&p, &idx) in pixels.iter().zip(assignments.iter()) {
sums[idx][0] += p[0] as u64;
sums[idx][1] += p[1] as u64;
sums[idx][2] += p[2] as u64;
counts[idx] += 1;
}
let largest = counts
.iter()
.enumerate()
.max_by_key(|(_, &c)| c)
.map(|(i, _)| i)
.unwrap_or(0);
(0..k)
.map(|i| {
if counts[i] == 0 {
let candidates: Vec<[u8; 3]> = pixels
.iter()
.zip(assignments.iter())
.filter(|(_, &a)| a == largest)
.map(|(&p, _)| p)
.collect();
if candidates.is_empty() {
[128, 128, 128] } else {
candidates[rng.gen_range(0..candidates.len())]
}
} else {
let n = counts[i] as u64;
[
(sums[i][0] / n) as u8,
(sums[i][1] / n) as u8,
(sums[i][2] / n) as u8,
]
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg(k: usize) -> Config {
Config::default()
.max_colors(k)
.sample_size(None)
.kmeans_seed(0)
.kmeans_max_iterations(200)
}
#[test]
fn test_extract_single_color() {
let pixels = vec![[255u8, 0, 0]; 50];
let palette = extract(&pixels, &cfg(3)).unwrap();
assert!(!palette.is_empty());
let top = &palette[0];
assert!(top.percentage > 0.9, "纯色图片:首位颜色占比应超过 90%");
}
#[test]
fn test_extract_two_clusters() {
let mut pixels = vec![[0u8, 0, 255]; 50];
pixels.extend(vec![[255u8, 0, 0]; 50]);
let palette = extract(&pixels, &cfg(2)).unwrap();
assert_eq!(palette.len(), 2);
for c in &palette {
assert!((c.percentage - 0.5).abs() < 0.15, "每个簇应占约 50%");
}
}
#[test]
fn test_percentages_sum_to_one() {
let pixels: Vec<[u8; 3]> = (0..255u8).map(|i| [i, i, i]).collect();
let palette = extract(&pixels, &cfg(5)).unwrap();
let total: f32 = palette.iter().map(|c| c.percentage).sum();
assert!((total - 1.0).abs() < 1e-5, "占比之和应为 1.0,实际 {total}");
}
#[test]
fn test_k_clamped_to_pixel_count() {
let pixels = vec![[0u8, 0, 0], [255, 255, 255]];
let palette = extract(&pixels, &cfg(10)).unwrap();
assert!(palette.len() <= 2);
}
#[test]
fn test_kmeans_plus_plus_k_equals_1() {
let pixels = vec![[1u8, 2, 3]; 10];
let mut rng = SmallRng::seed_from_u64(0);
let centroids = kmeans_plus_plus_init(&pixels, 1, &mut rng);
assert_eq!(centroids.len(), 1);
}
#[test]
fn test_assign_pixels_nearest() {
let pixels = vec![[0u8, 0, 0], [200, 200, 200]];
let centroids = vec![[0u8, 0, 0], [255, 255, 255]];
let assignments = assign_pixels(&pixels, ¢roids);
assert_eq!(assignments, vec![0, 1]);
}
#[test]
fn test_deterministic_with_same_seed() {
let pixels: Vec<[u8; 3]> = (0..100u8)
.map(|i| [i, i.wrapping_mul(2), i.wrapping_mul(3)])
.collect();
let palette1 = extract(&pixels, &cfg(4)).unwrap();
let palette2 = extract(&pixels, &cfg(4)).unwrap();
for (a, b) in palette1.iter().zip(palette2.iter()) {
assert_eq!((a.r, a.g, a.b), (b.r, b.g, b.b), "相同种子结果应一致");
}
}
}