1use std::cell::RefCell;
6
7use dilate::DilateExpand;
8use image::Rgba;
9use palette::IntoColor;
10use palette::Lab;
11use palette::Srgb;
12use rayon::iter::ParallelIterator;
13
14pub struct BitPaletteBuilder {
20 pub(crate) shift: usize,
21}
22
23impl BitPaletteBuilder {
24 pub(crate) fn new(palette_size: usize) -> Self {
25 BitPaletteBuilder {
26 shift: Self::shift(palette_size),
27 }
28 }
29
30 pub(crate) fn shift(palette_size: usize) -> usize {
31 24 - palette_size.ilog2() as usize
32 }
33
34 pub(crate) fn index(
35 color: Srgb<u8>,
36 shift: usize,
37 ) -> usize {
38 let r = color.red.dilate_expand::<3>().value();
39 let g = color.green.dilate_expand::<3>().value();
40 let b = color.blue.dilate_expand::<3>().value();
41
42 let rgb = g << 2 | r << 1 | b;
46
47 (rgb >> shift) as usize
48 }
49
50 pub fn build_palette(
53 image: &image::RgbaImage,
54 palette_size: usize,
55 ) -> Vec<Lab> {
56 let builder = Self::new(palette_size);
57
58 thread_local! {
59 static PALETTE: RefCell<Vec<(u64, u64, u64, u64)>> = RefCell::default();
60 }
61
62 let pool = rayon::ThreadPoolBuilder::new()
63 .num_threads(rayon::current_num_threads())
64 .build()
65 .unwrap();
66
67 pool.install(|| {
68 image.par_pixels().for_each(|pixel| {
69 PALETTE.with_borrow_mut(|palette| {
70 palette.resize(palette_size, (0, 0, 0, 0));
71
72 let pixel = Srgb::<u8>::new(pixel[0], pixel[1], pixel[2]);
73 let index = Self::index(pixel, builder.shift);
74 palette[index].0 += pixel.red as u64;
75 palette[index].1 += pixel.green as u64;
76 palette[index].2 += pixel.blue as u64;
77 palette[index].3 += 1;
78 });
79 });
80 });
81
82 let per_thread_palettes = pool.broadcast(|_ctx| PALETTE.with_borrow_mut(std::mem::take));
83
84 let mut final_palette = vec![(0, 0, 0, 0); palette_size];
85 for palette in per_thread_palettes {
86 for (dest, src) in final_palette.iter_mut().zip(palette) {
87 dest.0 += src.0;
88 dest.1 += src.1;
89 dest.2 += src.2;
90 dest.3 += src.3;
91 }
92 }
93
94 final_palette
95 .into_iter()
96 .filter(|node| node.3 > 0)
97 .map(|node| {
98 let rgb = Srgb::new(
99 (node.0 / node.3) as u8,
100 (node.1 / node.3) as u8,
101 (node.2 / node.3) as u8,
102 );
103 rgb.into_format().into_color()
104 })
105 .collect::<Vec<_>>()
106 }
107
108 pub(crate) fn build_bucketer(
109 palette: &[Lab],
110 palette_size: usize,
111 ) -> BitPaletteBucketer {
112 BitPaletteBucketer::new(palette, palette_size)
113 }
114}
115
116pub struct BitPaletteBucketer {
123 lut: Vec<(usize, usize, f32, f32)>,
124 shift: usize,
125}
126
127impl BitPaletteBucketer {
128 fn new(
133 palette: &[Lab],
134 palette_size: usize,
135 ) -> Self {
136 let shift = BitPaletteBuilder::shift(palette_size);
137 let n_bits = palette_size.ilog2() as usize;
138 let masks = morton_dim_masks(n_bits);
139 let all_mask = palette_size - 1;
140
141 let mut bucket_entries: Vec<Vec<usize>> = vec![vec![]; palette_size];
142 for (pi, &lab) in palette.iter().enumerate() {
143 let rgb: Srgb = lab.into_color();
144 let rgb: Srgb<u8> = rgb.into_format();
145 let idx = BitPaletteBuilder::index(rgb, shift);
146 bucket_entries[idx].push(pi);
147 }
148
149 let lut = (0..palette_size)
150 .map(|bucket_idx| {
151 if bucket_entries[bucket_idx].is_empty() {
152 return (0, 0, 0.0, 0.0);
153 }
154
155 let nearest = bucket_entries[bucket_idx][0];
156 let second = find_second_nearest(
157 nearest,
158 bucket_idx,
159 &bucket_entries,
160 palette,
161 masks,
162 all_mask,
163 );
164 (nearest, second.0, 0.0, second.1)
165 })
166 .collect();
167
168 Self { lut, shift }
169 }
170
171 pub(crate) fn nearest(
172 &self,
173 point: &[f32; 3],
174 ) -> usize {
175 let lab = Lab::new(point[0], point[1], point[2]);
176 let rgb: Srgb = lab.into_color();
177 let rgb: Srgb<u8> = rgb.into_format();
178 self.lut[BitPaletteBuilder::index(rgb, self.shift)].0
179 }
180
181 pub(crate) fn nearest_two(
182 &self,
183 point: Rgba<u8>,
184 ) -> [(usize, f32); 2] {
185 let (n1, n2, d1, d2) =
186 self.lut[BitPaletteBuilder::index(Srgb::new(point[0], point[1], point[2]), self.shift)];
187 [(n1, d1), (n2, d2)]
188 }
189
190 pub(crate) fn nearest_rgb(
191 &self,
192 pixel: Rgba<u8>,
193 ) -> usize {
194 let index = BitPaletteBuilder::index(Srgb::new(pixel[0], pixel[1], pixel[2]), self.shift);
195 self.lut[index].0
196 }
197}
198
199fn morton_dim_masks(n: usize) -> [usize; 3] {
202 let mut masks = [0usize; 3]; for i in 0..n {
204 masks[i % 3] |= 1 << i;
205 }
206 masks
207}
208
209fn morton_inc(
211 z: usize,
212 dim_mask: usize,
213 all_mask: usize,
214) -> Option<usize> {
215 let not_dim = all_mask & !dim_mask;
216 let t = (z | not_dim) + 1;
217 if t > all_mask {
218 return None;
219 }
220 Some((t & dim_mask) | (z & not_dim))
221}
222
223fn morton_dec(
225 z: usize,
226 dim_mask: usize,
227 all_mask: usize,
228) -> Option<usize> {
229 if z & dim_mask == 0 {
230 return None;
231 }
232 let not_dim = all_mask & !dim_mask;
233 let t = (z & dim_mask).wrapping_sub(1);
234 Some((t & dim_mask) | (z & not_dim))
235}
236
237fn morton_neighbor(
240 z: usize,
241 deltas: [i32; 3],
242 masks: &[usize; 3],
243 all_mask: usize,
244) -> Option<usize> {
245 let mut result = z;
246 for (delta, &mask) in deltas.iter().zip(masks.iter()) {
247 match delta.signum() {
248 1 => result = morton_inc(result, mask, all_mask)?,
249 -1 => result = morton_dec(result, mask, all_mask)?,
250 _ => {}
251 }
252 }
253 Some(result)
254}
255
256fn find_second_nearest(
260 nearest: usize,
261 bucket_idx: usize,
262 bucket_entries: &[Vec<usize>],
263 palette: &[Lab],
264 masks: [usize; 3],
265 all_mask: usize,
266) -> (usize, f32) {
267 use palette::color_difference::EuclideanDistance;
268
269 let target = &palette[nearest];
270 let mut best = (0usize, f32::INFINITY);
271
272 for db in -1i32..=1 {
273 for dr in -1i32..=1 {
274 for dg in -1i32..=1 {
275 let nb = if db == 0 && dr == 0 && dg == 0 {
276 Some(bucket_idx)
277 } else {
278 morton_neighbor(bucket_idx, [db, dr, dg], &masks, all_mask)
279 };
280 if let Some(nb) = nb {
281 for &pi in &bucket_entries[nb] {
282 if pi == nearest {
283 continue;
284 }
285 let dist = target.distance_squared(palette[pi]);
286 if dist < best.1 {
287 best = (pi, dist);
288 }
289 }
290 }
291 }
292 }
293 }
294
295 if best.1.is_finite() {
296 return best;
297 }
298
299 for (pi, color) in palette.iter().enumerate() {
300 if pi == nearest {
301 continue;
302 }
303 let dist = target.distance_squared(*color);
304 if dist < best.1 {
305 best = (pi, dist);
306 }
307 }
308
309 best
310}