Skip to main content

jxl_encoder/entropy_coding/
cluster.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Histogram clustering for entropy coding.
6//!
7//! Ported from libjxl `lib/jxl/enc_cluster.cc`.
8
9use alloc::collections::BinaryHeap;
10use core::cmp::Ordering;
11
12use super::histogram::{
13    DistanceScratch, Histogram, histogram_distance_reuse, histogram_kl_divergence,
14};
15use crate::error::{Error, Result};
16
17/// Minimum distance threshold for creating distinct clusters.
18const MIN_DISTANCE_FOR_DISTINCT: f32 = 48.0;
19
20/// Maximum number of histogram clusters.
21pub const CLUSTERS_LIMIT: usize = 256;
22
23/// Result of clustering histograms.
24#[derive(Debug, Clone)]
25pub struct ClusterResult {
26    /// The clustered histograms.
27    pub histograms: Vec<Histogram>,
28    /// Mapping from input index to cluster index.
29    pub symbols: Vec<u32>,
30}
31
32/// Clustering aggressiveness level.
33#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
34pub enum ClusteringType {
35    /// Only 4 clusters maximum (fastest encoding).
36    Fastest,
37    /// Default clustering.
38    #[default]
39    Fast,
40    /// With pair merge refinement (best compression).
41    Best,
42}
43
44/// Entropy coding method - affects header cost estimation for clustering.
45#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
46pub enum EntropyType {
47    /// Huffman prefix codes (used by libjxl-tiny, simpler header format).
48    #[default]
49    Huffman,
50    /// ANS (Asymmetric Numeral Systems) - used by full libjxl, larger alphabet support.
51    Ans,
52}
53
54/// Fast k-means-like clustering.
55///
56/// Algorithm:
57/// 1. Start with largest histogram as first cluster
58/// 2. Repeatedly add most distant histogram as new cluster
59/// 3. Stop when max clusters reached or distance < threshold
60/// 4. Assign remaining histograms to nearest cluster
61///
62/// Matches libjxl's `FastClusterHistograms` function.
63pub fn fast_cluster_histograms(
64    input: &[Histogram],
65    max_histograms: usize,
66) -> Result<ClusterResult> {
67    fast_cluster_histograms_with_prev(input, max_histograms, &[])
68}
69
70/// Fast clustering with support for pre-existing histograms.
71///
72/// This is the full implementation matching libjxl's `FastClusterHistograms`.
73/// The `prev_histograms` are fixed clusters that new histograms can be assigned to,
74/// but won't be merged into.
75pub fn fast_cluster_histograms_with_prev(
76    input: &[Histogram],
77    max_histograms: usize,
78    prev_histograms: &[Histogram],
79) -> Result<ClusterResult> {
80    if input.is_empty() {
81        return Ok(ClusterResult {
82            histograms: prev_histograms.to_vec(),
83            symbols: Vec::new(),
84        });
85    }
86
87    let prev_count = prev_histograms.len();
88    let mut out: Vec<Histogram> = prev_histograms.to_vec();
89    out.reserve(max_histograms);
90    let mut dist_scratch = DistanceScratch::new();
91
92    // Initialize symbols to "unassigned" marker
93    let unassigned = max_histograms as u32;
94    let mut symbols = vec![unassigned; input.len()];
95
96    // Initialize distances to max (except empty histograms)
97    let mut dists = vec![f32::MAX; input.len()];
98
99    // Find largest histogram and compute entropies
100    let mut largest_idx = 0;
101    for (i, h) in input.iter().enumerate() {
102        if h.total_count == 0 {
103            // Empty histograms get assigned to cluster 0
104            symbols[i] = 0;
105            dists[i] = 0.0;
106            continue;
107        }
108        h.shannon_entropy(); // Compute and cache entropy
109        if h.total_count > input[largest_idx].total_count {
110            largest_idx = i;
111        }
112    }
113
114    // If there are previous histograms, compute their entropies and
115    // update distances using KL divergence
116    if prev_count > 0 {
117        for h in &out {
118            h.shannon_entropy();
119        }
120        for (i, dist) in dists.iter_mut().enumerate() {
121            if *dist == 0.0 {
122                continue;
123            }
124            for out_hist in out.iter().take(prev_count) {
125                let kl = histogram_kl_divergence(&input[i], out_hist);
126                *dist = dist.min(kl);
127            }
128        }
129        // Find the histogram with maximum distance (most different from prev)
130        if let Some((max_idx, &max_dist)) = dists
131            .iter()
132            .enumerate()
133            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
134            && max_dist > 0.0
135        {
136            largest_idx = max_idx;
137        }
138    }
139
140    // Main clustering loop
141    while out.len() < prev_count + max_histograms {
142        // Add the largest/most distant histogram as a new cluster
143        symbols[largest_idx] = out.len() as u32;
144        out.push(input[largest_idx].clone());
145        dists[largest_idx] = 0.0;
146
147        // Find next candidate: histogram with maximum distance
148        let mut new_largest_idx = 0;
149        for (i, h) in input.iter().enumerate() {
150            if dists[i] == 0.0 {
151                continue;
152            }
153            // Update distance using histogram distance to new cluster
154            let dist = histogram_distance_reuse(h, out.last().unwrap(), &mut dist_scratch);
155            dists[i] = dists[i].min(dist);
156            if dists[i] > dists[new_largest_idx] {
157                new_largest_idx = i;
158            }
159        }
160        largest_idx = new_largest_idx;
161
162        // Stop if distance is below threshold
163        if dists[largest_idx] < MIN_DISTANCE_FOR_DISTINCT {
164            break;
165        }
166    }
167
168    // Assign remaining histograms to nearest cluster
169    for i in 0..input.len() {
170        if symbols[i] != unassigned {
171            continue;
172        }
173
174        // Find best cluster
175        let mut best = 0;
176        let mut best_dist = f32::MAX;
177
178        for (j, out_hist) in out.iter().enumerate() {
179            let dist = if j < prev_count {
180                // Use KL divergence for previous histograms
181                histogram_kl_divergence(&input[i], out_hist)
182            } else {
183                // Use symmetric distance for new histograms
184                histogram_distance_reuse(&input[i], out_hist, &mut dist_scratch)
185            };
186
187            if dist < best_dist {
188                best = j;
189                best_dist = dist;
190            }
191        }
192
193        if best_dist >= f32::MAX {
194            return Err(Error::InvalidHistogram(format!(
195                "Failed to find cluster for histogram {}",
196                i
197            )));
198        }
199
200        // Merge into best cluster (only for non-previous histograms)
201        if best >= prev_count {
202            out[best].add_histogram(&input[i]);
203            out[best].shannon_entropy(); // Recompute entropy
204        }
205        symbols[i] = best as u32;
206    }
207
208    Ok(ClusterResult {
209        histograms: out,
210        symbols,
211    })
212}
213
214/// Histogram pair for merge refinement priority queue.
215#[derive(Clone, Copy, Debug)]
216struct HistogramPair {
217    cost: f32,
218    first: u32,
219    second: u32,
220    version: u32,
221}
222
223impl PartialEq for HistogramPair {
224    fn eq(&self, other: &Self) -> bool {
225        self.cost == other.cost
226            && self.first == other.first
227            && self.second == other.second
228            && self.version == other.version
229    }
230}
231
232impl Eq for HistogramPair {}
233
234impl PartialOrd for HistogramPair {
235    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
236        Some(self.cmp(other))
237    }
238}
239
240impl Ord for HistogramPair {
241    fn cmp(&self, other: &Self) -> Ordering {
242        // Reverse order: lower cost = higher priority
243        // Use tuple comparison for tie-breaking
244        let self_tuple = (
245            ordered_float::OrderedFloat(self.cost),
246            self.first,
247            self.second,
248            self.version,
249        );
250        let other_tuple = (
251            ordered_float::OrderedFloat(other.cost),
252            other.first,
253            other.second,
254            other.version,
255        );
256        // Reverse because BinaryHeap is a max-heap
257        other_tuple.cmp(&self_tuple)
258    }
259}
260
261/// Wrapper for f32 that implements Ord for use in priority queues.
262mod ordered_float {
263    use core::cmp::Ordering;
264
265    #[derive(Clone, Copy, Debug)]
266    pub struct OrderedFloat(pub f32);
267
268    impl PartialEq for OrderedFloat {
269        fn eq(&self, other: &Self) -> bool {
270            self.0 == other.0
271        }
272    }
273
274    impl Eq for OrderedFloat {}
275
276    impl PartialOrd for OrderedFloat {
277        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
278            Some(self.cmp(other))
279        }
280    }
281
282    impl Ord for OrderedFloat {
283        fn cmp(&self, other: &Self) -> Ordering {
284            self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
285        }
286    }
287}
288
289/// Estimate Huffman population cost for clustering merge decisions.
290///
291/// For Huffman coding, the key insight is that header cost savings from
292/// merging histograms are minimal compared to data cost increases.
293/// The simple/tiny clustering uses data-only cost (sum of count * depth)
294/// and produces good results.
295///
296/// This function computes data cost using actual Huffman code lengths,
297/// plus a small header penalty that scales with alphabet size to
298/// discourage creating very large merged alphabets.
299fn huffman_population_cost(h: &Histogram) -> f32 {
300    if h.total_count == 0 {
301        return 0.0;
302    }
303
304    let alphabet_size = h.alphabet_size();
305    if alphabet_size == 0 {
306        return 0.0;
307    }
308
309    // Compute ACTUAL Huffman data cost using real code lengths
310    let data_cost = compute_huffman_data_cost(h, alphabet_size);
311
312    // For merge decisions, we want to penalize large alphabets slightly
313    // because they require more complex tree serialization.
314    // But don't over-penalize - the data cost is the main factor.
315    //
316    // Count non-zero symbols to estimate header complexity
317    let non_zero_count = h
318        .counts
319        .iter()
320        .take(alphabet_size)
321        .filter(|&&c| c > 0)
322        .count();
323
324    // Small header penalty: 0.1 bits per non-zero symbol
325    // This is much smaller than data cost, so it only tips the balance
326    // when data costs are very close.
327    let header_penalty = (non_zero_count as f32) * 0.1;
328
329    data_cost + header_penalty
330}
331
332/// Compute actual Huffman data cost using real code lengths.
333///
334/// This builds a real Huffman tree and computes sum(count * depth),
335/// which is the exact number of bits needed to encode the data.
336fn compute_huffman_data_cost(h: &Histogram, alphabet_size: usize) -> f32 {
337    use super::huffman_tree::create_huffman_tree;
338
339    if alphabet_size == 0 {
340        return 0.0;
341    }
342
343    // Convert to u32 counts for create_huffman_tree
344    let counts: Vec<u32> = h
345        .counts
346        .iter()
347        .take(alphabet_size)
348        .map(|&c| c.max(0) as u32)
349        .collect();
350
351    // Check for empty or single-symbol histogram
352    let non_zero = counts.iter().filter(|&&c| c > 0).count();
353    if non_zero == 0 {
354        return 0.0;
355    }
356    if non_zero == 1 {
357        // Single symbol needs 1 bit per occurrence
358        return counts.iter().sum::<u32>() as f32;
359    }
360
361    // Build actual Huffman tree with depth limit 15
362    let depths = create_huffman_tree(&counts, 15);
363
364    // Compute data cost: sum(count * depth)
365    let mut cost = 0.0f32;
366    for (i, &count) in counts.iter().enumerate() {
367        if count > 0 && i < depths.len() {
368            cost += count as f32 * depths[i] as f32;
369        }
370    }
371
372    cost
373}
374
375/// Compute cost of encoding histogram A's data using a tree built for histogram B.
376///
377/// This is the key insight for correct merge cost estimation:
378/// When contexts are merged, BOTH original contexts use the merged tree,
379/// which is suboptimal for each individually.
380#[allow(dead_code)]
381fn compute_cross_coding_cost(data: &Histogram, tree: &Histogram, alphabet_size: usize) -> f32 {
382    use super::huffman_tree::create_huffman_tree;
383
384    if alphabet_size == 0 {
385        return 0.0;
386    }
387
388    // Build tree from 'tree' histogram
389    let tree_counts: Vec<u32> = tree
390        .counts
391        .iter()
392        .take(alphabet_size)
393        .map(|&c| c.max(0) as u32)
394        .collect();
395
396    let non_zero = tree_counts.iter().filter(|&&c| c > 0).count();
397    if non_zero == 0 {
398        return 0.0;
399    }
400
401    let depths = if non_zero == 1 {
402        vec![1u8; alphabet_size]
403    } else {
404        create_huffman_tree(&tree_counts, 15)
405    };
406
407    // Encode 'data' using depths from 'tree'
408    let mut cost = 0.0f32;
409    for (i, &count) in data.counts.iter().take(alphabet_size).enumerate() {
410        if count > 0 && i < depths.len() {
411            let depth = if depths[i] == 0 { 15 } else { depths[i] }; // Penalize symbols not in tree
412            cost += count.max(0) as f32 * depth as f32;
413        }
414    }
415
416    cost
417}
418
419/// Estimate ANS population cost (header + data bits).
420///
421/// This is a simplified version of libjxl's `Histogram::ANSPopulationCost()`.
422/// ANS uses a frequency table with log-scale precision, supporting larger alphabets.
423fn ans_population_cost(h: &Histogram) -> f32 {
424    if h.total_count == 0 {
425        return 0.0;
426    }
427
428    let alphabet_size = h.alphabet_size();
429    if alphabet_size <= 1 {
430        // Single symbol or empty: almost no header cost
431        return 0.0;
432    }
433
434    // Data cost (entropy)
435    let data_cost = h.cached_entropy();
436
437    // Header cost estimate: roughly 5 bits per symbol for frequency table
438    // ANS encodes frequencies using variable-length coding based on precision
439    // This is a rough approximation - actual cost depends on the shift parameter
440    let header_cost = (alphabet_size as f32) * 5.0;
441
442    data_cost + header_cost
443}
444
445/// Estimate population cost for a histogram based on entropy type.
446fn population_cost(h: &Histogram, entropy_type: EntropyType) -> f32 {
447    match entropy_type {
448        EntropyType::Huffman => huffman_population_cost(h),
449        EntropyType::Ans => ans_population_cost(h),
450    }
451}
452
453/// Refine clusters by merging pairs that reduce total cost.
454///
455/// This implements the pair merge refinement from libjxl's `ClusterHistograms`
456/// when `params.clustering == ClusteringType::Best`.
457///
458/// The `entropy_type` parameter controls the cost model used for merge decisions:
459/// - `EntropyType::Huffman`: Uses Huffman tree serialization cost model
460/// - `EntropyType::Ans`: Uses ANS frequency table cost model
461pub fn refine_clusters_by_merging(
462    histograms: &mut Vec<Histogram>,
463    symbols: &mut [u32],
464    entropy_type: EntropyType,
465) -> Result<()> {
466    if histograms.is_empty() {
467        return Ok(());
468    }
469
470    // Compute initial costs
471    for h in histograms.iter() {
472        h.shannon_entropy();
473    }
474
475    // Version tracking for invalidation
476    let mut version = vec![1u32; histograms.len()];
477    let mut next_version = 2u32;
478
479    // Renumbering map (for tracking merges)
480    let mut renumbering: Vec<u32> = (0..histograms.len() as u32).collect();
481
482    // Create priority queue of pairs to merge
483    let mut pairs_to_merge: BinaryHeap<HistogramPair> = BinaryHeap::new();
484
485    // Reusable scratch histogram to avoid per-pair clone allocation
486    let mut merged = Histogram::new();
487
488    for i in 0..histograms.len() as u32 {
489        for j in (i + 1)..histograms.len() as u32 {
490            // Compute cost of merging (reuse scratch allocation)
491            merged.copy_from(&histograms[i as usize]);
492            merged.add_histogram(&histograms[j as usize]);
493            merged.shannon_entropy();
494
495            let merged_cost = population_cost(&merged, entropy_type);
496            let individual_cost = population_cost(&histograms[i as usize], entropy_type)
497                + population_cost(&histograms[j as usize], entropy_type);
498
499            let cost = merged_cost - individual_cost;
500
501            // Only enqueue if merging is beneficial
502            if cost < 0.0 {
503                pairs_to_merge.push(HistogramPair {
504                    cost,
505                    first: i,
506                    second: j,
507                    version: version[i as usize].max(version[j as usize]),
508                });
509            }
510        }
511    }
512
513    // Process merges
514    while let Some(pair) = pairs_to_merge.pop() {
515        let first = pair.first as usize;
516        let second = pair.second as usize;
517
518        // Check if pair is still valid
519        let expected_version = version[first].max(version[second]);
520        if pair.version != expected_version || version[first] == 0 || version[second] == 0 {
521            continue;
522        }
523
524        // Merge second into first (copy into scratch to avoid borrow conflict)
525        merged.copy_from(&histograms[second]);
526        histograms[first].add_histogram(&merged);
527        histograms[first].shannon_entropy();
528
529        // Update renumbering
530        for item in renumbering.iter_mut() {
531            if *item == pair.second {
532                *item = pair.first;
533            }
534        }
535
536        // Mark second as dead
537        version[second] = 0;
538        version[first] = next_version;
539        next_version += 1;
540
541        // Add new pairs with the merged histogram
542        for j in 0..histograms.len() as u32 {
543            if j == pair.first || version[j as usize] == 0 {
544                continue;
545            }
546
547            merged.copy_from(&histograms[first]);
548            merged.add_histogram(&histograms[j as usize]);
549            merged.shannon_entropy();
550
551            let merged_cost = population_cost(&merged, entropy_type);
552            let individual_cost = population_cost(&histograms[first], entropy_type)
553                + population_cost(&histograms[j as usize], entropy_type);
554
555            let cost = merged_cost - individual_cost;
556
557            if cost < 0.0 {
558                pairs_to_merge.push(HistogramPair {
559                    cost,
560                    first: pair.first.min(j),
561                    second: pair.first.max(j),
562                    version: version[first].max(version[j as usize]),
563                });
564            }
565        }
566    }
567
568    // Build reverse renumbering and compact
569    let mut reverse_renumbering = vec![u32::MAX; histograms.len()];
570    let mut num_alive = 0u32;
571
572    for i in 0..histograms.len() {
573        if version[i] == 0 {
574            continue;
575        }
576        if num_alive != i as u32 {
577            histograms[num_alive as usize] = histograms[i].clone();
578        }
579        reverse_renumbering[i] = num_alive;
580        num_alive += 1;
581    }
582    histograms.truncate(num_alive as usize);
583
584    // Update symbols
585    for symbol in symbols.iter_mut() {
586        let renumbered = renumbering[*symbol as usize];
587        *symbol = reverse_renumbering[renumbered as usize];
588    }
589
590    Ok(())
591}
592
593/// Reindex histograms so that symbols appear in increasing order.
594fn histogram_reindex(histograms: &mut Vec<Histogram>, prev_count: usize, symbols: &mut [u32]) {
595    use std::collections::HashMap;
596
597    let tmp = histograms.clone();
598    let mut new_index: HashMap<u32, u32> = HashMap::new();
599
600    // Previous histograms keep their indices
601    for i in 0..prev_count {
602        new_index.insert(i as u32, i as u32);
603    }
604
605    // Assign new indices in order of first appearance
606    let mut next_index = prev_count as u32;
607    for &symbol in symbols.iter() {
608        if let std::collections::hash_map::Entry::Vacant(e) = new_index.entry(symbol) {
609            e.insert(next_index);
610            histograms[next_index as usize] = tmp[symbol as usize].clone();
611            next_index += 1;
612        }
613    }
614
615    histograms.truncate(next_index as usize);
616
617    // Update symbols
618    for symbol in symbols.iter_mut() {
619        *symbol = new_index[symbol];
620    }
621}
622
623/// Full clustering pipeline.
624///
625/// Combines fast clustering with optional pair merge refinement.
626///
627/// # Arguments
628///
629/// * `clustering_type` - Controls clustering aggressiveness (Fastest/Fast/Best)
630/// * `entropy_type` - Controls cost model for merge decisions (Huffman/Ans)
631/// * `input` - Input histograms to cluster
632/// * `max_histograms` - Maximum number of output clusters
633pub fn cluster_histograms(
634    clustering_type: ClusteringType,
635    entropy_type: EntropyType,
636    input: &[Histogram],
637    max_histograms: usize,
638) -> Result<ClusterResult> {
639    let max_histograms = match clustering_type {
640        ClusteringType::Fastest => max_histograms.min(4),
641        _ => max_histograms,
642    };
643
644    let max_histograms = max_histograms.min(input.len()).min(CLUSTERS_LIMIT);
645
646    // Fast clustering
647    let mut result = fast_cluster_histograms(input, max_histograms)?;
648
649    // Pair merge refinement for Best quality
650    if clustering_type == ClusteringType::Best && !result.histograms.is_empty() {
651        refine_clusters_by_merging(&mut result.histograms, &mut result.symbols, entropy_type)?;
652    }
653
654    // Reindex for canonical form
655    histogram_reindex(&mut result.histograms, 0, &mut result.symbols);
656
657    Ok(result)
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    fn make_histogram(counts: &[i32]) -> Histogram {
665        let h = Histogram::from_counts(counts);
666        h.shannon_entropy(); // Pre-compute entropy
667        h
668    }
669
670    #[test]
671    fn test_fast_cluster_single() {
672        let input = vec![make_histogram(&[100, 50, 25])];
673
674        let result = fast_cluster_histograms(&input, 10).unwrap();
675
676        assert_eq!(result.histograms.len(), 1);
677        assert_eq!(result.symbols.len(), 1);
678        assert_eq!(result.symbols[0], 0);
679    }
680
681    #[test]
682    fn test_fast_cluster_identical() {
683        let input = vec![
684            make_histogram(&[100, 50, 25]),
685            make_histogram(&[100, 50, 25]),
686            make_histogram(&[100, 50, 25]),
687        ];
688
689        let result = fast_cluster_histograms(&input, 10).unwrap();
690
691        // All identical histograms should cluster together
692        assert_eq!(result.histograms.len(), 1);
693        assert_eq!(result.symbols, vec![0, 0, 0]);
694    }
695
696    #[test]
697    fn test_fast_cluster_different() {
698        let input = vec![
699            make_histogram(&[100, 0, 0]),
700            make_histogram(&[0, 100, 0]),
701            make_histogram(&[0, 0, 100]),
702        ];
703
704        let result = fast_cluster_histograms(&input, 10).unwrap();
705
706        // Very different histograms should be in separate clusters
707        assert!(result.histograms.len() >= 2);
708        // Each histogram should be assigned to some cluster
709        assert!(
710            result
711                .symbols
712                .iter()
713                .all(|&s| (s as usize) < result.histograms.len())
714        );
715    }
716
717    #[test]
718    fn test_fast_cluster_max_limit() {
719        let input: Vec<Histogram> = (0..10)
720            .map(|i| {
721                let mut counts = vec![0i32; 10];
722                counts[i] = 100;
723                make_histogram(&counts)
724            })
725            .collect();
726
727        let result = fast_cluster_histograms(&input, 4).unwrap();
728
729        // Should not exceed max limit
730        assert!(result.histograms.len() <= 4);
731    }
732
733    #[test]
734    fn test_fast_cluster_empty() {
735        let input: Vec<Histogram> = vec![];
736        let result = fast_cluster_histograms(&input, 10).unwrap();
737
738        assert!(result.histograms.is_empty());
739        assert!(result.symbols.is_empty());
740    }
741
742    #[test]
743    fn test_fast_cluster_with_empty_histograms() {
744        let input = vec![
745            Histogram::new(), // Empty
746            make_histogram(&[100, 50]),
747            Histogram::new(), // Empty
748        ];
749
750        let result = fast_cluster_histograms(&input, 10).unwrap();
751
752        // Empty histograms should be assigned to cluster 0
753        assert!(!result.histograms.is_empty());
754        assert_eq!(result.symbols[0], 0);
755        assert_eq!(result.symbols[2], 0);
756    }
757
758    #[test]
759    fn test_cluster_histograms_fastest() {
760        let input: Vec<Histogram> = (0..10)
761            .map(|i| {
762                let mut counts = vec![0i32; 10];
763                counts[i] = 100;
764                make_histogram(&counts)
765            })
766            .collect();
767
768        let result =
769            cluster_histograms(ClusteringType::Fastest, EntropyType::Huffman, &input, 10).unwrap();
770
771        // Fastest should limit to 4 clusters
772        assert!(result.histograms.len() <= 4);
773    }
774
775    #[test]
776    fn test_cluster_histograms_best_merges_huffman() {
777        // Create two pairs of similar histograms
778        let input = vec![
779            make_histogram(&[100, 50, 25, 10]),
780            make_histogram(&[105, 52, 23, 11]), // Similar to 0
781            make_histogram(&[10, 25, 50, 100]),
782            make_histogram(&[11, 23, 52, 105]), // Similar to 2
783        ];
784
785        let result =
786            cluster_histograms(ClusteringType::Best, EntropyType::Huffman, &input, 10).unwrap();
787
788        // With best quality, similar histograms should be merged
789        assert!(result.histograms.len() <= 4);
790    }
791
792    #[test]
793    fn test_cluster_histograms_best_merges_ans() {
794        // Create two pairs of similar histograms
795        let input = vec![
796            make_histogram(&[100, 50, 25, 10]),
797            make_histogram(&[105, 52, 23, 11]), // Similar to 0
798            make_histogram(&[10, 25, 50, 100]),
799            make_histogram(&[11, 23, 52, 105]), // Similar to 2
800        ];
801
802        let result =
803            cluster_histograms(ClusteringType::Best, EntropyType::Ans, &input, 10).unwrap();
804
805        // With best quality, similar histograms should be merged
806        assert!(result.histograms.len() <= 4);
807    }
808
809    #[test]
810    fn test_huffman_vs_ans_cost_model() {
811        // Histogram with many symbols - ANS and Huffman should have different costs
812        let mut counts = vec![0i32; 64];
813        for (i, c) in counts.iter_mut().enumerate() {
814            *c = (64 - i as i32) * 10; // Decreasing frequencies
815        }
816        let h = make_histogram(&counts);
817
818        let huffman_cost = huffman_population_cost(&h);
819        let ans_cost = ans_population_cost(&h);
820
821        // Both should be positive
822        assert!(huffman_cost > 0.0);
823        assert!(ans_cost > 0.0);
824
825        // For large alphabets, ANS header cost (5 bits/symbol) should be higher
826        // than Huffman's nested tree (~3 bits/symbol + 30 bit overhead)
827        // This test just verifies they're different - actual values depend on distribution
828        assert!((huffman_cost - ans_cost).abs() > 1.0);
829    }
830
831    #[test]
832    fn test_histogram_reindex() {
833        let mut histograms = vec![
834            make_histogram(&[100]),
835            make_histogram(&[200]),
836            make_histogram(&[300]),
837        ];
838        let mut symbols = vec![2, 0, 2, 1, 0];
839
840        histogram_reindex(&mut histograms, 0, &mut symbols);
841
842        // Symbols should now be renumbered in order of first appearance
843        // 2 -> 0, 0 -> 1, 1 -> 2
844        assert_eq!(symbols, vec![0, 1, 0, 2, 1]);
845    }
846}
847
848#[test]
849fn test_huffman_cost_disjoint_histograms() {
850    // Disjoint histograms - merging should NOT be beneficial
851    let a = Histogram::from_counts(&[100, 50, 25, 0, 0, 0, 0, 0]);
852    a.shannon_entropy();
853
854    let b = Histogram::from_counts(&[0, 0, 0, 80, 40, 20, 0, 0]);
855    b.shannon_entropy();
856
857    let mut merged = a.clone();
858    merged.add_histogram(&b);
859    merged.shannon_entropy();
860
861    let cost_a = huffman_population_cost(&a);
862    let cost_b = huffman_population_cost(&b);
863    let cost_merged = huffman_population_cost(&merged);
864    let delta = cost_merged - cost_a - cost_b;
865
866    // Disjoint histograms: merging increases data cost significantly
867    assert!(delta >= 0.0, "Disjoint merge should not be beneficial");
868}
869
870#[test]
871fn test_huffman_cost_identical_histograms() {
872    // Identical histograms - merging should have near-zero delta
873    let a = Histogram::from_counts(&[100, 50, 25, 10, 0, 0, 0, 0]);
874    a.shannon_entropy();
875
876    let b = Histogram::from_counts(&[100, 50, 25, 10, 0, 0, 0, 0]);
877    b.shannon_entropy();
878
879    let mut merged = a.clone();
880    merged.add_histogram(&b);
881    merged.shannon_entropy();
882
883    let cost_a = huffman_population_cost(&a);
884    let cost_b = huffman_population_cost(&b);
885    let cost_merged = huffman_population_cost(&merged);
886    let delta = cost_merged - cost_a - cost_b;
887
888    // Identical histograms use same Huffman tree, so merged cost = 2x individual
889    assert!(
890        delta.abs() < 1.0,
891        "Identical histograms should have near-zero delta, got {}",
892        delta
893    );
894}