crate::ix!();
pub trait CompareWeightedBranching {
fn compare_weighted_branching(&self, other: &GrowerTreeConfiguration) -> CompareOutcome;
}
impl CompareWeightedBranching for GrowerTreeConfiguration {
fn compare_weighted_branching(&self, other: &GrowerTreeConfiguration) -> CompareOutcome {
match (self.weighted_branching(), other.weighted_branching()) {
(None, None) => CompareOutcome::Exact,
(None, Some(_)) | (Some(_), None) => CompareOutcome::Partial(0.5),
(Some(a), Some(b)) => {
let mean_diff = (*a.mean() as i32 - *b.mean() as i32).abs();
let var_diff = (*a.variance() as i32 - *b.variance() as i32).abs();
if mean_diff == 0 && var_diff == 0 {
CompareOutcome::Exact
} else if mean_diff < 3 && var_diff < 3 {
let penalty = (mean_diff + var_diff) as f32 * 0.1;
let score = (1.0 - penalty).max(0.0);
CompareOutcome::Partial(score)
} else {
CompareOutcome::Incompatible
}
}
}
}
}
#[cfg(test)]
mod compare_weighted_branching_tests {
use super::*;
#[traced_test]
fn both_none() {
let a = GrowerTreeConfiguration::default();
let b = a.clone();
assert_eq!(a.compare_weighted_branching(&b), CompareOutcome::Exact);
}
#[traced_test]
fn same_mean_variance_exact() {
let wb = WeightedBranchingConfigurationBuilder::default()
.mean(4).variance(2).build().unwrap();
let a = GrowerTreeConfiguration::default()
.to_builder().weighted_branching(Some(wb.clone())).build().unwrap();
let b = a.clone();
assert_eq!(a.compare_weighted_branching(&b), CompareOutcome::Exact);
}
#[traced_test]
fn partial_same_mean() {
let w1 = WeightedBranchingConfigurationBuilder::default()
.mean(5).variance(1).build().unwrap();
let w2 = WeightedBranchingConfigurationBuilder::default()
.mean(5).variance(3).build().unwrap();
let a = GrowerTreeConfiguration::default()
.to_builder().weighted_branching(Some(w1)).build().unwrap();
let b = a.to_builder().weighted_branching(Some(w2)).build().unwrap();
match a.compare_weighted_branching(&b) {
CompareOutcome::Partial(_) => {},
other => panic!("expected partial, got {:?}", other),
}
}
#[traced_test]
fn incompatible() {
let w1 = WeightedBranchingConfigurationBuilder::default()
.mean(3).variance(1).build().unwrap();
let w2 = WeightedBranchingConfigurationBuilder::default()
.mean(5).variance(8).build().unwrap();
let a = GrowerTreeConfiguration::default()
.to_builder().weighted_branching(Some(w1)).build().unwrap();
let b = a.to_builder().weighted_branching(Some(w2)).build().unwrap();
assert_eq!(a.compare_weighted_branching(&b), CompareOutcome::Incompatible);
}
}