use std::fmt;
use crate::learner::StreamingLearner;
use irithyll_core::error::ConfigError;
use irithyll_core::rng::xorshift64;
#[inline]
fn rand_f64(state: &mut u64) -> f64 {
(xorshift64(state) as f64) / (u64::MAX as f64)
}
#[derive(Clone, Debug)]
pub struct MondrianForestConfig {
pub n_trees: usize,
pub max_depth: usize,
pub lifetime: f64,
pub seed: u64,
}
impl MondrianForestConfig {
#[inline]
pub fn builder() -> MondrianForestConfigBuilder {
MondrianForestConfigBuilder::default()
}
}
impl Default for MondrianForestConfig {
fn default() -> Self {
Self {
n_trees: 10,
max_depth: 8,
lifetime: 5.0,
seed: 42,
}
}
}
pub struct MondrianForestConfigBuilder {
n_trees: usize,
max_depth: usize,
lifetime: f64,
seed: u64,
}
impl Default for MondrianForestConfigBuilder {
fn default() -> Self {
Self {
n_trees: 10,
max_depth: 8,
lifetime: 5.0,
seed: 42,
}
}
}
impl MondrianForestConfigBuilder {
#[inline]
pub fn n_trees(mut self, n: usize) -> Self {
self.n_trees = n;
self
}
#[inline]
pub fn max_depth(mut self, d: usize) -> Self {
self.max_depth = d;
self
}
#[inline]
pub fn lifetime(mut self, l: f64) -> Self {
self.lifetime = l;
self
}
#[inline]
pub fn seed(mut self, s: u64) -> Self {
self.seed = s;
self
}
pub fn build(self) -> Result<MondrianForestConfig, ConfigError> {
if self.n_trees == 0 {
return Err(ConfigError::out_of_range(
"n_trees",
"must be >= 1",
self.n_trees,
));
}
if self.max_depth == 0 {
return Err(ConfigError::out_of_range(
"max_depth",
"must be >= 1",
self.max_depth,
));
}
if self.lifetime <= 0.0 {
return Err(ConfigError::out_of_range(
"lifetime",
"must be > 0",
self.lifetime,
));
}
Ok(MondrianForestConfig {
n_trees: self.n_trees,
max_depth: self.max_depth,
lifetime: self.lifetime,
seed: self.seed,
})
}
}
#[derive(Clone)]
struct MondrianTree {
split_feature: Vec<usize>,
split_threshold: Vec<f64>,
left_child: Vec<Option<usize>>,
right_child: Vec<Option<usize>>,
depth: Vec<usize>,
lower: Vec<Vec<f64>>,
upper: Vec<Vec<f64>>,
sum_targets: Vec<f64>,
sum_weights: Vec<f64>,
count: Vec<u64>,
split_time: Vec<f64>,
rng_state: u64,
}
impl MondrianTree {
fn new(seed: u64, n_features: usize) -> Self {
let mut tree = Self {
split_feature: Vec::with_capacity(64),
split_threshold: Vec::with_capacity(64),
left_child: Vec::with_capacity(64),
right_child: Vec::with_capacity(64),
depth: Vec::with_capacity(64),
lower: Vec::with_capacity(64),
upper: Vec::with_capacity(64),
sum_targets: Vec::with_capacity(64),
sum_weights: Vec::with_capacity(64),
count: Vec::with_capacity(64),
split_time: Vec::with_capacity(64),
rng_state: if seed == 0 { 1 } else { seed },
};
tree.alloc_leaf(0, n_features);
tree
}
fn alloc_leaf(&mut self, depth: usize, n_features: usize) -> usize {
let idx = self.split_feature.len();
self.split_feature.push(0);
self.split_threshold.push(0.0);
self.left_child.push(None);
self.right_child.push(None);
self.depth.push(depth);
self.lower.push(vec![f64::MAX; n_features]);
self.upper.push(vec![f64::MIN; n_features]);
self.sum_targets.push(0.0);
self.sum_weights.push(0.0);
self.count.push(0);
self.split_time.push(0.0);
idx
}
#[inline]
fn is_leaf(&self, idx: usize) -> bool {
self.left_child[idx].is_none()
}
fn route_to_leaf(&self, features: &[f64]) -> usize {
let mut idx = 0;
loop {
if self.is_leaf(idx) {
return idx;
}
let f = self.split_feature[idx];
if features[f] <= self.split_threshold[idx] {
idx = self.left_child[idx].unwrap();
} else {
idx = self.right_child[idx].unwrap();
}
}
}
fn expand_bbox(&mut self, idx: usize, features: &[f64]) {
let lower = &mut self.lower[idx];
let upper = &mut self.upper[idx];
for (j, &x) in features.iter().enumerate() {
if x < lower[j] {
lower[j] = x;
}
if x > upper[j] {
upper[j] = x;
}
}
}
fn train_one(
&mut self,
features: &[f64],
target: f64,
weight: f64,
max_depth: usize,
min_split_count: u64,
) {
let n_features = features.len();
let leaf = self.route_to_leaf(features);
self.sum_targets[leaf] += target * weight;
self.sum_weights[leaf] += weight;
self.count[leaf] += 1;
self.expand_bbox(leaf, features);
if self.count[leaf] < min_split_count || self.depth[leaf] >= max_depth {
return;
}
let mut ranges = vec![0.0f64; n_features];
let mut total_range = 0.0f64;
for (j, range_j) in ranges.iter_mut().enumerate() {
let r = self.upper[leaf][j] - self.lower[leaf][j];
let r = if r.is_finite() && r > 0.0 { r } else { 0.0 };
*range_j = r;
total_range += r;
}
if total_range < 1e-15 {
return;
}
let dart = rand_f64(&mut self.rng_state) * total_range;
let mut cumulative = 0.0;
let mut chosen_feature = 0;
for (j, &range_j) in ranges.iter().enumerate() {
cumulative += range_j;
if dart <= cumulative {
chosen_feature = j;
break;
}
}
let lo = self.lower[leaf][chosen_feature];
let hi = self.upper[leaf][chosen_feature];
let jitter = rand_f64(&mut self.rng_state); let threshold = lo + (hi - lo) * (0.25 + 0.5 * jitter);
let leaf_depth = self.depth[leaf];
let left_idx = self.alloc_leaf(leaf_depth + 1, n_features);
let right_idx = self.alloc_leaf(leaf_depth + 1, n_features);
self.split_feature[leaf] = chosen_feature;
self.split_threshold[leaf] = threshold;
self.left_child[leaf] = Some(left_idx);
self.right_child[leaf] = Some(right_idx);
self.split_time[leaf] = total_range;
let half_target = self.sum_targets[leaf] / 2.0;
let half_weight = self.sum_weights[leaf] / 2.0;
let half_count = self.count[leaf] / 2;
self.sum_targets[left_idx] = half_target;
self.sum_weights[left_idx] = half_weight;
self.count[left_idx] = half_count.max(1);
self.sum_targets[right_idx] = self.sum_targets[leaf] - half_target;
self.sum_weights[right_idx] = self.sum_weights[leaf] - half_weight;
self.count[right_idx] = (self.count[leaf] - half_count).max(1);
self.lower[left_idx] = self.lower[leaf].clone();
self.upper[left_idx] = self.upper[leaf].clone();
self.upper[left_idx][chosen_feature] = threshold;
self.lower[right_idx] = self.lower[leaf].clone();
self.upper[right_idx] = self.upper[leaf].clone();
self.lower[right_idx][chosen_feature] = threshold;
}
#[inline]
fn predict(&self, features: &[f64]) -> f64 {
let leaf = self.route_to_leaf(features);
if self.sum_weights[leaf] > 0.0 {
self.sum_targets[leaf] / self.sum_weights[leaf]
} else {
0.0
}
}
fn reset(&mut self, n_features: usize) {
self.split_feature.clear();
self.split_threshold.clear();
self.left_child.clear();
self.right_child.clear();
self.depth.clear();
self.lower.clear();
self.upper.clear();
self.sum_targets.clear();
self.sum_weights.clear();
self.count.clear();
self.split_time.clear();
self.alloc_leaf(0, n_features);
}
#[inline]
fn n_nodes(&self) -> usize {
self.split_feature.len()
}
}
impl fmt::Debug for MondrianTree {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MondrianTree")
.field("n_nodes", &self.n_nodes())
.field("rng_state", &self.rng_state)
.finish()
}
}
#[derive(Clone)]
pub struct MondrianForest {
config: MondrianForestConfig,
trees: Vec<MondrianTree>,
samples_seen: u64,
n_features: Option<usize>,
}
impl MondrianForest {
pub fn new(config: MondrianForestConfig) -> Self {
Self {
trees: Vec::with_capacity(config.n_trees),
samples_seen: 0,
n_features: None,
config,
}
}
#[inline]
pub fn n_trees(&self) -> usize {
self.config.n_trees
}
#[inline]
pub fn config(&self) -> &MondrianForestConfig {
&self.config
}
fn init_trees(&mut self, n_features: usize) {
self.n_features = Some(n_features);
self.trees.clear();
for i in 0..self.config.n_trees {
let seed = self
.config
.seed
.wrapping_add(i as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(1);
let seed = if seed == 0 { 1 } else { seed };
self.trees.push(MondrianTree::new(seed, n_features));
}
}
#[inline]
fn min_split_count(&self) -> u64 {
(2 * self.config.n_trees) as u64
}
}
impl Default for MondrianForest {
fn default() -> Self {
Self::new(MondrianForestConfig::default())
}
}
impl StreamingLearner for MondrianForest {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let n_features = features.len();
if self.n_features.is_none() {
self.init_trees(n_features);
}
let max_depth = self.config.max_depth;
let min_split = self.min_split_count();
for tree in &mut self.trees {
tree.train_one(features, target, weight, max_depth, min_split);
}
self.samples_seen += 1;
}
fn predict(&self, features: &[f64]) -> f64 {
if self.trees.is_empty() {
return 0.0;
}
let sum: f64 = self.trees.iter().map(|t| t.predict(features)).sum();
sum / self.trees.len() as f64
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
if let Some(nf) = self.n_features {
for tree in &mut self.trees {
tree.reset(nf);
}
} else {
self.trees.clear();
}
self.samples_seen = 0;
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
<Self as crate::learner::Tunable>::adjust_config(self, lr_multiplier, lambda_delta);
}
}
impl crate::learner::Tunable for MondrianForest {
fn diagnostics_array(&self) -> [f64; 5] {
[
0.0,
0.0,
0.0,
self.trees.len() as f64,
1.0 / (1.0 + self.samples_seen as f64),
]
}
fn adjust_config(&mut self, lr_multiplier: f64, _lambda_delta: f64) {
self.config.lifetime = (self.config.lifetime * lr_multiplier).max(1e-6);
}
}
impl fmt::Debug for MondrianForest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MondrianForest")
.field("config", &self.config)
.field("n_trees", &self.trees.len())
.field("samples_seen", &self.samples_seen)
.field("n_features", &self.n_features)
.finish()
}
}
impl crate::automl::DiagnosticSource for MondrianForest {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
Some(crate::automl::ConfigDiagnostics {
effective_dof: self.trees.len() as f64,
uncertainty: 1.0 / (1.0 + self.samples_seen as f64),
..Default::default()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn forest_with_trees(n: usize) -> MondrianForest {
let config = MondrianForestConfig::builder()
.n_trees(n)
.max_depth(8)
.lifetime(5.0)
.seed(42)
.build()
.expect("valid config");
MondrianForest::new(config)
}
#[test]
fn test_creation() {
let config = MondrianForestConfig::builder()
.n_trees(15)
.max_depth(6)
.lifetime(3.0)
.seed(99)
.build()
.expect("valid config");
let forest = MondrianForest::new(config);
assert_eq!(forest.n_samples_seen(), 0);
assert_eq!(forest.n_trees(), 15);
assert_eq!(forest.config().max_depth, 6);
assert!((forest.config().lifetime - 3.0).abs() < 1e-12);
assert_eq!(forest.config().seed, 99);
assert!(forest.trees.is_empty());
}
#[test]
fn test_default_config() {
let config = MondrianForestConfig::default();
assert_eq!(config.n_trees, 10);
assert_eq!(config.max_depth, 8);
assert!((config.lifetime - 5.0).abs() < 1e-12);
assert_eq!(config.seed, 42);
let forest = MondrianForest::default();
assert_eq!(forest.n_trees(), 10);
assert_eq!(forest.n_samples_seen(), 0);
}
#[test]
fn test_single_sample() {
let mut forest = forest_with_trees(10);
forest.train(&[3.0, 4.0], 7.0);
assert_eq!(forest.n_samples_seen(), 1);
let pred = forest.predict(&[3.0, 4.0]);
assert!(
(pred - 7.0).abs() < 1e-12,
"single sample prediction should be 7.0, got {}",
pred,
);
}
#[test]
fn test_multiple_samples() {
let mut forest = forest_with_trees(10);
for i in 0..200 {
let x1 = (i as f64) * 0.05;
let x2 = (i as f64) * 0.03;
forest.train(&[x1, x2], x1 + x2);
}
assert_eq!(forest.n_samples_seen(), 200);
let pred = forest.predict(&[5.0, 3.0]);
assert!(pred.is_finite(), "prediction must be finite, got {}", pred);
assert!(
pred > 0.0,
"prediction should be positive for positive inputs, got {}",
pred,
);
}
#[test]
fn test_convergence() {
let mut forest = forest_with_trees(10);
let constant_target = 42.0;
for i in 0..500 {
let x = (i as f64) * 0.01;
forest.train(&[x, x * 2.0], constant_target);
}
let pred = forest.predict(&[2.5, 5.0]);
assert!(
(pred - constant_target).abs() < 1.0,
"expected prediction near {}, got {}",
constant_target,
pred,
);
}
#[test]
fn test_different_regions() {
let mut forest = forest_with_trees(20);
for i in 0..300 {
let x = (i as f64) * 0.001;
forest.train(&[x, x], 10.0);
}
for i in 0..300 {
let x = 100.0 + (i as f64) * 0.001;
forest.train(&[x, x], 90.0);
}
let pred_a = forest.predict(&[0.1, 0.1]);
let pred_b = forest.predict(&[100.1, 100.1]);
assert!(
pred_b > pred_a,
"region B prediction ({}) should exceed region A ({})",
pred_b,
pred_a,
);
}
#[test]
fn test_reset() {
let mut forest = forest_with_trees(10);
for i in 0..100 {
forest.train(&[i as f64, (i as f64) * 0.5], i as f64);
}
assert_eq!(forest.n_samples_seen(), 100);
forest.reset();
assert_eq!(forest.n_samples_seen(), 0);
for tree in &forest.trees {
assert_eq!(
tree.n_nodes(),
1,
"tree should have exactly 1 node after reset"
);
assert!(tree.is_leaf(0));
}
let pred = forest.predict(&[5.0, 2.5]);
assert!(
pred.abs() < 1e-12,
"prediction after reset should be 0.0, got {}",
pred,
);
}
#[test]
fn test_predict_batch() {
let mut forest = forest_with_trees(10);
for i in 0..100 {
let x = i as f64;
forest.train(&[x, x * 0.5], x);
}
let rows: Vec<&[f64]> = vec![&[1.0, 0.5], &[50.0, 25.0], &[99.0, 49.5]];
let batch = forest.predict_batch(&rows);
assert_eq!(batch.len(), rows.len());
for (i, row) in rows.iter().enumerate() {
let individual = forest.predict(row);
assert!(
(batch[i] - individual).abs() < 1e-12,
"batch[{}] = {} != individual = {}",
i,
batch[i],
individual,
);
}
}
#[test]
fn test_trait_object() {
let forest = forest_with_trees(5);
let mut boxed: Box<dyn StreamingLearner> = Box::new(forest);
boxed.train(&[1.0, 2.0], 3.0);
assert_eq!(boxed.n_samples_seen(), 1);
let pred = boxed.predict(&[1.0, 2.0]);
assert!(pred.is_finite());
boxed.reset();
assert_eq!(boxed.n_samples_seen(), 0);
}
#[test]
fn test_clone() {
let mut forest = forest_with_trees(10);
for i in 0..100 {
forest.train(&[i as f64, (i as f64) * 2.0], i as f64);
}
let mut cloned = forest.clone();
assert_eq!(cloned.n_samples_seen(), forest.n_samples_seen());
let features = [50.0, 100.0];
let pred_orig = forest.predict(&features);
let pred_clone = cloned.predict(&features);
assert!(
(pred_orig - pred_clone).abs() < 1e-12,
"clone prediction should match original: {} vs {}",
pred_orig,
pred_clone,
);
for i in 0..50 {
cloned.train(&[i as f64, (i as f64) * 2.0], 999.0);
}
assert_eq!(forest.n_samples_seen(), 100);
assert_eq!(cloned.n_samples_seen(), 150);
let pred_orig_after = forest.predict(&features);
assert!(
(pred_orig - pred_orig_after).abs() < 1e-12,
"original should be unchanged after training clone",
);
}
#[test]
fn test_multi_tree() {
let mut forest_5 = forest_with_trees(5);
let mut forest_50 = MondrianForest::new(
MondrianForestConfig::builder()
.n_trees(50)
.max_depth(8)
.seed(42)
.build()
.expect("valid config"),
);
let mut rng_state: u64 = 12345;
let mut data = Vec::new();
for _ in 0..300 {
let x = rand_f64(&mut rng_state) * 10.0;
let noise = (rand_f64(&mut rng_state) - 0.5) * 2.0;
data.push((x, x + noise));
}
for &(x, y) in &data {
forest_5.train(&[x], y);
forest_50.train(&[x], y);
}
let mut mse_5 = 0.0;
let mut mse_50 = 0.0;
let test_points = 50;
for i in 0..test_points {
let x = (i as f64) * 0.2;
let true_y = x; let p5 = forest_5.predict(&[x]);
let p50 = forest_50.predict(&[x]);
mse_5 += (p5 - true_y).powi(2);
mse_50 += (p50 - true_y).powi(2);
}
mse_5 /= test_points as f64;
mse_50 /= test_points as f64;
assert!(mse_5.is_finite(), "MSE for 5 trees should be finite");
assert!(mse_50.is_finite(), "MSE for 50 trees should be finite");
}
#[test]
fn test_n_samples_seen() {
let mut forest = forest_with_trees(10);
assert_eq!(forest.n_samples_seen(), 0);
for i in 1..=75 {
forest.train(&[i as f64], i as f64);
assert_eq!(forest.n_samples_seen(), i);
}
forest.train_one(&[100.0], 100.0, 5.0);
assert_eq!(forest.n_samples_seen(), 76);
}
}