use crate::data::ColumnarMatrix;
use crate::tree::Tree;
use crate::{Matrix, utils::odds};
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};
impl Tree {
pub fn predict_contributions_row_probability_change(
&self,
row: &[f64],
contribs: &mut [f64],
missing: &f64,
current_logodds: f64,
) -> f64 {
contribs[contribs.len() - 1] +=
odds(current_logodds + self.nodes.get(&0).unwrap().weight_value as f64) - odds(current_logodds);
let mut node_idx = 0;
let mut lo = current_logodds;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
let node_odds = odds(node.weight_value as f64 + current_logodds);
if node.is_leaf {
lo += node.weight_value as f64;
break;
}
let child_idx = node.get_child_idx(&row[node.split_feature], missing);
let child_odds = odds(self.nodes.get(&child_idx).unwrap().weight_value as f64 + current_logodds);
let delta = child_odds - node_odds;
contribs[node.split_feature] += delta;
node_idx = child_idx;
}
lo
}
pub fn predict_contributions_row_midpoint_difference(&self, row: &[f64], contribs: &mut [f64], missing: &f64) {
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
break;
}
let child_idx = node.get_child_idx(&row[node.split_feature], missing);
let child = &self.nodes.get(&child_idx).unwrap();
if node.has_missing_branch() && child_idx == node.missing_node {
node_idx = child_idx;
continue;
}
let other_child = if child_idx == node.left_child {
&self.nodes[&node.right_child]
} else {
&self.nodes[&node.left_child]
};
let mid = (child.weight_value * child.hessian_sum + other_child.weight_value * other_child.hessian_sum)
/ (child.hessian_sum + other_child.hessian_sum);
let delta = child.weight_value - mid;
contribs[node.split_feature] += delta as f64;
node_idx = child_idx;
}
}
pub fn predict_contributions_row_branch_difference(&self, row: &[f64], contribs: &mut [f64], missing: &f64) {
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
break;
}
let child_idx = node.get_child_idx(&row[node.split_feature], missing);
if node.has_missing_branch() && child_idx == node.missing_node {
node_idx = child_idx;
continue;
}
let other_child = if child_idx == node.left_child {
&self.nodes[&node.right_child]
} else {
&self.nodes[&node.left_child]
};
let delta = self.nodes.get(&child_idx).unwrap().weight_value - other_child.weight_value;
contribs[node.split_feature] += delta as f64;
node_idx = child_idx;
}
}
pub fn predict_contributions_row_mode_difference(&self, row: &[f64], contribs: &mut [f64], missing: &f64) {
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
break;
}
let child_idx = node.get_child_idx(&row[node.split_feature], missing);
if node.has_missing_branch() && child_idx == node.missing_node {
node_idx = child_idx;
continue;
}
let left_node = &self.nodes.get(&node.left_child).unwrap();
let right_node = &self.nodes.get(&node.right_child).unwrap();
let child_weight = self.nodes.get(&child_idx).unwrap().weight_value;
let delta = if left_node.hessian_sum == right_node.hessian_sum {
0.
} else if left_node.hessian_sum > right_node.hessian_sum {
child_weight - left_node.weight_value
} else {
child_weight - right_node.weight_value
};
contribs[node.split_feature] += delta as f64;
node_idx = child_idx;
}
}
pub fn predict_contributions_row_weight(&self, row: &[f64], contribs: &mut [f64], missing: &f64) {
contribs[contribs.len() - 1] += self.nodes.get(&0).unwrap().weight_value as f64;
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
break;
}
let child_idx = node.get_child_idx(&row[node.split_feature], missing);
let node_weight = self.nodes.get(&node_idx).unwrap().weight_value as f64;
let child_weight = self.nodes.get(&child_idx).unwrap().weight_value as f64;
let delta = child_weight - node_weight;
contribs[node.split_feature] += delta;
node_idx = child_idx
}
}
pub fn predict_contributions_weight(&self, data: &Matrix<f64>, contribs: &mut [f64], missing: &f64) {
data.index
.par_iter()
.zip(contribs.par_chunks_mut(data.cols + 1))
.for_each(|(row, contribs)| self.predict_contributions_row_weight(&data.get_row(*row), contribs, missing))
}
pub fn predict_contributions_row_average(
&self,
row: &[f64],
contribs: &mut [f64],
weights: &HashMap<usize, f64>,
missing: &f64,
) {
contribs[contribs.len() - 1] += weights[&0];
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
break;
}
let child_idx = node.get_child_idx(&row[node.split_feature], missing);
let node_weight = weights[&node_idx];
let child_weight = weights[&child_idx];
let delta = child_weight - node_weight;
contribs[node.split_feature] += delta;
node_idx = child_idx
}
}
pub fn predict_contributions_average(
&self,
data: &Matrix<f64>,
contribs: &mut [f64],
weights: &HashMap<usize, f64>,
missing: &f64,
) {
data.index
.par_iter()
.zip(contribs.par_chunks_mut(data.cols + 1))
.for_each(|(row, contribs)| {
self.predict_contributions_row_average(&data.get_row(*row), contribs, weights, missing)
})
}
fn predict_row(&self, data: &Matrix<f64>, row: usize, missing: &f64) -> f64 {
let mut node_idx = 0;
loop {
let node = self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
return node.weight_value as f64;
} else {
node_idx = node.get_child_idx(data.get(row, node.split_feature), missing);
}
}
}
pub fn predict_row_from_row_slice(&self, row: &[f64], missing: &f64) -> f64 {
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
return node.weight_value as f64;
} else {
node_idx = node.get_child_idx(&row[node.split_feature], missing);
}
}
}
fn predict_single_threaded(&self, data: &Matrix<f64>, missing: &f64) -> Vec<f64> {
data.index.iter().map(|i| self.predict_row(data, *i, missing)).collect()
}
fn predict_parallel(&self, data: &Matrix<f64>, missing: &f64) -> Vec<f64> {
data.index
.par_iter()
.map(|i| self.predict_row(data, *i, missing))
.collect()
}
pub fn predict(&self, data: &Matrix<f64>, parallel: bool, missing: &f64) -> Vec<f64> {
if parallel {
self.predict_parallel(data, missing)
} else {
self.predict_single_threaded(data, missing)
}
}
fn predict_weights_row(&self, data: &Matrix<f64>, row: usize, missing: &f64) -> [f32; 5] {
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
return node.stats.as_ref().map_or([node.weight_value; 5], |s| s.weights);
} else {
node_idx = node.get_child_idx(data.get(row, node.split_feature), missing);
}
}
}
pub fn predict_weights(&self, data: &Matrix<f64>, parallel: bool, missing: &f64) -> Vec<[f32; 5]> {
if parallel {
data.index
.par_iter()
.map(|i| self.predict_weights_row(data, *i, missing))
.collect()
} else {
data.index
.iter()
.map(|i| self.predict_weights_row(data, *i, missing))
.collect()
}
}
fn predict_nodes_row(&self, data: &Matrix<f64>, row: usize, missing: &f64) -> HashSet<usize> {
let mut node_idx = 0;
let mut v = HashSet::new();
v.insert(node_idx);
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
break;
} else {
node_idx = node.get_child_idx(data.get(row, node.split_feature), missing);
v.insert(node_idx);
}
}
v
}
fn predict_nodes_single_threaded(&self, data: &Matrix<f64>, missing: &f64) -> Vec<HashSet<usize>> {
data.index
.iter()
.map(|i| self.predict_nodes_row(data, *i, missing))
.collect()
}
fn predict_nodes_parallel(&self, data: &Matrix<f64>, missing: &f64) -> Vec<HashSet<usize>> {
data.index
.par_iter()
.map(|i| self.predict_nodes_row(data, *i, missing))
.collect()
}
pub fn predict_nodes(&self, data: &Matrix<f64>, parallel: bool, missing: &f64) -> Vec<HashSet<usize>> {
if parallel {
self.predict_nodes_parallel(data, missing)
} else {
self.predict_nodes_single_threaded(data, missing)
}
}
fn predict_row_columnar(&self, data: &ColumnarMatrix<f64>, row: usize, missing: &f64) -> f64 {
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
return node.weight_value as f64;
} else {
let val = if data.is_valid(row, node.split_feature) {
data.get(row, node.split_feature)
} else {
missing
};
node_idx = node.get_child_idx(val, missing);
}
}
}
fn predict_single_threaded_columnar(&self, data: &ColumnarMatrix<f64>, missing: &f64) -> Vec<f64> {
data.index
.iter()
.map(|i| self.predict_row_columnar(data, *i, missing))
.collect()
}
fn predict_parallel_columnar(&self, data: &ColumnarMatrix<f64>, missing: &f64) -> Vec<f64> {
data.index
.par_iter()
.map(|i| self.predict_row_columnar(data, *i, missing))
.collect()
}
pub fn predict_columnar(&self, data: &ColumnarMatrix<f64>, parallel: bool, missing: &f64) -> Vec<f64> {
if parallel {
self.predict_parallel_columnar(data, missing)
} else {
self.predict_single_threaded_columnar(data, missing)
}
}
fn predict_weights_row_columnar(&self, data: &ColumnarMatrix<f64>, row: usize, missing: &f64) -> [f32; 5] {
let mut node_idx = 0;
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
return node.stats.as_ref().map_or([node.weight_value; 5], |s| s.weights);
} else {
let val = if data.is_valid(row, node.split_feature) {
data.get(row, node.split_feature)
} else {
missing
};
node_idx = node.get_child_idx(val, missing);
}
}
}
pub fn predict_weights_columnar(&self, data: &ColumnarMatrix<f64>, parallel: bool, missing: &f64) -> Vec<[f32; 5]> {
if parallel {
data.index
.par_iter()
.map(|i| self.predict_weights_row_columnar(data, *i, missing))
.collect()
} else {
data.index
.iter()
.map(|i| self.predict_weights_row_columnar(data, *i, missing))
.collect()
}
}
fn predict_nodes_row_columnar(&self, data: &ColumnarMatrix<f64>, row: usize, missing: &f64) -> HashSet<usize> {
let mut node_idx = 0;
let mut set = HashSet::new();
set.insert(0);
loop {
let node = &self.nodes.get(&node_idx).unwrap();
if node.is_leaf {
return set;
} else {
let val = if data.is_valid(row, node.split_feature) {
data.get(row, node.split_feature)
} else {
missing
};
node_idx = node.get_child_idx(val, missing);
set.insert(node_idx);
}
}
}
fn predict_nodes_single_threaded_columnar(&self, data: &ColumnarMatrix<f64>, missing: &f64) -> Vec<HashSet<usize>> {
data.index
.iter()
.map(|i| self.predict_nodes_row_columnar(data, *i, missing))
.collect()
}
fn predict_nodes_parallel_columnar(&self, data: &ColumnarMatrix<f64>, missing: &f64) -> Vec<HashSet<usize>> {
data.index
.par_iter()
.map(|i| self.predict_nodes_row_columnar(data, *i, missing))
.collect()
}
pub fn predict_nodes_columnar(
&self,
data: &ColumnarMatrix<f64>,
parallel: bool,
missing: &f64,
) -> Vec<HashSet<usize>> {
if parallel {
self.predict_nodes_parallel_columnar(data, missing)
} else {
self.predict_nodes_single_threaded_columnar(data, missing)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Matrix;
use crate::node::Node;
use std::collections::HashMap;
fn create_mock_tree() -> Tree {
let mut tree = Tree::new();
let root = Node {
num: 0,
weight_value: 0.0, hessian_sum: 30.0,
split_value: 0.5,
split_feature: 0,
split_gain: 0.0,
missing_node: 1,
left_child: 1,
right_child: 2,
is_leaf: false,
parent_node: 0,
left_cats: None,
stats: None,
};
let left = Node {
num: 1,
weight_value: 0.1,
hessian_sum: 10.0,
split_value: 0.0,
split_feature: 0,
split_gain: 0.0,
missing_node: 1,
left_child: 0,
right_child: 0,
is_leaf: true,
parent_node: 0,
left_cats: None,
stats: None,
};
let right = Node {
num: 2,
weight_value: 0.2,
hessian_sum: 20.0,
split_value: 0.0,
split_feature: 0,
split_gain: 0.0,
missing_node: 2,
left_child: 0,
right_child: 0,
is_leaf: true,
parent_node: 0,
left_cats: None,
stats: None,
};
tree.nodes.insert(0, root);
tree.nodes.insert(1, left);
tree.nodes.insert(2, right);
tree.n_leaves = 2;
tree
}
#[test]
fn test_tree_predict_row() {
let tree = create_mock_tree();
let data = Matrix::new(&[0.1, 0.6, 0.0, 0.0], 2, 2);
let missing = f64::NAN;
let p0 = tree.predict_row(&data, 0, &missing);
let p1 = tree.predict_row(&data, 1, &missing);
assert_eq!(p0, 0.1f32 as f64);
assert_eq!(p1, 0.2f32 as f64);
}
#[test]
fn test_tree_predict() {
let tree = create_mock_tree();
let data = Matrix::new(&[0.1, 0.6, 0.0, 0.0], 2, 2);
let missing = f64::NAN;
let preds = tree.predict(&data, false, &missing);
assert_eq!(preds, vec![0.1f32 as f64, 0.2f32 as f64]);
}
#[test]
fn test_tree_predict_contributions_weight() {
let tree = create_mock_tree();
let row = [0.1, 0.0];
let mut contribs = vec![0.0; 3]; let missing = f64::NAN;
tree.predict_contributions_row_weight(&row, &mut contribs, &missing);
assert_eq!(contribs[2], 0.0);
assert_eq!(contribs[0], 0.1f32 as f64);
}
#[test]
fn test_tree_predict_contributions_average() {
let tree = create_mock_tree();
let row = [0.1, 0.0];
let mut contribs = vec![0.0; 3];
let missing = f64::NAN;
let mut weights = HashMap::new();
weights.insert(0, 0.15);
weights.insert(1, 0.1);
weights.insert(2, 0.2);
tree.predict_contributions_row_average(&row, &mut contribs, &weights, &missing);
assert_eq!(contribs[2], 0.15);
assert!((contribs[0] - (-0.05)).abs() < 1e-7);
}
#[test]
fn test_tree_predict_contributions_probability_change() {
let tree = create_mock_tree();
let row = [0.1, 0.0];
let mut contribs = vec![0.0; 3];
let missing = f64::NAN;
tree.predict_contributions_row_probability_change(&row, &mut contribs, &missing, 0.0);
assert!(contribs[2].abs() < 1e-7);
let expected = 1.0 / (1.0 + (-0.1f64).exp()) - 0.5;
assert!((contribs[0] - expected).abs() < 1e-7);
}
#[test]
fn test_tree_predict_contributions_midpoint() {
let tree = create_mock_tree();
let row = [0.1, 0.0];
let mut contribs = vec![0.0; 3];
let missing = f64::NAN;
tree.predict_contributions_row_midpoint_difference(&row, &mut contribs, &missing);
let expected = 0.1f32 - (0.1f32 * 10.0 + 0.2f32 * 20.0) / 30.0;
assert!((contribs[0] - expected as f64).abs() < 1e-7);
}
}