1use alloc::collections::BinaryHeap;
10use core::cmp::Ordering;
11
12use super::histogram::{Histogram, histogram_distance, histogram_kl_divergence};
13use crate::error::{Error, Result};
14
15const MIN_DISTANCE_FOR_DISTINCT: f32 = 48.0;
17
18pub const CLUSTERS_LIMIT: usize = 256;
20
21#[derive(Debug, Clone)]
23pub struct ClusterResult {
24 pub histograms: Vec<Histogram>,
26 pub symbols: Vec<u32>,
28}
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
32pub enum ClusteringType {
33 Fastest,
35 #[default]
37 Fast,
38 Best,
40}
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
44pub enum EntropyType {
45 #[default]
47 Huffman,
48 Ans,
50}
51
52pub fn fast_cluster_histograms(
62 input: &[Histogram],
63 max_histograms: usize,
64) -> Result<ClusterResult> {
65 fast_cluster_histograms_with_prev(input, max_histograms, &[])
66}
67
68pub fn fast_cluster_histograms_with_prev(
74 input: &[Histogram],
75 max_histograms: usize,
76 prev_histograms: &[Histogram],
77) -> Result<ClusterResult> {
78 if input.is_empty() {
79 return Ok(ClusterResult {
80 histograms: prev_histograms.to_vec(),
81 symbols: Vec::new(),
82 });
83 }
84
85 let prev_count = prev_histograms.len();
86 let mut out: Vec<Histogram> = prev_histograms.to_vec();
87 out.reserve(max_histograms);
88
89 let unassigned = max_histograms as u32;
91 let mut symbols = vec![unassigned; input.len()];
92
93 let mut dists = vec![f32::MAX; input.len()];
95
96 let mut largest_idx = 0;
98 for (i, h) in input.iter().enumerate() {
99 if h.total_count == 0 {
100 symbols[i] = 0;
102 dists[i] = 0.0;
103 continue;
104 }
105 h.shannon_entropy(); if h.total_count > input[largest_idx].total_count {
107 largest_idx = i;
108 }
109 }
110
111 if prev_count > 0 {
114 for h in &out {
115 h.shannon_entropy();
116 }
117 for (i, dist) in dists.iter_mut().enumerate() {
118 if *dist == 0.0 {
119 continue;
120 }
121 for out_hist in out.iter().take(prev_count) {
122 let kl = histogram_kl_divergence(&input[i], out_hist);
123 *dist = dist.min(kl);
124 }
125 }
126 if let Some((max_idx, &max_dist)) = dists
128 .iter()
129 .enumerate()
130 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
131 && max_dist > 0.0
132 {
133 largest_idx = max_idx;
134 }
135 }
136
137 while out.len() < prev_count + max_histograms {
139 symbols[largest_idx] = out.len() as u32;
141 out.push(input[largest_idx].clone());
142 dists[largest_idx] = 0.0;
143
144 let mut new_largest_idx = 0;
146 for (i, h) in input.iter().enumerate() {
147 if dists[i] == 0.0 {
148 continue;
149 }
150 let dist = histogram_distance(h, out.last().unwrap());
152 dists[i] = dists[i].min(dist);
153 if dists[i] > dists[new_largest_idx] {
154 new_largest_idx = i;
155 }
156 }
157 largest_idx = new_largest_idx;
158
159 if dists[largest_idx] < MIN_DISTANCE_FOR_DISTINCT {
161 break;
162 }
163 }
164
165 for i in 0..input.len() {
167 if symbols[i] != unassigned {
168 continue;
169 }
170
171 let mut best = 0;
173 let mut best_dist = f32::MAX;
174
175 for (j, out_hist) in out.iter().enumerate() {
176 let dist = if j < prev_count {
177 histogram_kl_divergence(&input[i], out_hist)
179 } else {
180 histogram_distance(&input[i], out_hist)
182 };
183
184 if dist < best_dist {
185 best = j;
186 best_dist = dist;
187 }
188 }
189
190 if best_dist >= f32::MAX {
191 return Err(Error::InvalidHistogram(format!(
192 "Failed to find cluster for histogram {}",
193 i
194 )));
195 }
196
197 if best >= prev_count {
199 out[best].add_histogram(&input[i]);
200 out[best].shannon_entropy(); }
202 symbols[i] = best as u32;
203 }
204
205 Ok(ClusterResult {
206 histograms: out,
207 symbols,
208 })
209}
210
211#[derive(Clone, Copy, Debug)]
213struct HistogramPair {
214 cost: f32,
215 first: u32,
216 second: u32,
217 version: u32,
218}
219
220impl PartialEq for HistogramPair {
221 fn eq(&self, other: &Self) -> bool {
222 self.cost == other.cost
223 && self.first == other.first
224 && self.second == other.second
225 && self.version == other.version
226 }
227}
228
229impl Eq for HistogramPair {}
230
231impl PartialOrd for HistogramPair {
232 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
233 Some(self.cmp(other))
234 }
235}
236
237impl Ord for HistogramPair {
238 fn cmp(&self, other: &Self) -> Ordering {
239 let self_tuple = (
242 ordered_float::OrderedFloat(self.cost),
243 self.first,
244 self.second,
245 self.version,
246 );
247 let other_tuple = (
248 ordered_float::OrderedFloat(other.cost),
249 other.first,
250 other.second,
251 other.version,
252 );
253 other_tuple.cmp(&self_tuple)
255 }
256}
257
258mod ordered_float {
260 use core::cmp::Ordering;
261
262 #[derive(Clone, Copy, Debug)]
263 pub struct OrderedFloat(pub f32);
264
265 impl PartialEq for OrderedFloat {
266 fn eq(&self, other: &Self) -> bool {
267 self.0 == other.0
268 }
269 }
270
271 impl Eq for OrderedFloat {}
272
273 impl PartialOrd for OrderedFloat {
274 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
275 Some(self.cmp(other))
276 }
277 }
278
279 impl Ord for OrderedFloat {
280 fn cmp(&self, other: &Self) -> Ordering {
281 self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
282 }
283 }
284}
285
286fn huffman_population_cost(h: &Histogram) -> f32 {
297 if h.total_count == 0 {
298 return 0.0;
299 }
300
301 let alphabet_size = h.alphabet_size();
302 if alphabet_size == 0 {
303 return 0.0;
304 }
305
306 let data_cost = compute_huffman_data_cost(h, alphabet_size);
308
309 let non_zero_count = h
315 .counts
316 .iter()
317 .take(alphabet_size)
318 .filter(|&&c| c > 0)
319 .count();
320
321 let header_penalty = (non_zero_count as f32) * 0.1;
325
326 data_cost + header_penalty
327}
328
329fn compute_huffman_data_cost(h: &Histogram, alphabet_size: usize) -> f32 {
334 use super::huffman_tree::create_huffman_tree;
335
336 if alphabet_size == 0 {
337 return 0.0;
338 }
339
340 let counts: Vec<u32> = h
342 .counts
343 .iter()
344 .take(alphabet_size)
345 .map(|&c| c.max(0) as u32)
346 .collect();
347
348 let non_zero = counts.iter().filter(|&&c| c > 0).count();
350 if non_zero == 0 {
351 return 0.0;
352 }
353 if non_zero == 1 {
354 return counts.iter().sum::<u32>() as f32;
356 }
357
358 let depths = create_huffman_tree(&counts, 15);
360
361 let mut cost = 0.0f32;
363 for (i, &count) in counts.iter().enumerate() {
364 if count > 0 && i < depths.len() {
365 cost += count as f32 * depths[i] as f32;
366 }
367 }
368
369 cost
370}
371
372#[allow(dead_code)]
378fn compute_cross_coding_cost(data: &Histogram, tree: &Histogram, alphabet_size: usize) -> f32 {
379 use super::huffman_tree::create_huffman_tree;
380
381 if alphabet_size == 0 {
382 return 0.0;
383 }
384
385 let tree_counts: Vec<u32> = tree
387 .counts
388 .iter()
389 .take(alphabet_size)
390 .map(|&c| c.max(0) as u32)
391 .collect();
392
393 let non_zero = tree_counts.iter().filter(|&&c| c > 0).count();
394 if non_zero == 0 {
395 return 0.0;
396 }
397
398 let depths = if non_zero == 1 {
399 vec![1u8; alphabet_size]
400 } else {
401 create_huffman_tree(&tree_counts, 15)
402 };
403
404 let mut cost = 0.0f32;
406 for (i, &count) in data.counts.iter().take(alphabet_size).enumerate() {
407 if count > 0 && i < depths.len() {
408 let depth = if depths[i] == 0 { 15 } else { depths[i] }; cost += count.max(0) as f32 * depth as f32;
410 }
411 }
412
413 cost
414}
415
416fn ans_population_cost(h: &Histogram) -> f32 {
421 if h.total_count == 0 {
422 return 0.0;
423 }
424
425 let alphabet_size = h.alphabet_size();
426 if alphabet_size <= 1 {
427 return 0.0;
429 }
430
431 let data_cost = h.cached_entropy();
433
434 let header_cost = (alphabet_size as f32) * 5.0;
438
439 data_cost + header_cost
440}
441
442fn population_cost(h: &Histogram, entropy_type: EntropyType) -> f32 {
444 match entropy_type {
445 EntropyType::Huffman => huffman_population_cost(h),
446 EntropyType::Ans => ans_population_cost(h),
447 }
448}
449
450pub fn refine_clusters_by_merging(
459 histograms: &mut Vec<Histogram>,
460 symbols: &mut [u32],
461 entropy_type: EntropyType,
462) -> Result<()> {
463 if histograms.is_empty() {
464 return Ok(());
465 }
466
467 for h in histograms.iter() {
469 h.shannon_entropy();
470 }
471
472 let mut version = vec![1u32; histograms.len()];
474 let mut next_version = 2u32;
475
476 let mut renumbering: Vec<u32> = (0..histograms.len() as u32).collect();
478
479 let mut pairs_to_merge: BinaryHeap<HistogramPair> = BinaryHeap::new();
481
482 for i in 0..histograms.len() as u32 {
483 for j in (i + 1)..histograms.len() as u32 {
484 let mut merged = histograms[i as usize].clone();
486 merged.add_histogram(&histograms[j as usize]);
487 merged.shannon_entropy();
488
489 let merged_cost = population_cost(&merged, entropy_type);
490 let individual_cost = population_cost(&histograms[i as usize], entropy_type)
491 + population_cost(&histograms[j as usize], entropy_type);
492
493 let cost = merged_cost - individual_cost;
494
495 if cost < 0.0 {
497 pairs_to_merge.push(HistogramPair {
498 cost,
499 first: i,
500 second: j,
501 version: version[i as usize].max(version[j as usize]),
502 });
503 }
504 }
505 }
506
507 while let Some(pair) = pairs_to_merge.pop() {
509 let first = pair.first as usize;
510 let second = pair.second as usize;
511
512 let expected_version = version[first].max(version[second]);
514 if pair.version != expected_version || version[first] == 0 || version[second] == 0 {
515 continue;
516 }
517
518 let second_histo = histograms[second].clone();
520 histograms[first].add_histogram(&second_histo);
521 histograms[first].shannon_entropy();
522
523 for item in renumbering.iter_mut() {
525 if *item == pair.second {
526 *item = pair.first;
527 }
528 }
529
530 version[second] = 0;
532 version[first] = next_version;
533 next_version += 1;
534
535 for j in 0..histograms.len() as u32 {
537 if j == pair.first || version[j as usize] == 0 {
538 continue;
539 }
540
541 let mut merged = histograms[first].clone();
542 merged.add_histogram(&histograms[j as usize]);
543 merged.shannon_entropy();
544
545 let merged_cost = population_cost(&merged, entropy_type);
546 let individual_cost = population_cost(&histograms[first], entropy_type)
547 + population_cost(&histograms[j as usize], entropy_type);
548
549 let cost = merged_cost - individual_cost;
550
551 if cost < 0.0 {
552 pairs_to_merge.push(HistogramPair {
553 cost,
554 first: pair.first.min(j),
555 second: pair.first.max(j),
556 version: version[first].max(version[j as usize]),
557 });
558 }
559 }
560 }
561
562 let mut reverse_renumbering = vec![u32::MAX; histograms.len()];
564 let mut num_alive = 0u32;
565
566 for i in 0..histograms.len() {
567 if version[i] == 0 {
568 continue;
569 }
570 if num_alive != i as u32 {
571 histograms[num_alive as usize] = histograms[i].clone();
572 }
573 reverse_renumbering[i] = num_alive;
574 num_alive += 1;
575 }
576 histograms.truncate(num_alive as usize);
577
578 for symbol in symbols.iter_mut() {
580 let renumbered = renumbering[*symbol as usize];
581 *symbol = reverse_renumbering[renumbered as usize];
582 }
583
584 Ok(())
585}
586
587fn histogram_reindex(histograms: &mut Vec<Histogram>, prev_count: usize, symbols: &mut [u32]) {
589 use std::collections::HashMap;
590
591 let tmp = histograms.clone();
592 let mut new_index: HashMap<u32, u32> = HashMap::new();
593
594 for i in 0..prev_count {
596 new_index.insert(i as u32, i as u32);
597 }
598
599 let mut next_index = prev_count as u32;
601 for &symbol in symbols.iter() {
602 if let std::collections::hash_map::Entry::Vacant(e) = new_index.entry(symbol) {
603 e.insert(next_index);
604 histograms[next_index as usize] = tmp[symbol as usize].clone();
605 next_index += 1;
606 }
607 }
608
609 histograms.truncate(next_index as usize);
610
611 for symbol in symbols.iter_mut() {
613 *symbol = new_index[symbol];
614 }
615}
616
617pub fn cluster_histograms(
628 clustering_type: ClusteringType,
629 entropy_type: EntropyType,
630 input: &[Histogram],
631 max_histograms: usize,
632) -> Result<ClusterResult> {
633 let max_histograms = match clustering_type {
634 ClusteringType::Fastest => max_histograms.min(4),
635 _ => max_histograms,
636 };
637
638 let max_histograms = max_histograms.min(input.len()).min(CLUSTERS_LIMIT);
639
640 let mut result = fast_cluster_histograms(input, max_histograms)?;
642
643 if clustering_type == ClusteringType::Best && !result.histograms.is_empty() {
645 refine_clusters_by_merging(&mut result.histograms, &mut result.symbols, entropy_type)?;
646 }
647
648 histogram_reindex(&mut result.histograms, 0, &mut result.symbols);
650
651 Ok(result)
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657
658 fn make_histogram(counts: &[i32]) -> Histogram {
659 let h = Histogram::from_counts(counts);
660 h.shannon_entropy(); h
662 }
663
664 #[test]
665 fn test_fast_cluster_single() {
666 let input = vec![make_histogram(&[100, 50, 25])];
667
668 let result = fast_cluster_histograms(&input, 10).unwrap();
669
670 assert_eq!(result.histograms.len(), 1);
671 assert_eq!(result.symbols.len(), 1);
672 assert_eq!(result.symbols[0], 0);
673 }
674
675 #[test]
676 fn test_fast_cluster_identical() {
677 let input = vec![
678 make_histogram(&[100, 50, 25]),
679 make_histogram(&[100, 50, 25]),
680 make_histogram(&[100, 50, 25]),
681 ];
682
683 let result = fast_cluster_histograms(&input, 10).unwrap();
684
685 assert_eq!(result.histograms.len(), 1);
687 assert_eq!(result.symbols, vec![0, 0, 0]);
688 }
689
690 #[test]
691 fn test_fast_cluster_different() {
692 let input = vec![
693 make_histogram(&[100, 0, 0]),
694 make_histogram(&[0, 100, 0]),
695 make_histogram(&[0, 0, 100]),
696 ];
697
698 let result = fast_cluster_histograms(&input, 10).unwrap();
699
700 assert!(result.histograms.len() >= 2);
702 assert!(
704 result
705 .symbols
706 .iter()
707 .all(|&s| (s as usize) < result.histograms.len())
708 );
709 }
710
711 #[test]
712 fn test_fast_cluster_max_limit() {
713 let input: Vec<Histogram> = (0..10)
714 .map(|i| {
715 let mut counts = vec![0i32; 10];
716 counts[i] = 100;
717 make_histogram(&counts)
718 })
719 .collect();
720
721 let result = fast_cluster_histograms(&input, 4).unwrap();
722
723 assert!(result.histograms.len() <= 4);
725 }
726
727 #[test]
728 fn test_fast_cluster_empty() {
729 let input: Vec<Histogram> = vec![];
730 let result = fast_cluster_histograms(&input, 10).unwrap();
731
732 assert!(result.histograms.is_empty());
733 assert!(result.symbols.is_empty());
734 }
735
736 #[test]
737 fn test_fast_cluster_with_empty_histograms() {
738 let input = vec![
739 Histogram::new(), make_histogram(&[100, 50]),
741 Histogram::new(), ];
743
744 let result = fast_cluster_histograms(&input, 10).unwrap();
745
746 assert!(!result.histograms.is_empty());
748 assert_eq!(result.symbols[0], 0);
749 assert_eq!(result.symbols[2], 0);
750 }
751
752 #[test]
753 fn test_cluster_histograms_fastest() {
754 let input: Vec<Histogram> = (0..10)
755 .map(|i| {
756 let mut counts = vec![0i32; 10];
757 counts[i] = 100;
758 make_histogram(&counts)
759 })
760 .collect();
761
762 let result =
763 cluster_histograms(ClusteringType::Fastest, EntropyType::Huffman, &input, 10).unwrap();
764
765 assert!(result.histograms.len() <= 4);
767 }
768
769 #[test]
770 fn test_cluster_histograms_best_merges_huffman() {
771 let input = vec![
773 make_histogram(&[100, 50, 25, 10]),
774 make_histogram(&[105, 52, 23, 11]), make_histogram(&[10, 25, 50, 100]),
776 make_histogram(&[11, 23, 52, 105]), ];
778
779 let result =
780 cluster_histograms(ClusteringType::Best, EntropyType::Huffman, &input, 10).unwrap();
781
782 assert!(result.histograms.len() <= 4);
784 }
785
786 #[test]
787 fn test_cluster_histograms_best_merges_ans() {
788 let input = vec![
790 make_histogram(&[100, 50, 25, 10]),
791 make_histogram(&[105, 52, 23, 11]), make_histogram(&[10, 25, 50, 100]),
793 make_histogram(&[11, 23, 52, 105]), ];
795
796 let result =
797 cluster_histograms(ClusteringType::Best, EntropyType::Ans, &input, 10).unwrap();
798
799 assert!(result.histograms.len() <= 4);
801 }
802
803 #[test]
804 fn test_huffman_vs_ans_cost_model() {
805 let mut counts = vec![0i32; 64];
807 for (i, c) in counts.iter_mut().enumerate() {
808 *c = (64 - i as i32) * 10; }
810 let h = make_histogram(&counts);
811
812 let huffman_cost = huffman_population_cost(&h);
813 let ans_cost = ans_population_cost(&h);
814
815 assert!(huffman_cost > 0.0);
817 assert!(ans_cost > 0.0);
818
819 assert!((huffman_cost - ans_cost).abs() > 1.0);
823 }
824
825 #[test]
826 fn test_histogram_reindex() {
827 let mut histograms = vec![
828 make_histogram(&[100]),
829 make_histogram(&[200]),
830 make_histogram(&[300]),
831 ];
832 let mut symbols = vec![2, 0, 2, 1, 0];
833
834 histogram_reindex(&mut histograms, 0, &mut symbols);
835
836 assert_eq!(symbols, vec![0, 1, 0, 2, 1]);
839 }
840}
841
842#[test]
843fn test_huffman_cost_disjoint_histograms() {
844 let a = Histogram::from_counts(&[100, 50, 25, 0, 0, 0, 0, 0]);
846 a.shannon_entropy();
847
848 let b = Histogram::from_counts(&[0, 0, 0, 80, 40, 20, 0, 0]);
849 b.shannon_entropy();
850
851 let mut merged = a.clone();
852 merged.add_histogram(&b);
853 merged.shannon_entropy();
854
855 let cost_a = huffman_population_cost(&a);
856 let cost_b = huffman_population_cost(&b);
857 let cost_merged = huffman_population_cost(&merged);
858 let delta = cost_merged - cost_a - cost_b;
859
860 assert!(delta >= 0.0, "Disjoint merge should not be beneficial");
862}
863
864#[test]
865fn test_huffman_cost_identical_histograms() {
866 let a = Histogram::from_counts(&[100, 50, 25, 10, 0, 0, 0, 0]);
868 a.shannon_entropy();
869
870 let b = Histogram::from_counts(&[100, 50, 25, 10, 0, 0, 0, 0]);
871 b.shannon_entropy();
872
873 let mut merged = a.clone();
874 merged.add_histogram(&b);
875 merged.shannon_entropy();
876
877 let cost_a = huffman_population_cost(&a);
878 let cost_b = huffman_population_cost(&b);
879 let cost_merged = huffman_population_cost(&merged);
880 let delta = cost_merged - cost_a - cost_b;
881
882 assert!(
884 delta.abs() < 1.0,
885 "Identical histograms should have near-zero delta, got {}",
886 delta
887 );
888}