use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum DecisionTreeAlgorithm {
ID3,
C45,
}
#[derive(Debug, Clone)]
pub enum DecisionTreeNode<L> {
Internal {
feature_index: usize,
split_value: Option<f64>,
branches: HashMap<String, DecisionTreeNode<L>>,
},
Leaf(L),
}
#[derive(Debug, Clone)]
pub struct DecisionTree<L> {
pub root: DecisionTreeNode<L>,
pub algorithm: DecisionTreeAlgorithm,
}
impl<L: Clone + Eq + std::hash::Hash> DecisionTree<L> {
pub fn fit(
data: &[Vec<String>],
labels: &[L],
feature_names: Option<&[String]>,
algorithm: DecisionTreeAlgorithm,
) -> Self {
assert_eq!(
data.len(),
labels.len(),
"data and labels must match in length"
);
if data.is_empty() {
panic!("No training data provided.");
}
let num_features = data[0].len();
for (i, row) in data.iter().enumerate() {
if row.len() != num_features {
panic!("Row {} has inconsistent feature length.", i);
}
}
let feat_names = match feature_names {
Some(names) => {
assert_eq!(
names.len(),
num_features,
"feature_names must match data columns"
);
names.to_vec()
}
None => (0..num_features).map(|i| format!("F{}", i)).collect(),
};
let root = build_tree(data, labels, &feat_names, &algorithm);
Self { root, algorithm }
}
pub fn predict(&self, features: &[String]) -> L {
traverse_tree(&self.root, features)
}
pub fn predict_batch(&self, data: &[Vec<String>]) -> Vec<L> {
data.iter().map(|row| self.predict(row)).collect()
}
}
fn build_tree<L: Clone + Eq + std::hash::Hash>(
data: &[Vec<String>],
labels: &[L],
feature_names: &[String],
algo: &DecisionTreeAlgorithm,
) -> DecisionTreeNode<L> {
if all_same(labels) {
return DecisionTreeNode::Leaf(labels[0].clone());
}
if data.is_empty() || data[0].is_empty() || feature_names.is_empty() {
return DecisionTreeNode::Leaf(majority_label(labels));
}
match algo {
DecisionTreeAlgorithm::ID3 => {
let (best_feat_idx, best_split) = find_best_split_id3(data, labels);
if let Some(split_val) = best_split {
make_continuous_node(data, labels, feature_names, best_feat_idx, split_val, algo)
} else {
make_categorical_node(data, labels, feature_names, best_feat_idx, algo)
}
}
DecisionTreeAlgorithm::C45 => {
let (best_feat_idx, best_split) = find_best_split_c45(data, labels);
if let Some(split_val) = best_split {
make_continuous_node(data, labels, feature_names, best_feat_idx, split_val, algo)
} else {
make_categorical_node(data, labels, feature_names, best_feat_idx, algo)
}
}
}
}
fn make_continuous_node<L: Clone + Eq + std::hash::Hash>(
data: &[Vec<String>],
labels: &[L],
feature_names: &[String],
feat_idx: usize,
split_value: f64,
algo: &DecisionTreeAlgorithm,
) -> DecisionTreeNode<L> {
let mut left_data = Vec::new();
let mut left_labels = Vec::new();
let mut right_data = Vec::new();
let mut right_labels = Vec::new();
for (row, lbl) in data.iter().zip(labels.iter()) {
let val = row[feat_idx].parse::<f64>().unwrap_or(f64::NAN);
if val.is_nan() {
continue; }
if val <= split_value {
left_data.push(row.clone());
left_labels.push(lbl.clone());
} else {
right_data.push(row.clone());
right_labels.push(lbl.clone());
}
}
let mut branches = HashMap::new();
branches.insert(
"≤".to_string(),
build_tree(&left_data, &left_labels, feature_names, algo),
);
branches.insert(
">".to_string(),
build_tree(&right_data, &right_labels, feature_names, algo),
);
DecisionTreeNode::Internal {
feature_index: feat_idx,
split_value: Some(split_value),
branches,
}
}
fn make_categorical_node<L: Clone + Eq + std::hash::Hash>(
data: &[Vec<String>],
labels: &[L],
feature_names: &[String],
feat_idx: usize,
algo: &DecisionTreeAlgorithm,
) -> DecisionTreeNode<L> {
let mut subsets: HashMap<String, (Vec<Vec<String>>, Vec<L>)> = HashMap::new();
for (row, lbl) in data.iter().zip(labels.iter()) {
let val = row[feat_idx].clone();
subsets.entry(val).or_default().0.push(row.clone());
subsets
.entry(row[feat_idx].clone())
.or_default()
.1
.push(lbl.clone());
}
let is_categorical = true; let updated_feature_names = if is_categorical {
let mut new_names = feature_names.to_vec();
new_names.remove(feat_idx);
new_names
} else {
feature_names.to_vec()
};
let mut branches = HashMap::new();
for (val, (sub_data, sub_labels)) in subsets.into_iter() {
let next_data = if is_categorical {
remove_column(&sub_data, feat_idx)
} else {
sub_data
};
let child = build_tree(&next_data, &sub_labels, &updated_feature_names, algo);
branches.insert(val, child);
}
DecisionTreeNode::Internal {
feature_index: feat_idx,
split_value: None,
branches,
}
}
fn find_best_split_id3<L: Clone + Eq + std::hash::Hash>(
data: &[Vec<String>],
labels: &[L],
) -> (usize, Option<f64>) {
let num_features = data[0].len();
let base_entropy = entropy(labels);
let mut best_gain = f64::NEG_INFINITY;
let mut best_feat = 0;
let mut best_split: Option<f64> = None;
for feat_idx in 0..num_features {
if is_continuous_column(data, feat_idx) {
let thresholds = possible_thresholds(data, feat_idx);
for &th in &thresholds {
let gain = info_gain_continuous(data, labels, feat_idx, th, base_entropy);
if gain > best_gain {
best_gain = gain;
best_feat = feat_idx;
best_split = Some(th);
}
}
} else {
let gain = info_gain_categorical(data, labels, feat_idx, base_entropy);
if gain > best_gain {
best_gain = gain;
best_feat = feat_idx;
best_split = None;
}
}
}
(best_feat, best_split)
}
fn find_best_split_c45<L: Clone + Eq + std::hash::Hash>(
data: &[Vec<String>],
labels: &[L],
) -> (usize, Option<f64>) {
let num_features = data[0].len();
let base_entropy = entropy(labels);
let mut best_ratio = f64::NEG_INFINITY;
let mut best_feat = 0;
let mut best_split: Option<f64> = None;
for feat_idx in 0..num_features {
if is_continuous_column(data, feat_idx) {
let thresholds = possible_thresholds(data, feat_idx);
for &th in &thresholds {
let gain = info_gain_continuous(data, labels, feat_idx, th, base_entropy);
if gain <= 0.0 {
continue;
}
let split_info = split_info_continuous(data, feat_idx, th);
let ratio = if split_info.abs() < 1e-12 {
0.0
} else {
gain / split_info
};
if ratio > best_ratio {
best_ratio = ratio;
best_feat = feat_idx;
best_split = Some(th);
}
}
} else {
let gain = info_gain_categorical(data, labels, feat_idx, base_entropy);
if gain <= 0.0 {
continue;
}
let split_info = split_info_categorical(data, feat_idx);
let ratio = if split_info.abs() < 1e-12 {
0.0
} else {
gain / split_info
};
if ratio > best_ratio {
best_ratio = ratio;
best_feat = feat_idx;
best_split = None;
}
}
}
(best_feat, best_split)
}
fn info_gain_continuous<L: Clone + Eq + std::hash::Hash>(
data: &[Vec<String>],
labels: &[L],
feat_idx: usize,
threshold: f64,
base_entropy: f64,
) -> f64 {
let mut left_labels = Vec::new();
let mut right_labels = Vec::new();
for (row, lbl) in data.iter().zip(labels.iter()) {
let val = row[feat_idx].parse::<f64>().unwrap_or(f64::NAN);
if val.is_nan() {
continue; }
if val <= threshold {
left_labels.push(lbl.clone());
} else {
right_labels.push(lbl.clone());
}
}
let n = (left_labels.len() + right_labels.len()) as f64;
let h_left = entropy(&left_labels);
let h_right = entropy(&right_labels);
let w_left = left_labels.len() as f64 / n;
let w_right = right_labels.len() as f64 / n;
base_entropy - (w_left * h_left + w_right * h_right)
}
fn info_gain_categorical<L: Clone + Eq + std::hash::Hash>(
data: &[Vec<String>],
labels: &[L],
feat_idx: usize,
base_entropy: f64,
) -> f64 {
let mut subsets: HashMap<String, Vec<L>> = HashMap::new();
for (row, lbl) in data.iter().zip(labels.iter()) {
subsets
.entry(row[feat_idx].clone())
.or_default()
.push(lbl.clone());
}
let n = labels.len() as f64;
let mut remainder = 0.0;
for (_val, sub_labels) in subsets.into_iter() {
let w = sub_labels.len() as f64 / n;
remainder += w * entropy(&sub_labels);
}
base_entropy - remainder
}
fn split_info_continuous(data: &[Vec<String>], feat_idx: usize, threshold: f64) -> f64 {
let mut left_count = 0;
let mut right_count = 0;
for row in data {
let val = row[feat_idx].parse::<f64>().unwrap_or(f64::NAN);
if !val.is_nan() {
if val <= threshold {
left_count += 1;
} else {
right_count += 1;
}
}
}
let n = (left_count + right_count) as f64;
let mut si = 0.0;
if left_count > 0 {
let p = left_count as f64 / n;
si -= p * log2(p);
}
if right_count > 0 {
let p = right_count as f64 / n;
si -= p * log2(p);
}
si
}
fn split_info_categorical(data: &[Vec<String>], feat_idx: usize) -> f64 {
let mut counts = HashMap::new();
for row in data {
*counts.entry(row[feat_idx].clone()).or_insert(0) += 1;
}
let n = data.len() as f64;
let mut si = 0.0;
for (_val, count) in counts {
let p = count as f64 / n;
si -= p * log2(p);
}
si
}
fn possible_thresholds(data: &[Vec<String>], feat_idx: usize) -> Vec<f64> {
let mut vals = Vec::new();
for row in data {
if let Ok(x) = row[feat_idx].parse::<f64>() {
if !x.is_nan() {
vals.push(x);
}
}
}
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
vals.dedup();
let mut thresholds = Vec::new();
for w in vals.windows(2) {
let mid = 0.5 * (w[0] + w[1]);
thresholds.push(mid);
}
thresholds
}
fn is_continuous_column(data: &[Vec<String>], feat_idx: usize) -> bool {
let mut all_numeric = true;
for row in data {
if row[feat_idx].parse::<f64>().is_err() {
all_numeric = false;
break;
}
}
all_numeric
}
fn remove_column(data: &[Vec<String>], col_idx: usize) -> Vec<Vec<String>> {
let mut out = Vec::new();
for row in data {
let mut new_row = row.clone();
new_row.remove(col_idx);
out.push(new_row);
}
out
}
fn traverse_tree<L: Clone>(node: &DecisionTreeNode<L>, features: &[String]) -> L {
match node {
DecisionTreeNode::Leaf(lbl) => lbl.clone(),
DecisionTreeNode::Internal {
feature_index,
split_value,
branches,
} => {
if let Some(th) = split_value {
let val = features[*feature_index].parse::<f64>().unwrap_or(f64::NAN);
let branch_key = if val <= *th { "≤" } else { ">" };
match branches.get(branch_key) {
Some(next_node) => traverse_tree(next_node, features),
None => {
let mut iter = branches.values();
iter.next().unwrap().clone_leaf()
}
}
} else {
let feat_val = &features[*feature_index];
match branches.get(feat_val) {
Some(next_node) => traverse_tree(next_node, features),
None => {
let mut iter = branches.values();
iter.next().unwrap().clone_leaf()
}
}
}
}
}
}
impl<L: Clone> DecisionTreeNode<L> {
fn clone_leaf(&self) -> L {
match self {
DecisionTreeNode::Leaf(lbl) => lbl.clone(),
DecisionTreeNode::Internal { branches, .. } => {
let first = branches.values().next().expect("No branches in node");
first.clone_leaf()
}
}
}
}
fn all_same<L: PartialEq>(labels: &[L]) -> bool {
if labels.is_empty() {
return true;
}
labels.iter().all(|x| x == &labels[0])
}
fn majority_label<L: Clone + Eq + std::hash::Hash>(labels: &[L]) -> L {
let mut counts = HashMap::new();
for lbl in labels {
*counts.entry(lbl.clone()).or_insert(0) += 1;
}
counts.into_iter().max_by_key(|(_k, v)| *v).unwrap().0
}
fn entropy<L: Eq + std::hash::Hash>(labels: &[L]) -> f64 {
let mut counts = HashMap::new();
for lbl in labels {
*counts.entry(lbl).or_insert(0) += 1;
}
let n = labels.len() as f64;
let mut ent = 0.0;
for (_lbl, count) in counts.into_iter() {
let p = count as f64 / n;
ent -= p * log2(p);
}
ent
}
fn log2(x: f64) -> f64 {
x.ln() / std::f64::consts::LN_2
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_id3_basic() {
let data = vec![
vec!["Red".to_string(), "Round".to_string()],
vec!["Blue".to_string(), "Square".to_string()],
vec!["Red".to_string(), "Square".to_string()],
vec!["Blue".to_string(), "Round".to_string()],
];
let labels = vec!["Yes", "No", "Yes", "No"];
let tree = DecisionTree::fit(&data, &labels, None, DecisionTreeAlgorithm::ID3);
let pred = tree.predict(&["Red".to_string(), "Round".to_string()]);
assert_eq!(pred, "Yes");
}
#[test]
fn test_c45_continuous() {
let data = vec![
vec!["30.5".to_string(), "Red".to_string()],
vec!["35.0".to_string(), "Blue".to_string()],
vec!["40.0".to_string(), "Red".to_string()],
vec!["45.0".to_string(), "Blue".to_string()],
vec!["50.0".to_string(), "Blue".to_string()],
];
let labels = vec!["Buy", "NoBuy", "Buy", "NoBuy", "NoBuy"];
let tree = DecisionTree::fit(&data, &labels, None, DecisionTreeAlgorithm::C45);
let pred1 = tree.predict(&["38.0".to_string(), "Red".to_string()]);
assert_eq!(pred1, "Buy");
let pred2 = tree.predict(&["47.0".to_string(), "Blue".to_string()]);
assert_eq!(pred2, "NoBuy");
}
}