crate::ix!();
pub trait CompareLevelSkipping {
fn compare_level_skipping(&self, other: &GrowerTreeConfiguration) -> CompareOutcome;
}
impl CompareLevelSkipping for GrowerTreeConfiguration {
fn compare_level_skipping(&self, other: &GrowerTreeConfiguration) -> CompareOutcome {
match (self.level_skipping(), other.level_skipping()) {
(None, None) => CompareOutcome::Exact,
(None, Some(_)) | (Some(_), None) => CompareOutcome::Partial(0.5),
(Some(a), Some(b)) => {
let va = a.leaf_probability_per_level();
let vb = b.leaf_probability_per_level();
if va.len() != vb.len() {
CompareOutcome::Incompatible
} else {
let mut score = 1.0;
for (x, y) in va.iter().zip(vb.iter()) {
let diff = (x - y).abs();
if diff > 0.3 {
score *= 0.5;
} else if diff > 0.1 {
score *= 0.8;
}
}
if score < 0.2 {
CompareOutcome::Incompatible
} else if score < 1.0 {
CompareOutcome::Partial(score)
} else {
CompareOutcome::Exact
}
}
}
}
}
}
#[cfg(test)]
mod compare_level_skipping_tests {
use super::*;
#[traced_test]
fn none_none() {
let a = GrowerTreeConfiguration::default();
let b = a.clone();
assert_eq!(a.compare_level_skipping(&b), CompareOutcome::Exact);
}
#[traced_test]
fn same_arrays_exact() {
let ls = LevelSkippingConfigurationBuilder::default()
.leaf_probability_per_level(vec![0.0,0.2,1.0])
.build().unwrap();
let a = GrowerTreeConfiguration::default()
.to_builder().level_skipping(Some(ls.clone())).build().unwrap();
let b = a.clone();
assert_eq!(a.compare_level_skipping(&b), CompareOutcome::Exact);
}
#[traced_test]
fn partial_same_length_different_values() {
let c1 = vec![0.0,0.3,0.5];
let c2 = vec![0.1,0.1,0.5];
let ls1 = LevelSkippingConfigurationBuilder::default()
.leaf_probability_per_level(c1)
.build().unwrap();
let ls2 = LevelSkippingConfigurationBuilder::default()
.leaf_probability_per_level(c2)
.build().unwrap();
let a = GrowerTreeConfiguration::default()
.to_builder().level_skipping(Some(ls1)).build().unwrap();
let b = a.to_builder().level_skipping(Some(ls2)).build().unwrap();
match a.compare_level_skipping(&b) {
CompareOutcome::Partial(_) => {},
other => panic!("expected partial, got {:?}", other),
}
}
#[traced_test]
fn incompatible_diff_len() {
let ls1 = LevelSkippingConfigurationBuilder::default()
.leaf_probability_per_level(vec![0.0,0.2])
.build().unwrap();
let ls2 = LevelSkippingConfigurationBuilder::default()
.leaf_probability_per_level(vec![0.0,0.2,0.8])
.build().unwrap();
let a = GrowerTreeConfiguration::default()
.to_builder().level_skipping(Some(ls1)).build().unwrap();
let b = a.to_builder().level_skipping(Some(ls2)).build().unwrap();
assert_eq!(a.compare_level_skipping(&b), CompareOutcome::Incompatible);
}
}