capability_grower_configuration_comparison/
compare_tree_expansion_policy.rs

1// ---------------- [ File: capability-grower-configuration-comparison/src/compare_tree_expansion_policy.rs ]
2crate::ix!();
3
4pub trait CompareTreeExpansionPolicy {
5    fn compare_tree_expansion_policy(&self, other: &GrowerTreeConfiguration) -> CompareOutcome;
6}
7
8impl CompareTreeExpansionPolicy for GrowerTreeConfiguration {
9    fn compare_tree_expansion_policy(&self, other: &GrowerTreeConfiguration) -> CompareOutcome {
10        use TreeExpansionPolicy::*;
11        match (self.tree_expansion_policy(), other.tree_expansion_policy()) {
12            (Simple, Simple) => CompareOutcome::Exact,
13            (AlwaysAggregate, AlwaysAggregate) => CompareOutcome::Exact,
14            (AlwaysDispatch, AlwaysDispatch) => CompareOutcome::Exact,
15            (AlwaysLeafHolder, AlwaysLeafHolder) => CompareOutcome::Exact,
16
17            (Weighted(a), Weighted(b)) => {
18                let diff_sum = 
19                    (a.aggregator_weight() - b.aggregator_weight()).abs() +
20                    (a.dispatch_weight() - b.dispatch_weight()).abs() +
21                    (a.leaf_holder_weight() - b.leaf_holder_weight()).abs();
22                if diff_sum < 0.01 {
23                    CompareOutcome::Exact
24                } else if diff_sum < 1.0 {
25                    let score = (1.0 - diff_sum).max(0.0);
26                    CompareOutcome::Partial(score)
27                } else {
28                    CompareOutcome::Incompatible
29                }
30            }
31
32            (WeightedWithLimits(a), WeightedWithLimits(b)) => {
33                let mut score = 1.0;
34                let diff_sum =
35                    (a.aggregator_weight() - b.aggregator_weight()).abs() +
36                    (a.dispatch_weight() - b.dispatch_weight()).abs() +
37                    (a.leaf_holder_weight() - b.leaf_holder_weight()).abs();
38                if diff_sum > 1.0 {
39                    return CompareOutcome::Incompatible;
40                } else if diff_sum > 0.01 {
41                    score *= (1.0 - diff_sum);
42                }
43                if a.aggregator_max_depth() != b.aggregator_max_depth() {
44                    score *= 0.8;
45                }
46                if a.dispatch_max_depth() != b.dispatch_max_depth() {
47                    score *= 0.8;
48                }
49                if a.leaf_min_depth() != b.leaf_min_depth() {
50                    score *= 0.8;
51                }
52                if score < 0.2 {
53                    CompareOutcome::Incompatible
54                } else if score < 1.0 {
55                    CompareOutcome::Partial(score)
56                } else {
57                    CompareOutcome::Exact
58                }
59            }
60
61            (DepthBased(a), DepthBased(b)) => {
62                if a.aggregator_start_level() == b.aggregator_start_level()
63                    && a.leaf_start_level() == b.leaf_start_level()
64                {
65                    CompareOutcome::Exact
66                } else {
67                    CompareOutcome::Partial(0.5)
68                }
69            }
70
71            (Phased(a), Phased(b)) => {
72                if a.phases().len() != b.phases().len() {
73                    return CompareOutcome::Partial(0.5);
74                }
75                let mut total_score = 1.0;
76                for (pa, pb) in a.phases().iter().zip(b.phases().iter()) {
77                    if pa.start_level() != pb.start_level() {
78                        total_score *= 0.8;
79                    }
80                    let w_diff = 
81                        (pa.aggregator_weight() - pb.aggregator_weight()).abs() +
82                        (pa.dispatch_weight() - pb.dispatch_weight()).abs() +
83                        (pa.leaf_weight() - pb.leaf_weight()).abs();
84                    if w_diff > 0.5 {
85                        total_score *= 0.5;
86                    } else if w_diff > 0.1 {
87                        total_score *= 0.8;
88                    }
89                }
90                if total_score < 0.2 {
91                    CompareOutcome::Incompatible
92                } else if total_score < 1.0 {
93                    CompareOutcome::Partial(total_score)
94                } else {
95                    CompareOutcome::Exact
96                }
97            }
98
99            (Scripted(a), Scripted(b)) => {
100                let la = a.levels();
101                let lb = b.levels();
102                if la.len() != lb.len() {
103                    return CompareOutcome::Partial(0.5);
104                }
105                let mut score = 1.0;
106                for (lvl, wa) in la.iter() {
107                    if let Some(wb) = lb.get(lvl) {
108                        let diff_sum = 
109                            (wa.aggregator_chance() - wb.aggregator_chance()).abs() +
110                            (wa.dispatch_chance() - wb.dispatch_chance()).abs() +
111                            (wa.leaf_chance() - wb.leaf_chance()).abs();
112                        if diff_sum > 0.5 {
113                            score *= 0.5;
114                        } else if diff_sum > 0.1 {
115                            score *= 0.8;
116                        }
117                    } else {
118                        score *= 0.5;
119                    }
120                }
121                if score < 0.2 {
122                    CompareOutcome::Incompatible
123                } else if score < 1.0 {
124                    CompareOutcome::Partial(score)
125                } else {
126                    CompareOutcome::Exact
127                }
128            }
129
130            // default fallback if variants differ
131            _ => CompareOutcome::Partial(0.3),
132        }
133    }
134}
135
136#[cfg(test)]
137mod compare_tree_expansion_policy_tests {
138    use super::*;
139
140    #[traced_test]
141    fn simple_simple_exact() {
142        let a = GrowerTreeConfiguration::default()
143            .to_builder().tree_expansion_policy(TreeExpansionPolicy::Simple).build().unwrap();
144        let b = a.clone();
145        assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
146    }
147
148    #[traced_test]
149    fn always_leaf_holder_exact() {
150        let a = GrowerTreeConfiguration::default()
151            .to_builder().tree_expansion_policy(TreeExpansionPolicy::AlwaysLeafHolder).build().unwrap();
152        let b = a.clone();
153        assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
154    }
155
156    #[traced_test]
157    fn weighted_partial() {
158        let w1 = WeightedNodeVariantPolicyBuilder::default()
159            .aggregator_weight(0.3).dispatch_weight(0.4).leaf_holder_weight(0.3)
160            .build().unwrap();
161        let w2 = WeightedNodeVariantPolicyBuilder::default()
162            .aggregator_weight(0.3).dispatch_weight(0.35).leaf_holder_weight(0.35)
163            .build().unwrap();
164        let a = GrowerTreeConfiguration::default()
165            .to_builder().tree_expansion_policy(TreeExpansionPolicy::Weighted(w1)).build().unwrap();
166        let b = a.to_builder().tree_expansion_policy(TreeExpansionPolicy::Weighted(w2)).build().unwrap();
167        match a.compare_tree_expansion_policy(&b) {
168            CompareOutcome::Partial(_) => {},
169            other => panic!("expected partial, got {:?}", other),
170        }
171    }
172
173    #[traced_test]
174    fn different_variants_partial() {
175        let a = GrowerTreeConfiguration::default()
176            .to_builder().tree_expansion_policy(TreeExpansionPolicy::Simple).build().unwrap();
177        let b = a.to_builder().tree_expansion_policy(TreeExpansionPolicy::AlwaysDispatch).build().unwrap();
178        match a.compare_tree_expansion_policy(&b) {
179            CompareOutcome::Partial(_) => {},
180            other => panic!("expected partial, got {:?}", other),
181        }
182    }
183
184    #[traced_test]
185    fn depth_based_exact() {
186        let db = DepthBasedNodeVariantPolicyBuilder::default()
187            .aggregator_start_level(2)
188            .leaf_start_level(5)
189            .build().unwrap();
190        let a = GrowerTreeConfiguration::default()
191            .to_builder().tree_expansion_policy(TreeExpansionPolicy::DepthBased(db.clone())).build().unwrap();
192        let b = a.clone();
193        assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
194    }
195}