use rand::{Rng, RngExt};
use crate::domain::bounding_box::BoundingBox;
use crate::error::{RcfError, RcfResult};
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Cut {
dim: usize,
value: f64,
}
impl Cut {
#[must_use]
pub fn new(dim: usize, value: f64) -> Self {
Self { dim, value }
}
#[must_use]
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
#[inline]
pub fn value(&self) -> f64 {
self.value
}
#[must_use]
#[inline]
pub fn left_of(&self, point: &[f64]) -> bool {
point[self.dim] <= self.value
}
pub fn random_cut<const D: usize, R: Rng + ?Sized>(
bbox: &BoundingBox<D>,
rng: &mut R,
) -> RcfResult<Self> {
let total = bbox.range_sum();
if total <= 0.0 {
return Err(RcfError::EmptyBoundingBox);
}
let mut target = rng.random::<f64>() * total;
let mut chosen = 0_usize;
for d in 0..bbox.dim() {
let r = bbox.range_at(d);
if target < r {
chosen = d;
break;
}
target -= r;
chosen = d;
}
let lo = bbox.min()[chosen];
let hi = bbox.max()[chosen];
let value = if (hi - lo).abs() < f64::EPSILON {
lo
} else {
lo + rng.random::<f64>() * (hi - lo)
};
Ok(Self { dim: chosen, value })
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
fn unit_box<const D: usize>() -> BoundingBox<D> {
let mut b = BoundingBox::<D>::from_point(&vec![0.0; D]).unwrap();
b.extend(&vec![1.0; D]).unwrap();
b
}
#[test]
fn left_of_strictly_below_is_left() {
let cut = Cut::new(0, 0.5);
assert!(cut.left_of(&[0.4, 9.9]));
}
#[test]
fn left_of_at_value_is_left() {
let cut = Cut::new(1, 2.0);
assert!(cut.left_of(&[1.0, 2.0]));
}
#[test]
fn left_of_strictly_above_is_right() {
let cut = Cut::new(0, 0.5);
assert!(!cut.left_of(&[0.6, 9.9]));
}
#[test]
fn random_cut_is_in_range() {
let mut rng = ChaCha8Rng::seed_from_u64(1);
let bbox: BoundingBox<3> = unit_box();
for _ in 0..100 {
let cut = Cut::random_cut(&bbox, &mut rng).unwrap();
assert!(cut.dim() < bbox.dim());
assert!(cut.value() >= bbox.min()[cut.dim()]);
assert!(cut.value() <= bbox.max()[cut.dim()]);
}
}
#[test]
fn random_cut_degenerate_box_fails() {
let mut rng = ChaCha8Rng::seed_from_u64(1);
let bbox = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
let err = Cut::random_cut(&bbox, &mut rng).unwrap_err();
assert!(matches!(err, RcfError::EmptyBoundingBox));
}
#[test]
fn random_cut_dim_distribution_proportional_to_range() {
let mut bbox = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
bbox.extend(&[1.0, 9.0]).unwrap();
let mut rng = ChaCha8Rng::seed_from_u64(7);
let mut counts = [0_u32; 2];
let trials = 5000;
for _ in 0..trials {
let cut = Cut::random_cut(&bbox, &mut rng).unwrap();
counts[cut.dim()] += 1;
}
let p1 = f64::from(counts[1]) / f64::from(trials);
assert!(
(0.87..=0.93).contains(&p1),
"dim-1 share = {p1} outside [0.87, 0.93]"
);
}
#[test]
fn random_cut_deterministic_for_same_seed() {
let bbox: BoundingBox<4> = unit_box();
let mut rng_a = ChaCha8Rng::seed_from_u64(42);
let mut rng_b = ChaCha8Rng::seed_from_u64(42);
for _ in 0..20 {
let a = Cut::random_cut(&bbox, &mut rng_a).unwrap();
let b = Cut::random_cut(&bbox, &mut rng_b).unwrap();
assert_eq!(a, b);
}
}
#[test]
fn cut_constructor_accessors() {
let c = Cut::new(7, 1.25);
assert_eq!(c.dim(), 7);
assert!((c.value() - 1.25).abs() < f64::EPSILON);
}
}