a_sixel/
bitmerge.rs

1//! Uses `BitSixelEncoder` with k-means and agglomerative merging to build a
2//! palette.
3//!
4//! This encoder offers the best tradeoffs between speed and quality. You can
5//! customize the parameters to produce speed nearly as good as
6//! `BitSixelEncoder`, while producing superior results, or you can produce
7//! results as good as or better than `KMeansSixelEncoder`, while being much
8//! faster.
9//!
10//! The default parameters are tuned to produce results similar to k-means,
11//! while being ~5x faster.
12//!
13//! # Scaling:
14//! - `STAGE_1_PALETTE_SIZE`: The target size of the palette as a result of the
15//!   first bit-bucketing pass. These buckets will then be passed into k-means.
16//!   Time-taken scales somewhat linearly with this value.
17//! - `STAGE_2_PALETTE_SIZE`: The target size of the palette as a result of the
18//!   k-means clustering. This will then go through variance-minimizing
19//!   agglomerative merging to produce the final palette. Time-taken scales
20//!   **quadratically** with this value.
21
22use std::{
23    cmp::Reverse,
24    collections::{
25        BinaryHeap,
26        HashSet,
27    },
28    sync::atomic::Ordering,
29};
30
31use ordered_float::OrderedFloat;
32use palette::{
33    color_difference::EuclideanDistance,
34    IntoColor,
35    Lab,
36    Srgb,
37};
38use rayon::iter::{
39    IndexedParallelIterator,
40    IntoParallelIterator,
41    IntoParallelRefIterator,
42    ParallelIterator,
43};
44
45use crate::{
46    dither::Sierra,
47    kmeans::parallel_kmeans,
48    private,
49    BitPaletteBuilder,
50    PaletteBuilder,
51    SixelEncoder,
52};
53
54pub type BitMergeSixelEncoderMono<D = Sierra, const LARGE: usize = { 1 << 18 }> =
55    SixelEncoder<BitMergePaletteBuilder<2, LARGE>, D>;
56pub type BitMergeSixelEncoder4<D = Sierra, const LARGE: usize = { 1 << 18 }> =
57    SixelEncoder<BitMergePaletteBuilder<4, LARGE>, D>;
58pub type BitMergeSixelEncoder8<D = Sierra, const LARGE: usize = { 1 << 18 }> =
59    SixelEncoder<BitMergePaletteBuilder<8, LARGE>, D>;
60pub type BitMergeSixelEncoder16<D = Sierra, const LARGE: usize = { 1 << 18 }> =
61    SixelEncoder<BitMergePaletteBuilder<16, LARGE>, D>;
62pub type BitMergeSixelEncoder32<D = Sierra, const LARGE: usize = { 1 << 18 }> =
63    SixelEncoder<BitMergePaletteBuilder<32, LARGE>, D>;
64pub type BitMergeSixelEncoder64<D = Sierra, const LARGE: usize = { 1 << 18 }> =
65    SixelEncoder<BitMergePaletteBuilder<64, LARGE>, D>;
66pub type BitMergeSixelEncoder128<D = Sierra, const LARGE: usize = { 1 << 18 }> =
67    SixelEncoder<BitMergePaletteBuilder<128, LARGE>, D>;
68pub type BitMergeSixelEncoder256<D = Sierra, const LARGE: usize = { 1 << 18 }> =
69    SixelEncoder<BitMergePaletteBuilder<256, LARGE>, D>;
70
71pub struct BitMergePaletteBuilder<
72    const TARGET_PALETTE_SIZE: usize,
73    const STAGE_1_PALETTE_SIZE: usize,
74    const STAGE_2_PALETTE_SIZE: usize = 512,
75>;
76
77impl<
78        const TARGET_PALETTE_SIZE: usize,
79        const STAGE_1_PALETTE_SIZE: usize,
80        const STAGE_2_PALETTE_SIZE: usize,
81    > private::Sealed
82    for BitMergePaletteBuilder<TARGET_PALETTE_SIZE, STAGE_1_PALETTE_SIZE, STAGE_2_PALETTE_SIZE>
83{
84}
85
86impl<
87        const TARGET_PALETTE_SIZE: usize,
88        const STAGE_1_PALETTE_SIZE: usize,
89        const STAGE_2_PALETTE_SIZE: usize,
90    > PaletteBuilder
91    for BitMergePaletteBuilder<TARGET_PALETTE_SIZE, STAGE_1_PALETTE_SIZE, STAGE_2_PALETTE_SIZE>
92{
93    const NAME: &'static str = "Bit-Merge";
94    const PALETTE_SIZE: usize = TARGET_PALETTE_SIZE;
95
96    fn build_palette(image: &image::RgbImage) -> Vec<palette::Lab> {
97        let bit = BitPaletteBuilder::<STAGE_1_PALETTE_SIZE>::new();
98        image.par_pixels().for_each(|pixel| {
99            bit.insert(palette::Srgb::<u8>::new(pixel[0], pixel[1], pixel[2]));
100        });
101
102        let candidates = bit
103            .buckets
104            .into_par_iter()
105            .filter_map(|bucket| {
106                if bucket.count.load(Ordering::Relaxed) > 0 {
107                    let lab: Lab = Srgb::new(
108                        (bucket.color.0.load(Ordering::Relaxed)
109                            / bucket.count.load(Ordering::Relaxed)) as u8,
110                        (bucket.color.1.load(Ordering::Relaxed)
111                            / bucket.count.load(Ordering::Relaxed)) as u8,
112                        (bucket.color.2.load(Ordering::Relaxed)
113                            / bucket.count.load(Ordering::Relaxed)) as u8,
114                    )
115                    .into_format()
116                    .into_color();
117                    Some((lab, bucket.count.load(Ordering::Relaxed) as f32))
118                } else {
119                    None
120                }
121            })
122            .collect::<Vec<_>>();
123
124        let (mut stage2_colors, mut stage2_counts) =
125            parallel_kmeans::<STAGE_2_PALETTE_SIZE>(&candidates);
126
127        agglomerative_merge::<STAGE_2_PALETTE_SIZE, TARGET_PALETTE_SIZE>(
128            &mut stage2_colors,
129            &mut stage2_counts,
130        )
131    }
132}
133
134pub(crate) fn agglomerative_merge<const IN_SIZE: usize, const OUT_SIZE: usize>(
135    stage2_colors: &mut [Lab],
136    stage2_counts: &mut [f32],
137) -> Vec<Lab> {
138    let mut live_stage2_colors = stage2_colors.len();
139    let mut bucket_generations = [0; IN_SIZE];
140    for (idx, count) in stage2_counts.iter().enumerate() {
141        if *count == 0.0 {
142            live_stage2_colors -= 1;
143            bucket_generations[idx] = -1;
144        }
145    }
146    let mut pqueue = BinaryHeap::new();
147
148    let entries = stage2_colors
149        .par_iter()
150        .copied()
151        .enumerate()
152        .filter(|(idx, _)| bucket_generations[*idx] >= 0)
153        .flat_map(|(idx, b_color)| {
154            let bucket_generations = &bucket_generations;
155            let stage2_counts = &stage2_counts;
156            stage2_colors
157                .par_iter()
158                .enumerate()
159                .skip(idx + 1)
160                .filter(move |(jdx, _)| bucket_generations[*jdx] >= 0)
161                .map(move |(jdx, b2_color)| {
162                    let merged_var = ((stage2_counts[idx] * stage2_counts[jdx])
163                        / (stage2_counts[idx] + stage2_counts[jdx]))
164                        * (b_color.distance_squared(*b2_color));
165
166                    PQueueEntry {
167                        variance: Reverse(OrderedFloat(merged_var)),
168                        idx1: (idx, bucket_generations[idx]),
169                        idx2: (jdx, bucket_generations[jdx]),
170                    }
171                })
172        })
173        .collect::<Vec<_>>();
174
175    for entry in entries {
176        pqueue.push(entry);
177    }
178
179    while live_stage2_colors > OUT_SIZE {
180        let Some(PQueueEntry {
181            idx1: (idx1, gen1),
182            idx2: (idx2, gen2),
183            ..
184        }) = pqueue.pop()
185        else {
186            assert!(live_stage2_colors <= OUT_SIZE);
187            break;
188        };
189        if bucket_generations[idx1] != gen1 || bucket_generations[idx2] != gen2 {
190            continue;
191        }
192
193        let l = (stage2_colors[idx1].l * stage2_counts[idx1]
194            + stage2_colors[idx2].l * stage2_counts[idx2])
195            / (stage2_counts[idx1] + stage2_counts[idx2]);
196        let a = (stage2_colors[idx1].a * stage2_counts[idx1]
197            + stage2_colors[idx2].a * stage2_counts[idx2])
198            / (stage2_counts[idx1] + stage2_counts[idx2]);
199        let b = (stage2_colors[idx1].b * stage2_counts[idx1]
200            + stage2_colors[idx2].b * stage2_counts[idx2])
201            / (stage2_counts[idx1] + stage2_counts[idx2]);
202
203        stage2_colors[idx1].l = l;
204        stage2_colors[idx1].a = a;
205        stage2_colors[idx1].b = b;
206
207        stage2_counts[idx1] += stage2_counts[idx2];
208
209        bucket_generations[idx1] += 1;
210        bucket_generations[idx2] = -1;
211        live_stage2_colors -= 1;
212
213        for kdx in 0..stage2_colors.len() {
214            if kdx == idx1 || bucket_generations[kdx] == -1 {
215                continue;
216            }
217
218            let merged_var = ((stage2_counts[idx1] * stage2_counts[kdx])
219                / (stage2_counts[idx1] + stage2_counts[kdx]))
220                * (stage2_colors[idx1].distance_squared(stage2_colors[kdx]));
221
222            pqueue.push(PQueueEntry {
223                variance: Reverse(OrderedFloat(merged_var)),
224                idx1: (idx1, bucket_generations[idx1]),
225                idx2: (kdx, bucket_generations[kdx]),
226            });
227        }
228    }
229
230    stage2_colors
231        .iter()
232        .zip(bucket_generations)
233        .filter(|(_, generation)| *generation >= 0)
234        .map(|(lab, _)| {
235            [
236                OrderedFloat(lab.l),
237                OrderedFloat(lab.a),
238                OrderedFloat(lab.b),
239            ]
240        })
241        .collect::<HashSet<_>>()
242        .into_iter()
243        .map(|[l, a, b]| Lab::new(*l, *a, *b))
244        .collect::<Vec<_>>()
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248struct PQueueEntry {
249    variance: Reverse<OrderedFloat<f32>>,
250    idx1: (usize, i32),
251    idx2: (usize, i32),
252}
253
254impl Ord for PQueueEntry {
255    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
256        self.variance.cmp(&other.variance)
257    }
258}
259
260impl PartialOrd for PQueueEntry {
261    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
262        Some(self.cmp(other))
263    }
264}