Skip to main content

dominant_color_rs/
lib.rs

1use image::{DynamicImage, Pixel};
2use palette::{FromColor, Oklab, Srgb};
3use rand::distr::{Distribution, weighted::WeightedIndex};
4use rand::seq::IndexedRandom;
5use std::ops::RangeInclusive;
6
7/// Result of a K-Means clustering operation.
8pub struct KMeansResult<const DIMS: usize> {
9    /// The computed centroids for each cluster.
10    pub centroids: Vec<[f32; DIMS]>,
11    /// The indices of the data points assigned to each cluster.
12    pub clusters: Vec<Vec<usize>>,
13}
14
15/// Calculates the silhouette score for a clustering result.
16///
17/// The silhouette score is a measure of how similar an object is to its own cluster
18/// compared to other clusters. The score ranges from -1 to 1, where a high value
19/// indicates that the object is well matched to its own cluster and poorly matched
20/// to neighboring clusters.
21pub fn silhouette_score<const DIMS: usize, F>(
22    data: &[[f32; DIMS]],
23    result: &KMeansResult<DIMS>,
24    distance: F,
25) -> f32
26where
27    F: Fn(&[f32], &[f32]) -> f32,
28{
29    let mut s = vec![0.0; data.len()];
30
31    for ((cluster_index, cluster), centroid) in
32        std::iter::zip(result.clusters.iter().enumerate(), result.centroids.iter())
33    {
34        for &point_index in cluster {
35            let a = distance(&data[point_index], centroid);
36            let b = result
37                .centroids
38                .iter()
39                .enumerate()
40                .filter(|(other_cluster_index, _)| *other_cluster_index != cluster_index)
41                .map(|(_, other_cluster_centroid)| {
42                    distance(&data[point_index], other_cluster_centroid)
43                })
44                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
45                .unwrap_or(1.0);
46
47            if a < b {
48                s[point_index] = 1.0 - (a / b);
49            } else if a > b {
50                s[point_index] = (b / a) - 1.0;
51            }
52        }
53    }
54
55    s.iter().sum::<f32>() / data.len() as f32
56}
57
58/// Calculates the squared Euclidean distance between two points.
59pub fn eucl_distance_squared(first: &[f32], second: &[f32]) -> f32 {
60    std::iter::zip(first, second)
61        .map(|(a, b)| (a - b).powi(2))
62        .sum()
63}
64
65/// Calculates the Euclidean distance between two points.
66pub fn eucl_distance(first: &[f32], second: &[f32]) -> f32 {
67    eucl_distance_squared(first, second).sqrt()
68}
69
70fn calculate_centroids<const DIMS: usize>(
71    data: &[[f32; DIMS]],
72    clusters: &[Vec<usize>],
73    old_centroids: &[[f32; DIMS]],
74) -> Vec<[f32; DIMS]> {
75    let mut ans = vec![];
76    for (cluster, old_centroid) in std::iter::zip(clusters, old_centroids) {
77        if cluster.is_empty() {
78            ans.push(*old_centroid);
79            continue;
80        }
81
82        let mut sum = cluster
83            .iter()
84            .map(|&index| data[index])
85            .fold([0.0; DIMS], |mut acc, x| {
86                for i in 0..DIMS {
87                    acc[i] += x[i];
88                }
89                acc
90            });
91
92        let cluster_size = cluster.len() as f32;
93        for v in sum.iter_mut() {
94            *v /= cluster_size;
95        }
96
97        ans.push(sum);
98    }
99
100    ans
101}
102
103fn array_eq(first: &[f32], second: &[f32], eps: f32) -> bool {
104    std::iter::zip(first, second).all(|(a, b)| (a - b).abs() <= eps)
105}
106
107fn centroids_eq<const DIMS: usize>(
108    first: &Vec<[f32; DIMS]>,
109    second: &Vec<[f32; DIMS]>,
110    eps: f32,
111) -> bool {
112    std::iter::zip(first, second).all(|(a, b)| array_eq(a, b, eps))
113}
114
115fn initialize_centroids<const DIMS: usize, F>(
116    data: &[[f32; DIMS]],
117    k: usize,
118    distance: &F,
119    init: KMeansInit,
120) -> Vec<[f32; DIMS]>
121where
122    F: Fn(&[f32], &[f32]) -> f32,
123{
124    let mut rng = rand::rng();
125
126    let mut centroids = Vec::with_capacity(k);
127    if !data.is_empty() && k > 0 {
128        match init {
129            KMeansInit::Random => {
130                centroids = data.sample(&mut rng, k).cloned().collect();
131            }
132            KMeansInit::KMeansPlusPlus => {
133                // 1. Choose first centroid uniformly at random
134                centroids.push(*data.choose(&mut rng).expect("Data is not empty"));
135
136                // 2. Choose remaining k-1 centroids
137                for _ in 1..k {
138                    let weights: Vec<f64> = data
139                        .iter()
140                        .map(|point| {
141                            centroids
142                                .iter()
143                                .map(|c| distance(c, point))
144                                .min_by(|a, b| a.total_cmp(b))
145                                .unwrap_or(0.0) as f64
146                        })
147                        .collect();
148
149                    if let Ok(dist) = WeightedIndex::new(&weights) {
150                        centroids.push(data[dist.sample(&mut rng)]);
151                    } else {
152                        // If all weights are 0 or valid indices cannot be created, pick randomly
153                        centroids.push(*data.choose(&mut rng).expect("Data is not empty"));
154                    }
155                }
156            }
157        }
158    }
159    centroids
160}
161
162/// Performs K-Means clustering on the provided data.
163///
164/// * `data`: The data points to cluster.
165/// * `k`: The number of clusters to find.
166/// * `distance`: A function that calculates the distance between two points.
167/// * `max_iters`: The maximum number of iterations to perform.
168/// * `eps`: The convergence threshold. If the centroids move by less than this amount, the algorithm stops.
169pub fn kmeans<const DIMS: usize, F>(
170    data: &[[f32; DIMS]],
171    k: usize,
172    distance: F,
173    max_iters: usize,
174    eps: f32,
175    init: KMeansInit,
176) -> KMeansResult<DIMS>
177where
178    F: Fn(&[f32], &[f32]) -> f32,
179{
180    let mut centroids = initialize_centroids(data, k, &distance, init);
181
182    let mut clusters: Vec<Vec<usize>> = vec![vec![]; k];
183    for _i in 0..max_iters {
184        for c in clusters.iter_mut() {
185            c.clear();
186        }
187
188        // Assign each point to the "closest" centroid
189        for (index, point) in data.iter().enumerate() {
190            let closest_centroid = centroids
191                .iter()
192                .map(|centroid| distance(centroid, point))
193                .enumerate()
194                .min_by(|(_, a), (_, b)| a.total_cmp(b))
195                .map(|(index, _)| index)
196                .expect("Can't assign point to the closest centroid");
197            clusters[closest_centroid].push(index);
198        }
199
200        let new_centroids = calculate_centroids(data, &clusters, &centroids);
201        if centroids_eq(&new_centroids, &centroids, eps) {
202            break;
203        }
204        centroids = new_centroids;
205    }
206
207    KMeansResult {
208        centroids,
209        clusters,
210    }
211}
212
213/// Calculates the saturation (chroma) of an RGB color.
214///
215/// Expected input is an RGB array where each component is in the range `[0.0, 1.0]`.
216pub fn saturation(point: &[f32; 3]) -> f32 {
217    let max = point
218        .iter()
219        .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
220        .unwrap_or(&0.0);
221    let min = point
222        .iter()
223        .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
224        .unwrap_or(&0.0);
225
226    max - min
227}
228
229/// Initialization method for K-Means.
230#[derive(Clone, Copy, PartialEq)]
231pub enum KMeansInit {
232    /// Standard random initialization.
233    Random,
234    /// K-Means++ initialization for better convergence.
235    KMeansPlusPlus,
236}
237
238/// Color space used for K-Means clustering.
239#[derive(Clone, Copy, PartialEq)]
240pub enum ColorSpace {
241    /// Standard RGB color space. Faster but not perceptually uniform.
242    Rgb,
243    /// Oklab color space. Perceptually uniform, producing more accurate visual clusters.
244    Oklab,
245}
246
247/// Settings for dominant color extraction.
248pub struct Settings {
249    /// The size (width and height) to which the image will be resized before processing.
250    pub img_size: u32,
251    /// The range of cluster counts (k) to try. The one with the best silhouette score will be chosen.
252    pub clusters: RangeInclusive<usize>,
253    /// Maximum iterations for K-Means.
254    pub max_iters: usize,
255    /// Convergence threshold for K-Means.
256    pub eps: f32,
257    /// Initialization method.
258    pub init: KMeansInit,
259    /// Color space to use for clustering.
260    pub color_space: ColorSpace,
261}
262
263impl Default for Settings {
264    fn default() -> Self {
265        Self {
266            img_size: 72,
267            clusters: 2..=6,
268            max_iters: 100,
269            eps: 1e-6,
270            init: KMeansInit::KMeansPlusPlus,
271            color_space: ColorSpace::Oklab,
272        }
273    }
274}
275
276fn dominant_colors_private(img: &DynamicImage, settings: &Settings) -> Vec<([f32; 3], f32)> {
277    let resized = image::imageops::resize(
278        img,
279        settings.img_size,
280        settings.img_size,
281        image::imageops::FilterType::Triangle,
282    );
283
284    let pixels: Vec<_> = resized
285        .pixels()
286        .map(|pixel| {
287            let rgb = pixel.to_rgb();
288            match settings.color_space {
289                ColorSpace::Rgb => [
290                    rgb.0[0] as f32 / 255.0,
291                    rgb.0[1] as f32 / 255.0,
292                    rgb.0[2] as f32 / 255.0,
293                ],
294                ColorSpace::Oklab => {
295                    let srgb = Srgb::new(
296                        rgb.0[0] as f32 / 255.0,
297                        rgb.0[1] as f32 / 255.0,
298                        rgb.0[2] as f32 / 255.0,
299                    );
300                    let lab = Oklab::from_color(srgb);
301                    [lab.l, lab.a, lab.b]
302                }
303            }
304        })
305        .collect();
306
307    // take the kmeans_result maximizing silhouette_score
308    let kmeans_result = settings
309        .clusters
310        .clone()
311        .map(|k| {
312            kmeans(
313                &pixels,
314                k,
315                eucl_distance_squared,
316                settings.max_iters,
317                settings.eps,
318                settings.init,
319            )
320        })
321        .map(|kmeans_result| {
322            (
323                silhouette_score(&pixels, &kmeans_result, eucl_distance),
324                kmeans_result,
325            )
326        })
327        .max_by(|(score1, _), (score2, _)| {
328            score1
329                .partial_cmp(score2)
330                .unwrap_or(std::cmp::Ordering::Equal)
331        })
332        .map(|(_, kmeans_result)| kmeans_result);
333
334    match kmeans_result {
335        Some(kmeans_result) => std::iter::zip(
336            kmeans_result.centroids.iter(),
337            kmeans_result.clusters.iter(),
338        )
339        .filter(|(_centroid, cluster)| !cluster.is_empty())
340        .map(|(centroid, _cluster)| match settings.color_space {
341            ColorSpace::Rgb => (*centroid, saturation(centroid)),
342            ColorSpace::Oklab => {
343                let lab = Oklab::new(centroid[0], centroid[1], centroid[2]);
344                let rgb = Srgb::from_color(lab);
345                let chroma = (centroid[1].powi(2) + centroid[2].powi(2)).sqrt();
346                (
347                    [
348                        rgb.red.clamp(0.0, 1.0),
349                        rgb.green.clamp(0.0, 1.0),
350                        rgb.blue.clamp(0.0, 1.0),
351                    ],
352                    chroma,
353                )
354            }
355        })
356        .collect(),
357        None => vec![],
358    }
359}
360
361/// Calculates the dominant colors of an image.
362///
363/// Returns a vector of RGB colors represented as `[f32; 3]`,
364/// where each component is in the range `[0.0, 1.0]`.
365/// The colors are sorted by saturation in descending order.
366pub fn dominant_colors(img: &DynamicImage, settings: &Settings) -> Vec<[f32; 3]> {
367    let mut centroids_and_saturations = dominant_colors_private(img, settings);
368
369    // sort clusters by their saturation value (descending)
370    centroids_and_saturations.sort_by(|(_, sat1), (_, sat2)| {
371        sat2.partial_cmp(sat1).unwrap_or(std::cmp::Ordering::Equal)
372    });
373
374    // return just the centroid colors
375    centroids_and_saturations
376        .into_iter()
377        .map(|(centroid, _)| centroid)
378        .collect()
379}
380
381/// Calculates the dominant color of an image.
382///
383/// The dominant color is chosen as the cluster centroid with the highest saturation.
384/// Returns the RGB color as `[f32; 3]` where each component is in the range `[0.0, 1.0]`.
385pub fn dominant_color(img: &DynamicImage, settings: &Settings) -> Option<[f32; 3]> {
386    // dominant color is centroid having the highest saturation
387    dominant_colors_private(img, settings)
388        .into_iter()
389        .max_by(|(_, sat1), (_, sat2)| sat1.partial_cmp(sat2).unwrap_or(std::cmp::Ordering::Equal))
390        .map(|(centroid, _)| centroid)
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_saturation_chroma() {
399        let dark_red = [4.0 / 255.0, 2.0 / 255.0, 2.0 / 255.0];
400        let vivid_red = [187.0 / 255.0, 78.0 / 255.0, 69.0 / 255.0];
401
402        let sat_dark = saturation(&dark_red);
403        let sat_vivid = saturation(&vivid_red);
404        assert!(sat_vivid > sat_dark);
405    }
406
407    #[test]
408    fn test() {
409        let entries = std::fs::read_dir("testimg").unwrap();
410        for entry in entries {
411            let path = entry.unwrap().path();
412            if path.is_file() {
413                let img = image::open(path).unwrap();
414                dominant_color(&img, &Settings::default());
415            }
416        }
417    }
418}