1use std::{
8 cmp::Ordering,
9 collections::{
10 BinaryHeap,
11 HashSet,
12 },
13};
14
15use ndarray::Array2;
16use ordered_float::OrderedFloat;
17use palette::{
18 color_difference::EuclideanDistance,
19 Lab,
20};
21use rayon::{
22 iter::{
23 IntoParallelRefIterator,
24 ParallelIterator,
25 },
26 slice::ParallelSliceMut,
27};
28use rustyml::utility::principal_component_analysis::PCA;
29
30use crate::{
31 dither::Sierra,
32 private,
33 rgb_to_lab,
34 PaletteBuilder,
35 SixelEncoder,
36};
37
38pub type WuSixelEncoderMono<D = Sierra> = SixelEncoder<WuPaletteBuilder<2>, D>;
39pub type WuSixelEncoder4<D = Sierra> = SixelEncoder<WuPaletteBuilder<4>, D>;
40pub type WuSixelEncoder8<D = Sierra> = SixelEncoder<WuPaletteBuilder<8>, D>;
41pub type WuSixelEncoder16<D = Sierra> = SixelEncoder<WuPaletteBuilder<16>, D>;
42pub type WuSixelEncoder32<D = Sierra> = SixelEncoder<WuPaletteBuilder<32>, D>;
43pub type WuSixelEncoder64<D = Sierra> = SixelEncoder<WuPaletteBuilder<64>, D>;
44pub type WuSixelEncoder128<D = Sierra> = SixelEncoder<WuPaletteBuilder<128>, D>;
45pub type WuSixelEncoder256<D = Sierra> = SixelEncoder<WuPaletteBuilder<256>, D>;
46
47#[derive(Debug)]
48struct Hist {
49 points: Vec<Lab>,
50 mean: Lab,
51 variance: OrderedFloat<f32>,
52}
53
54impl Hist {
55 fn new(points: Vec<Lab>) -> Self {
56 let count = points.len() as f32;
57 let sum = points
58 .par_iter()
59 .copied()
60 .reduce(|| <Lab>::new(0.0, 0.0, 0.0), |acc, p| acc + p);
61 let mean = sum / count;
62
63 let variance = points
64 .par_iter()
65 .map(|p| p.distance_squared(mean))
66 .sum::<f32>()
67 / count;
68
69 Self {
70 points,
71 mean,
72 variance: OrderedFloat(variance),
73 }
74 }
75
76 fn split(&mut self) -> (Self, Self) {
77 let data = Array2::from_shape_fn((self.points.len(), 3), |(i, j)| match j {
78 0 => self.points[i].l as f64,
79 1 => self.points[i].a as f64,
80 2 => self.points[i].b as f64,
81 _ => unreachable!(),
82 });
83
84 let mut pca = PCA::new(3);
85
86 match pca.fit_transform(data.view()) {
87 Ok(projection) => {
88 let mut projections = projection
89 .column(0)
90 .into_iter()
91 .zip(self.points.iter())
92 .map(|(proj, point)| (*proj as f32, *point))
93 .collect::<Vec<_>>();
94 projections.par_sort_by_key(|(v, _)| OrderedFloat(*v));
95
96 let left = projections[..projections.len() / 2]
97 .iter()
98 .copied()
99 .map(|(_, p)| p)
100 .collect::<Vec<_>>();
101 let right = projections[projections.len() / 2..]
102 .iter()
103 .copied()
104 .map(|(_, p)| p)
105 .collect::<Vec<_>>();
106
107 (Self::new(left), Self::new(right))
108 }
109 Err(_) => {
110 let l_var = self
111 .points
112 .par_iter()
113 .map(|p| (p.l - self.mean.l).powi(2))
114 .sum::<f32>();
115
116 let a_var = self
117 .points
118 .par_iter()
119 .map(|p| (p.a - self.mean.a).powi(2))
120 .sum::<f32>();
121
122 let b_var = self
123 .points
124 .par_iter()
125 .map(|p| (p.b - self.mean.b).powi(2))
126 .sum::<f32>();
127
128 if l_var >= a_var && l_var >= b_var {
129 self.points.sort_by_key(|p| OrderedFloat(p.l));
130 } else if a_var >= b_var {
131 self.points.sort_by_key(|p| OrderedFloat(p.a));
132 } else {
133 self.points.sort_by_key(|p| OrderedFloat(p.b));
134 }
135
136 let left_points = self.points[..self.points.len() / 2].to_vec();
137 let right_points = self.points[self.points.len() / 2..].to_vec();
138
139 (Self::new(left_points), Self::new(right_points))
140 }
141 }
142 }
143}
144
145impl PartialEq for Hist {
146 fn eq(&self, other: &Self) -> bool {
147 self.variance == other.variance
148 }
149}
150
151impl Eq for Hist {}
152
153impl PartialOrd for Hist {
154 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
155 Some(self.cmp(other))
156 }
157}
158
159impl Ord for Hist {
160 fn cmp(&self, other: &Self) -> Ordering {
161 self.variance.cmp(&other.variance)
162 }
163}
164
165pub struct WuPaletteBuilder<const PALETTE_SIZE: usize>;
166
167impl<const PALETTE_SIZE: usize> private::Sealed for WuPaletteBuilder<PALETTE_SIZE> {}
168impl<const PALETTE_SIZE: usize> PaletteBuilder for WuPaletteBuilder<PALETTE_SIZE> {
169 const NAME: &'static str = "Wu";
170 const PALETTE_SIZE: usize = PALETTE_SIZE;
171
172 fn build_palette(image: &image::RgbImage) -> Vec<Lab> {
173 let lab_points: Vec<Lab> = image.pixels().copied().map(rgb_to_lab).collect();
174
175 let mut heap = BinaryHeap::new();
176 heap.push(Hist::new(lab_points));
177
178 while heap.len() < PALETTE_SIZE {
179 let Some(mut hist) = heap.pop() else {
180 break;
181 };
182
183 let (left, right) = hist.split();
184 if !left.points.is_empty() {
185 heap.push(left);
186 }
187 if !right.points.is_empty() {
188 heap.push(right);
189 }
190 }
191
192 heap.into_iter()
193 .map(|hist| {
194 [
195 OrderedFloat(hist.mean.l),
196 OrderedFloat(hist.mean.a),
197 OrderedFloat(hist.mean.b),
198 ]
199 })
200 .collect::<HashSet<_>>()
201 .into_iter()
202 .map(|[l, a, b]| Lab::new(*l, *a, *b))
203 .collect()
204 }
205}