use super::helper_function::preliminary_check;
use crate::error::ModelError;
use crate::machine_learning::decision_tree::{Node, NodeType};
use crate::math::average_path_length_factor;
use crate::{Deserialize, Serialize};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, ArrayBase, Axis, Data, Ix2};
use ndarray_rand::rand::rngs::StdRng;
use ndarray_rand::rand::{Rng, SeedableRng, rng};
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
const DEFAULT_PARALLEL_THRESHOLD_TREES: usize = 10;
const DEFAULT_PARALLEL_THRESHOLD_SAMPLES: usize = 100;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct IsolationForest {
trees: Option<Vec<Box<Node>>>,
n_estimators: usize,
max_samples: usize,
max_depth: usize,
random_state: Option<u64>,
n_features: usize,
}
impl Default for IsolationForest {
fn default() -> Self {
Self {
trees: None,
n_estimators: 100,
max_samples: 256,
max_depth: 8, random_state: None,
n_features: 0,
}
}
}
impl IsolationForest {
pub fn new(
n_estimators: usize,
max_samples: usize,
max_depth: Option<usize>,
random_state: Option<u64>,
) -> Result<Self, ModelError> {
if n_estimators == 0 {
return Err(ModelError::InputValidationError(
"n_estimators must be greater than 0".to_string(),
));
}
if max_samples == 0 {
return Err(ModelError::InputValidationError(
"max_samples must be greater than 0".to_string(),
));
}
if let Some(depth) = max_depth {
if depth == 0 {
return Err(ModelError::InputValidationError(
"max_depth must be greater than 0".to_string(),
));
}
}
let computed_max_depth = max_depth.unwrap_or_else(|| {
(max_samples as f64).log2().ceil() as usize
});
Ok(Self {
trees: None,
n_estimators,
max_samples,
max_depth: computed_max_depth,
random_state,
n_features: 0,
})
}
get_field!(get_n_estimators, n_estimators, usize);
get_field!(get_max_samples, max_samples, usize);
get_field!(get_max_depth, max_depth, usize);
get_field!(get_random_state, random_state, Option<u64>);
get_field!(get_n_features, n_features, usize);
get_field_as_ref!(get_trees, trees, Option<&Vec<Box<Node>>>);
pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
preliminary_check(x, None)?;
self.n_features = x.ncols();
let progress_bar = ProgressBar::new(self.n_estimators as u64);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
progress_bar.set_message("Building isolation trees");
let build_tree = |i: usize| -> Result<Box<Node>, ModelError> {
let mut rng = if let Some(seed) = self.random_state {
StdRng::seed_from_u64(seed.wrapping_add(i as u64))
} else {
StdRng::from_rng(&mut rng())
};
let sample_size = self.max_samples.min(x.nrows());
let sample_indices = self.sample_indices(x.nrows(), sample_size, &mut rng);
let result = self
.build_isolation_tree(x, &sample_indices, 0, &mut rng)
.map(Box::new);
progress_bar.inc(1);
result
};
let trees: Result<Vec<Box<Node>>, ModelError> =
if self.n_estimators >= DEFAULT_PARALLEL_THRESHOLD_TREES {
(0..self.n_estimators)
.into_par_iter()
.map(build_tree)
.collect()
} else {
(0..self.n_estimators).map(build_tree).collect()
};
progress_bar.finish_with_message("Trees built successfully");
self.trees = Some(trees?);
println!(
"\nIsolation Forest training completed: {} samples, {} features, {} trees (max depth: {}, max samples per tree: {})",
x.nrows(),
x.ncols(),
self.n_estimators,
self.max_depth,
self.max_samples
);
Ok(self)
}
fn sample_indices(&self, n: usize, sample_size: usize, rng: &mut StdRng) -> Vec<usize> {
let mut indices: Vec<usize> = (0..n).collect();
for i in 0..sample_size {
let j = rng.random_range(i..n);
indices.swap(i, j);
}
indices.truncate(sample_size);
indices
}
fn build_isolation_tree<S>(
&self,
x: &ArrayBase<S, Ix2>,
indices: &[usize],
current_depth: usize,
rng: &mut StdRng,
) -> Result<Node, ModelError>
where
S: Data<Elem = f64>,
{
if current_depth >= self.max_depth || indices.len() <= 1 {
return Ok(Node::new_leaf(indices.len() as f64, None, None));
}
let feature_index = rng.random_range(0..self.n_features);
let mut min_val = f64::INFINITY;
let mut max_val = f64::NEG_INFINITY;
for &idx in indices {
let val = x[[idx, feature_index]];
min_val = min_val.min(val);
max_val = max_val.max(val);
}
if (max_val - min_val).abs() < 1e-10 {
return Ok(Node::new_leaf(indices.len() as f64, None, None));
}
let threshold = rng.random_range(min_val..max_val);
let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
.iter()
.partition(|&&idx| x[[idx, feature_index]] < threshold);
if left_indices.is_empty() || right_indices.is_empty() {
return Ok(Node::new_leaf(indices.len() as f64, None, None));
}
let mut node = Node::new_internal(feature_index, threshold);
node.left = Some(Box::new(self.build_isolation_tree(
x,
&left_indices,
current_depth + 1,
rng,
)?));
node.right = Some(Box::new(self.build_isolation_tree(
x,
&right_indices,
current_depth + 1,
rng,
)?));
Ok(node)
}
fn path_length(&self, sample: &[f64], node: &Node, current_depth: usize) -> f64 {
match &node.node_type {
NodeType::Leaf { value, .. } => {
current_depth as f64 + average_path_length_factor(*value as usize)
}
NodeType::Internal {
feature_index,
threshold,
..
} => {
if sample[*feature_index] < *threshold {
if let Some(ref left) = node.left {
self.path_length(sample, left, current_depth + 1)
} else {
current_depth as f64
}
} else {
if let Some(ref right) = node.right {
self.path_length(sample, right, current_depth + 1)
} else {
current_depth as f64
}
}
}
}
}
pub fn anomaly_score(&self, sample: &[f64]) -> Result<f64, ModelError> {
if self.trees.is_none() {
return Err(ModelError::NotFitted);
}
if sample.len() != self.n_features {
return Err(ModelError::InputValidationError(
"Sample feature dimension mismatch".to_string(),
));
}
let trees = self.trees.as_ref().unwrap();
let avg_path_length: f64 = trees
.iter()
.map(|tree| self.path_length(sample, tree, 0))
.sum::<f64>()
/ trees.len() as f64;
let c_n = average_path_length_factor(self.max_samples);
let score = 2.0_f64.powf(-avg_path_length / c_n);
Ok(score)
}
pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64>,
{
if self.trees.is_none() {
return Err(ModelError::NotFitted);
}
preliminary_check(x, None)?;
if x.ncols() != self.n_features {
return Err(ModelError::InputValidationError(format!(
"Feature dimension mismatch: expected {} features, got {}",
self.n_features,
x.ncols()
)));
}
let scores: Result<Vec<f64>, ModelError> =
if x.nrows() >= DEFAULT_PARALLEL_THRESHOLD_SAMPLES {
x.axis_iter(Axis(0))
.into_par_iter()
.map(|row| self.anomaly_score(row.as_slice().unwrap()))
.collect()
} else {
x.axis_iter(Axis(0))
.map(|row| self.anomaly_score(row.as_slice().unwrap()))
.collect()
};
Ok(Array1::from_vec(scores?))
}
pub fn fit_predict<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
self.fit(x)?;
self.predict(x)
}
model_save_and_load_methods!(IsolationForest);
}