use crate::booster::core::PerpetualBooster;
use crate::data::Matrix;
use crate::drift::stats::chi2_contingency_2x2;
use std::collections::HashMap;
pub fn calculate_drift(booster: &PerpetualBooster, data: &Matrix<f64>, drift_type: &str, parallel: bool) -> f32 {
let trees = booster.get_prediction_trees();
if trees.is_empty() {
return 0.0;
}
let node_preds = booster.predict_nodes(data, parallel);
calculate_drift_from_nodes(trees, &node_preds, drift_type)
}
pub fn calculate_drift_columnar(
booster: &PerpetualBooster,
data: &crate::data::ColumnarMatrix<f64>,
drift_type: &str,
parallel: bool,
) -> f32 {
let trees = booster.get_prediction_trees();
if trees.is_empty() {
return 0.0;
}
let node_preds = booster.predict_nodes_columnar(data, parallel);
calculate_drift_from_nodes(trees, &node_preds, drift_type)
}
fn calculate_drift_from_nodes(
trees: &[crate::tree::core::Tree],
node_preds: &[Vec<std::collections::HashSet<usize>>],
drift_type: &str,
) -> f32 {
let mut new_node_counts: Vec<HashMap<usize, usize>> = vec![HashMap::new(); trees.len()];
for (tree_idx, tree_results) in node_preds.iter().enumerate() {
if tree_idx >= trees.len() {
continue;
}
for sample_nodes in tree_results {
for &node_idx in sample_nodes {
*new_node_counts[tree_idx].entry(node_idx).or_insert(0) += 1;
}
}
}
let mut drift_stats = Vec::new();
for (tree_idx, tree) in trees.iter().enumerate() {
for (&_node_idx, node) in &tree.nodes {
if node.is_leaf {
continue;
}
let left_idx = node.left_child;
let right_idx = node.right_child;
let left_node = match tree.nodes.get(&left_idx) {
Some(n) => n,
None => continue,
};
let right_node = match tree.nodes.get(&right_idx) {
Some(n) => n,
None => continue,
};
let left_stats = left_node.stats.as_ref();
let right_stats = right_node.stats.as_ref();
if let (Some(l_s), Some(r_s)) = (left_stats, right_stats) {
let should_include = match drift_type {
"data" => true, "concept" => left_node.is_leaf || right_node.is_leaf, _ => false,
};
if should_include {
let train_l = l_s.count as f64;
let train_r = r_s.count as f64;
let new_l = *new_node_counts[tree_idx].get(&left_idx).unwrap_or(&0) as f64;
let new_r = *new_node_counts[tree_idx].get(&right_idx).unwrap_or(&0) as f64;
if (train_l > 0.0 || train_r > 0.0) && (new_l > 0.0 || new_r > 0.0) {
let stat = chi2_contingency_2x2(train_l, new_l, train_r, new_r);
if !stat.is_nan() {
drift_stats.push(stat);
}
}
}
}
}
}
if drift_stats.is_empty() {
0.0
} else {
(drift_stats.iter().sum::<f64>() / drift_stats.len() as f64) as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_drift() {
use crate::booster::core::PerpetualBooster;
use crate::objective::Objective;
let mut booster = PerpetualBooster::default();
booster.cfg.objective = Objective::SquaredLoss;
booster.cfg.save_node_stats = true;
let data_vec = vec![1.0, 2.0, 3.0, 4.0];
let data = Matrix::new(&data_vec, 4, 1);
let target = vec![0.0, 1.0, 0.0, 1.0];
booster.fit(&data, &target, None, None).unwrap();
let drift = calculate_drift(&booster, &data, "data", false);
assert!(drift >= 0.0);
}
#[test]
fn test_calculate_drift_types() {
use crate::booster::core::PerpetualBooster;
use crate::objective::Objective;
let mut booster = PerpetualBooster::default();
booster.cfg.objective = Objective::SquaredLoss;
booster.cfg.save_node_stats = true;
let data_vec = vec![1.0, 2.0, 3.0, 4.0];
let data = Matrix::new(&data_vec, 4, 1);
let target = vec![0.0, 1.0, 0.0, 1.0];
booster.fit(&data, &target, None, None).unwrap();
let drift_data = calculate_drift(&booster, &data, "data", false);
assert!(drift_data >= 0.0);
let drift_concept = calculate_drift(&booster, &data, "concept", false);
assert!(drift_concept >= 0.0);
let drift_invalid = calculate_drift(&booster, &data, "invalid", false);
assert_eq!(drift_invalid, 0.0);
}
#[test]
fn test_calculate_drift_empty() {
let booster = PerpetualBooster::default();
let data_vec = vec![1.0, 2.0];
let data = Matrix::new(&data_vec, 2, 1);
let drift = calculate_drift(&booster, &data, "data", false);
assert_eq!(drift, 0.0);
}
#[test]
fn test_calculate_drift_columnar() {
use crate::booster::core::PerpetualBooster;
use crate::data::ColumnarMatrix;
use crate::objective::Objective;
let mut booster = PerpetualBooster::default();
booster.cfg.objective = Objective::SquaredLoss;
booster.cfg.save_node_stats = true;
let data_vec = vec![1.0, 2.0, 3.0, 4.0];
let data = Matrix::new(&data_vec, 4, 1);
let target = vec![0.0, 1.0, 0.0, 1.0];
booster.fit(&data, &target, None, None).unwrap();
let columnar_data = ColumnarMatrix::new(vec![&data_vec], None, 4);
let drift = calculate_drift_columnar(&booster, &columnar_data, "data", false);
assert!(drift >= 0.0);
}
#[test]
fn test_calculate_drift_no_stats() {
use crate::booster::core::PerpetualBooster;
use crate::objective::Objective;
let mut booster = PerpetualBooster::default();
booster.cfg.objective = Objective::SquaredLoss;
booster.cfg.save_node_stats = false;
let data_vec = vec![1.0, 2.0, 3.0, 4.0];
let data = Matrix::new(&data_vec, 4, 1);
let target = vec![0.0, 1.0, 1.0, 0.0];
booster.fit(&data, &target, None, None).unwrap();
let drift = calculate_drift(&booster, &data, "data", false);
assert_eq!(drift, 0.0);
}
}