Skip to main content

jxl_encoder/entropy_coding/
histogram.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Histogram data structure with entropy calculations.
6//!
7//! Ported from libjxl `lib/jxl/enc_ans_params.h` and `lib/jxl/enc_cluster.cc`.
8
9use core::cell::Cell;
10
11/// Alignment for SIMD-friendly histogram operations.
12/// Matches libjxl's `Histogram::kRounding`.
13pub const HISTOGRAM_ROUNDING: usize = 8;
14
15/// Minimum distance threshold for creating distinct clusters.
16/// Below this threshold, histograms are considered similar enough to merge.
17pub const MIN_DISTANCE_FOR_DISTINCT: f32 = 48.0;
18
19/// A histogram counting symbol occurrences.
20///
21/// This is the encoder-side histogram structure, corresponding to libjxl's
22/// `Histogram` class in `enc_ans_params.h`.
23#[derive(Clone, Debug)]
24pub struct Histogram {
25    /// Symbol counts (aligned to HISTOGRAM_ROUNDING).
26    /// Uses i32 to match libjxl's `ANSHistBin` type.
27    pub counts: Vec<i32>,
28    /// Sum of all counts.
29    pub total_count: usize,
30    /// Cached entropy value.
31    /// WARNING: Not automatically kept up-to-date - call `shannon_entropy()` to refresh.
32    entropy: Cell<f32>,
33}
34
35impl Default for Histogram {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl Histogram {
42    /// Creates an empty histogram.
43    pub fn new() -> Self {
44        Self {
45            counts: Vec::new(),
46            total_count: 0,
47            entropy: Cell::new(0.0),
48        }
49    }
50
51    /// Creates a histogram with pre-allocated capacity for `length` symbols.
52    /// The capacity is rounded up to HISTOGRAM_ROUNDING.
53    pub fn with_capacity(length: usize) -> Self {
54        let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
55        Self {
56            counts: vec![0; rounded_len],
57            total_count: 0,
58            entropy: Cell::new(0.0),
59        }
60    }
61
62    /// Creates a histogram from a slice of counts.
63    pub fn from_counts(counts: &[i32]) -> Self {
64        let total: i32 = counts.iter().sum();
65        let rounded_len = div_ceil(counts.len(), HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
66        let mut result_counts = vec![0i32; rounded_len];
67        result_counts[..counts.len()].copy_from_slice(counts);
68
69        Self {
70            counts: result_counts,
71            total_count: total as usize,
72            entropy: Cell::new(0.0),
73        }
74    }
75
76    /// Creates a flat (uniform) histogram.
77    pub fn flat(length: usize, total_count: usize) -> Self {
78        let base = (total_count / length) as i32;
79        let remainder = total_count % length;
80
81        let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
82        let mut counts = vec![0i32; rounded_len];
83
84        for (i, count) in counts.iter_mut().enumerate().take(length) {
85            *count = base + if i < remainder { 1 } else { 0 };
86        }
87
88        Self {
89            counts,
90            total_count,
91            entropy: Cell::new(0.0),
92        }
93    }
94
95    /// Clears all counts.
96    pub fn clear(&mut self) {
97        self.counts.clear();
98        self.total_count = 0;
99        self.entropy.set(0.0);
100    }
101
102    /// Add one occurrence of a symbol.
103    pub fn add(&mut self, symbol: usize) {
104        self.ensure_capacity(symbol + 1);
105        self.counts[symbol] += 1;
106        self.total_count += 1;
107    }
108
109    /// Ensures the histogram can hold at least `length` symbols.
110    pub fn ensure_capacity(&mut self, length: usize) {
111        let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
112        if self.counts.len() < rounded_len {
113            self.counts.resize(rounded_len, 0);
114        }
115    }
116
117    /// Fast add (caller must ensure capacity first).
118    #[inline]
119    pub fn fast_add(&mut self, symbol: usize) {
120        debug_assert!(symbol < self.counts.len());
121        self.counts[symbol] += 1;
122    }
123
124    /// Add another histogram's counts to this one.
125    pub fn add_histogram(&mut self, other: &Histogram) {
126        if other.counts.len() > self.counts.len() {
127            self.counts.resize(other.counts.len(), 0);
128        }
129        for (i, &count) in other.counts.iter().enumerate() {
130            self.counts[i] += count;
131        }
132        self.total_count += other.total_count;
133    }
134
135    /// Trim trailing zeros and update total_count.
136    /// Should be called after a sequence of `fast_add` calls.
137    pub fn condition(&mut self) {
138        // Find the last non-zero position
139        let mut last_nonzero: i32 = -1;
140        let mut total: i64 = 0;
141
142        for (i, &count) in self.counts.iter().enumerate() {
143            total += count as i64;
144            if count != 0 {
145                last_nonzero = i as i32;
146            }
147        }
148
149        // Resize to rounded length past last non-zero
150        let new_len = if last_nonzero >= 0 {
151            div_ceil((last_nonzero + 1) as usize, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING
152        } else {
153            0
154        };
155        self.counts.resize(new_len, 0);
156        self.total_count = total as usize;
157    }
158
159    /// Compute Shannon entropy: -sum(count * log2(count / total)).
160    /// Result is in bits (not nats).
161    ///
162    /// Formula: sum of -(count/total) * log2(count/total) * total
163    ///        = sum of -count * log2(count/total)
164    ///        = sum of -count * (log2(count) - log2(total))
165    ///        = sum of -count * log2(count) + count * log2(total)
166    ///        = -sum(count * log2(count)) + total * log2(total)
167    ///
168    /// libjxl uses: -count * log2(count / total), excluding when count == total.
169    pub fn shannon_entropy(&self) -> f32 {
170        if self.total_count == 0 {
171            self.entropy.set(0.0);
172            return 0.0;
173        }
174
175        let inv_total = 1.0 / self.total_count as f32;
176        let total = self.total_count as f32;
177        let mut entropy = 0.0f32;
178
179        for &count in &self.counts {
180            if count > 0 {
181                let count_f = count as f32;
182                // When count == total, this symbol has probability 1, so entropy contribution is 0
183                if count_f != total {
184                    // Entropy contribution: -count * log2(count * inv_total)
185                    entropy -= count_f * (count_f * inv_total).log2();
186                }
187            }
188        }
189
190        self.entropy.set(entropy);
191        entropy
192    }
193
194    /// Get the cached entropy value.
195    /// Call `shannon_entropy()` first to ensure it's up-to-date.
196    pub fn cached_entropy(&self) -> f32 {
197        self.entropy.get()
198    }
199
200    /// Set the cached entropy value (used when loading from test data).
201    pub fn set_cached_entropy(&self, entropy: f32) {
202        self.entropy.set(entropy);
203    }
204
205    /// Alphabet size (highest non-zero symbol + 1).
206    pub fn alphabet_size(&self) -> usize {
207        for i in (0..self.counts.len()).rev() {
208            if self.counts[i] > 0 {
209                return i + 1;
210            }
211        }
212        0
213    }
214
215    /// Returns the index of the maximum symbol with non-zero count.
216    pub fn max_symbol(&self) -> usize {
217        if self.total_count == 0 {
218            return 0;
219        }
220        for i in (1..self.counts.len()).rev() {
221            if self.counts[i] > 0 {
222                return i;
223            }
224        }
225        0
226    }
227
228    /// Check if histogram is empty (all zeros).
229    pub fn is_empty(&self) -> bool {
230        self.total_count == 0
231    }
232}
233
234/// Distance between two histograms (for clustering).
235///
236/// This measures how many extra bits are needed to encode the combined
237/// distribution vs encoding them separately. Lower = more similar.
238///
239/// Formula: entropy(combined) - entropy(a) - entropy(b)
240///
241/// IMPORTANT: Both histograms must have their entropy pre-computed
242/// (call `shannon_entropy()` first).
243pub fn histogram_distance(a: &Histogram, b: &Histogram) -> f32 {
244    if a.total_count == 0 || b.total_count == 0 {
245        return 0.0;
246    }
247
248    let combined_total = (a.total_count + b.total_count) as f32;
249    let inv_total = 1.0 / combined_total;
250    let mut combined_entropy = 0.0f32;
251
252    let max_len = a.counts.len().max(b.counts.len());
253
254    for i in 0..max_len {
255        let a_count = a.counts.get(i).copied().unwrap_or(0);
256        let b_count = b.counts.get(i).copied().unwrap_or(0);
257        let combined = (a_count + b_count) as f32;
258
259        if combined > 0.0 && combined != combined_total {
260            combined_entropy -= combined * (combined * inv_total).log2();
261        }
262    }
263
264    // Distance = combined_entropy - a.entropy - b.entropy
265    combined_entropy - a.cached_entropy() - b.cached_entropy()
266}
267
268/// KL divergence: cost of encoding `actual` using `coding` histogram.
269///
270/// Returns the extra bits needed to encode `actual`'s symbols using
271/// `coding`'s probability distribution, compared to using `actual`'s
272/// own distribution.
273///
274/// Returns infinity if `actual` has symbols not present in `coding`.
275///
276/// IMPORTANT: Both histograms must have their entropy pre-computed.
277pub fn histogram_kl_divergence(actual: &Histogram, coding: &Histogram) -> f32 {
278    if actual.total_count == 0 {
279        return 0.0;
280    }
281    if coding.total_count == 0 {
282        return f32::INFINITY;
283    }
284
285    let coding_inv = 1.0 / coding.total_count as f32;
286    let mut cost = 0.0f32;
287
288    for (i, &count) in actual.counts.iter().enumerate() {
289        if count > 0 {
290            let coding_count = coding.counts.get(i).copied().unwrap_or(0);
291            if coding_count == 0 {
292                // Symbol in actual but not in coding -> infinite cost
293                return f32::INFINITY;
294            }
295            let coding_prob = coding_count as f32 * coding_inv;
296            // Cost: -count * log2(coding_prob)
297            cost -= count as f32 * coding_prob.log2();
298        }
299    }
300
301    // KL divergence = cost - entropy(actual)
302    cost - actual.cached_entropy()
303}
304
305/// Ceiling division.
306#[inline]
307fn div_ceil(a: usize, b: usize) -> usize {
308    a.div_ceil(b)
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn test_histogram_new() {
317        let h = Histogram::new();
318        assert!(h.is_empty());
319        assert_eq!(h.total_count, 0);
320        assert_eq!(h.alphabet_size(), 0);
321    }
322
323    #[test]
324    fn test_histogram_from_counts() {
325        let h = Histogram::from_counts(&[10, 20, 30]);
326        assert_eq!(h.total_count, 60);
327        assert_eq!(h.alphabet_size(), 3);
328        assert!(!h.is_empty());
329    }
330
331    #[test]
332    fn test_histogram_add() {
333        let mut h = Histogram::new();
334        h.add(0);
335        h.add(0);
336        h.add(5);
337
338        assert_eq!(h.total_count, 3);
339        assert_eq!(h.counts[0], 2);
340        assert_eq!(h.counts[5], 1);
341        assert_eq!(h.alphabet_size(), 6);
342    }
343
344    #[test]
345    fn test_histogram_flat() {
346        let h = Histogram::flat(4, 100);
347        assert_eq!(h.total_count, 100);
348        // 100 / 4 = 25 each
349        assert_eq!(h.counts[0], 25);
350        assert_eq!(h.counts[1], 25);
351        assert_eq!(h.counts[2], 25);
352        assert_eq!(h.counts[3], 25);
353    }
354
355    #[test]
356    fn test_histogram_flat_remainder() {
357        let h = Histogram::flat(4, 10);
358        assert_eq!(h.total_count, 10);
359        // 10 / 4 = 2 remainder 2, so first two get 3, last two get 2
360        assert_eq!(h.counts[0], 3);
361        assert_eq!(h.counts[1], 3);
362        assert_eq!(h.counts[2], 2);
363        assert_eq!(h.counts[3], 2);
364    }
365
366    #[test]
367    fn test_histogram_condition() {
368        let mut h = Histogram::with_capacity(100);
369        h.fast_add(0);
370        h.fast_add(0);
371        h.fast_add(5);
372        h.condition();
373
374        assert_eq!(h.total_count, 3);
375        assert_eq!(h.counts.len(), HISTOGRAM_ROUNDING); // Rounded up from 6
376    }
377
378    #[test]
379    fn test_shannon_entropy_uniform() {
380        // Uniform distribution: entropy = log2(n) bits per symbol
381        let h = Histogram::from_counts(&[100, 100, 100, 100]);
382        let entropy = h.shannon_entropy();
383        // Expected: 400 * log2(4) = 400 * 2 = 800 bits total
384        // But our formula gives bits for the total, which is:
385        // sum of -count * log2(count/total) = 4 * (-100 * log2(0.25)) = 4 * 100 * 2 = 800
386        assert!((entropy - 800.0).abs() < 0.01, "entropy = {}", entropy);
387    }
388
389    #[test]
390    fn test_shannon_entropy_skewed() {
391        // Single symbol: entropy = 0 (no uncertainty)
392        let h = Histogram::from_counts(&[100, 0, 0, 0]);
393        let entropy = h.shannon_entropy();
394        assert!((entropy - 0.0).abs() < 0.01, "entropy = {}", entropy);
395    }
396
397    #[test]
398    fn test_shannon_entropy_binary() {
399        // Two equal symbols: entropy = n * 1 bit = n
400        let h = Histogram::from_counts(&[50, 50]);
401        let entropy = h.shannon_entropy();
402        // 2 * (-50 * log2(0.5)) = 2 * 50 * 1 = 100
403        assert!((entropy - 100.0).abs() < 0.01, "entropy = {}", entropy);
404    }
405
406    #[test]
407    fn test_histogram_distance_identical() {
408        let a = Histogram::from_counts(&[100, 50, 25]);
409        let b = Histogram::from_counts(&[100, 50, 25]);
410        a.shannon_entropy();
411        b.shannon_entropy();
412
413        let dist = histogram_distance(&a, &b);
414        // Identical histograms: combined entropy = 2x each entropy
415        // So distance = 2*E - E - E = 0
416        assert!(dist.abs() < 0.01, "distance = {}", dist);
417    }
418
419    #[test]
420    fn test_histogram_distance_different() {
421        let a = Histogram::from_counts(&[100, 0, 0]);
422        let b = Histogram::from_counts(&[0, 0, 100]);
423        a.shannon_entropy();
424        b.shannon_entropy();
425
426        let dist = histogram_distance(&a, &b);
427        // a has entropy 0, b has entropy 0 (single symbol each)
428        // Combined has 100 each in symbols 0 and 2
429        // Combined entropy = 2 * (-100 * log2(0.5)) = 200
430        // Distance = 200 - 0 - 0 = 200
431        assert!((dist - 200.0).abs() < 0.01, "distance = {}", dist);
432    }
433
434    #[test]
435    fn test_histogram_distance_empty() {
436        let a = Histogram::new();
437        let b = Histogram::from_counts(&[100]);
438        a.shannon_entropy();
439        b.shannon_entropy();
440
441        let dist = histogram_distance(&a, &b);
442        assert_eq!(dist, 0.0);
443    }
444
445    #[test]
446    fn test_kl_divergence_identical() {
447        let a = Histogram::from_counts(&[100, 50, 25]);
448        a.shannon_entropy();
449
450        let div = histogram_kl_divergence(&a, &a);
451        assert!(div.abs() < 0.01, "kl = {}", div);
452    }
453
454    #[test]
455    fn test_kl_divergence_missing_symbol() {
456        let a = Histogram::from_counts(&[100, 50, 25]);
457        let b = Histogram::from_counts(&[100, 50, 0]); // Missing symbol 2
458        a.shannon_entropy();
459        b.shannon_entropy();
460
461        let div = histogram_kl_divergence(&a, &b);
462        assert!(div.is_infinite(), "kl = {}", div);
463    }
464
465    #[test]
466    fn test_add_histogram() {
467        let mut a = Histogram::from_counts(&[10, 20]);
468        let b = Histogram::from_counts(&[5, 10, 15]);
469
470        a.add_histogram(&b);
471
472        assert_eq!(a.total_count, 60);
473        assert_eq!(a.counts[0], 15);
474        assert_eq!(a.counts[1], 30);
475        assert_eq!(a.counts[2], 15);
476    }
477
478    #[test]
479    fn test_max_symbol() {
480        let h = Histogram::from_counts(&[10, 20, 0, 5, 0, 0]);
481        assert_eq!(h.max_symbol(), 3);
482
483        let h2 = Histogram::from_counts(&[10]);
484        assert_eq!(h2.max_symbol(), 0);
485
486        let h3 = Histogram::new();
487        assert_eq!(h3.max_symbol(), 0);
488    }
489}