swatchthis 0.2.0

Colour swatch extraction from images using k-means clustering
Documentation
use crate::color::Rgb;
use crate::sample_step;

pub fn extract_colors_median_cut(pixels: &[Rgb], k: usize) -> Vec<(Rgb, u32)> {
    if pixels.is_empty() || k == 0 {
        return Vec::new();
    }

    let step = sample_step(pixels.len());
    let sampled: Vec<Rgb> = pixels.iter().step_by(step).copied().collect();
    let k = k.min(sampled.len());

    let mut boxes = vec![ColorBox::from_pixels(&sampled)];

    while boxes.len() < k {
        let (split_idx, _) = boxes
            .iter()
            .enumerate()
            .filter(|(_, b)| b.count > 1)
            .max_by_key(|(_, b)| b.range())
            .unwrap_or((0, &boxes[0]));

        if boxes[split_idx].count <= 1 {
            break;
        }

        let to_split = boxes.swap_remove(split_idx);
        let (a, b) = to_split.split();
        boxes.push(a);
        boxes.push(b);
    }

    boxes
        .iter()
        .filter(|b| b.count > 0)
        .map(|b| (b.average(), b.count as u32))
        .collect()
}

#[derive(Clone)]
struct ColorBox {
    pixels: Vec<Rgb>,
    count: usize,
}

impl ColorBox {
    fn from_pixels(pixels: &[Rgb]) -> Self {
        Self {
            pixels: pixels.to_vec(),
            count: pixels.len(),
        }
    }

    fn ranges(&self) -> (u8, u8, u8) {
        let (mut r_min, mut g_min, mut b_min) = (u8::MAX, u8::MAX, u8::MAX);
        let (mut r_max, mut g_max, mut b_max) = (0u8, 0u8, 0u8);
        for p in &self.pixels {
            r_min = r_min.min(p.r);
            r_max = r_max.max(p.r);
            g_min = g_min.min(p.g);
            g_max = g_max.max(p.g);
            b_min = b_min.min(p.b);
            b_max = b_max.max(p.b);
        }
        (r_max - r_min, g_max - g_min, b_max - b_min)
    }

    fn range(&self) -> u8 {
        let (r, g, b) = self.ranges();
        r.max(g).max(b)
    }

    fn split(mut self) -> (ColorBox, ColorBox) {
        let (r_range, g_range, b_range) = self.ranges();

        if r_range >= g_range && r_range >= b_range {
            self.pixels.sort_unstable_by_key(|p| p.r);
        } else if g_range >= b_range {
            self.pixels.sort_unstable_by_key(|p| p.g);
        } else {
            self.pixels.sort_unstable_by_key(|p| p.b);
        }

        let mid = self.pixels.len() / 2;
        let right = self.pixels.split_off(mid);
        let left = self.pixels;

        (ColorBox::from_pixels(&left), ColorBox::from_pixels(&right))
    }

    fn average(&self) -> Rgb {
        let (mut r, mut g, mut b) = (0u64, 0u64, 0u64);
        for p in &self.pixels {
            r += p.r as u64;
            g += p.g as u64;
            b += p.b as u64;
        }
        let n = self.count as u64;
        Rgb::new((r / n) as u8, (g / n) as u8, (b / n) as u8)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_pixels(colors: &[(u8, u8, u8)], count_each: usize) -> Vec<Rgb> {
        colors
            .iter()
            .flat_map(|&(r, g, b)| std::iter::repeat_n(Rgb::new(r, g, b), count_each))
            .collect()
    }

    fn total_population(result: &[(Rgb, u32)]) -> u32 {
        result.iter().map(|(_, p)| p).sum()
    }

    #[test]
    fn empty_input() {
        assert!(extract_colors_median_cut(&[], 5).is_empty());
    }

    #[test]
    fn zero_k() {
        let pixels = vec![Rgb::new(255, 0, 0); 10];
        assert!(extract_colors_median_cut(&pixels, 0).is_empty());
    }

    #[test]
    fn single_color() {
        let pixels = make_pixels(&[(42, 42, 42)], 50);
        let result = extract_colors_median_cut(&pixels, 1);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].0, Rgb::new(42, 42, 42));
        assert_eq!(result[0].1, 50);
    }

    #[test]
    fn extracts_distinct_colors() {
        let pixels = make_pixels(&[(255, 0, 0), (0, 255, 0), (0, 0, 255)], 100);
        let result = extract_colors_median_cut(&pixels, 3);
        assert_eq!(result.len(), 3);
        for pair in result.windows(2) {
            assert!(
                pair[0].0.distance_squared(pair[1].0) > 0,
                "got duplicate colours in result",
            );
        }
    }

    #[test]
    fn population_sums_to_pixel_count() {
        let pixels = make_pixels(&[(200, 50, 50), (50, 200, 50), (50, 50, 200)], 80);
        let result = extract_colors_median_cut(&pixels, 3);
        assert_eq!(total_population(&result), 240);
    }

    #[test]
    fn k_larger_than_pixels() {
        let pixels = vec![Rgb::new(10, 20, 30), Rgb::new(40, 50, 60)];
        let result = extract_colors_median_cut(&pixels, 100);
        assert_eq!(result.len(), 2);
    }

    #[test]
    fn k_one() {
        let pixels = make_pixels(&[(100, 150, 200), (200, 100, 50)], 50);
        let result = extract_colors_median_cut(&pixels, 1);
        assert_eq!(result.len(), 1);
        assert_eq!(total_population(&result), 100);
    }
}