capability_grower_configuration_comparison/
compare_tree_expansion_policy.rs1crate::ix!();
3
4pub trait CompareTreeExpansionPolicy {
5 fn compare_tree_expansion_policy(&self, other: &GrowerTreeConfiguration) -> CompareOutcome;
6}
7
8impl CompareTreeExpansionPolicy for GrowerTreeConfiguration {
9 fn compare_tree_expansion_policy(&self, other: &GrowerTreeConfiguration) -> CompareOutcome {
10 use TreeExpansionPolicy::*;
11 match (self.tree_expansion_policy(), other.tree_expansion_policy()) {
12 (Simple, Simple) => CompareOutcome::Exact,
13 (AlwaysAggregate, AlwaysAggregate) => CompareOutcome::Exact,
14 (AlwaysDispatch, AlwaysDispatch) => CompareOutcome::Exact,
15 (AlwaysLeafHolder, AlwaysLeafHolder) => CompareOutcome::Exact,
16
17 (Weighted(a), Weighted(b)) => {
18 let diff_sum =
19 (a.aggregator_weight() - b.aggregator_weight()).abs() +
20 (a.dispatch_weight() - b.dispatch_weight()).abs() +
21 (a.leaf_holder_weight() - b.leaf_holder_weight()).abs();
22 if diff_sum < 0.01 {
23 CompareOutcome::Exact
24 } else if diff_sum < 1.0 {
25 let score = (1.0 - diff_sum).max(0.0);
26 CompareOutcome::Partial(score)
27 } else {
28 CompareOutcome::Incompatible
29 }
30 }
31
32 (WeightedWithLimits(a), WeightedWithLimits(b)) => {
33 let mut score = 1.0;
34 let diff_sum =
35 (a.aggregator_weight() - b.aggregator_weight()).abs() +
36 (a.dispatch_weight() - b.dispatch_weight()).abs() +
37 (a.leaf_holder_weight() - b.leaf_holder_weight()).abs();
38 if diff_sum > 1.0 {
39 return CompareOutcome::Incompatible;
40 } else if diff_sum > 0.01 {
41 score *= (1.0 - diff_sum);
42 }
43 if a.aggregator_max_depth() != b.aggregator_max_depth() {
44 score *= 0.8;
45 }
46 if a.dispatch_max_depth() != b.dispatch_max_depth() {
47 score *= 0.8;
48 }
49 if a.leaf_min_depth() != b.leaf_min_depth() {
50 score *= 0.8;
51 }
52 if score < 0.2 {
53 CompareOutcome::Incompatible
54 } else if score < 1.0 {
55 CompareOutcome::Partial(score)
56 } else {
57 CompareOutcome::Exact
58 }
59 }
60
61 (DepthBased(a), DepthBased(b)) => {
62 if a.aggregator_start_level() == b.aggregator_start_level()
63 && a.leaf_start_level() == b.leaf_start_level()
64 {
65 CompareOutcome::Exact
66 } else {
67 CompareOutcome::Partial(0.5)
68 }
69 }
70
71 (Phased(a), Phased(b)) => {
72 if a.phases().len() != b.phases().len() {
73 return CompareOutcome::Partial(0.5);
74 }
75 let mut total_score = 1.0;
76 for (pa, pb) in a.phases().iter().zip(b.phases().iter()) {
77 if pa.start_level() != pb.start_level() {
78 total_score *= 0.8;
79 }
80 let w_diff =
81 (pa.aggregator_weight() - pb.aggregator_weight()).abs() +
82 (pa.dispatch_weight() - pb.dispatch_weight()).abs() +
83 (pa.leaf_weight() - pb.leaf_weight()).abs();
84 if w_diff > 0.5 {
85 total_score *= 0.5;
86 } else if w_diff > 0.1 {
87 total_score *= 0.8;
88 }
89 }
90 if total_score < 0.2 {
91 CompareOutcome::Incompatible
92 } else if total_score < 1.0 {
93 CompareOutcome::Partial(total_score)
94 } else {
95 CompareOutcome::Exact
96 }
97 }
98
99 (Scripted(a), Scripted(b)) => {
100 let la = a.levels();
101 let lb = b.levels();
102 if la.len() != lb.len() {
103 return CompareOutcome::Partial(0.5);
104 }
105 let mut score = 1.0;
106 for (lvl, wa) in la.iter() {
107 if let Some(wb) = lb.get(lvl) {
108 let diff_sum =
109 (wa.aggregator_chance() - wb.aggregator_chance()).abs() +
110 (wa.dispatch_chance() - wb.dispatch_chance()).abs() +
111 (wa.leaf_chance() - wb.leaf_chance()).abs();
112 if diff_sum > 0.5 {
113 score *= 0.5;
114 } else if diff_sum > 0.1 {
115 score *= 0.8;
116 }
117 } else {
118 score *= 0.5;
119 }
120 }
121 if score < 0.2 {
122 CompareOutcome::Incompatible
123 } else if score < 1.0 {
124 CompareOutcome::Partial(score)
125 } else {
126 CompareOutcome::Exact
127 }
128 }
129
130 _ => CompareOutcome::Partial(0.3),
132 }
133 }
134}
135
136#[cfg(test)]
137mod compare_tree_expansion_policy_tests {
138 use super::*;
139
140 #[traced_test]
141 fn simple_simple_exact() {
142 let a = GrowerTreeConfiguration::default()
143 .to_builder().tree_expansion_policy(TreeExpansionPolicy::Simple).build().unwrap();
144 let b = a.clone();
145 assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
146 }
147
148 #[traced_test]
149 fn always_leaf_holder_exact() {
150 let a = GrowerTreeConfiguration::default()
151 .to_builder().tree_expansion_policy(TreeExpansionPolicy::AlwaysLeafHolder).build().unwrap();
152 let b = a.clone();
153 assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
154 }
155
156 #[traced_test]
157 fn weighted_partial() {
158 let w1 = WeightedNodeVariantPolicyBuilder::default()
159 .aggregator_weight(0.3).dispatch_weight(0.4).leaf_holder_weight(0.3)
160 .build().unwrap();
161 let w2 = WeightedNodeVariantPolicyBuilder::default()
162 .aggregator_weight(0.3).dispatch_weight(0.35).leaf_holder_weight(0.35)
163 .build().unwrap();
164 let a = GrowerTreeConfiguration::default()
165 .to_builder().tree_expansion_policy(TreeExpansionPolicy::Weighted(w1)).build().unwrap();
166 let b = a.to_builder().tree_expansion_policy(TreeExpansionPolicy::Weighted(w2)).build().unwrap();
167 match a.compare_tree_expansion_policy(&b) {
168 CompareOutcome::Partial(_) => {},
169 other => panic!("expected partial, got {:?}", other),
170 }
171 }
172
173 #[traced_test]
174 fn different_variants_partial() {
175 let a = GrowerTreeConfiguration::default()
176 .to_builder().tree_expansion_policy(TreeExpansionPolicy::Simple).build().unwrap();
177 let b = a.to_builder().tree_expansion_policy(TreeExpansionPolicy::AlwaysDispatch).build().unwrap();
178 match a.compare_tree_expansion_policy(&b) {
179 CompareOutcome::Partial(_) => {},
180 other => panic!("expected partial, got {:?}", other),
181 }
182 }
183
184 #[traced_test]
185 fn depth_based_exact() {
186 let db = DepthBasedNodeVariantPolicyBuilder::default()
187 .aggregator_start_level(2)
188 .leaf_start_level(5)
189 .build().unwrap();
190 let a = GrowerTreeConfiguration::default()
191 .to_builder().tree_expansion_policy(TreeExpansionPolicy::DepthBased(db.clone())).build().unwrap();
192 let b = a.clone();
193 assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
194 }
195}