imagequant/
mediancut.rs

1use crate::hist::{HistItem, HistogramInternal};
2use crate::pal::{f_pixel, PalF, PalLen, PalPop, ARGBF};
3use crate::quant::quality_to_mse;
4use crate::PushInCapacity;
5use crate::{Error, OrdFloat};
6use rgb::prelude::*;
7use std::cmp::Reverse;
8
9struct MedianCutter<'hist> {
10    boxes: Vec<MBox<'hist>>,
11    hist_total_perceptual_weight: f64,
12    target_colors: PalLen,
13}
14
15struct MBox<'hist> {
16    /// Histogram entries that fell into this bucket
17    pub colors: &'hist mut [HistItem],
18    /// Center color selected to represent the colors
19    pub avg_color: f_pixel,
20    /// Difference from the average color, per channel, weighed using `adjusted_weight`
21    pub variance: ARGBF,
22    pub adjusted_weight_sum: f64,
23    pub total_error: Option<f64>,
24    /// max color difference between avg_color and any histogram entry
25    pub max_error: f32,
26}
27
28impl<'hist> MBox<'hist> {
29    pub fn new(hist: &'hist mut [HistItem]) -> Self {
30        let weight_sum = hist.iter().map(|item| {
31            debug_assert!(item.adjusted_weight.is_finite());
32            debug_assert!(item.adjusted_weight > 0.);
33            f64::from(item.adjusted_weight)
34        }).sum();
35        Self::new_inner(hist, weight_sum, weighed_average_color(hist))
36    }
37
38    fn from_split(hist: &'hist mut [HistItem], adjusted_weight_sum: f64) -> Self {
39        debug_assert!(!hist.is_empty());
40        let avg_color = weighed_average_color(hist);
41        Self::new_inner(hist, adjusted_weight_sum, avg_color)
42    }
43
44    fn new_inner(hist: &'hist mut [HistItem], adjusted_weight_sum: f64, avg_color: f_pixel) -> Self {
45        let (variance, max_error) = Self::box_stats(hist, avg_color);
46        Self {
47            variance,
48            max_error,
49            avg_color,
50            colors: hist,
51            adjusted_weight_sum,
52            total_error: None,
53        }
54    }
55
56    fn box_stats(hist: &[HistItem], avg_color: f_pixel) -> (ARGBF, f32) {
57        let mut variance = ARGBF::default();
58        let mut max_error = 0.;
59        for item in hist {
60            variance += (avg_color.0 - item.color.0).map(|c| c * c) * item.adjusted_weight;
61            let diff = avg_color.diff(&item.color);
62            if diff > max_error {
63                max_error = diff;
64            }
65        }
66        (variance, max_error)
67    }
68
69    pub fn compute_total_error(&mut self) -> f64 {
70        let avg = self.avg_color;
71        let e = self.colors.iter().map(move |a| f64::from(avg.diff(&a.color)) * f64::from(a.perceptual_weight)).sum::<f64>();
72        self.total_error = Some(e);
73        e
74    }
75
76    pub fn prepare_sort(&mut self) {
77        struct ChanVariance {
78            pub chan: usize,
79            pub variance: f32,
80        }
81
82        // Sort dimensions by their variance, and then sort colors first by dimension with the highest variance
83        let vars: [f32; 4] = rgb::bytemuck::cast(self.variance);
84        let mut channels = [
85            ChanVariance { chan: 0, variance: vars[0] },
86            ChanVariance { chan: 1, variance: vars[1] },
87            ChanVariance { chan: 2, variance: vars[2] },
88            ChanVariance { chan: 3, variance: vars[3] },
89        ];
90        channels.sort_unstable_by_key(|ch| Reverse(OrdFloat::new(ch.variance)));
91
92        for item in self.colors.iter_mut() {
93            let chans: [f32; 4] = rgb::bytemuck::cast(item.color.0);
94            // Only the first channel really matters. But other channels are included, because when trying median cut
95            // many times with different histogram weights, I don't want sort randomness to influence the outcome.
96            item.tmp.mc_sort_value = (u32::from((chans[channels[0].chan] * 65535.) as u16) << 16)
97                | u32::from(((chans[channels[2].chan] + chans[channels[1].chan] / 2. + chans[channels[3].chan] / 4.) * 65535.) as u16); // box will be split to make color_weight of each side even
98        }
99    }
100
101    fn median_color(&mut self) -> f_pixel {
102        let len = self.colors.len();
103        let (_, mid_item, _) = self.colors.select_nth_unstable_by_key(len / 2, |a| a.mc_sort_value());
104        mid_item.color
105    }
106
107    pub fn prepare_color_weight_total(&mut self) -> f64 {
108        let median = self.median_color();
109        self.colors.iter_mut().map(move |a| {
110            let w = (median.diff(&a.color).sqrt() * (2. + a.adjusted_weight)).sqrt();
111            debug_assert!(w.is_finite());
112            a.mc_color_weight = w;
113            f64::from(w)
114        })
115        .sum()
116    }
117
118    #[inline]
119    pub fn split(mut self) -> [Self; 2] {
120        self.prepare_sort();
121        let half_weight = self.prepare_color_weight_total() / 2.;
122        // yeah, there's some off-by-one error in there
123        let break_at = hist_item_sort_half(self.colors, half_weight).max(1);
124
125        let (left, right) = self.colors.split_at_mut(break_at);
126        let left_sum = left.iter().map(|a| f64::from(a.adjusted_weight)).sum();
127        let right_sum = self.adjusted_weight_sum - left_sum;
128
129        [MBox::from_split(left, left_sum),
130         MBox::from_split(right, right_sum)]
131    }
132}
133
134#[inline]
135fn qsort_pivot(base: &[HistItem]) -> usize {
136    let len = base.len();
137    if len < 32 {
138        return len / 2;
139    }
140    let mut pivots = [8, len / 2, len - 1];
141    // LLVM can't see it's in bounds :(
142    pivots.sort_unstable_by_key(move |&idx| unsafe {
143        debug_assert!(base.get(idx).is_some());
144        base.get_unchecked(idx)
145    }.mc_sort_value());
146    pivots[1]
147}
148
149fn qsort_partition(base: &mut [HistItem]) -> usize {
150    let mut r = base.len();
151    base.swap(qsort_pivot(base), 0);
152    let pivot_value = base[0].mc_sort_value();
153    let mut l = 1;
154    while l < r {
155        if base[l].mc_sort_value() >= pivot_value {
156            l += 1;
157        } else {
158            r -= 1;
159            while l < r && base[r].mc_sort_value() <= pivot_value {
160                r -= 1;
161            }
162            base.swap(l, r);
163        }
164    }
165    l -= 1;
166    base.swap(l, 0);
167    l
168}
169
170/// sorts the slice to make the sum of weights lower than `weight_half_sum` one side,
171/// returns index of the edge between <halfvar and >halfvar parts of the set
172#[inline(never)]
173fn hist_item_sort_half(mut base: &mut [HistItem], mut weight_half_sum: f64) -> usize {
174    let mut base_index = 0;
175    if base.is_empty() {
176        return 0;
177    }
178    loop {
179        let partition = qsort_partition(base);
180        let (left, right) = base.split_at_mut(partition + 1); // +1, because pivot stays on the left side
181        let left_sum = left.iter().map(|c| f64::from(c.mc_color_weight)).sum::<f64>();
182        if left_sum >= weight_half_sum {
183            match left.get_mut(..partition) { // trim pivot point, avoid panick branch in []
184                Some(left) if !left.is_empty() => { base = left; continue; },
185                _ => return base_index,
186            }
187        }
188        weight_half_sum -= left_sum;
189        base_index += left.len();
190        if !right.is_empty() {
191            base = right;
192        } else {
193            return base_index;
194        }
195    }
196}
197
198impl<'hist> MedianCutter<'hist> {
199    fn total_box_error_below_target(&mut self, mut target_mse: f64) -> bool {
200        target_mse *= self.hist_total_perceptual_weight;
201        let mut total_error = self.boxes.iter().filter_map(|mb| mb.total_error).sum::<f64>();
202        if total_error > target_mse {
203            return false;
204        }
205        for mb in self.boxes.iter_mut().filter(|mb| mb.total_error.is_none()) {
206            total_error += mb.compute_total_error();
207            if total_error > target_mse {
208                return false;
209            }
210        }
211        true
212    }
213
214    pub fn new(hist: &'hist mut HistogramInternal, target_colors: PalLen) -> Result<Self, Error> {
215        let hist_total_perceptual_weight = hist.total_perceptual_weight;
216
217        debug_assert!(hist.clusters[0].begin == 0);
218        debug_assert!(hist.clusters.last().unwrap().end as usize == hist.items.len());
219
220        let mut hist_items = &mut hist.items[..];
221        let mut boxes = Vec::new();
222        boxes.try_reserve(target_colors as usize)?;
223
224        let used_boxes = hist.clusters.iter().filter(|b| b.begin != b.end).count();
225        if used_boxes <= target_colors as usize / 3 {
226            // boxes are guaranteed to be sorted
227            let mut prev_end = 0;
228            for b in hist.clusters.iter().filter(|b| b.begin != b.end) {
229                let begin = b.begin as usize;
230                debug_assert_eq!(begin, prev_end);
231                let end = b.end as usize;
232                prev_end = end;
233                let (this_box, rest) = hist_items.split_at_mut(end - begin);
234                hist_items = rest;
235                boxes.push_in_cap(MBox::new(this_box));
236            }
237        } else {
238            boxes.push_in_cap(MBox::new(hist_items));
239        }
240
241        Ok(Self {
242            boxes,
243            hist_total_perceptual_weight,
244            target_colors,
245        })
246    }
247
248    fn into_palette(mut self) -> PalF {
249        let mut palette = PalF::new();
250
251        for (i, mbox) in self.boxes.iter_mut().enumerate() {
252            mbox.colors.iter_mut().for_each(move |a| a.tmp.likely_palette_index = i as _);
253
254            // store total color popularity (perceptual_weight is approximation of it)
255            let pop = mbox.colors.iter().map(|a| f64::from(a.perceptual_weight)).sum::<f64>();
256            let mut representative_color = mbox.avg_color;
257            if mbox.colors.len() > 2 {
258                representative_color = mbox.colors.iter().min_by_key(|a| OrdFloat::new(representative_color.diff(&a.color))).map(|a| a.color).unwrap_or_default();
259            }
260            palette.push(representative_color, PalPop::new(pop as f32));
261        }
262        palette
263    }
264
265    fn cut(mut self, target_mse: f64, max_mse: f64) -> PalF {
266        let max_mse = max_mse.max(quality_to_mse(20));
267
268        while self.boxes.len() < self.target_colors as usize {
269            // first splits boxes that exceed quality limit (to have colors for things like odd green pixel),
270            // later raises the limit to allow large smooth areas/gradients get colors.
271            let fraction_done = self.boxes.len() as f64 / f64::from(self.target_colors);
272            let current_max_mse = (fraction_done * 16.).mul_add(max_mse, max_mse);
273            let Some(bi) = self.take_best_splittable_box(current_max_mse) else {
274                break
275            };
276
277            self.boxes.extend(bi.split());
278
279            if self.total_box_error_below_target(target_mse) {
280                break;
281            }
282        }
283
284        self.into_palette()
285    }
286
287    fn take_best_splittable_box(&mut self, max_mse: f64) -> Option<MBox<'hist>> {
288        self.boxes.iter().enumerate()
289            .filter(|(_, mbox)| mbox.colors.len() > 1)
290            .map(move |(i, mbox)| {
291                let mut thissum = mbox.adjusted_weight_sum * mbox.variance.iter().map(|f| f as f64).sum::<f64>();
292                if f64::from(mbox.max_error) > max_mse {
293                    thissum = thissum * f64::from(mbox.max_error) / max_mse;
294                }
295                (i, thissum)
296            })
297            .max_by_key(|&(_, thissum)| OrdFloat::new64(thissum))
298            .map(|(i, _)| self.boxes.swap_remove(i))
299    }
300}
301
302#[inline(never)]
303pub(crate) fn mediancut(hist: &mut HistogramInternal, target_colors: PalLen, target_mse: f64, max_mse_per_color: f64) -> Result<PalF, Error> {
304    Ok(MedianCutter::new(hist, target_colors)?.cut(target_mse, max_mse_per_color))
305}
306
307fn weighed_average_color(hist: &[HistItem]) -> f_pixel {
308    debug_assert!(!hist.is_empty());
309    let mut t = f_pixel::default();
310    let mut sum = 0.;
311    for c in hist {
312        sum += c.adjusted_weight;
313        t.0 += c.color.0 * c.adjusted_weight;
314    }
315    if sum != 0. {
316        t.0 /= sum;
317    }
318    t
319}