1use alloc::collections::BinaryHeap;
10use core::cmp::Ordering;
11
12use super::histogram::{
13 DistanceScratch, Histogram, histogram_distance_reuse, histogram_kl_divergence,
14};
15use crate::error::{Error, Result};
16
17const MIN_DISTANCE_FOR_DISTINCT: f32 = 48.0;
19
20pub const CLUSTERS_LIMIT: usize = 256;
22
23#[derive(Debug, Clone)]
25pub struct ClusterResult {
26 pub histograms: Vec<Histogram>,
28 pub symbols: Vec<u32>,
30}
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
34pub enum ClusteringType {
35 Fastest,
37 #[default]
39 Fast,
40 Best,
42}
43
44#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
46pub enum EntropyType {
47 #[default]
49 Huffman,
50 Ans,
52}
53
54pub fn fast_cluster_histograms(
64 input: &[Histogram],
65 max_histograms: usize,
66) -> Result<ClusterResult> {
67 fast_cluster_histograms_with_prev(input, max_histograms, &[])
68}
69
70pub fn fast_cluster_histograms_with_prev(
76 input: &[Histogram],
77 max_histograms: usize,
78 prev_histograms: &[Histogram],
79) -> Result<ClusterResult> {
80 if input.is_empty() {
81 return Ok(ClusterResult {
82 histograms: prev_histograms.to_vec(),
83 symbols: Vec::new(),
84 });
85 }
86
87 let prev_count = prev_histograms.len();
88 let mut out: Vec<Histogram> = prev_histograms.to_vec();
89 out.reserve(max_histograms);
90 let mut dist_scratch = DistanceScratch::new();
91
92 let unassigned = max_histograms as u32;
94 let mut symbols = vec![unassigned; input.len()];
95
96 let mut dists = vec![f32::MAX; input.len()];
98
99 let mut largest_idx = 0;
101 for (i, h) in input.iter().enumerate() {
102 if h.total_count == 0 {
103 symbols[i] = 0;
105 dists[i] = 0.0;
106 continue;
107 }
108 h.shannon_entropy(); if h.total_count > input[largest_idx].total_count {
110 largest_idx = i;
111 }
112 }
113
114 if prev_count > 0 {
117 for h in &out {
118 h.shannon_entropy();
119 }
120 for (i, dist) in dists.iter_mut().enumerate() {
121 if *dist == 0.0 {
122 continue;
123 }
124 for out_hist in out.iter().take(prev_count) {
125 let kl = histogram_kl_divergence(&input[i], out_hist);
126 *dist = dist.min(kl);
127 }
128 }
129 if let Some((max_idx, &max_dist)) = dists
131 .iter()
132 .enumerate()
133 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
134 && max_dist > 0.0
135 {
136 largest_idx = max_idx;
137 }
138 }
139
140 while out.len() < prev_count + max_histograms {
142 symbols[largest_idx] = out.len() as u32;
144 out.push(input[largest_idx].clone());
145 dists[largest_idx] = 0.0;
146
147 let mut new_largest_idx = 0;
149 for (i, h) in input.iter().enumerate() {
150 if dists[i] == 0.0 {
151 continue;
152 }
153 let dist = histogram_distance_reuse(h, out.last().unwrap(), &mut dist_scratch);
155 dists[i] = dists[i].min(dist);
156 if dists[i] > dists[new_largest_idx] {
157 new_largest_idx = i;
158 }
159 }
160 largest_idx = new_largest_idx;
161
162 if dists[largest_idx] < MIN_DISTANCE_FOR_DISTINCT {
164 break;
165 }
166 }
167
168 for i in 0..input.len() {
170 if symbols[i] != unassigned {
171 continue;
172 }
173
174 let mut best = 0;
176 let mut best_dist = f32::MAX;
177
178 for (j, out_hist) in out.iter().enumerate() {
179 let dist = if j < prev_count {
180 histogram_kl_divergence(&input[i], out_hist)
182 } else {
183 histogram_distance_reuse(&input[i], out_hist, &mut dist_scratch)
185 };
186
187 if dist < best_dist {
188 best = j;
189 best_dist = dist;
190 }
191 }
192
193 if best_dist >= f32::MAX {
194 return Err(Error::InvalidHistogram(format!(
195 "Failed to find cluster for histogram {}",
196 i
197 )));
198 }
199
200 if best >= prev_count {
202 out[best].add_histogram(&input[i]);
203 out[best].shannon_entropy(); }
205 symbols[i] = best as u32;
206 }
207
208 Ok(ClusterResult {
209 histograms: out,
210 symbols,
211 })
212}
213
214#[derive(Clone, Copy, Debug)]
216struct HistogramPair {
217 cost: f32,
218 first: u32,
219 second: u32,
220 version: u32,
221}
222
223impl PartialEq for HistogramPair {
224 fn eq(&self, other: &Self) -> bool {
225 self.cost == other.cost
226 && self.first == other.first
227 && self.second == other.second
228 && self.version == other.version
229 }
230}
231
232impl Eq for HistogramPair {}
233
234impl PartialOrd for HistogramPair {
235 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
236 Some(self.cmp(other))
237 }
238}
239
240impl Ord for HistogramPair {
241 fn cmp(&self, other: &Self) -> Ordering {
242 let self_tuple = (
245 ordered_float::OrderedFloat(self.cost),
246 self.first,
247 self.second,
248 self.version,
249 );
250 let other_tuple = (
251 ordered_float::OrderedFloat(other.cost),
252 other.first,
253 other.second,
254 other.version,
255 );
256 other_tuple.cmp(&self_tuple)
258 }
259}
260
261mod ordered_float {
263 use core::cmp::Ordering;
264
265 #[derive(Clone, Copy, Debug)]
266 pub struct OrderedFloat(pub f32);
267
268 impl PartialEq for OrderedFloat {
269 fn eq(&self, other: &Self) -> bool {
270 self.0 == other.0
271 }
272 }
273
274 impl Eq for OrderedFloat {}
275
276 impl PartialOrd for OrderedFloat {
277 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
278 Some(self.cmp(other))
279 }
280 }
281
282 impl Ord for OrderedFloat {
283 fn cmp(&self, other: &Self) -> Ordering {
284 self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
285 }
286 }
287}
288
289fn huffman_population_cost(h: &Histogram) -> f32 {
300 if h.total_count == 0 {
301 return 0.0;
302 }
303
304 let alphabet_size = h.alphabet_size();
305 if alphabet_size == 0 {
306 return 0.0;
307 }
308
309 let data_cost = compute_huffman_data_cost(h, alphabet_size);
311
312 let non_zero_count = h
318 .counts
319 .iter()
320 .take(alphabet_size)
321 .filter(|&&c| c > 0)
322 .count();
323
324 let header_penalty = (non_zero_count as f32) * 0.1;
328
329 data_cost + header_penalty
330}
331
332fn compute_huffman_data_cost(h: &Histogram, alphabet_size: usize) -> f32 {
337 use super::huffman_tree::create_huffman_tree;
338
339 if alphabet_size == 0 {
340 return 0.0;
341 }
342
343 let counts: Vec<u32> = h
345 .counts
346 .iter()
347 .take(alphabet_size)
348 .map(|&c| c.max(0) as u32)
349 .collect();
350
351 let non_zero = counts.iter().filter(|&&c| c > 0).count();
353 if non_zero == 0 {
354 return 0.0;
355 }
356 if non_zero == 1 {
357 return counts.iter().sum::<u32>() as f32;
359 }
360
361 let depths = create_huffman_tree(&counts, 15);
363
364 let mut cost = 0.0f32;
366 for (i, &count) in counts.iter().enumerate() {
367 if count > 0 && i < depths.len() {
368 cost += count as f32 * depths[i] as f32;
369 }
370 }
371
372 cost
373}
374
375#[allow(dead_code)]
381fn compute_cross_coding_cost(data: &Histogram, tree: &Histogram, alphabet_size: usize) -> f32 {
382 use super::huffman_tree::create_huffman_tree;
383
384 if alphabet_size == 0 {
385 return 0.0;
386 }
387
388 let tree_counts: Vec<u32> = tree
390 .counts
391 .iter()
392 .take(alphabet_size)
393 .map(|&c| c.max(0) as u32)
394 .collect();
395
396 let non_zero = tree_counts.iter().filter(|&&c| c > 0).count();
397 if non_zero == 0 {
398 return 0.0;
399 }
400
401 let depths = if non_zero == 1 {
402 vec![1u8; alphabet_size]
403 } else {
404 create_huffman_tree(&tree_counts, 15)
405 };
406
407 let mut cost = 0.0f32;
409 for (i, &count) in data.counts.iter().take(alphabet_size).enumerate() {
410 if count > 0 && i < depths.len() {
411 let depth = if depths[i] == 0 { 15 } else { depths[i] }; cost += count.max(0) as f32 * depth as f32;
413 }
414 }
415
416 cost
417}
418
419fn ans_population_cost(h: &Histogram) -> f32 {
424 if h.total_count == 0 {
425 return 0.0;
426 }
427
428 let alphabet_size = h.alphabet_size();
429 if alphabet_size <= 1 {
430 return 0.0;
432 }
433
434 let data_cost = h.cached_entropy();
436
437 let header_cost = (alphabet_size as f32) * 5.0;
441
442 data_cost + header_cost
443}
444
445fn population_cost(h: &Histogram, entropy_type: EntropyType) -> f32 {
447 match entropy_type {
448 EntropyType::Huffman => huffman_population_cost(h),
449 EntropyType::Ans => ans_population_cost(h),
450 }
451}
452
453pub fn refine_clusters_by_merging(
462 histograms: &mut Vec<Histogram>,
463 symbols: &mut [u32],
464 entropy_type: EntropyType,
465) -> Result<()> {
466 if histograms.is_empty() {
467 return Ok(());
468 }
469
470 for h in histograms.iter() {
472 h.shannon_entropy();
473 }
474
475 let mut version = vec![1u32; histograms.len()];
477 let mut next_version = 2u32;
478
479 let mut renumbering: Vec<u32> = (0..histograms.len() as u32).collect();
481
482 let mut pairs_to_merge: BinaryHeap<HistogramPair> = BinaryHeap::new();
484
485 let mut merged = Histogram::new();
487
488 for i in 0..histograms.len() as u32 {
489 for j in (i + 1)..histograms.len() as u32 {
490 merged.copy_from(&histograms[i as usize]);
492 merged.add_histogram(&histograms[j as usize]);
493 merged.shannon_entropy();
494
495 let merged_cost = population_cost(&merged, entropy_type);
496 let individual_cost = population_cost(&histograms[i as usize], entropy_type)
497 + population_cost(&histograms[j as usize], entropy_type);
498
499 let cost = merged_cost - individual_cost;
500
501 if cost < 0.0 {
503 pairs_to_merge.push(HistogramPair {
504 cost,
505 first: i,
506 second: j,
507 version: version[i as usize].max(version[j as usize]),
508 });
509 }
510 }
511 }
512
513 while let Some(pair) = pairs_to_merge.pop() {
515 let first = pair.first as usize;
516 let second = pair.second as usize;
517
518 let expected_version = version[first].max(version[second]);
520 if pair.version != expected_version || version[first] == 0 || version[second] == 0 {
521 continue;
522 }
523
524 merged.copy_from(&histograms[second]);
526 histograms[first].add_histogram(&merged);
527 histograms[first].shannon_entropy();
528
529 for item in renumbering.iter_mut() {
531 if *item == pair.second {
532 *item = pair.first;
533 }
534 }
535
536 version[second] = 0;
538 version[first] = next_version;
539 next_version += 1;
540
541 for j in 0..histograms.len() as u32 {
543 if j == pair.first || version[j as usize] == 0 {
544 continue;
545 }
546
547 merged.copy_from(&histograms[first]);
548 merged.add_histogram(&histograms[j as usize]);
549 merged.shannon_entropy();
550
551 let merged_cost = population_cost(&merged, entropy_type);
552 let individual_cost = population_cost(&histograms[first], entropy_type)
553 + population_cost(&histograms[j as usize], entropy_type);
554
555 let cost = merged_cost - individual_cost;
556
557 if cost < 0.0 {
558 pairs_to_merge.push(HistogramPair {
559 cost,
560 first: pair.first.min(j),
561 second: pair.first.max(j),
562 version: version[first].max(version[j as usize]),
563 });
564 }
565 }
566 }
567
568 let mut reverse_renumbering = vec![u32::MAX; histograms.len()];
570 let mut num_alive = 0u32;
571
572 for i in 0..histograms.len() {
573 if version[i] == 0 {
574 continue;
575 }
576 if num_alive != i as u32 {
577 histograms[num_alive as usize] = histograms[i].clone();
578 }
579 reverse_renumbering[i] = num_alive;
580 num_alive += 1;
581 }
582 histograms.truncate(num_alive as usize);
583
584 for symbol in symbols.iter_mut() {
586 let renumbered = renumbering[*symbol as usize];
587 *symbol = reverse_renumbering[renumbered as usize];
588 }
589
590 Ok(())
591}
592
593fn histogram_reindex(histograms: &mut Vec<Histogram>, prev_count: usize, symbols: &mut [u32]) {
595 use std::collections::HashMap;
596
597 let tmp = histograms.clone();
598 let mut new_index: HashMap<u32, u32> = HashMap::new();
599
600 for i in 0..prev_count {
602 new_index.insert(i as u32, i as u32);
603 }
604
605 let mut next_index = prev_count as u32;
607 for &symbol in symbols.iter() {
608 if let std::collections::hash_map::Entry::Vacant(e) = new_index.entry(symbol) {
609 e.insert(next_index);
610 histograms[next_index as usize] = tmp[symbol as usize].clone();
611 next_index += 1;
612 }
613 }
614
615 histograms.truncate(next_index as usize);
616
617 for symbol in symbols.iter_mut() {
619 *symbol = new_index[symbol];
620 }
621}
622
623pub fn cluster_histograms(
634 clustering_type: ClusteringType,
635 entropy_type: EntropyType,
636 input: &[Histogram],
637 max_histograms: usize,
638) -> Result<ClusterResult> {
639 let max_histograms = match clustering_type {
640 ClusteringType::Fastest => max_histograms.min(4),
641 _ => max_histograms,
642 };
643
644 let max_histograms = max_histograms.min(input.len()).min(CLUSTERS_LIMIT);
645
646 let mut result = fast_cluster_histograms(input, max_histograms)?;
648
649 if clustering_type == ClusteringType::Best && !result.histograms.is_empty() {
651 refine_clusters_by_merging(&mut result.histograms, &mut result.symbols, entropy_type)?;
652 }
653
654 histogram_reindex(&mut result.histograms, 0, &mut result.symbols);
656
657 Ok(result)
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 fn make_histogram(counts: &[i32]) -> Histogram {
665 let h = Histogram::from_counts(counts);
666 h.shannon_entropy(); h
668 }
669
670 #[test]
671 fn test_fast_cluster_single() {
672 let input = vec![make_histogram(&[100, 50, 25])];
673
674 let result = fast_cluster_histograms(&input, 10).unwrap();
675
676 assert_eq!(result.histograms.len(), 1);
677 assert_eq!(result.symbols.len(), 1);
678 assert_eq!(result.symbols[0], 0);
679 }
680
681 #[test]
682 fn test_fast_cluster_identical() {
683 let input = vec![
684 make_histogram(&[100, 50, 25]),
685 make_histogram(&[100, 50, 25]),
686 make_histogram(&[100, 50, 25]),
687 ];
688
689 let result = fast_cluster_histograms(&input, 10).unwrap();
690
691 assert_eq!(result.histograms.len(), 1);
693 assert_eq!(result.symbols, vec![0, 0, 0]);
694 }
695
696 #[test]
697 fn test_fast_cluster_different() {
698 let input = vec![
699 make_histogram(&[100, 0, 0]),
700 make_histogram(&[0, 100, 0]),
701 make_histogram(&[0, 0, 100]),
702 ];
703
704 let result = fast_cluster_histograms(&input, 10).unwrap();
705
706 assert!(result.histograms.len() >= 2);
708 assert!(
710 result
711 .symbols
712 .iter()
713 .all(|&s| (s as usize) < result.histograms.len())
714 );
715 }
716
717 #[test]
718 fn test_fast_cluster_max_limit() {
719 let input: Vec<Histogram> = (0..10)
720 .map(|i| {
721 let mut counts = vec![0i32; 10];
722 counts[i] = 100;
723 make_histogram(&counts)
724 })
725 .collect();
726
727 let result = fast_cluster_histograms(&input, 4).unwrap();
728
729 assert!(result.histograms.len() <= 4);
731 }
732
733 #[test]
734 fn test_fast_cluster_empty() {
735 let input: Vec<Histogram> = vec![];
736 let result = fast_cluster_histograms(&input, 10).unwrap();
737
738 assert!(result.histograms.is_empty());
739 assert!(result.symbols.is_empty());
740 }
741
742 #[test]
743 fn test_fast_cluster_with_empty_histograms() {
744 let input = vec![
745 Histogram::new(), make_histogram(&[100, 50]),
747 Histogram::new(), ];
749
750 let result = fast_cluster_histograms(&input, 10).unwrap();
751
752 assert!(!result.histograms.is_empty());
754 assert_eq!(result.symbols[0], 0);
755 assert_eq!(result.symbols[2], 0);
756 }
757
758 #[test]
759 fn test_cluster_histograms_fastest() {
760 let input: Vec<Histogram> = (0..10)
761 .map(|i| {
762 let mut counts = vec![0i32; 10];
763 counts[i] = 100;
764 make_histogram(&counts)
765 })
766 .collect();
767
768 let result =
769 cluster_histograms(ClusteringType::Fastest, EntropyType::Huffman, &input, 10).unwrap();
770
771 assert!(result.histograms.len() <= 4);
773 }
774
775 #[test]
776 fn test_cluster_histograms_best_merges_huffman() {
777 let input = vec![
779 make_histogram(&[100, 50, 25, 10]),
780 make_histogram(&[105, 52, 23, 11]), make_histogram(&[10, 25, 50, 100]),
782 make_histogram(&[11, 23, 52, 105]), ];
784
785 let result =
786 cluster_histograms(ClusteringType::Best, EntropyType::Huffman, &input, 10).unwrap();
787
788 assert!(result.histograms.len() <= 4);
790 }
791
792 #[test]
793 fn test_cluster_histograms_best_merges_ans() {
794 let input = vec![
796 make_histogram(&[100, 50, 25, 10]),
797 make_histogram(&[105, 52, 23, 11]), make_histogram(&[10, 25, 50, 100]),
799 make_histogram(&[11, 23, 52, 105]), ];
801
802 let result =
803 cluster_histograms(ClusteringType::Best, EntropyType::Ans, &input, 10).unwrap();
804
805 assert!(result.histograms.len() <= 4);
807 }
808
809 #[test]
810 fn test_huffman_vs_ans_cost_model() {
811 let mut counts = vec![0i32; 64];
813 for (i, c) in counts.iter_mut().enumerate() {
814 *c = (64 - i as i32) * 10; }
816 let h = make_histogram(&counts);
817
818 let huffman_cost = huffman_population_cost(&h);
819 let ans_cost = ans_population_cost(&h);
820
821 assert!(huffman_cost > 0.0);
823 assert!(ans_cost > 0.0);
824
825 assert!((huffman_cost - ans_cost).abs() > 1.0);
829 }
830
831 #[test]
832 fn test_histogram_reindex() {
833 let mut histograms = vec![
834 make_histogram(&[100]),
835 make_histogram(&[200]),
836 make_histogram(&[300]),
837 ];
838 let mut symbols = vec![2, 0, 2, 1, 0];
839
840 histogram_reindex(&mut histograms, 0, &mut symbols);
841
842 assert_eq!(symbols, vec![0, 1, 0, 2, 1]);
845 }
846}
847
848#[test]
849fn test_huffman_cost_disjoint_histograms() {
850 let a = Histogram::from_counts(&[100, 50, 25, 0, 0, 0, 0, 0]);
852 a.shannon_entropy();
853
854 let b = Histogram::from_counts(&[0, 0, 0, 80, 40, 20, 0, 0]);
855 b.shannon_entropy();
856
857 let mut merged = a.clone();
858 merged.add_histogram(&b);
859 merged.shannon_entropy();
860
861 let cost_a = huffman_population_cost(&a);
862 let cost_b = huffman_population_cost(&b);
863 let cost_merged = huffman_population_cost(&merged);
864 let delta = cost_merged - cost_a - cost_b;
865
866 assert!(delta >= 0.0, "Disjoint merge should not be beneficial");
868}
869
870#[test]
871fn test_huffman_cost_identical_histograms() {
872 let a = Histogram::from_counts(&[100, 50, 25, 10, 0, 0, 0, 0]);
874 a.shannon_entropy();
875
876 let b = Histogram::from_counts(&[100, 50, 25, 10, 0, 0, 0, 0]);
877 b.shannon_entropy();
878
879 let mut merged = a.clone();
880 merged.add_histogram(&b);
881 merged.shannon_entropy();
882
883 let cost_a = huffman_population_cost(&a);
884 let cost_b = huffman_population_cost(&b);
885 let cost_merged = huffman_population_cost(&merged);
886 let delta = cost_merged - cost_a - cost_b;
887
888 assert!(
890 delta.abs() < 1.0,
891 "Identical histograms should have near-zero delta, got {}",
892 delta
893 );
894}