crate::ix!();
pub trait CompareTreeExpansionPolicy {
fn compare_tree_expansion_policy(&self, other: &GrowerTreeConfiguration) -> CompareOutcome;
}
impl CompareTreeExpansionPolicy for GrowerTreeConfiguration {
fn compare_tree_expansion_policy(&self, other: &GrowerTreeConfiguration) -> CompareOutcome {
use TreeExpansionPolicy::*;
match (self.tree_expansion_policy(), other.tree_expansion_policy()) {
(Simple, Simple) => CompareOutcome::Exact,
(AlwaysAggregate, AlwaysAggregate) => CompareOutcome::Exact,
(AlwaysDispatch, AlwaysDispatch) => CompareOutcome::Exact,
(AlwaysLeafHolder, AlwaysLeafHolder) => CompareOutcome::Exact,
(Weighted(a), Weighted(b)) => {
let diff_sum =
(a.aggregator_weight() - b.aggregator_weight()).abs() +
(a.dispatch_weight() - b.dispatch_weight()).abs() +
(a.leaf_holder_weight() - b.leaf_holder_weight()).abs();
if diff_sum < 0.01 {
CompareOutcome::Exact
} else if diff_sum < 1.0 {
let score = (1.0 - diff_sum).max(0.0);
CompareOutcome::Partial(score)
} else {
CompareOutcome::Incompatible
}
}
(WeightedWithLimits(a), WeightedWithLimits(b)) => {
let mut score = 1.0;
let diff_sum =
(a.aggregator_weight() - b.aggregator_weight()).abs() +
(a.dispatch_weight() - b.dispatch_weight()).abs() +
(a.leaf_holder_weight() - b.leaf_holder_weight()).abs();
if diff_sum > 1.0 {
return CompareOutcome::Incompatible;
} else if diff_sum > 0.01 {
score *= (1.0 - diff_sum);
}
if a.aggregator_max_depth() != b.aggregator_max_depth() {
score *= 0.8;
}
if a.dispatch_max_depth() != b.dispatch_max_depth() {
score *= 0.8;
}
if a.leaf_min_depth() != b.leaf_min_depth() {
score *= 0.8;
}
if score < 0.2 {
CompareOutcome::Incompatible
} else if score < 1.0 {
CompareOutcome::Partial(score)
} else {
CompareOutcome::Exact
}
}
(DepthBased(a), DepthBased(b)) => {
if a.aggregator_start_level() == b.aggregator_start_level()
&& a.leaf_start_level() == b.leaf_start_level()
{
CompareOutcome::Exact
} else {
CompareOutcome::Partial(0.5)
}
}
(Phased(a), Phased(b)) => {
if a.phases().len() != b.phases().len() {
return CompareOutcome::Partial(0.5);
}
let mut total_score = 1.0;
for (pa, pb) in a.phases().iter().zip(b.phases().iter()) {
if pa.start_level() != pb.start_level() {
total_score *= 0.8;
}
let w_diff =
(pa.aggregator_weight() - pb.aggregator_weight()).abs() +
(pa.dispatch_weight() - pb.dispatch_weight()).abs() +
(pa.leaf_weight() - pb.leaf_weight()).abs();
if w_diff > 0.5 {
total_score *= 0.5;
} else if w_diff > 0.1 {
total_score *= 0.8;
}
}
if total_score < 0.2 {
CompareOutcome::Incompatible
} else if total_score < 1.0 {
CompareOutcome::Partial(total_score)
} else {
CompareOutcome::Exact
}
}
(Scripted(a), Scripted(b)) => {
let la = a.levels();
let lb = b.levels();
if la.len() != lb.len() {
return CompareOutcome::Partial(0.5);
}
let mut score = 1.0;
for (lvl, wa) in la.iter() {
if let Some(wb) = lb.get(lvl) {
let diff_sum =
(wa.aggregator_chance() - wb.aggregator_chance()).abs() +
(wa.dispatch_chance() - wb.dispatch_chance()).abs() +
(wa.leaf_chance() - wb.leaf_chance()).abs();
if diff_sum > 0.5 {
score *= 0.5;
} else if diff_sum > 0.1 {
score *= 0.8;
}
} else {
score *= 0.5;
}
}
if score < 0.2 {
CompareOutcome::Incompatible
} else if score < 1.0 {
CompareOutcome::Partial(score)
} else {
CompareOutcome::Exact
}
}
_ => CompareOutcome::Partial(0.3),
}
}
}
#[cfg(test)]
mod compare_tree_expansion_policy_tests {
use super::*;
#[traced_test]
fn simple_simple_exact() {
let a = GrowerTreeConfiguration::default()
.to_builder().tree_expansion_policy(TreeExpansionPolicy::Simple).build().unwrap();
let b = a.clone();
assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
}
#[traced_test]
fn always_leaf_holder_exact() {
let a = GrowerTreeConfiguration::default()
.to_builder().tree_expansion_policy(TreeExpansionPolicy::AlwaysLeafHolder).build().unwrap();
let b = a.clone();
assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
}
#[traced_test]
fn weighted_partial() {
let w1 = WeightedNodeVariantPolicyBuilder::default()
.aggregator_weight(0.3).dispatch_weight(0.4).leaf_holder_weight(0.3)
.build().unwrap();
let w2 = WeightedNodeVariantPolicyBuilder::default()
.aggregator_weight(0.3).dispatch_weight(0.35).leaf_holder_weight(0.35)
.build().unwrap();
let a = GrowerTreeConfiguration::default()
.to_builder().tree_expansion_policy(TreeExpansionPolicy::Weighted(w1)).build().unwrap();
let b = a.to_builder().tree_expansion_policy(TreeExpansionPolicy::Weighted(w2)).build().unwrap();
match a.compare_tree_expansion_policy(&b) {
CompareOutcome::Partial(_) => {},
other => panic!("expected partial, got {:?}", other),
}
}
#[traced_test]
fn different_variants_partial() {
let a = GrowerTreeConfiguration::default()
.to_builder().tree_expansion_policy(TreeExpansionPolicy::Simple).build().unwrap();
let b = a.to_builder().tree_expansion_policy(TreeExpansionPolicy::AlwaysDispatch).build().unwrap();
match a.compare_tree_expansion_policy(&b) {
CompareOutcome::Partial(_) => {},
other => panic!("expected partial, got {:?}", other),
}
}
#[traced_test]
fn depth_based_exact() {
let db = DepthBasedNodeVariantPolicyBuilder::default()
.aggregator_start_level(2)
.leaf_start_level(5)
.build().unwrap();
let a = GrowerTreeConfiguration::default()
.to_builder().tree_expansion_policy(TreeExpansionPolicy::DepthBased(db.clone())).build().unwrap();
let b = a.clone();
assert_eq!(a.compare_tree_expansion_policy(&b), CompareOutcome::Exact);
}
}