jxl_encoder/entropy_coding/
histogram.rs1use core::cell::Cell;
10
11pub const HISTOGRAM_ROUNDING: usize = 8;
14
15pub const MIN_DISTANCE_FOR_DISTINCT: f32 = 48.0;
18
19#[derive(Clone, Debug)]
24pub struct Histogram {
25 pub counts: Vec<i32>,
28 pub total_count: usize,
30 entropy: Cell<f32>,
33}
34
35impl Default for Histogram {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl Histogram {
42 pub fn new() -> Self {
44 Self {
45 counts: Vec::new(),
46 total_count: 0,
47 entropy: Cell::new(0.0),
48 }
49 }
50
51 pub fn with_capacity(length: usize) -> Self {
54 let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
55 Self {
56 counts: vec![0; rounded_len],
57 total_count: 0,
58 entropy: Cell::new(0.0),
59 }
60 }
61
62 pub fn from_counts(counts: &[i32]) -> Self {
64 let total: i32 = counts.iter().sum();
65 let rounded_len = div_ceil(counts.len(), HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
66 let mut result_counts = vec![0i32; rounded_len];
67 result_counts[..counts.len()].copy_from_slice(counts);
68
69 Self {
70 counts: result_counts,
71 total_count: total as usize,
72 entropy: Cell::new(0.0),
73 }
74 }
75
76 pub fn flat(length: usize, total_count: usize) -> Self {
78 let base = (total_count / length) as i32;
79 let remainder = total_count % length;
80
81 let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
82 let mut counts = vec![0i32; rounded_len];
83
84 for (i, count) in counts.iter_mut().enumerate().take(length) {
85 *count = base + if i < remainder { 1 } else { 0 };
86 }
87
88 Self {
89 counts,
90 total_count,
91 entropy: Cell::new(0.0),
92 }
93 }
94
95 pub fn clear(&mut self) {
97 self.counts.clear();
98 self.total_count = 0;
99 self.entropy.set(0.0);
100 }
101
102 pub fn add(&mut self, symbol: usize) {
104 self.ensure_capacity(symbol + 1);
105 self.counts[symbol] += 1;
106 self.total_count += 1;
107 }
108
109 pub fn ensure_capacity(&mut self, length: usize) {
111 let rounded_len = div_ceil(length, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
112 if self.counts.len() < rounded_len {
113 self.counts.resize(rounded_len, 0);
114 }
115 }
116
117 #[inline]
119 pub fn fast_add(&mut self, symbol: usize) {
120 debug_assert!(symbol < self.counts.len());
121 self.counts[symbol] += 1;
122 }
123
124 pub fn add_histogram(&mut self, other: &Histogram) {
126 if other.counts.len() > self.counts.len() {
127 self.counts.resize(other.counts.len(), 0);
128 }
129 for (i, &count) in other.counts.iter().enumerate() {
130 self.counts[i] += count;
131 }
132 self.total_count += other.total_count;
133 }
134
135 pub fn condition(&mut self) {
138 let mut last_nonzero: i32 = -1;
140 let mut total: i64 = 0;
141
142 for (i, &count) in self.counts.iter().enumerate() {
143 total += count as i64;
144 if count != 0 {
145 last_nonzero = i as i32;
146 }
147 }
148
149 let new_len = if last_nonzero >= 0 {
151 div_ceil((last_nonzero + 1) as usize, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING
152 } else {
153 0
154 };
155 self.counts.resize(new_len, 0);
156 self.total_count = total as usize;
157 }
158
159 pub fn shannon_entropy(&self) -> f32 {
170 if self.total_count == 0 {
171 self.entropy.set(0.0);
172 return 0.0;
173 }
174
175 let entropy = jxl_simd::shannon_entropy_bits(&self.counts, self.total_count);
176 self.entropy.set(entropy);
177 entropy
178 }
179
180 pub fn cached_entropy(&self) -> f32 {
183 self.entropy.get()
184 }
185
186 pub fn set_cached_entropy(&self, entropy: f32) {
188 self.entropy.set(entropy);
189 }
190
191 pub fn alphabet_size(&self) -> usize {
193 for i in (0..self.counts.len()).rev() {
194 if self.counts[i] > 0 {
195 return i + 1;
196 }
197 }
198 0
199 }
200
201 pub fn max_symbol(&self) -> usize {
203 if self.total_count == 0 {
204 return 0;
205 }
206 for i in (1..self.counts.len()).rev() {
207 if self.counts[i] > 0 {
208 return i;
209 }
210 }
211 0
212 }
213
214 pub fn is_empty(&self) -> bool {
216 self.total_count == 0
217 }
218
219 pub fn copy_from(&mut self, source: &Histogram) {
224 let src_len = source.counts.len();
225 if self.counts.len() < src_len {
226 self.counts.resize(src_len, 0);
227 }
228 self.counts[..src_len].copy_from_slice(&source.counts[..src_len]);
229 if self.counts.len() > src_len {
230 self.counts[src_len..].fill(0);
231 }
232 self.total_count = source.total_count;
233 self.entropy.set(source.cached_entropy());
234 }
235}
236
237pub struct DistanceScratch {
241 combined_counts: Vec<i32>,
242}
243
244impl Default for DistanceScratch {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250impl DistanceScratch {
251 pub fn new() -> Self {
253 Self {
254 combined_counts: Vec::new(),
255 }
256 }
257
258 #[inline]
261 fn ensure_capacity(&mut self, len: usize) {
262 if self.combined_counts.len() < len {
263 self.combined_counts.resize(len, 0);
264 }
265 }
266}
267
268pub fn histogram_distance(a: &Histogram, b: &Histogram) -> f32 {
278 let mut scratch = DistanceScratch::new();
279 histogram_distance_reuse(a, b, &mut scratch)
280}
281
282pub fn histogram_distance_reuse(
284 a: &Histogram,
285 b: &Histogram,
286 scratch: &mut DistanceScratch,
287) -> f32 {
288 if a.total_count == 0 || b.total_count == 0 {
289 return 0.0;
290 }
291
292 let combined_total = a.total_count + b.total_count;
293 let a_len = a.counts.len();
294 let b_len = b.counts.len();
295 let max_len = a_len.max(b_len);
296
297 let aligned_len = div_ceil(max_len, HISTOGRAM_ROUNDING) * HISTOGRAM_ROUNDING;
299 scratch.ensure_capacity(aligned_len);
300 let combined_counts = &mut scratch.combined_counts[..aligned_len];
301
302 let min_len = a_len.min(b_len);
304 for ((slot, &ac), &bc) in combined_counts[..min_len]
305 .iter_mut()
306 .zip(&a.counts[..min_len])
307 .zip(&b.counts[..min_len])
308 {
309 *slot = ac + bc;
310 }
311 if a_len > min_len {
313 combined_counts[min_len..a_len].copy_from_slice(&a.counts[min_len..a_len]);
314 } else if b_len > min_len {
315 combined_counts[min_len..b_len].copy_from_slice(&b.counts[min_len..b_len]);
316 }
317 if max_len < aligned_len {
319 combined_counts[max_len..aligned_len].fill(0);
320 }
321
322 let combined_entropy = jxl_simd::shannon_entropy_bits(combined_counts, combined_total);
323
324 combined_entropy - a.cached_entropy() - b.cached_entropy()
326}
327
328pub fn histogram_kl_divergence(actual: &Histogram, coding: &Histogram) -> f32 {
338 if actual.total_count == 0 {
339 return 0.0;
340 }
341 if coding.total_count == 0 {
342 return f32::INFINITY;
343 }
344
345 let coding_inv = 1.0 / coding.total_count as f32;
346 let mut cost = 0.0f32;
347
348 for (i, &count) in actual.counts.iter().enumerate() {
349 if count > 0 {
350 let coding_count = coding.counts.get(i).copied().unwrap_or(0);
351 if coding_count == 0 {
352 return f32::INFINITY;
354 }
355 let coding_prob = coding_count as f32 * coding_inv;
356 cost -= count as f32 * jxl_simd::fast_log2f(coding_prob);
358 }
359 }
360
361 cost - actual.cached_entropy()
363}
364
365#[inline]
367fn div_ceil(a: usize, b: usize) -> usize {
368 a.div_ceil(b)
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_histogram_new() {
377 let h = Histogram::new();
378 assert!(h.is_empty());
379 assert_eq!(h.total_count, 0);
380 assert_eq!(h.alphabet_size(), 0);
381 }
382
383 #[test]
384 fn test_histogram_from_counts() {
385 let h = Histogram::from_counts(&[10, 20, 30]);
386 assert_eq!(h.total_count, 60);
387 assert_eq!(h.alphabet_size(), 3);
388 assert!(!h.is_empty());
389 }
390
391 #[test]
392 fn test_histogram_add() {
393 let mut h = Histogram::new();
394 h.add(0);
395 h.add(0);
396 h.add(5);
397
398 assert_eq!(h.total_count, 3);
399 assert_eq!(h.counts[0], 2);
400 assert_eq!(h.counts[5], 1);
401 assert_eq!(h.alphabet_size(), 6);
402 }
403
404 #[test]
405 fn test_histogram_flat() {
406 let h = Histogram::flat(4, 100);
407 assert_eq!(h.total_count, 100);
408 assert_eq!(h.counts[0], 25);
410 assert_eq!(h.counts[1], 25);
411 assert_eq!(h.counts[2], 25);
412 assert_eq!(h.counts[3], 25);
413 }
414
415 #[test]
416 fn test_histogram_flat_remainder() {
417 let h = Histogram::flat(4, 10);
418 assert_eq!(h.total_count, 10);
419 assert_eq!(h.counts[0], 3);
421 assert_eq!(h.counts[1], 3);
422 assert_eq!(h.counts[2], 2);
423 assert_eq!(h.counts[3], 2);
424 }
425
426 #[test]
427 fn test_histogram_condition() {
428 let mut h = Histogram::with_capacity(100);
429 h.fast_add(0);
430 h.fast_add(0);
431 h.fast_add(5);
432 h.condition();
433
434 assert_eq!(h.total_count, 3);
435 assert_eq!(h.counts.len(), HISTOGRAM_ROUNDING); }
437
438 #[test]
439 fn test_shannon_entropy_uniform() {
440 let h = Histogram::from_counts(&[100, 100, 100, 100]);
442 let entropy = h.shannon_entropy();
443 assert!((entropy - 800.0).abs() < 0.01, "entropy = {}", entropy);
447 }
448
449 #[test]
450 fn test_shannon_entropy_skewed() {
451 let h = Histogram::from_counts(&[100, 0, 0, 0]);
453 let entropy = h.shannon_entropy();
454 assert!((entropy - 0.0).abs() < 0.01, "entropy = {}", entropy);
455 }
456
457 #[test]
458 fn test_shannon_entropy_binary() {
459 let h = Histogram::from_counts(&[50, 50]);
461 let entropy = h.shannon_entropy();
462 assert!((entropy - 100.0).abs() < 0.01, "entropy = {}", entropy);
464 }
465
466 #[test]
467 fn test_histogram_distance_identical() {
468 let a = Histogram::from_counts(&[100, 50, 25]);
469 let b = Histogram::from_counts(&[100, 50, 25]);
470 a.shannon_entropy();
471 b.shannon_entropy();
472
473 let dist = histogram_distance(&a, &b);
474 assert!(dist.abs() < 0.01, "distance = {}", dist);
477 }
478
479 #[test]
480 fn test_histogram_distance_different() {
481 let a = Histogram::from_counts(&[100, 0, 0]);
482 let b = Histogram::from_counts(&[0, 0, 100]);
483 a.shannon_entropy();
484 b.shannon_entropy();
485
486 let dist = histogram_distance(&a, &b);
487 assert!((dist - 200.0).abs() < 0.01, "distance = {}", dist);
492 }
493
494 #[test]
495 fn test_histogram_distance_empty() {
496 let a = Histogram::new();
497 let b = Histogram::from_counts(&[100]);
498 a.shannon_entropy();
499 b.shannon_entropy();
500
501 let dist = histogram_distance(&a, &b);
502 assert_eq!(dist, 0.0);
503 }
504
505 #[test]
506 fn test_kl_divergence_identical() {
507 let a = Histogram::from_counts(&[100, 50, 25]);
508 a.shannon_entropy();
509
510 let div = histogram_kl_divergence(&a, &a);
511 assert!(div.abs() < 0.01, "kl = {}", div);
512 }
513
514 #[test]
515 fn test_kl_divergence_missing_symbol() {
516 let a = Histogram::from_counts(&[100, 50, 25]);
517 let b = Histogram::from_counts(&[100, 50, 0]); a.shannon_entropy();
519 b.shannon_entropy();
520
521 let div = histogram_kl_divergence(&a, &b);
522 assert!(div.is_infinite(), "kl = {}", div);
523 }
524
525 #[test]
526 fn test_add_histogram() {
527 let mut a = Histogram::from_counts(&[10, 20]);
528 let b = Histogram::from_counts(&[5, 10, 15]);
529
530 a.add_histogram(&b);
531
532 assert_eq!(a.total_count, 60);
533 assert_eq!(a.counts[0], 15);
534 assert_eq!(a.counts[1], 30);
535 assert_eq!(a.counts[2], 15);
536 }
537
538 #[test]
539 fn test_max_symbol() {
540 let h = Histogram::from_counts(&[10, 20, 0, 5, 0, 0]);
541 assert_eq!(h.max_symbol(), 3);
542
543 let h2 = Histogram::from_counts(&[10]);
544 assert_eq!(h2.max_symbol(), 0);
545
546 let h3 = Histogram::new();
547 assert_eq!(h3.max_symbol(), 0);
548 }
549}