Skip to main content

jxl_encoder/entropy_coding/
histogram.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Histogram data structure with entropy calculations.
6//!
7//! Ported from libjxl `lib/jxl/enc_ans_params.h` and `lib/jxl/enc_cluster.cc`.
8
9use core::cell::Cell;
10
11/// Alignment for SIMD-friendly histogram operations.
12/// Matches libjxl's `Histogram::kRounding`.
13pub const HISTOGRAM_ROUNDING: usize = 8;
14
15/// Minimum distance threshold for creating distinct clusters.
16/// Below this threshold, histograms are considered similar enough to merge.
17pub const MIN_DISTANCE_FOR_DISTINCT: f32 = 48.0;
18
19/// A histogram counting symbol occurrences.
20///
21/// This is the encoder-side histogram structure, corresponding to libjxl's
22/// `Histogram` class in `enc_ans_params.h`.
23#[derive(Clone, Debug)]
24pub struct Histogram {
25    /// Symbol counts (aligned to HISTOGRAM_ROUNDING).
26    /// Uses i32 to match libjxl's `ANSHistBin` type.
27    pub counts: Vec<i32>,
28    /// Sum of all counts.
29    pub total_count: usize,
30    /// Cached entropy value.
31    /// WARNING: Not automatically kept up-to-date - call `shannon_entropy()` to refresh.
32    entropy: Cell<f32>,
33}
34
35impl Default for Histogram {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl Histogram {
42    /// Creates an empty histogram.
43    pub fn new() -> Self {
44        Self {
45            counts: Vec::new(),
46            total_count: 0,
47            entropy: Cell::new(0.0),
48        }
49    }
50
51    /// Creates a histogram with pre-allocated capacity for `length` symbols.
52    /// The capacity is rounded up to HISTOGRAM_ROUNDING.
53    pub fn with_capacity(length: usize) -> Self {
54        let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
55        Self {
56            counts: vec![0; rounded_len],
57            total_count: 0,
58            entropy: Cell::new(0.0),
59        }
60    }
61
62    /// Creates a histogram from a slice of counts.
63    pub fn from_counts(counts: &[i32]) -> Self {
64        let total: i32 = counts.iter().sum();
65        let rounded_len = div_ceil(counts.len(), HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
66        let mut result_counts = vec![0i32; rounded_len];
67        result_counts[..counts.len()].copy_from_slice(counts);
68
69        Self {
70            counts: result_counts,
71            total_count: total as usize,
72            entropy: Cell::new(0.0),
73        }
74    }
75
76    /// Creates a flat (uniform) histogram.
77    pub fn flat(length: usize, total_count: usize) -> Self {
78        let base = (total_count / length) as i32;
79        let remainder = total_count % length;
80
81        let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
82        let mut counts = vec![0i32; rounded_len];
83
84        for (i, count) in counts.iter_mut().enumerate().take(length) {
85            *count = base + if i < remainder { 1 } else { 0 };
86        }
87
88        Self {
89            counts,
90            total_count,
91            entropy: Cell::new(0.0),
92        }
93    }
94
95    /// Clears all counts.
96    pub fn clear(&mut self) {
97        self.counts.clear();
98        self.total_count = 0;
99        self.entropy.set(0.0);
100    }
101
102    /// Add one occurrence of a symbol.
103    pub fn add(&mut self, symbol: usize) {
104        self.ensure_capacity(symbol + 1);
105        self.counts[symbol] += 1;
106        self.total_count += 1;
107    }
108
109    /// Ensures the histogram can hold at least `length` symbols.
110    pub fn ensure_capacity(&mut self, length: usize) {
111        let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
112        if self.counts.len() < rounded_len {
113            self.counts.resize(rounded_len, 0);
114        }
115    }
116
117    /// Fast add (caller must ensure capacity first).
118    #[inline]
119    pub fn fast_add(&mut self, symbol: usize) {
120        debug_assert!(symbol < self.counts.len());
121        self.counts[symbol] += 1;
122    }
123
124    /// Add another histogram's counts to this one.
125    pub fn add_histogram(&mut self, other: &Histogram) {
126        if other.counts.len() > self.counts.len() {
127            self.counts.resize(other.counts.len(), 0);
128        }
129        for (i, &count) in other.counts.iter().enumerate() {
130            self.counts[i] += count;
131        }
132        self.total_count += other.total_count;
133    }
134
135    /// Trim trailing zeros and update total_count.
136    /// Should be called after a sequence of `fast_add` calls.
137    pub fn condition(&mut self) {
138        // Find the last non-zero position
139        let mut last_nonzero: i32 = -1;
140        let mut total: i64 = 0;
141
142        for (i, &count) in self.counts.iter().enumerate() {
143            total += count as i64;
144            if count != 0 {
145                last_nonzero = i as i32;
146            }
147        }
148
149        // Resize to rounded length past last non-zero
150        let new_len = if last_nonzero >= 0 {
151            div_ceil((last_nonzero + 1) as usize, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING
152        } else {
153            0
154        };
155        self.counts.resize(new_len, 0);
156        self.total_count = total as usize;
157    }
158
159    /// Compute Shannon entropy: -sum(count * log2(count / total)).
160    /// Result is in bits (not nats).
161    ///
162    /// Formula: sum of -(count/total) * log2(count/total) * total
163    ///        = sum of -count * log2(count/total)
164    ///        = sum of -count * (log2(count) - log2(total))
165    ///        = sum of -count * log2(count) + count * log2(total)
166    ///        = -sum(count * log2(count)) + total * log2(total)
167    ///
168    /// libjxl uses: -count * log2(count / total), excluding when count == total.
169    pub fn shannon_entropy(&self) -> f32 {
170        if self.total_count == 0 {
171            self.entropy.set(0.0);
172            return 0.0;
173        }
174
175        let entropy = jxl_simd::shannon_entropy_bits(&self.counts, self.total_count);
176        self.entropy.set(entropy);
177        entropy
178    }
179
180    /// Get the cached entropy value.
181    /// Call `shannon_entropy()` first to ensure it's up-to-date.
182    pub fn cached_entropy(&self) -> f32 {
183        self.entropy.get()
184    }
185
186    /// Set the cached entropy value (used when loading from test data).
187    pub fn set_cached_entropy(&self, entropy: f32) {
188        self.entropy.set(entropy);
189    }
190
191    /// Alphabet size (highest non-zero symbol + 1).
192    pub fn alphabet_size(&self) -> usize {
193        for i in (0..self.counts.len()).rev() {
194            if self.counts[i] > 0 {
195                return i + 1;
196            }
197        }
198        0
199    }
200
201    /// Returns the index of the maximum symbol with non-zero count.
202    pub fn max_symbol(&self) -> usize {
203        if self.total_count == 0 {
204            return 0;
205        }
206        for i in (1..self.counts.len()).rev() {
207            if self.counts[i] > 0 {
208                return i;
209            }
210        }
211        0
212    }
213
214    /// Check if histogram is empty (all zeros).
215    pub fn is_empty(&self) -> bool {
216        self.total_count == 0
217    }
218
219    /// Copy contents from another histogram, reusing this histogram's allocation.
220    ///
221    /// Unlike `clone()`, this avoids allocating a new `Vec` when `self` already
222    /// has sufficient capacity.
223    pub fn copy_from(&mut self, source: &Histogram) {
224        let src_len = source.counts.len();
225        if self.counts.len() < src_len {
226            self.counts.resize(src_len, 0);
227        }
228        self.counts[..src_len].copy_from_slice(&source.counts[..src_len]);
229        if self.counts.len() > src_len {
230            self.counts[src_len..].fill(0);
231        }
232        self.total_count = source.total_count;
233        self.entropy.set(source.cached_entropy());
234    }
235}
236
237/// Scratch buffer for `histogram_distance` to avoid per-call heap allocation.
238///
239/// Reuse across multiple calls in hot clustering loops.
240pub struct DistanceScratch {
241    combined_counts: Vec<i32>,
242}
243
244impl Default for DistanceScratch {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250impl DistanceScratch {
251    /// Create a new scratch buffer.
252    pub fn new() -> Self {
253        Self {
254            combined_counts: Vec::new(),
255        }
256    }
257
258    /// Ensure the scratch buffer has at least `len` elements.
259    /// Does NOT zero — caller is responsible for writing all used positions.
260    #[inline]
261    fn ensure_capacity(&mut self, len: usize) {
262        if self.combined_counts.len() < len {
263            self.combined_counts.resize(len, 0);
264        }
265    }
266}
267
268/// Distance between two histograms (for clustering).
269///
270/// This measures how many extra bits are needed to encode the combined
271/// distribution vs encoding them separately. Lower = more similar.
272///
273/// Formula: entropy(combined) - entropy(a) - entropy(b)
274///
275/// IMPORTANT: Both histograms must have their entropy pre-computed
276/// (call `shannon_entropy()` first).
277pub fn histogram_distance(a: &Histogram, b: &Histogram) -> f32 {
278    let mut scratch = DistanceScratch::new();
279    histogram_distance_reuse(a, b, &mut scratch)
280}
281
282/// Like [`histogram_distance`] but reuses a scratch buffer to avoid allocation.
283pub fn histogram_distance_reuse(
284    a: &Histogram,
285    b: &Histogram,
286    scratch: &mut DistanceScratch,
287) -> f32 {
288    if a.total_count == 0 || b.total_count == 0 {
289        return 0.0;
290    }
291
292    let combined_total = a.total_count + b.total_count;
293    let a_len = a.counts.len();
294    let b_len = b.counts.len();
295    let max_len = a_len.max(b_len);
296
297    // Build combined counts (HISTOGRAM_ROUNDING-aligned for SIMD)
298    let aligned_len = div_ceil(max_len, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
299    scratch.ensure_capacity(aligned_len);
300    let combined_counts = &mut scratch.combined_counts[..aligned_len];
301
302    // Add overlapping region using zip (no per-element bounds checks)
303    let min_len = a_len.min(b_len);
304    for ((slot, &ac), &bc) in combined_counts[..min_len]
305        .iter_mut()
306        .zip(&a.counts[..min_len])
307        .zip(&b.counts[..min_len])
308    {
309        *slot = ac + bc;
310    }
311    // Copy non-overlapping tail from whichever histogram is longer
312    if a_len > min_len {
313        combined_counts[min_len..a_len].copy_from_slice(&a.counts[min_len..a_len]);
314    } else if b_len > min_len {
315        combined_counts[min_len..b_len].copy_from_slice(&b.counts[min_len..b_len]);
316    }
317    // Zero only the SIMD padding tail (positions max_len..aligned_len)
318    if max_len < aligned_len {
319        combined_counts[max_len..aligned_len].fill(0);
320    }
321
322    let combined_entropy = jxl_simd::shannon_entropy_bits(combined_counts, combined_total);
323
324    // Distance = combined_entropy - a.entropy - b.entropy
325    combined_entropy - a.cached_entropy() - b.cached_entropy()
326}
327
328/// KL divergence: cost of encoding `actual` using `coding` histogram.
329///
330/// Returns the extra bits needed to encode `actual`'s symbols using
331/// `coding`'s probability distribution, compared to using `actual`'s
332/// own distribution.
333///
334/// Returns infinity if `actual` has symbols not present in `coding`.
335///
336/// IMPORTANT: Both histograms must have their entropy pre-computed.
337pub fn histogram_kl_divergence(actual: &Histogram, coding: &Histogram) -> f32 {
338    if actual.total_count == 0 {
339        return 0.0;
340    }
341    if coding.total_count == 0 {
342        return f32::INFINITY;
343    }
344
345    let coding_inv = 1.0 / coding.total_count as f32;
346    let mut cost = 0.0f32;
347
348    for (i, &count) in actual.counts.iter().enumerate() {
349        if count > 0 {
350            let coding_count = coding.counts.get(i).copied().unwrap_or(0);
351            if coding_count == 0 {
352                // Symbol in actual but not in coding -> infinite cost
353                return f32::INFINITY;
354            }
355            let coding_prob = coding_count as f32 * coding_inv;
356            // Cost: -count * log2(coding_prob)
357            cost -= count as f32 * jxl_simd::fast_log2f(coding_prob);
358        }
359    }
360
361    // KL divergence = cost - entropy(actual)
362    cost - actual.cached_entropy()
363}
364
365/// Ceiling division.
366#[inline]
367fn div_ceil(a: usize, b: usize) -> usize {
368    a.div_ceil(b)
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_histogram_new() {
377        let h = Histogram::new();
378        assert!(h.is_empty());
379        assert_eq!(h.total_count, 0);
380        assert_eq!(h.alphabet_size(), 0);
381    }
382
383    #[test]
384    fn test_histogram_from_counts() {
385        let h = Histogram::from_counts(&[10, 20, 30]);
386        assert_eq!(h.total_count, 60);
387        assert_eq!(h.alphabet_size(), 3);
388        assert!(!h.is_empty());
389    }
390
391    #[test]
392    fn test_histogram_add() {
393        let mut h = Histogram::new();
394        h.add(0);
395        h.add(0);
396        h.add(5);
397
398        assert_eq!(h.total_count, 3);
399        assert_eq!(h.counts[0], 2);
400        assert_eq!(h.counts[5], 1);
401        assert_eq!(h.alphabet_size(), 6);
402    }
403
404    #[test]
405    fn test_histogram_flat() {
406        let h = Histogram::flat(4, 100);
407        assert_eq!(h.total_count, 100);
408        // 100 / 4 = 25 each
409        assert_eq!(h.counts[0], 25);
410        assert_eq!(h.counts[1], 25);
411        assert_eq!(h.counts[2], 25);
412        assert_eq!(h.counts[3], 25);
413    }
414
415    #[test]
416    fn test_histogram_flat_remainder() {
417        let h = Histogram::flat(4, 10);
418        assert_eq!(h.total_count, 10);
419        // 10 / 4 = 2 remainder 2, so first two get 3, last two get 2
420        assert_eq!(h.counts[0], 3);
421        assert_eq!(h.counts[1], 3);
422        assert_eq!(h.counts[2], 2);
423        assert_eq!(h.counts[3], 2);
424    }
425
426    #[test]
427    fn test_histogram_condition() {
428        let mut h = Histogram::with_capacity(100);
429        h.fast_add(0);
430        h.fast_add(0);
431        h.fast_add(5);
432        h.condition();
433
434        assert_eq!(h.total_count, 3);
435        assert_eq!(h.counts.len(), HISTOGRAM_ROUNDING); // Rounded up from 6
436    }
437
438    #[test]
439    fn test_shannon_entropy_uniform() {
440        // Uniform distribution: entropy = log2(n) bits per symbol
441        let h = Histogram::from_counts(&[100, 100, 100, 100]);
442        let entropy = h.shannon_entropy();
443        // Expected: 400 * log2(4) = 400 * 2 = 800 bits total
444        // But our formula gives bits for the total, which is:
445        // sum of -count * log2(count/total) = 4 * (-100 * log2(0.25)) = 4 * 100 * 2 = 800
446        assert!((entropy - 800.0).abs() < 0.01, "entropy = {}", entropy);
447    }
448
449    #[test]
450    fn test_shannon_entropy_skewed() {
451        // Single symbol: entropy = 0 (no uncertainty)
452        let h = Histogram::from_counts(&[100, 0, 0, 0]);
453        let entropy = h.shannon_entropy();
454        assert!((entropy - 0.0).abs() < 0.01, "entropy = {}", entropy);
455    }
456
457    #[test]
458    fn test_shannon_entropy_binary() {
459        // Two equal symbols: entropy = n * 1 bit = n
460        let h = Histogram::from_counts(&[50, 50]);
461        let entropy = h.shannon_entropy();
462        // 2 * (-50 * log2(0.5)) = 2 * 50 * 1 = 100
463        assert!((entropy - 100.0).abs() < 0.01, "entropy = {}", entropy);
464    }
465
466    #[test]
467    fn test_histogram_distance_identical() {
468        let a = Histogram::from_counts(&[100, 50, 25]);
469        let b = Histogram::from_counts(&[100, 50, 25]);
470        a.shannon_entropy();
471        b.shannon_entropy();
472
473        let dist = histogram_distance(&a, &b);
474        // Identical histograms: combined entropy = 2x each entropy
475        // So distance = 2*E - E - E = 0
476        assert!(dist.abs() < 0.01, "distance = {}", dist);
477    }
478
479    #[test]
480    fn test_histogram_distance_different() {
481        let a = Histogram::from_counts(&[100, 0, 0]);
482        let b = Histogram::from_counts(&[0, 0, 100]);
483        a.shannon_entropy();
484        b.shannon_entropy();
485
486        let dist = histogram_distance(&a, &b);
487        // a has entropy 0, b has entropy 0 (single symbol each)
488        // Combined has 100 each in symbols 0 and 2
489        // Combined entropy = 2 * (-100 * log2(0.5)) = 200
490        // Distance = 200 - 0 - 0 = 200
491        assert!((dist - 200.0).abs() < 0.01, "distance = {}", dist);
492    }
493
494    #[test]
495    fn test_histogram_distance_empty() {
496        let a = Histogram::new();
497        let b = Histogram::from_counts(&[100]);
498        a.shannon_entropy();
499        b.shannon_entropy();
500
501        let dist = histogram_distance(&a, &b);
502        assert_eq!(dist, 0.0);
503    }
504
505    #[test]
506    fn test_kl_divergence_identical() {
507        let a = Histogram::from_counts(&[100, 50, 25]);
508        a.shannon_entropy();
509
510        let div = histogram_kl_divergence(&a, &a);
511        assert!(div.abs() < 0.01, "kl = {}", div);
512    }
513
514    #[test]
515    fn test_kl_divergence_missing_symbol() {
516        let a = Histogram::from_counts(&[100, 50, 25]);
517        let b = Histogram::from_counts(&[100, 50, 0]); // Missing symbol 2
518        a.shannon_entropy();
519        b.shannon_entropy();
520
521        let div = histogram_kl_divergence(&a, &b);
522        assert!(div.is_infinite(), "kl = {}", div);
523    }
524
525    #[test]
526    fn test_add_histogram() {
527        let mut a = Histogram::from_counts(&[10, 20]);
528        let b = Histogram::from_counts(&[5, 10, 15]);
529
530        a.add_histogram(&b);
531
532        assert_eq!(a.total_count, 60);
533        assert_eq!(a.counts[0], 15);
534        assert_eq!(a.counts[1], 30);
535        assert_eq!(a.counts[2], 15);
536    }
537
538    #[test]
539    fn test_max_symbol() {
540        let h = Histogram::from_counts(&[10, 20, 0, 5, 0, 0]);
541        assert_eq!(h.max_symbol(), 3);
542
543        let h2 = Histogram::from_counts(&[10]);
544        assert_eq!(h2.max_symbol(), 0);
545
546        let h3 = Histogram::new();
547        assert_eq!(h3.max_symbol(), 0);
548    }
549}