capability_grower_configuration_comparison/
compare_weighted_branching.rs1crate::ix!();
3
4pub trait CompareWeightedBranching {
5 fn compare_weighted_branching(&self, other: &GrowerTreeConfiguration) -> CompareOutcome;
6}
7
8impl CompareWeightedBranching for GrowerTreeConfiguration {
9 fn compare_weighted_branching(&self, other: &GrowerTreeConfiguration) -> CompareOutcome {
10 match (self.weighted_branching(), other.weighted_branching()) {
11 (None, None) => CompareOutcome::Exact,
12 (None, Some(_)) | (Some(_), None) => CompareOutcome::Partial(0.5),
13 (Some(a), Some(b)) => {
14 let mean_diff = (*a.mean() as i32 - *b.mean() as i32).abs();
15 let var_diff = (*a.variance() as i32 - *b.variance() as i32).abs();
16 if mean_diff == 0 && var_diff == 0 {
17 CompareOutcome::Exact
18 } else if mean_diff < 3 && var_diff < 3 {
19 let penalty = (mean_diff + var_diff) as f32 * 0.1;
20 let score = (1.0 - penalty).max(0.0);
21 CompareOutcome::Partial(score)
22 } else {
23 CompareOutcome::Incompatible
24 }
25 }
26 }
27 }
28}
29
30#[cfg(test)]
31mod compare_weighted_branching_tests {
32 use super::*;
33
34 #[traced_test]
35 fn both_none() {
36 let a = GrowerTreeConfiguration::default();
37 let b = a.clone();
38 assert_eq!(a.compare_weighted_branching(&b), CompareOutcome::Exact);
39 }
40
41 #[traced_test]
42 fn same_mean_variance_exact() {
43 let wb = WeightedBranchingConfigurationBuilder::default()
44 .mean(4).variance(2).build().unwrap();
45 let a = GrowerTreeConfiguration::default()
46 .to_builder().weighted_branching(Some(wb.clone())).build().unwrap();
47 let b = a.clone();
48 assert_eq!(a.compare_weighted_branching(&b), CompareOutcome::Exact);
49 }
50
51 #[traced_test]
52 fn partial_same_mean() {
53 let w1 = WeightedBranchingConfigurationBuilder::default()
54 .mean(5).variance(1).build().unwrap();
55 let w2 = WeightedBranchingConfigurationBuilder::default()
56 .mean(5).variance(3).build().unwrap();
57 let a = GrowerTreeConfiguration::default()
58 .to_builder().weighted_branching(Some(w1)).build().unwrap();
59 let b = a.to_builder().weighted_branching(Some(w2)).build().unwrap();
60 match a.compare_weighted_branching(&b) {
61 CompareOutcome::Partial(_) => {},
62 other => panic!("expected partial, got {:?}", other),
63 }
64 }
65
66 #[traced_test]
67 fn incompatible() {
68 let w1 = WeightedBranchingConfigurationBuilder::default()
69 .mean(3).variance(1).build().unwrap();
70 let w2 = WeightedBranchingConfigurationBuilder::default()
71 .mean(5).variance(8).build().unwrap();
72 let a = GrowerTreeConfiguration::default()
73 .to_builder().weighted_branching(Some(w1)).build().unwrap();
74 let b = a.to_builder().weighted_branching(Some(w2)).build().unwrap();
75 assert_eq!(a.compare_weighted_branching(&b), CompareOutcome::Incompatible);
76 }
77}