1use image::{DynamicImage, Pixel};
2use palette::{FromColor, Oklab, Srgb};
3use rand::distr::{Distribution, weighted::WeightedIndex};
4use rand::seq::IndexedRandom;
5use std::ops::RangeInclusive;
6
7pub struct KMeansResult<const DIMS: usize> {
9 pub centroids: Vec<[f32; DIMS]>,
11 pub clusters: Vec<Vec<usize>>,
13}
14
15pub fn silhouette_score<const DIMS: usize, F>(
22 data: &[[f32; DIMS]],
23 result: &KMeansResult<DIMS>,
24 distance: F,
25) -> f32
26where
27 F: Fn(&[f32], &[f32]) -> f32,
28{
29 let mut s = vec![0.0; data.len()];
30
31 for ((cluster_index, cluster), centroid) in
32 std::iter::zip(result.clusters.iter().enumerate(), result.centroids.iter())
33 {
34 for &point_index in cluster {
35 let a = distance(&data[point_index], centroid);
36 let b = result
37 .centroids
38 .iter()
39 .enumerate()
40 .filter(|(other_cluster_index, _)| *other_cluster_index != cluster_index)
41 .map(|(_, other_cluster_centroid)| {
42 distance(&data[point_index], other_cluster_centroid)
43 })
44 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
45 .unwrap_or(1.0);
46
47 if a < b {
48 s[point_index] = 1.0 - (a / b);
49 } else if a > b {
50 s[point_index] = (b / a) - 1.0;
51 }
52 }
53 }
54
55 s.iter().sum::<f32>() / data.len() as f32
56}
57
58pub fn eucl_distance_squared(first: &[f32], second: &[f32]) -> f32 {
60 std::iter::zip(first, second)
61 .map(|(a, b)| (a - b).powi(2))
62 .sum()
63}
64
65pub fn eucl_distance(first: &[f32], second: &[f32]) -> f32 {
67 eucl_distance_squared(first, second).sqrt()
68}
69
70fn calculate_centroids<const DIMS: usize>(
71 data: &[[f32; DIMS]],
72 clusters: &[Vec<usize>],
73 old_centroids: &[[f32; DIMS]],
74) -> Vec<[f32; DIMS]> {
75 let mut ans = vec![];
76 for (cluster, old_centroid) in std::iter::zip(clusters, old_centroids) {
77 if cluster.is_empty() {
78 ans.push(*old_centroid);
79 continue;
80 }
81
82 let mut sum = cluster
83 .iter()
84 .map(|&index| data[index])
85 .fold([0.0; DIMS], |mut acc, x| {
86 for i in 0..DIMS {
87 acc[i] += x[i];
88 }
89 acc
90 });
91
92 let cluster_size = cluster.len() as f32;
93 for v in sum.iter_mut() {
94 *v /= cluster_size;
95 }
96
97 ans.push(sum);
98 }
99
100 ans
101}
102
103fn array_eq(first: &[f32], second: &[f32], eps: f32) -> bool {
104 std::iter::zip(first, second).all(|(a, b)| (a - b).abs() <= eps)
105}
106
107fn centroids_eq<const DIMS: usize>(
108 first: &Vec<[f32; DIMS]>,
109 second: &Vec<[f32; DIMS]>,
110 eps: f32,
111) -> bool {
112 std::iter::zip(first, second).all(|(a, b)| array_eq(a, b, eps))
113}
114
115fn initialize_centroids<const DIMS: usize, F>(
116 data: &[[f32; DIMS]],
117 k: usize,
118 distance: &F,
119 init: KMeansInit,
120) -> Vec<[f32; DIMS]>
121where
122 F: Fn(&[f32], &[f32]) -> f32,
123{
124 let mut rng = rand::rng();
125
126 let mut centroids = Vec::with_capacity(k);
127 if !data.is_empty() && k > 0 {
128 match init {
129 KMeansInit::Random => {
130 centroids = data.sample(&mut rng, k).cloned().collect();
131 }
132 KMeansInit::KMeansPlusPlus => {
133 centroids.push(*data.choose(&mut rng).expect("Data is not empty"));
135
136 for _ in 1..k {
138 let weights: Vec<f64> = data
139 .iter()
140 .map(|point| {
141 centroids
142 .iter()
143 .map(|c| distance(c, point))
144 .min_by(|a, b| a.total_cmp(b))
145 .unwrap_or(0.0) as f64
146 })
147 .collect();
148
149 if let Ok(dist) = WeightedIndex::new(&weights) {
150 centroids.push(data[dist.sample(&mut rng)]);
151 } else {
152 centroids.push(*data.choose(&mut rng).expect("Data is not empty"));
154 }
155 }
156 }
157 }
158 }
159 centroids
160}
161
162pub fn kmeans<const DIMS: usize, F>(
170 data: &[[f32; DIMS]],
171 k: usize,
172 distance: F,
173 max_iters: usize,
174 eps: f32,
175 init: KMeansInit,
176) -> KMeansResult<DIMS>
177where
178 F: Fn(&[f32], &[f32]) -> f32,
179{
180 let mut centroids = initialize_centroids(data, k, &distance, init);
181
182 let mut clusters: Vec<Vec<usize>> = vec![vec![]; k];
183 for _i in 0..max_iters {
184 for c in clusters.iter_mut() {
185 c.clear();
186 }
187
188 for (index, point) in data.iter().enumerate() {
190 let closest_centroid = centroids
191 .iter()
192 .map(|centroid| distance(centroid, point))
193 .enumerate()
194 .min_by(|(_, a), (_, b)| a.total_cmp(b))
195 .map(|(index, _)| index)
196 .expect("Can't assign point to the closest centroid");
197 clusters[closest_centroid].push(index);
198 }
199
200 let new_centroids = calculate_centroids(data, &clusters, ¢roids);
201 if centroids_eq(&new_centroids, ¢roids, eps) {
202 break;
203 }
204 centroids = new_centroids;
205 }
206
207 KMeansResult {
208 centroids,
209 clusters,
210 }
211}
212
213pub fn saturation(point: &[f32; 3]) -> f32 {
217 let max = point
218 .iter()
219 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
220 .unwrap_or(&0.0);
221 let min = point
222 .iter()
223 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
224 .unwrap_or(&0.0);
225
226 max - min
227}
228
229#[derive(Clone, Copy, PartialEq)]
231pub enum KMeansInit {
232 Random,
234 KMeansPlusPlus,
236}
237
238#[derive(Clone, Copy, PartialEq)]
240pub enum ColorSpace {
241 Rgb,
243 Oklab,
245}
246
247pub struct Settings {
249 pub img_size: u32,
251 pub clusters: RangeInclusive<usize>,
253 pub max_iters: usize,
255 pub eps: f32,
257 pub init: KMeansInit,
259 pub color_space: ColorSpace,
261}
262
263impl Default for Settings {
264 fn default() -> Self {
265 Self {
266 img_size: 72,
267 clusters: 2..=6,
268 max_iters: 100,
269 eps: 1e-6,
270 init: KMeansInit::KMeansPlusPlus,
271 color_space: ColorSpace::Oklab,
272 }
273 }
274}
275
276fn dominant_colors_private(img: &DynamicImage, settings: &Settings) -> Vec<([f32; 3], f32)> {
277 let resized = image::imageops::resize(
278 img,
279 settings.img_size,
280 settings.img_size,
281 image::imageops::FilterType::Triangle,
282 );
283
284 let pixels: Vec<_> = resized
285 .pixels()
286 .map(|pixel| {
287 let rgb = pixel.to_rgb();
288 match settings.color_space {
289 ColorSpace::Rgb => [
290 rgb.0[0] as f32 / 255.0,
291 rgb.0[1] as f32 / 255.0,
292 rgb.0[2] as f32 / 255.0,
293 ],
294 ColorSpace::Oklab => {
295 let srgb = Srgb::new(
296 rgb.0[0] as f32 / 255.0,
297 rgb.0[1] as f32 / 255.0,
298 rgb.0[2] as f32 / 255.0,
299 );
300 let lab = Oklab::from_color(srgb);
301 [lab.l, lab.a, lab.b]
302 }
303 }
304 })
305 .collect();
306
307 let kmeans_result = settings
309 .clusters
310 .clone()
311 .map(|k| {
312 kmeans(
313 &pixels,
314 k,
315 eucl_distance_squared,
316 settings.max_iters,
317 settings.eps,
318 settings.init,
319 )
320 })
321 .map(|kmeans_result| {
322 (
323 silhouette_score(&pixels, &kmeans_result, eucl_distance),
324 kmeans_result,
325 )
326 })
327 .max_by(|(score1, _), (score2, _)| {
328 score1
329 .partial_cmp(score2)
330 .unwrap_or(std::cmp::Ordering::Equal)
331 })
332 .map(|(_, kmeans_result)| kmeans_result);
333
334 match kmeans_result {
335 Some(kmeans_result) => std::iter::zip(
336 kmeans_result.centroids.iter(),
337 kmeans_result.clusters.iter(),
338 )
339 .filter(|(_centroid, cluster)| !cluster.is_empty())
340 .map(|(centroid, _cluster)| match settings.color_space {
341 ColorSpace::Rgb => (*centroid, saturation(centroid)),
342 ColorSpace::Oklab => {
343 let lab = Oklab::new(centroid[0], centroid[1], centroid[2]);
344 let rgb = Srgb::from_color(lab);
345 let chroma = (centroid[1].powi(2) + centroid[2].powi(2)).sqrt();
346 (
347 [
348 rgb.red.clamp(0.0, 1.0),
349 rgb.green.clamp(0.0, 1.0),
350 rgb.blue.clamp(0.0, 1.0),
351 ],
352 chroma,
353 )
354 }
355 })
356 .collect(),
357 None => vec![],
358 }
359}
360
361pub fn dominant_colors(img: &DynamicImage, settings: &Settings) -> Vec<[f32; 3]> {
367 let mut centroids_and_saturations = dominant_colors_private(img, settings);
368
369 centroids_and_saturations.sort_by(|(_, sat1), (_, sat2)| {
371 sat2.partial_cmp(sat1).unwrap_or(std::cmp::Ordering::Equal)
372 });
373
374 centroids_and_saturations
376 .into_iter()
377 .map(|(centroid, _)| centroid)
378 .collect()
379}
380
381pub fn dominant_color(img: &DynamicImage, settings: &Settings) -> Option<[f32; 3]> {
386 dominant_colors_private(img, settings)
388 .into_iter()
389 .max_by(|(_, sat1), (_, sat2)| sat1.partial_cmp(sat2).unwrap_or(std::cmp::Ordering::Equal))
390 .map(|(centroid, _)| centroid)
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_saturation_chroma() {
399 let dark_red = [4.0 / 255.0, 2.0 / 255.0, 2.0 / 255.0];
400 let vivid_red = [187.0 / 255.0, 78.0 / 255.0, 69.0 / 255.0];
401
402 let sat_dark = saturation(&dark_red);
403 let sat_vivid = saturation(&vivid_red);
404 assert!(sat_vivid > sat_dark);
405 }
406
407 #[test]
408 fn test() {
409 let entries = std::fs::read_dir("testimg").unwrap();
410 for entry in entries {
411 let path = entry.unwrap().path();
412 if path.is_file() {
413 let img = image::open(path).unwrap();
414 dominant_color(&img, &Settings::default());
415 }
416 }
417 }
418}