mod classifier;
mod gradient_boosting;
mod helpers;
mod random_forest;
mod regressor;
pub use classifier::DecisionTreeClassifier;
pub use gradient_boosting::GradientBoostingClassifier;
pub use helpers::{gini_impurity, gini_split};
pub use random_forest::{RandomForestClassifier, RandomForestRegressor};
pub use regressor::DecisionTreeRegressor;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
pub feature_idx: usize,
pub threshold: f32,
pub left: Box<TreeNode>,
pub right: Box<TreeNode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Leaf {
pub class_label: usize,
pub n_samples: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TreeNode {
Node(Node),
Leaf(Leaf),
}
impl TreeNode {
#[must_use]
pub fn depth(&self) -> usize {
match self {
TreeNode::Leaf(_) => 0,
TreeNode::Node(node) => 1 + node.left.depth().max(node.right.depth()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegressionLeaf {
pub value: f32,
pub n_samples: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegressionNode {
pub feature_idx: usize,
pub threshold: f32,
pub left: Box<RegressionTreeNode>,
pub right: Box<RegressionTreeNode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RegressionTreeNode {
Node(RegressionNode),
Leaf(RegressionLeaf),
}
impl RegressionTreeNode {
#[must_use]
pub fn depth(&self) -> usize {
match self {
RegressionTreeNode::Leaf(_) => 0,
RegressionTreeNode::Node(node) => 1 + node.left.depth().max(node.right.depth()),
}
}
}
#[cfg(test)]
mod tests;