use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashSet;
use std::sync::atomic::Ordering;
use ordered_float::OrderedFloat;
use palette::IntoColor;
use palette::Lab;
use palette::Srgb;
use palette::color_difference::EuclideanDistance;
use rayon::iter::IndexedParallelIterator;
use rayon::iter::IntoParallelIterator;
use rayon::iter::IntoParallelRefIterator;
use rayon::iter::ParallelIterator;
use crate::BitPaletteBuilder;
use crate::PaletteBuilder;
use crate::kmeans::parallel_kmeans;
use crate::private;
pub struct BitMergePaletteBuilder<
const STAGE_1_PALETTE_SIZE: usize = { 1 << 18 },
const STAGE_2_PALETTE_SIZE: usize = 512,
>;
impl<const STAGE_1_PALETTE_SIZE: usize, const STAGE_2_PALETTE_SIZE: usize> private::Sealed
for BitMergePaletteBuilder<STAGE_1_PALETTE_SIZE, STAGE_2_PALETTE_SIZE>
{
}
impl<const STAGE_1_PALETTE_SIZE: usize, const STAGE_2_PALETTE_SIZE: usize> PaletteBuilder
for BitMergePaletteBuilder<STAGE_1_PALETTE_SIZE, STAGE_2_PALETTE_SIZE>
{
const NAME: &'static str = "Bit-Merge";
fn build_palette(
image: &image::RgbImage,
palette_size: usize,
) -> Vec<palette::Lab> {
let bit = BitPaletteBuilder::new(STAGE_1_PALETTE_SIZE);
image.par_pixels().for_each(|pixel| {
bit.insert(palette::Srgb::<u8>::new(pixel[0], pixel[1], pixel[2]));
});
let candidates = bit
.buckets
.into_par_iter()
.filter_map(|bucket| {
if bucket.count.load(Ordering::Relaxed) > 0 {
let lab: Lab = Srgb::new(
(bucket.color.0.load(Ordering::Relaxed)
/ bucket.count.load(Ordering::Relaxed)) as u8,
(bucket.color.1.load(Ordering::Relaxed)
/ bucket.count.load(Ordering::Relaxed)) as u8,
(bucket.color.2.load(Ordering::Relaxed)
/ bucket.count.load(Ordering::Relaxed)) as u8,
)
.into_format()
.into_color();
Some((lab, bucket.count.load(Ordering::Relaxed) as f32))
} else {
None
}
})
.collect::<Vec<_>>();
let (mut stage2_colors, mut stage2_counts) =
parallel_kmeans(&candidates, STAGE_2_PALETTE_SIZE);
agglomerative_merge::<STAGE_2_PALETTE_SIZE>(
&mut stage2_colors,
&mut stage2_counts,
palette_size,
)
}
}
pub(crate) fn agglomerative_merge<const IN_SIZE: usize>(
stage2_colors: &mut [Lab],
stage2_counts: &mut [f32],
out_size: usize,
) -> Vec<Lab> {
let mut live_stage2_colors = stage2_colors.len();
let mut bucket_generations = [0; IN_SIZE];
for (idx, count) in stage2_counts.iter().enumerate() {
if *count == 0.0 {
live_stage2_colors -= 1;
bucket_generations[idx] = -1;
}
}
let mut pqueue = BinaryHeap::new();
let entries = stage2_colors
.par_iter()
.copied()
.enumerate()
.filter(|(idx, _)| bucket_generations[*idx] >= 0)
.flat_map(|(idx, b_color)| {
let bucket_generations = &bucket_generations;
let stage2_counts = &stage2_counts;
stage2_colors
.par_iter()
.enumerate()
.skip(idx + 1)
.filter(move |(jdx, _)| bucket_generations[*jdx] >= 0)
.map(move |(jdx, b2_color)| {
let merged_var = ((stage2_counts[idx] * stage2_counts[jdx])
/ (stage2_counts[idx] + stage2_counts[jdx]))
* (b_color.distance_squared(*b2_color));
PQueueEntry {
variance: Reverse(OrderedFloat(merged_var)),
idx1: (idx, bucket_generations[idx]),
idx2: (jdx, bucket_generations[jdx]),
}
})
})
.collect::<Vec<_>>();
for entry in entries {
pqueue.push(entry);
}
while live_stage2_colors > out_size {
let Some(PQueueEntry {
idx1: (idx1, gen1),
idx2: (idx2, gen2),
..
}) = pqueue.pop()
else {
assert!(live_stage2_colors <= out_size);
break;
};
if bucket_generations[idx1] != gen1 || bucket_generations[idx2] != gen2 {
continue;
}
let l = (stage2_colors[idx1].l * stage2_counts[idx1]
+ stage2_colors[idx2].l * stage2_counts[idx2])
/ (stage2_counts[idx1] + stage2_counts[idx2]);
let a = (stage2_colors[idx1].a * stage2_counts[idx1]
+ stage2_colors[idx2].a * stage2_counts[idx2])
/ (stage2_counts[idx1] + stage2_counts[idx2]);
let b = (stage2_colors[idx1].b * stage2_counts[idx1]
+ stage2_colors[idx2].b * stage2_counts[idx2])
/ (stage2_counts[idx1] + stage2_counts[idx2]);
stage2_colors[idx1].l = l;
stage2_colors[idx1].a = a;
stage2_colors[idx1].b = b;
stage2_counts[idx1] += stage2_counts[idx2];
bucket_generations[idx1] += 1;
bucket_generations[idx2] = -1;
live_stage2_colors -= 1;
for kdx in 0..stage2_colors.len() {
if kdx == idx1 || bucket_generations[kdx] == -1 {
continue;
}
let merged_var = ((stage2_counts[idx1] * stage2_counts[kdx])
/ (stage2_counts[idx1] + stage2_counts[kdx]))
* (stage2_colors[idx1].distance_squared(stage2_colors[kdx]));
pqueue.push(PQueueEntry {
variance: Reverse(OrderedFloat(merged_var)),
idx1: (idx1, bucket_generations[idx1]),
idx2: (kdx, bucket_generations[kdx]),
});
}
}
stage2_colors
.iter()
.zip(bucket_generations)
.filter(|(_, generation)| *generation >= 0)
.map(|(lab, _)| {
[
OrderedFloat(lab.l),
OrderedFloat(lab.a),
OrderedFloat(lab.b),
]
})
.collect::<HashSet<_>>()
.into_iter()
.map(|[l, a, b]| Lab::new(*l, *a, *b))
.collect::<Vec<_>>()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct PQueueEntry {
variance: Reverse<OrderedFloat<f32>>,
idx1: (usize, i32),
idx2: (usize, i32),
}
impl Ord for PQueueEntry {
fn cmp(
&self,
other: &Self,
) -> std::cmp::Ordering {
self.variance.cmp(&other.variance)
}
}
impl PartialOrd for PQueueEntry {
fn partial_cmp(
&self,
other: &Self,
) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}