1use std::{
23 cmp::Reverse,
24 collections::{
25 BinaryHeap,
26 HashSet,
27 },
28 sync::atomic::Ordering,
29};
30
31use ordered_float::OrderedFloat;
32use palette::{
33 color_difference::EuclideanDistance,
34 IntoColor,
35 Lab,
36 Srgb,
37};
38use rayon::iter::{
39 IndexedParallelIterator,
40 IntoParallelIterator,
41 IntoParallelRefIterator,
42 ParallelIterator,
43};
44
45use crate::{
46 dither::Sierra,
47 kmeans::parallel_kmeans,
48 private,
49 BitPaletteBuilder,
50 PaletteBuilder,
51 SixelEncoder,
52};
53
54pub type BitMergeSixelEncoderMono<D = Sierra, const LARGE: usize = { 1 << 18 }> =
55 SixelEncoder<BitMergePaletteBuilder<2, LARGE>, D>;
56pub type BitMergeSixelEncoder4<D = Sierra, const LARGE: usize = { 1 << 18 }> =
57 SixelEncoder<BitMergePaletteBuilder<4, LARGE>, D>;
58pub type BitMergeSixelEncoder8<D = Sierra, const LARGE: usize = { 1 << 18 }> =
59 SixelEncoder<BitMergePaletteBuilder<8, LARGE>, D>;
60pub type BitMergeSixelEncoder16<D = Sierra, const LARGE: usize = { 1 << 18 }> =
61 SixelEncoder<BitMergePaletteBuilder<16, LARGE>, D>;
62pub type BitMergeSixelEncoder32<D = Sierra, const LARGE: usize = { 1 << 18 }> =
63 SixelEncoder<BitMergePaletteBuilder<32, LARGE>, D>;
64pub type BitMergeSixelEncoder64<D = Sierra, const LARGE: usize = { 1 << 18 }> =
65 SixelEncoder<BitMergePaletteBuilder<64, LARGE>, D>;
66pub type BitMergeSixelEncoder128<D = Sierra, const LARGE: usize = { 1 << 18 }> =
67 SixelEncoder<BitMergePaletteBuilder<128, LARGE>, D>;
68pub type BitMergeSixelEncoder256<D = Sierra, const LARGE: usize = { 1 << 18 }> =
69 SixelEncoder<BitMergePaletteBuilder<256, LARGE>, D>;
70
71pub struct BitMergePaletteBuilder<
72 const TARGET_PALETTE_SIZE: usize,
73 const STAGE_1_PALETTE_SIZE: usize,
74 const STAGE_2_PALETTE_SIZE: usize = 512,
75>;
76
77impl<
78 const TARGET_PALETTE_SIZE: usize,
79 const STAGE_1_PALETTE_SIZE: usize,
80 const STAGE_2_PALETTE_SIZE: usize,
81 > private::Sealed
82 for BitMergePaletteBuilder<TARGET_PALETTE_SIZE, STAGE_1_PALETTE_SIZE, STAGE_2_PALETTE_SIZE>
83{
84}
85
86impl<
87 const TARGET_PALETTE_SIZE: usize,
88 const STAGE_1_PALETTE_SIZE: usize,
89 const STAGE_2_PALETTE_SIZE: usize,
90 > PaletteBuilder
91 for BitMergePaletteBuilder<TARGET_PALETTE_SIZE, STAGE_1_PALETTE_SIZE, STAGE_2_PALETTE_SIZE>
92{
93 const NAME: &'static str = "Bit-Merge";
94 const PALETTE_SIZE: usize = TARGET_PALETTE_SIZE;
95
96 fn build_palette(image: &image::RgbImage) -> Vec<palette::Lab> {
97 let bit = BitPaletteBuilder::<STAGE_1_PALETTE_SIZE>::new();
98 image.par_pixels().for_each(|pixel| {
99 bit.insert(palette::Srgb::<u8>::new(pixel[0], pixel[1], pixel[2]));
100 });
101
102 let candidates = bit
103 .buckets
104 .into_par_iter()
105 .filter_map(|bucket| {
106 if bucket.count.load(Ordering::Relaxed) > 0 {
107 let lab: Lab = Srgb::new(
108 (bucket.color.0.load(Ordering::Relaxed)
109 / bucket.count.load(Ordering::Relaxed)) as u8,
110 (bucket.color.1.load(Ordering::Relaxed)
111 / bucket.count.load(Ordering::Relaxed)) as u8,
112 (bucket.color.2.load(Ordering::Relaxed)
113 / bucket.count.load(Ordering::Relaxed)) as u8,
114 )
115 .into_format()
116 .into_color();
117 Some((lab, bucket.count.load(Ordering::Relaxed) as f32))
118 } else {
119 None
120 }
121 })
122 .collect::<Vec<_>>();
123
124 let (mut stage2_colors, mut stage2_counts) =
125 parallel_kmeans::<STAGE_2_PALETTE_SIZE>(&candidates);
126
127 agglomerative_merge::<STAGE_2_PALETTE_SIZE, TARGET_PALETTE_SIZE>(
128 &mut stage2_colors,
129 &mut stage2_counts,
130 )
131 }
132}
133
134pub(crate) fn agglomerative_merge<const IN_SIZE: usize, const OUT_SIZE: usize>(
135 stage2_colors: &mut [Lab],
136 stage2_counts: &mut [f32],
137) -> Vec<Lab> {
138 let mut live_stage2_colors = stage2_colors.len();
139 let mut bucket_generations = [0; IN_SIZE];
140 for (idx, count) in stage2_counts.iter().enumerate() {
141 if *count == 0.0 {
142 live_stage2_colors -= 1;
143 bucket_generations[idx] = -1;
144 }
145 }
146 let mut pqueue = BinaryHeap::new();
147
148 let entries = stage2_colors
149 .par_iter()
150 .copied()
151 .enumerate()
152 .filter(|(idx, _)| bucket_generations[*idx] >= 0)
153 .flat_map(|(idx, b_color)| {
154 let bucket_generations = &bucket_generations;
155 let stage2_counts = &stage2_counts;
156 stage2_colors
157 .par_iter()
158 .enumerate()
159 .skip(idx + 1)
160 .filter(move |(jdx, _)| bucket_generations[*jdx] >= 0)
161 .map(move |(jdx, b2_color)| {
162 let merged_var = ((stage2_counts[idx] * stage2_counts[jdx])
163 / (stage2_counts[idx] + stage2_counts[jdx]))
164 * (b_color.distance_squared(*b2_color));
165
166 PQueueEntry {
167 variance: Reverse(OrderedFloat(merged_var)),
168 idx1: (idx, bucket_generations[idx]),
169 idx2: (jdx, bucket_generations[jdx]),
170 }
171 })
172 })
173 .collect::<Vec<_>>();
174
175 for entry in entries {
176 pqueue.push(entry);
177 }
178
179 while live_stage2_colors > OUT_SIZE {
180 let Some(PQueueEntry {
181 idx1: (idx1, gen1),
182 idx2: (idx2, gen2),
183 ..
184 }) = pqueue.pop()
185 else {
186 assert!(live_stage2_colors <= OUT_SIZE);
187 break;
188 };
189 if bucket_generations[idx1] != gen1 || bucket_generations[idx2] != gen2 {
190 continue;
191 }
192
193 let l = (stage2_colors[idx1].l * stage2_counts[idx1]
194 + stage2_colors[idx2].l * stage2_counts[idx2])
195 / (stage2_counts[idx1] + stage2_counts[idx2]);
196 let a = (stage2_colors[idx1].a * stage2_counts[idx1]
197 + stage2_colors[idx2].a * stage2_counts[idx2])
198 / (stage2_counts[idx1] + stage2_counts[idx2]);
199 let b = (stage2_colors[idx1].b * stage2_counts[idx1]
200 + stage2_colors[idx2].b * stage2_counts[idx2])
201 / (stage2_counts[idx1] + stage2_counts[idx2]);
202
203 stage2_colors[idx1].l = l;
204 stage2_colors[idx1].a = a;
205 stage2_colors[idx1].b = b;
206
207 stage2_counts[idx1] += stage2_counts[idx2];
208
209 bucket_generations[idx1] += 1;
210 bucket_generations[idx2] = -1;
211 live_stage2_colors -= 1;
212
213 for kdx in 0..stage2_colors.len() {
214 if kdx == idx1 || bucket_generations[kdx] == -1 {
215 continue;
216 }
217
218 let merged_var = ((stage2_counts[idx1] * stage2_counts[kdx])
219 / (stage2_counts[idx1] + stage2_counts[kdx]))
220 * (stage2_colors[idx1].distance_squared(stage2_colors[kdx]));
221
222 pqueue.push(PQueueEntry {
223 variance: Reverse(OrderedFloat(merged_var)),
224 idx1: (idx1, bucket_generations[idx1]),
225 idx2: (kdx, bucket_generations[kdx]),
226 });
227 }
228 }
229
230 stage2_colors
231 .iter()
232 .zip(bucket_generations)
233 .filter(|(_, generation)| *generation >= 0)
234 .map(|(lab, _)| {
235 [
236 OrderedFloat(lab.l),
237 OrderedFloat(lab.a),
238 OrderedFloat(lab.b),
239 ]
240 })
241 .collect::<HashSet<_>>()
242 .into_iter()
243 .map(|[l, a, b]| Lab::new(*l, *a, *b))
244 .collect::<Vec<_>>()
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248struct PQueueEntry {
249 variance: Reverse<OrderedFloat<f32>>,
250 idx1: (usize, i32),
251 idx2: (usize, i32),
252}
253
254impl Ord for PQueueEntry {
255 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
256 self.variance.cmp(&other.variance)
257 }
258}
259
260impl PartialOrd for PQueueEntry {
261 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
262 Some(self.cmp(other))
263 }
264}