use image::RgbaImage;
use ordered_float::OrderedFloat;
use palette::Lab;
use rayon::iter::IndexedParallelIterator;
use rayon::iter::IntoParallelRefIterator;
use rayon::iter::IntoParallelRefMutIterator;
use rayon::iter::ParallelIterator;
use rayon::slice::ParallelSliceMut;
use crate::PaletteBuilder;
use crate::private;
use crate::rgba_to_lab;
pub struct MedianCutPaletteBuilder;
impl private::Sealed for MedianCutPaletteBuilder {}
impl PaletteBuilder for MedianCutPaletteBuilder {
const NAME: &'static str = "Median-Cut";
fn build_palette(
image: &RgbaImage,
palette_size: usize,
) -> Vec<Lab> {
let pixels = image.pixels().copied().map(rgba_to_lab).collect::<Vec<_>>();
let mut buckets = Vec::with_capacity(palette_size);
buckets.push(pixels);
let mut bucket_stats = vec![None; palette_size + 1];
for _ in 0..palette_size - 1 {
let (best_bucket, max_idx, _) = buckets
.par_iter()
.zip(bucket_stats.par_iter_mut())
.enumerate()
.map(|(idx, (candidates, stats))| {
let (min, max) = if let Some((min, max)) = *stats {
(min, max)
} else {
let (min, max) = candidates.iter().copied().fold(
(
<Lab>::new(f32::MAX, f32::MAX, f32::MAX),
<Lab>::new(f32::MIN, f32::MIN, f32::MIN),
),
|(min, max), color| {
(
Lab::new(
min.l.min(color.l),
min.a.min(color.a),
min.b.min(color.b),
),
Lab::new(
max.l.max(color.l),
max.a.max(color.a),
max.b.max(color.b),
),
)
},
);
*stats = Some((min, max));
(min, max)
};
let range = [
(max.l - min.l) / (<Lab>::max_l() - <Lab>::min_l()),
(max.a - min.a) / (<Lab>::max_a() - <Lab>::min_a()),
(max.b - min.b) / (<Lab>::max_b() - <Lab>::min_b()),
];
let max_range_idx = range
.iter()
.enumerate()
.max_by_key(|(_, diff)| OrderedFloat(**diff))
.map(|(idx, _)| idx)
.unwrap();
(idx, max_range_idx, range[max_range_idx])
})
.reduce(
|| (0, 0, 0.0),
|a, b| {
if a.2 > b.2 { a } else { b }
},
);
let candidates = &mut buckets[best_bucket];
candidates.par_sort_by(|a, b| match max_idx {
0 => a.l.total_cmp(&b.l),
1 => a.a.total_cmp(&b.a),
2 => a.b.total_cmp(&b.b),
_ => unreachable!(),
});
let median_idx = candidates.len() / 2;
bucket_stats[best_bucket] = None;
let new_candidates = candidates.split_off(median_idx);
bucket_stats[buckets.len()] = None;
buckets.push(new_candidates);
}
buckets
.into_iter()
.map(|b| {
let b_len = b.len();
b.into_iter()
.fold(<Lab>::new(0.0, 0.0, 0.0), |mut acc, color| {
acc.l += color.l;
acc.a += color.a;
acc.b += color.b;
acc
})
/ b_len as f32
})
.collect()
}
}