use super::helper_function::preliminary_check;
use crate::error::ModelError;
use crate::math::{entropy, gini, variance};
use crate::{Deserialize, Serialize};
use ahash::AHashMap;
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, Array2, ArrayBase, ArrayView1, Axis, Data, Ix1, Ix2};
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
const DECISION_TREE_PARALLEL_THRESHOLD: usize = 1000;
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize)]
pub enum Algorithm {
ID3,
C45,
CART,
}
#[derive(Debug, Copy, Clone, Deserialize, Serialize)]
pub struct DecisionTreeParams {
pub max_depth: Option<usize>,
pub min_samples_split: usize,
pub min_samples_leaf: usize,
pub min_impurity_decrease: f64,
pub random_state: Option<u64>,
}
impl Default for DecisionTreeParams {
fn default() -> Self {
Self {
max_depth: None,
min_samples_split: 2,
min_samples_leaf: 1,
min_impurity_decrease: 0.0,
random_state: None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub enum NodeType {
Internal {
feature_index: usize,
threshold: f64,
categories: Option<Vec<String>>,
},
Leaf {
value: f64,
class: Option<usize>,
probabilities: Option<Vec<f64>>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Node {
pub node_type: NodeType,
pub left: Option<Box<Node>>,
pub right: Option<Box<Node>>,
pub children: Option<AHashMap<String, Box<Node>>>,
}
impl Node {
pub fn new_leaf(value: f64, class: Option<usize>, probabilities: Option<Vec<f64>>) -> Self {
Self {
node_type: NodeType::Leaf {
value,
class,
probabilities,
},
left: None,
right: None,
children: None,
}
}
pub fn new_internal(feature_index: usize, threshold: f64) -> Self {
Self {
node_type: NodeType::Internal {
feature_index,
threshold,
categories: None,
},
left: None,
right: None,
children: None,
}
}
pub fn new_categorical(feature_index: usize, categories: Vec<String>) -> Self {
Self {
node_type: NodeType::Internal {
feature_index,
threshold: 0.0, categories: Some(categories),
},
left: None,
right: None,
children: Some(AHashMap::new()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecisionTree {
algorithm: Algorithm,
root: Option<Box<Node>>,
n_features: usize,
n_classes: Option<usize>,
params: DecisionTreeParams,
is_classifier: bool,
}
impl DecisionTree {
pub fn new(
algorithm: Algorithm,
is_classifier: bool,
params: Option<DecisionTreeParams>,
) -> Result<Self, ModelError> {
if !is_classifier && algorithm != Algorithm::CART {
return Err(ModelError::InputValidationError(
"Only CART algorithm is supported for regression tasks".to_string(),
));
}
let params = params.unwrap_or_default();
if params.min_samples_split < 2 {
return Err(ModelError::InputValidationError(
"min_samples_split must be at least 2".to_string(),
));
}
if params.min_samples_leaf < 1 {
return Err(ModelError::InputValidationError(
"min_samples_leaf must be at least 1".to_string(),
));
}
if params.min_samples_leaf > params.min_samples_split {
return Err(ModelError::InputValidationError(format!(
"min_samples_leaf ({}) cannot be greater than min_samples_split ({})",
params.min_samples_leaf, params.min_samples_split
)));
}
if params.min_impurity_decrease < 0.0 || !params.min_impurity_decrease.is_finite() {
return Err(ModelError::InputValidationError(format!(
"min_impurity_decrease must be non-negative and finite, got {}",
params.min_impurity_decrease
)));
}
Ok(Self {
algorithm,
root: None,
n_features: 0,
n_classes: None,
params,
is_classifier,
})
}
get_field!(get_algorithm, algorithm, Algorithm);
get_field!(get_n_features, n_features, usize);
get_field!(get_n_classes, n_classes, Option<usize>);
get_field!(get_parameters, params, DecisionTreeParams);
get_field_as_ref!(get_root, root, Option<&Box<Node>>);
get_field!(get_is_classifier, is_classifier, bool);
pub fn fit<S>(
&mut self,
x: &ArrayBase<S, Ix2>,
y: &ArrayBase<S, Ix1>,
) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
preliminary_check(x, Some(y))?;
if x.nrows() < self.params.min_samples_split {
return Err(ModelError::InputValidationError(format!(
"Number of samples ({}) is less than min_samples_split ({})",
x.nrows(),
self.params.min_samples_split
)));
}
if x.ncols() == 0 {
return Err(ModelError::InputValidationError(
"Input data must have at least one feature".to_string(),
));
}
self.n_features = x.ncols();
if self.is_classifier {
let max_class = y.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).ok_or(
ModelError::ProcessingError("Cannot determine max class".to_string()),
)?;
for &label in y.iter() {
if label < 0.0 || label.fract() != 0.0 {
return Err(ModelError::InputValidationError(
"Class labels must be non-negative integers starting from 0".to_string(),
));
}
}
self.n_classes = Some((*max_class as usize) + 1);
}
let estimated_max_depth = self.params.max_depth.unwrap_or(20).min(20);
let estimated_nodes = (1 << (estimated_max_depth + 1)) - 1;
let progress_bar = ProgressBar::new(estimated_nodes as u64);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos} nodes | Depth: {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
progress_bar.set_message("0");
let indices: Vec<usize> = (0..x.nrows()).collect();
self.root = Some(Box::new(self.build_tree_with_progress(
x,
y,
&indices,
0,
&progress_bar,
)?));
progress_bar
.finish_with_message(format!("{}", self.count_nodes(self.root.as_ref().unwrap())));
let tree_depth = self.calculate_depth(self.root.as_ref().unwrap());
let total_nodes = self.count_nodes(self.root.as_ref().unwrap());
println!(
"\nDecision Tree training completed: {} samples, {} features, {} nodes, depth: {}",
x.nrows(),
self.n_features,
total_nodes,
tree_depth
);
Ok(self)
}
fn build_tree_with_progress<S>(
&self,
x: &ArrayBase<S, Ix2>,
y: &ArrayBase<S, Ix1>,
indices: &[usize],
depth: usize,
progress_bar: &ProgressBar,
) -> Result<Node, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
progress_bar.inc(1);
progress_bar.set_message(format!("{}", depth));
let n_samples = indices.len();
if n_samples < self.params.min_samples_split
|| (self.params.max_depth.is_some() && depth >= self.params.max_depth.unwrap())
|| self.is_pure(y, indices)
{
return Ok(self.create_leaf(y, indices));
}
let split_result = self.find_best_split(&x, &y, indices)?;
if let Some((feature_idx, threshold, left_indices, right_indices, impurity_decrease)) =
split_result
{
if impurity_decrease < self.params.min_impurity_decrease {
return Ok(self.create_leaf(y, indices));
}
if left_indices.len() < self.params.min_samples_leaf
|| right_indices.len() < self.params.min_samples_leaf
{
return Ok(self.create_leaf(y, indices));
}
let mut node = Node::new_internal(feature_idx, threshold);
node.left = Some(Box::new(self.build_tree_with_progress(
x,
y,
&left_indices,
depth + 1,
progress_bar,
)?));
node.right = Some(Box::new(self.build_tree_with_progress(
x,
y,
&right_indices,
depth + 1,
progress_bar,
)?));
Ok(node)
} else {
Ok(self.create_leaf(y, indices))
}
}
fn find_best_split<S>(
&self,
x: &ArrayBase<S, Ix2>,
y: &ArrayBase<S, Ix1>,
indices: &[usize],
) -> Result<Option<(usize, f64, Vec<usize>, Vec<usize>, f64)>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
let parent_impurity = self.calculate_impurity(y, indices);
let process_feature =
|feature_idx: usize| -> Option<(f64, (usize, f64, Vec<usize>, Vec<usize>, f64))> {
let mut feature_values: Vec<f64> =
indices.iter().map(|&i| x[[i, feature_idx]]).collect();
feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
feature_values.dedup();
let mut best_feature_gain = 0.0;
let mut best_feature_split: Option<(usize, f64, Vec<usize>, Vec<usize>, f64)> =
None;
for i in 0..feature_values.len().saturating_sub(1) {
let threshold = (feature_values[i] + feature_values[i + 1]) / 2.0;
let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
.iter()
.partition(|&&idx| x[[idx, feature_idx]] <= threshold);
if left_indices.is_empty() || right_indices.is_empty() {
continue;
}
let left_impurity = self.calculate_impurity(y, &left_indices);
let right_impurity = self.calculate_impurity(y, &right_indices);
let n_samples = indices.len() as f64;
let n_left = left_indices.len() as f64;
let n_right = right_indices.len() as f64;
let weighted_impurity = (n_left / n_samples) * left_impurity
+ (n_right / n_samples) * right_impurity;
let impurity_decrease = parent_impurity - weighted_impurity;
if impurity_decrease > best_feature_gain {
best_feature_gain = impurity_decrease;
best_feature_split = Some((
feature_idx,
threshold,
left_indices,
right_indices,
impurity_decrease,
));
}
}
best_feature_split.map(|split| (best_feature_gain, split))
};
let best_split = if indices.len() >= DECISION_TREE_PARALLEL_THRESHOLD {
(0..self.n_features)
.into_par_iter()
.filter_map(process_feature)
.max_by(|(gain_a, _), (gain_b, _)| gain_a.partial_cmp(gain_b).unwrap())
.map(|(_, split)| split)
} else {
(0..self.n_features)
.filter_map(process_feature)
.max_by(|(gain_a, _), (gain_b, _)| gain_a.partial_cmp(gain_b).unwrap())
.map(|(_, split)| split)
};
Ok(best_split)
}
fn calculate_impurity<S>(&self, y: &ArrayBase<S, Ix1>, indices: &[usize]) -> f64
where
S: Data<Elem = f64>,
{
if indices.is_empty() {
return 0.0;
}
if self.is_classifier {
match self.algorithm {
Algorithm::CART => {
let subset = y.select(Axis(0), indices);
let subset_view: ArrayView1<f64> = subset.view();
gini(&subset_view)
}
Algorithm::ID3 | Algorithm::C45 => {
let subset = y.select(Axis(0), indices);
let subset_view: ArrayView1<f64> = subset.view();
entropy(&subset_view)
}
}
} else {
self.calculate_mse(y, indices)
}
}
fn calculate_mse<S>(&self, y: &ArrayBase<S, Ix1>, indices: &[usize]) -> f64
where
S: Data<Elem = f64>,
{
let subset: Array1<f64> = indices.iter().map(|&i| y[i]).collect();
variance(&subset)
}
fn is_pure<S>(&self, y: &ArrayBase<S, Ix1>, indices: &[usize]) -> bool
where
S: Data<Elem = f64>,
{
if indices.is_empty() {
return true;
}
let first_value = y[indices[0]];
indices.iter().all(|&i| (y[i] - first_value).abs() < 1e-10)
}
fn create_leaf<S>(&self, y: &ArrayBase<S, Ix1>, indices: &[usize]) -> Node
where
S: Data<Elem = f64>,
{
if self.is_classifier {
let n_classes = self.n_classes.unwrap();
let mut class_counts = vec![0.0; n_classes];
for &idx in indices {
let class = y[idx] as usize;
class_counts[class] += 1.0;
}
let majority_class = class_counts
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.unwrap_or(0);
let total = indices.len() as f64;
let probabilities: Vec<f64> = class_counts.iter().map(|&count| count / total).collect();
Node::new_leaf(
majority_class as f64,
Some(majority_class),
Some(probabilities),
)
} else {
let mean = indices.iter().map(|&i| y[i]).sum::<f64>() / indices.len() as f64;
Node::new_leaf(mean, None, None)
}
}
pub fn predict_one(&self, x: &[f64]) -> Result<f64, ModelError> {
if self.root.is_none() {
return Err(ModelError::NotFitted);
}
if x.len() != self.n_features {
return Err(ModelError::TreeError("Feature dimension mismatch"));
}
self.traverse_tree(self.root.as_ref().unwrap(), x)
}
fn traverse_tree(&self, node: &Node, x: &[f64]) -> Result<f64, ModelError> {
match &node.node_type {
NodeType::Leaf { value, .. } => Ok(*value),
NodeType::Internal {
feature_index,
threshold,
categories,
} => {
if categories.is_some() {
return Err(ModelError::TreeError(
"Categorical splits not yet implemented",
));
}
if x[*feature_index] <= *threshold {
if let Some(ref left) = node.left {
self.traverse_tree(left, x)
} else {
Err(ModelError::TreeError("Missing left child"))
}
} else {
if let Some(ref right) = node.right {
self.traverse_tree(right, x)
} else {
Err(ModelError::TreeError("Missing right child"))
}
}
}
}
}
pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
if self.root.is_none() {
return Err(ModelError::NotFitted);
}
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Cannot predict on empty dataset".to_string(),
));
}
if x.ncols() != self.n_features {
return Err(ModelError::InputValidationError(format!(
"Number of features does not match training data, x columns: {}, expected: {}",
x.ncols(),
self.n_features
)));
}
if x.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input data contains NaN or infinite values".to_string(),
));
}
let predictions: Result<Vec<f64>, ModelError> =
if x.nrows() >= DECISION_TREE_PARALLEL_THRESHOLD {
x.axis_iter(Axis(0))
.into_par_iter()
.map(|row| {
let row_slice = row.to_vec();
self.predict_one(&row_slice)
})
.collect()
} else {
x.axis_iter(Axis(0))
.map(|row| {
let row_slice = row.to_vec();
self.predict_one(&row_slice)
})
.collect()
};
Ok(Array1::from_vec(predictions?))
}
pub fn fit_predict<S>(
&mut self,
x_train: &ArrayBase<S, Ix2>,
y_train: &ArrayBase<S, Ix1>,
) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
self.fit(x_train, y_train)?;
self.predict(x_train)
}
pub fn predict_proba<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>, ModelError>
where
S: Data<Elem = f64>,
{
if !self.is_classifier {
return Err(ModelError::TreeError(
"predict_proba is only available for classification",
));
}
if self.root.is_none() {
return Err(ModelError::NotFitted);
}
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Cannot predict on empty dataset".to_string(),
));
}
if x.ncols() != self.n_features {
return Err(ModelError::InputValidationError(format!(
"Number of features does not match training data, x columns: {}, expected: {}",
x.ncols(),
self.n_features
)));
}
if x.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input data contains NaN or infinite values".to_string(),
));
}
let n_classes = self.n_classes.unwrap();
let probabilities: Result<Vec<Vec<f64>>, ModelError> =
if x.nrows() >= DECISION_TREE_PARALLEL_THRESHOLD {
x.axis_iter(Axis(0))
.into_par_iter()
.map(|row| {
let row_slice = row.to_vec();
self.predict_proba_one(&row_slice)
})
.collect()
} else {
x.axis_iter(Axis(0))
.map(|row| {
let row_slice = row.to_vec();
self.predict_proba_one(&row_slice)
})
.collect()
};
let probabilities = probabilities?;
let mut result = Array2::zeros((x.nrows(), n_classes));
for (i, proba) in probabilities.iter().enumerate() {
for (j, &p) in proba.iter().enumerate() {
result[[i, j]] = p;
}
}
Ok(result)
}
pub fn predict_proba_one(&self, x: &[f64]) -> Result<Vec<f64>, ModelError> {
if !self.is_classifier {
return Err(ModelError::TreeError(
"predict_proba is only available for classification",
));
}
if self.root.is_none() {
return Err(ModelError::NotFitted);
}
if x.len() != self.n_features {
return Err(ModelError::TreeError("Feature dimension mismatch"));
}
self.get_probabilities(self.root.as_ref().unwrap(), x)
}
fn get_probabilities(&self, node: &Node, x: &[f64]) -> Result<Vec<f64>, ModelError> {
match &node.node_type {
NodeType::Leaf { probabilities, .. } => probabilities
.as_ref()
.cloned()
.ok_or(ModelError::TreeError("No probabilities in leaf node")),
NodeType::Internal {
feature_index,
threshold,
categories,
} => {
if categories.is_some() {
return Err(ModelError::TreeError(
"Categorical splits not yet implemented",
));
}
if x[*feature_index] <= *threshold {
if let Some(ref left) = node.left {
self.get_probabilities(left, x)
} else {
Err(ModelError::TreeError("Missing left child"))
}
} else {
if let Some(ref right) = node.right {
self.get_probabilities(right, x)
} else {
Err(ModelError::TreeError("Missing right child"))
}
}
}
}
}
pub fn generate_tree_structure(&self) -> Result<String, ModelError> {
if self.root.is_none() {
return Err(ModelError::NotFitted);
}
let mut output = String::new();
output.push_str("Decision Tree Structure:\n");
self.print_node(self.root.as_ref().unwrap(), &mut output, "", true);
Ok(output)
}
fn calculate_depth(&self, node: &Node) -> usize {
match &node.node_type {
NodeType::Leaf { .. } => 0,
NodeType::Internal { .. } => {
let left_depth = node.left.as_ref().map_or(0, |n| self.calculate_depth(n));
let right_depth = node.right.as_ref().map_or(0, |n| self.calculate_depth(n));
1 + left_depth.max(right_depth)
}
}
}
fn count_nodes(&self, node: &Node) -> usize {
let mut count = 1; match &node.node_type {
NodeType::Leaf { .. } => count,
NodeType::Internal { .. } => {
if let Some(ref left) = node.left {
count += self.count_nodes(left);
}
if let Some(ref right) = node.right {
count += self.count_nodes(right);
}
count
}
}
}
fn print_node(&self, node: &Node, output: &mut String, prefix: &str, is_last: bool) {
let connector = if is_last { "└── " } else { "├── " };
output.push_str(&format!("{}{}", prefix, connector));
match &node.node_type {
NodeType::Leaf {
value,
class,
probabilities,
} => {
if self.is_classifier {
output.push_str(&format!("Leaf: class={}", class.unwrap()));
if let Some(probs) = probabilities {
output.push_str(&format!(" probs={:?}", probs));
}
} else {
output.push_str(&format!("Leaf: value={:.4}", value));
}
output.push('\n');
}
NodeType::Internal {
feature_index,
threshold,
categories,
} => {
if categories.is_some() {
output.push_str(&format!(
"Split: feature[{}] (categorical)\n",
feature_index
));
} else {
output.push_str(&format!(
"Split: feature[{}] <= {:.4}\n",
feature_index, threshold
));
}
let new_prefix = format!("{}{}", prefix, if is_last { " " } else { "│ " });
if let Some(ref left) = node.left {
self.print_node(left, output, &new_prefix, false);
}
if let Some(ref right) = node.right {
self.print_node(right, output, &new_prefix, true);
}
}
}
}
model_save_and_load_methods!(DecisionTree);
}