a_sixel/
wu.rs

1//! Uses Wu's quantization algorithm to build a palette from an image.
2//!
3//! This algorithm uses principal component analysis (PCA) to recursively divide
4//! the color space along the axis of greatest variance, until the
5//! desired palette size is reached.
6
7use std::{
8    cmp::Ordering,
9    collections::{
10        BinaryHeap,
11        HashSet,
12    },
13};
14
15use ndarray::Array2;
16use ordered_float::OrderedFloat;
17use palette::{
18    color_difference::EuclideanDistance,
19    Lab,
20};
21use rayon::{
22    iter::{
23        IntoParallelRefIterator,
24        ParallelIterator,
25    },
26    slice::ParallelSliceMut,
27};
28use rustyml::utility::principal_component_analysis::PCA;
29
30use crate::{
31    dither::Sierra,
32    private,
33    rgb_to_lab,
34    PaletteBuilder,
35    SixelEncoder,
36};
37
38pub type WuSixelEncoderMono<D = Sierra> = SixelEncoder<WuPaletteBuilder<2>, D>;
39pub type WuSixelEncoder4<D = Sierra> = SixelEncoder<WuPaletteBuilder<4>, D>;
40pub type WuSixelEncoder8<D = Sierra> = SixelEncoder<WuPaletteBuilder<8>, D>;
41pub type WuSixelEncoder16<D = Sierra> = SixelEncoder<WuPaletteBuilder<16>, D>;
42pub type WuSixelEncoder32<D = Sierra> = SixelEncoder<WuPaletteBuilder<32>, D>;
43pub type WuSixelEncoder64<D = Sierra> = SixelEncoder<WuPaletteBuilder<64>, D>;
44pub type WuSixelEncoder128<D = Sierra> = SixelEncoder<WuPaletteBuilder<128>, D>;
45pub type WuSixelEncoder256<D = Sierra> = SixelEncoder<WuPaletteBuilder<256>, D>;
46
47#[derive(Debug)]
48struct Hist {
49    points: Vec<Lab>,
50    mean: Lab,
51    variance: OrderedFloat<f32>,
52}
53
54impl Hist {
55    fn new(points: Vec<Lab>) -> Self {
56        let count = points.len() as f32;
57        let sum = points
58            .par_iter()
59            .copied()
60            .reduce(|| <Lab>::new(0.0, 0.0, 0.0), |acc, p| acc + p);
61        let mean = sum / count;
62
63        let variance = points
64            .par_iter()
65            .map(|p| p.distance_squared(mean))
66            .sum::<f32>()
67            / count;
68
69        Self {
70            points,
71            mean,
72            variance: OrderedFloat(variance),
73        }
74    }
75
76    fn split(&mut self) -> (Self, Self) {
77        let data = Array2::from_shape_fn((self.points.len(), 3), |(i, j)| match j {
78            0 => self.points[i].l as f64,
79            1 => self.points[i].a as f64,
80            2 => self.points[i].b as f64,
81            _ => unreachable!(),
82        });
83
84        let mut pca = PCA::new(3);
85
86        match pca.fit_transform(data.view()) {
87            Ok(projection) => {
88                let mut projections = projection
89                    .column(0)
90                    .into_iter()
91                    .zip(self.points.iter())
92                    .map(|(proj, point)| (*proj as f32, *point))
93                    .collect::<Vec<_>>();
94                projections.par_sort_by_key(|(v, _)| OrderedFloat(*v));
95
96                let left = projections[..projections.len() / 2]
97                    .iter()
98                    .copied()
99                    .map(|(_, p)| p)
100                    .collect::<Vec<_>>();
101                let right = projections[projections.len() / 2..]
102                    .iter()
103                    .copied()
104                    .map(|(_, p)| p)
105                    .collect::<Vec<_>>();
106
107                (Self::new(left), Self::new(right))
108            }
109            Err(_) => {
110                let l_var = self
111                    .points
112                    .par_iter()
113                    .map(|p| (p.l - self.mean.l).powi(2))
114                    .sum::<f32>();
115
116                let a_var = self
117                    .points
118                    .par_iter()
119                    .map(|p| (p.a - self.mean.a).powi(2))
120                    .sum::<f32>();
121
122                let b_var = self
123                    .points
124                    .par_iter()
125                    .map(|p| (p.b - self.mean.b).powi(2))
126                    .sum::<f32>();
127
128                if l_var >= a_var && l_var >= b_var {
129                    self.points.sort_by_key(|p| OrderedFloat(p.l));
130                } else if a_var >= b_var {
131                    self.points.sort_by_key(|p| OrderedFloat(p.a));
132                } else {
133                    self.points.sort_by_key(|p| OrderedFloat(p.b));
134                }
135
136                let left_points = self.points[..self.points.len() / 2].to_vec();
137                let right_points = self.points[self.points.len() / 2..].to_vec();
138
139                (Self::new(left_points), Self::new(right_points))
140            }
141        }
142    }
143}
144
145impl PartialEq for Hist {
146    fn eq(&self, other: &Self) -> bool {
147        self.variance == other.variance
148    }
149}
150
151impl Eq for Hist {}
152
153impl PartialOrd for Hist {
154    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
155        Some(self.cmp(other))
156    }
157}
158
159impl Ord for Hist {
160    fn cmp(&self, other: &Self) -> Ordering {
161        self.variance.cmp(&other.variance)
162    }
163}
164
165pub struct WuPaletteBuilder<const PALETTE_SIZE: usize>;
166
167impl<const PALETTE_SIZE: usize> private::Sealed for WuPaletteBuilder<PALETTE_SIZE> {}
168impl<const PALETTE_SIZE: usize> PaletteBuilder for WuPaletteBuilder<PALETTE_SIZE> {
169    const NAME: &'static str = "Wu";
170    const PALETTE_SIZE: usize = PALETTE_SIZE;
171
172    fn build_palette(image: &image::RgbImage) -> Vec<Lab> {
173        let lab_points: Vec<Lab> = image.pixels().copied().map(rgb_to_lab).collect();
174
175        let mut heap = BinaryHeap::new();
176        heap.push(Hist::new(lab_points));
177
178        while heap.len() < PALETTE_SIZE {
179            let Some(mut hist) = heap.pop() else {
180                break;
181            };
182
183            let (left, right) = hist.split();
184            if !left.points.is_empty() {
185                heap.push(left);
186            }
187            if !right.points.is_empty() {
188                heap.push(right);
189            }
190        }
191
192        heap.into_iter()
193            .map(|hist| {
194                [
195                    OrderedFloat(hist.mean.l),
196                    OrderedFloat(hist.mean.a),
197                    OrderedFloat(hist.mean.b),
198                ]
199            })
200            .collect::<HashSet<_>>()
201            .into_iter()
202            .map(|[l, a, b]| Lab::new(*l, *a, *b))
203            .collect()
204    }
205}