1use ferrolearn_core::error::FerroError;
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2};
33use num_traits::Float;
34use rand::SeedableRng;
35use rand::rngs::StdRng;
36use serde::{Deserialize, Serialize};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
44enum IsoNode<F> {
45 Split {
47 feature: usize,
49 threshold: F,
51 left: usize,
53 right: usize,
55 n_samples: usize,
57 },
58 Leaf {
60 n_samples: usize,
62 },
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct IsolationForest<F> {
80 pub n_estimators: usize,
82 pub max_samples: usize,
84 pub contamination: f64,
86 pub random_state: Option<u64>,
88 _marker: std::marker::PhantomData<F>,
89}
90
91impl<F: Float> IsolationForest<F> {
92 #[must_use]
97 pub fn new() -> Self {
98 Self {
99 n_estimators: 100,
100 max_samples: 256,
101 contamination: 0.1,
102 random_state: None,
103 _marker: std::marker::PhantomData,
104 }
105 }
106
107 #[must_use]
109 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
110 self.n_estimators = n_estimators;
111 self
112 }
113
114 #[must_use]
116 pub fn with_max_samples(mut self, max_samples: usize) -> Self {
117 self.max_samples = max_samples;
118 self
119 }
120
121 #[must_use]
123 pub fn with_contamination(mut self, contamination: f64) -> Self {
124 self.contamination = contamination;
125 self
126 }
127
128 #[must_use]
130 pub fn with_random_state(mut self, seed: u64) -> Self {
131 self.random_state = Some(seed);
132 self
133 }
134}
135
136impl<F: Float> Default for IsolationForest<F> {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[derive(Debug, Clone)]
151pub struct FittedIsolationForest<F> {
152 trees: Vec<Vec<IsoNode<F>>>,
154 n_features: usize,
156 threshold: f64,
158 max_samples: usize,
160}
161
162impl<F: Float + Send + Sync + 'static> FittedIsolationForest<F> {
163 #[must_use]
165 pub fn n_estimators(&self) -> usize {
166 self.trees.len()
167 }
168
169 #[must_use]
171 pub fn n_features(&self) -> usize {
172 self.n_features
173 }
174
175 #[must_use]
177 pub fn threshold(&self) -> f64 {
178 self.threshold
179 }
180
181 pub fn score_samples(&self, x: &Array2<F>) -> Result<Array1<f64>, FerroError> {
192 if x.ncols() != self.n_features {
193 return Err(FerroError::ShapeMismatch {
194 expected: vec![self.n_features],
195 actual: vec![x.ncols()],
196 context: "number of features must match fitted model".into(),
197 });
198 }
199
200 let n_samples = x.nrows();
201 let c_n = average_path_length(self.max_samples);
202 let n_trees = self.trees.len() as f64;
203 let mut scores = Array1::zeros(n_samples);
204
205 for i in 0..n_samples {
206 let row = x.row(i);
207 let mut total_path = 0.0;
208 for tree_nodes in &self.trees {
209 total_path += path_length(tree_nodes, &row);
210 }
211 let mean_path = total_path / n_trees;
212 scores[i] = f64::powf(2.0, -mean_path / c_n);
213 }
214
215 Ok(scores)
216 }
217}
218
219impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for IsolationForest<F> {
220 type Fitted = FittedIsolationForest<F>;
221 type Error = FerroError;
222
223 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedIsolationForest<F>, FerroError> {
230 let (n_samples, n_features) = x.dim();
231
232 if n_samples == 0 {
233 return Err(FerroError::InsufficientSamples {
234 required: 1,
235 actual: 0,
236 context: "IsolationForest requires at least one sample".into(),
237 });
238 }
239 if self.n_estimators == 0 {
240 return Err(FerroError::InvalidParameter {
241 name: "n_estimators".into(),
242 reason: "must be at least 1".into(),
243 });
244 }
245 if self.max_samples == 0 {
246 return Err(FerroError::InvalidParameter {
247 name: "max_samples".into(),
248 reason: "must be at least 1".into(),
249 });
250 }
251 if !(0.0..=0.5).contains(&self.contamination) {
252 return Err(FerroError::InvalidParameter {
253 name: "contamination".into(),
254 reason: "must be in [0.0, 0.5]".into(),
255 });
256 }
257
258 let effective_max_samples = self.max_samples.min(n_samples);
259 let max_depth = (effective_max_samples as f64).log2().ceil() as usize;
260
261 let mut rng = if let Some(seed) = self.random_state {
262 StdRng::seed_from_u64(seed)
263 } else {
264 StdRng::from_os_rng()
265 };
266
267 let mut trees = Vec::with_capacity(self.n_estimators);
268 for _ in 0..self.n_estimators {
269 let sample_indices: Vec<usize> = (0..effective_max_samples)
271 .map(|_| {
272 use rand::RngCore;
273 (rng.next_u64() as usize) % n_samples
274 })
275 .collect();
276
277 let mut nodes = Vec::new();
278 let indices: Vec<usize> = (0..sample_indices.len()).collect();
279 build_isolation_tree(
281 x,
282 &sample_indices,
283 &indices,
284 &mut nodes,
285 0,
286 max_depth,
287 n_features,
288 &mut rng,
289 );
290 trees.push(nodes);
291 }
292
293 let fitted_no_threshold = FittedIsolationForest {
295 trees,
296 n_features,
297 threshold: 0.0,
298 max_samples: effective_max_samples,
299 };
300
301 let train_scores = fitted_no_threshold.score_samples(x)?;
302
303 let mut sorted_scores: Vec<f64> = train_scores.iter().copied().collect();
305 sorted_scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
306
307 let contamination_idx = ((self.contamination * n_samples as f64).ceil() as usize)
308 .max(1)
309 .min(n_samples);
310 let threshold = if contamination_idx < sorted_scores.len() {
311 sorted_scores[contamination_idx - 1]
312 } else {
313 sorted_scores[sorted_scores.len() - 1]
314 };
315
316 Ok(FittedIsolationForest {
317 trees: fitted_no_threshold.trees,
318 n_features,
319 threshold,
320 max_samples: effective_max_samples,
321 })
322 }
323}
324
325impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedIsolationForest<F> {
326 type Output = Array1<isize>;
327 type Error = FerroError;
328
329 fn predict(&self, x: &Array2<F>) -> Result<Array1<isize>, FerroError> {
338 let scores = self.score_samples(x)?;
339 let predictions = scores.mapv(|s| if s >= self.threshold { -1 } else { 1 });
340 Ok(predictions)
341 }
342}
343
344fn average_path_length(n: usize) -> f64 {
353 if n <= 1 {
354 return 0.0;
355 }
356 let n_f = n as f64;
357 2.0 * ((n_f - 1.0).ln() + 0.5772156649015329) - 2.0 * (n_f - 1.0) / n_f
359}
360
361fn path_length<F: Float>(nodes: &[IsoNode<F>], sample: &ndarray::ArrayView1<F>) -> f64 {
363 let mut idx = 0;
364 let mut depth: f64 = 0.0;
365 loop {
366 match &nodes[idx] {
367 IsoNode::Split {
368 feature,
369 threshold,
370 left,
371 right,
372 ..
373 } => {
374 if sample[*feature] <= *threshold {
375 idx = *left;
376 } else {
377 idx = *right;
378 }
379 depth += 1.0;
380 }
381 IsoNode::Leaf { n_samples } => {
382 return depth + average_path_length(*n_samples);
384 }
385 }
386 }
387}
388
389fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
391 use rand::RngCore;
392 let u = (rng.next_u64() as f64) / (u64::MAX as f64);
393 let range = max_val - min_val;
394 min_val + F::from(u).unwrap() * range
395}
396
397#[allow(clippy::too_many_arguments)]
402fn build_isolation_tree<F: Float>(
403 x: &Array2<F>,
404 sample_indices: &[usize],
405 indices: &[usize],
406 nodes: &mut Vec<IsoNode<F>>,
407 depth: usize,
408 max_depth: usize,
409 n_features: usize,
410 rng: &mut StdRng,
411) -> usize {
412 let n = indices.len();
413
414 if n <= 1 || depth >= max_depth {
416 let idx = nodes.len();
417 nodes.push(IsoNode::Leaf { n_samples: n });
418 return idx;
419 }
420
421 let max_attempts = n_features * 2;
423 for _ in 0..max_attempts {
424 use rand::RngCore;
425 let feature = (rng.next_u64() as usize) % n_features;
426
427 let mut min_val = x[[sample_indices[indices[0]], feature]];
429 let mut max_val = min_val;
430 for &i in &indices[1..] {
431 let v = x[[sample_indices[i], feature]];
432 if v < min_val {
433 min_val = v;
434 }
435 if v > max_val {
436 max_val = v;
437 }
438 }
439
440 if min_val >= max_val {
442 continue;
443 }
444
445 let threshold = random_threshold(rng, min_val, max_val);
446
447 let mut left_indices = Vec::new();
449 let mut right_indices = Vec::new();
450 for &i in indices {
451 if x[[sample_indices[i], feature]] <= threshold {
452 left_indices.push(i);
453 } else {
454 right_indices.push(i);
455 }
456 }
457
458 if left_indices.is_empty() || right_indices.is_empty() {
460 continue;
461 }
462
463 let node_idx = nodes.len();
465 nodes.push(IsoNode::Leaf { n_samples: 0 }); let left_child = build_isolation_tree(
468 x,
469 sample_indices,
470 &left_indices,
471 nodes,
472 depth + 1,
473 max_depth,
474 n_features,
475 rng,
476 );
477 let right_child = build_isolation_tree(
478 x,
479 sample_indices,
480 &right_indices,
481 nodes,
482 depth + 1,
483 max_depth,
484 n_features,
485 rng,
486 );
487
488 nodes[node_idx] = IsoNode::Split {
489 feature,
490 threshold,
491 left: left_child,
492 right: right_child,
493 n_samples: n,
494 };
495
496 return node_idx;
497 }
498
499 let idx = nodes.len();
501 nodes.push(IsoNode::Leaf { n_samples: n });
502 idx
503}
504
505#[cfg(test)]
510mod tests {
511 use super::*;
512 use ndarray::Array2;
513
514 fn make_normal_data() -> Array2<f64> {
515 Array2::from_shape_vec(
517 (10, 2),
518 vec![
519 4.5, 4.8, 5.1, 5.2, 4.9, 5.0, 5.3, 4.7, 4.8, 5.1, 5.0, 5.3, 5.2, 4.9, 4.7, 5.0,
520 5.1, 4.8, 4.9, 5.2,
521 ],
522 )
523 .unwrap()
524 }
525
526 fn make_data_with_anomaly() -> Array2<f64> {
527 Array2::from_shape_vec(
529 (10, 2),
530 vec![
531 4.5, 4.8, 5.1, 5.2, 4.9, 5.0, 5.3, 4.7, 4.8, 5.1, 5.0, 5.3, 5.2, 4.9, 4.7, 5.0,
532 5.1, 4.8, 100.0, 100.0,
533 ],
534 )
535 .unwrap()
536 }
537
538 #[test]
539 fn test_isolation_forest_default() {
540 let model = IsolationForest::<f64>::new();
541 assert_eq!(model.n_estimators, 100);
542 assert_eq!(model.max_samples, 256);
543 assert!((model.contamination - 0.1).abs() < 1e-10);
544 assert!(model.random_state.is_none());
545 }
546
547 #[test]
548 fn test_isolation_forest_builder() {
549 let model = IsolationForest::<f64>::new()
550 .with_n_estimators(50)
551 .with_max_samples(128)
552 .with_contamination(0.05)
553 .with_random_state(123);
554 assert_eq!(model.n_estimators, 50);
555 assert_eq!(model.max_samples, 128);
556 assert!((model.contamination - 0.05).abs() < 1e-10);
557 assert_eq!(model.random_state, Some(123));
558 }
559
560 #[test]
561 fn test_fit_predict_basic() {
562 let x = make_normal_data();
563 let model = IsolationForest::<f64>::new()
564 .with_n_estimators(50)
565 .with_random_state(42);
566 let fitted = model.fit(&x, &()).unwrap();
567 let preds = fitted.predict(&x).unwrap();
568 assert_eq!(preds.len(), 10);
569 assert!(preds.iter().all(|&v| v == 1 || v == -1));
571 }
572
573 #[test]
574 fn test_anomaly_detected() {
575 let x = make_data_with_anomaly();
576 let model = IsolationForest::<f64>::new()
577 .with_n_estimators(200)
578 .with_contamination(0.15)
579 .with_random_state(42);
580 let fitted = model.fit(&x, &()).unwrap();
581 let preds = fitted.predict(&x).unwrap();
582
583 assert_eq!(preds[9], -1, "outlier should be detected as anomaly");
585 }
586
587 #[test]
588 fn test_anomaly_scores() {
589 let x = make_data_with_anomaly();
590 let model = IsolationForest::<f64>::new()
591 .with_n_estimators(200)
592 .with_random_state(42);
593 let fitted = model.fit(&x, &()).unwrap();
594 let scores = fitted.score_samples(&x).unwrap();
595
596 assert_eq!(scores.len(), 10);
597 let anomaly_score = scores[9];
599 let max_normal_score = scores.iter().take(9).copied().fold(0.0_f64, f64::max);
600 assert!(
601 anomaly_score > max_normal_score,
602 "anomaly score ({anomaly_score}) should be greater than max normal score ({max_normal_score})"
603 );
604 }
605
606 #[test]
607 fn test_empty_input_error() {
608 let x = Array2::<f64>::zeros((0, 2));
609 let model = IsolationForest::<f64>::new();
610 let result = model.fit(&x, &());
611 assert!(result.is_err());
612 }
613
614 #[test]
615 fn test_zero_estimators_error() {
616 let x = make_normal_data();
617 let model = IsolationForest::<f64>::new().with_n_estimators(0);
618 let result = model.fit(&x, &());
619 assert!(result.is_err());
620 }
621
622 #[test]
623 fn test_zero_max_samples_error() {
624 let x = make_normal_data();
625 let model = IsolationForest::<f64>::new().with_max_samples(0);
626 let result = model.fit(&x, &());
627 assert!(result.is_err());
628 }
629
630 #[test]
631 fn test_invalid_contamination_error() {
632 let x = make_normal_data();
633 let model = IsolationForest::<f64>::new().with_contamination(0.6);
634 let result = model.fit(&x, &());
635 assert!(result.is_err());
636 }
637
638 #[test]
639 fn test_predict_shape_mismatch() {
640 let x_train = make_normal_data();
641 let model = IsolationForest::<f64>::new()
642 .with_n_estimators(10)
643 .with_random_state(42);
644 let fitted = model.fit(&x_train, &()).unwrap();
645
646 let x_test = Array2::<f64>::zeros((3, 5)); let result = fitted.predict(&x_test);
648 assert!(result.is_err());
649 }
650
651 #[test]
652 fn test_score_shape_mismatch() {
653 let x_train = make_normal_data();
654 let model = IsolationForest::<f64>::new()
655 .with_n_estimators(10)
656 .with_random_state(42);
657 let fitted = model.fit(&x_train, &()).unwrap();
658
659 let x_test = Array2::<f64>::zeros((3, 5));
660 let result = fitted.score_samples(&x_test);
661 assert!(result.is_err());
662 }
663
664 #[test]
665 fn test_average_path_length_values() {
666 assert!((average_path_length(1) - 0.0).abs() < 1e-10);
667 let c2 = average_path_length(2);
669 assert!(c2 > 0.0 && c2 < 1.0, "c(2) = {c2}");
670 let c256 = average_path_length(256);
672 assert!(c256 > 5.0 && c256 < 15.0, "c(256) = {c256}");
673 }
674
675 #[test]
676 fn test_reproducibility() {
677 let x = make_data_with_anomaly();
678 let model = IsolationForest::<f64>::new()
679 .with_n_estimators(50)
680 .with_random_state(999);
681
682 let fitted1 = model.fit(&x, &()).unwrap();
683 let preds1 = fitted1.predict(&x).unwrap();
684
685 let fitted2 = model.fit(&x, &()).unwrap();
686 let preds2 = fitted2.predict(&x).unwrap();
687
688 assert_eq!(preds1, preds2);
689 }
690
691 #[test]
692 fn test_max_samples_larger_than_data() {
693 let x = make_normal_data();
695 let model = IsolationForest::<f64>::new()
696 .with_n_estimators(10)
697 .with_max_samples(10000)
698 .with_random_state(42);
699 let fitted = model.fit(&x, &()).unwrap();
700 let preds = fitted.predict(&x).unwrap();
701 assert_eq!(preds.len(), 10);
702 }
703
704 #[test]
705 fn test_f32() {
706 let x = Array2::<f32>::from_shape_vec(
707 (6, 2),
708 vec![
709 1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 100.0, 100.0,
710 ],
711 )
712 .unwrap();
713 let model = IsolationForest::<f32>::new()
714 .with_n_estimators(50)
715 .with_contamination(0.2)
716 .with_random_state(42);
717 let fitted = model.fit(&x, &()).unwrap();
718 let preds = fitted.predict(&x).unwrap();
719 assert_eq!(preds.len(), 6);
720 assert!(preds.iter().all(|&v| v == 1 || v == -1));
721 }
722
723 #[test]
724 fn test_single_sample() {
725 let x = Array2::<f64>::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
726 let model = IsolationForest::<f64>::new()
727 .with_n_estimators(10)
728 .with_contamination(0.0)
729 .with_random_state(42);
730 let fitted = model.fit(&x, &()).unwrap();
731 let preds = fitted.predict(&x).unwrap();
732 assert_eq!(preds.len(), 1);
733 }
734
735 #[test]
736 fn test_fitted_accessors() {
737 let x = make_normal_data();
738 let model = IsolationForest::<f64>::new()
739 .with_n_estimators(10)
740 .with_random_state(42);
741 let fitted = model.fit(&x, &()).unwrap();
742 assert_eq!(fitted.n_estimators(), 10);
743 assert_eq!(fitted.n_features(), 2);
744 assert!(fitted.threshold() >= 0.0);
745 }
746}