capability_grower_configuration_comparison/
compare_level_skipping.rs1crate::ix!();
3
4pub trait CompareLevelSkipping {
5 fn compare_level_skipping(&self, other: &GrowerTreeConfiguration) -> CompareOutcome;
6}
7
8impl CompareLevelSkipping for GrowerTreeConfiguration {
9 fn compare_level_skipping(&self, other: &GrowerTreeConfiguration) -> CompareOutcome {
10 match (self.level_skipping(), other.level_skipping()) {
11 (None, None) => CompareOutcome::Exact,
12 (None, Some(_)) | (Some(_), None) => CompareOutcome::Partial(0.5),
13 (Some(a), Some(b)) => {
14 let va = a.leaf_probability_per_level();
15 let vb = b.leaf_probability_per_level();
16 if va.len() != vb.len() {
17 CompareOutcome::Incompatible
18 } else {
19 let mut score = 1.0;
20 for (x, y) in va.iter().zip(vb.iter()) {
21 let diff = (x - y).abs();
22 if diff > 0.3 {
23 score *= 0.5;
24 } else if diff > 0.1 {
25 score *= 0.8;
26 }
27 }
28 if score < 0.2 {
29 CompareOutcome::Incompatible
30 } else if score < 1.0 {
31 CompareOutcome::Partial(score)
32 } else {
33 CompareOutcome::Exact
34 }
35 }
36 }
37 }
38 }
39}
40
41#[cfg(test)]
42mod compare_level_skipping_tests {
43 use super::*;
44
45 #[traced_test]
46 fn none_none() {
47 let a = GrowerTreeConfiguration::default();
48 let b = a.clone();
49 assert_eq!(a.compare_level_skipping(&b), CompareOutcome::Exact);
50 }
51
52 #[traced_test]
53 fn same_arrays_exact() {
54 let ls = LevelSkippingConfigurationBuilder::default()
55 .leaf_probability_per_level(vec![0.0,0.2,1.0])
56 .build().unwrap();
57 let a = GrowerTreeConfiguration::default()
58 .to_builder().level_skipping(Some(ls.clone())).build().unwrap();
59 let b = a.clone();
60 assert_eq!(a.compare_level_skipping(&b), CompareOutcome::Exact);
61 }
62
63 #[traced_test]
64 fn partial_same_length_different_values() {
65 let c1 = vec![0.0,0.3,0.5];
66 let c2 = vec![0.1,0.1,0.5];
67 let ls1 = LevelSkippingConfigurationBuilder::default()
68 .leaf_probability_per_level(c1)
69 .build().unwrap();
70 let ls2 = LevelSkippingConfigurationBuilder::default()
71 .leaf_probability_per_level(c2)
72 .build().unwrap();
73 let a = GrowerTreeConfiguration::default()
74 .to_builder().level_skipping(Some(ls1)).build().unwrap();
75 let b = a.to_builder().level_skipping(Some(ls2)).build().unwrap();
76 match a.compare_level_skipping(&b) {
77 CompareOutcome::Partial(_) => {},
78 other => panic!("expected partial, got {:?}", other),
79 }
80 }
81
82 #[traced_test]
83 fn incompatible_diff_len() {
84 let ls1 = LevelSkippingConfigurationBuilder::default()
85 .leaf_probability_per_level(vec![0.0,0.2])
86 .build().unwrap();
87 let ls2 = LevelSkippingConfigurationBuilder::default()
88 .leaf_probability_per_level(vec![0.0,0.2,0.8])
89 .build().unwrap();
90 let a = GrowerTreeConfiguration::default()
91 .to_builder().level_skipping(Some(ls1)).build().unwrap();
92 let b = a.to_builder().level_skipping(Some(ls2)).build().unwrap();
93 assert_eq!(a.compare_level_skipping(&b), CompareOutcome::Incompatible);
94 }
95}