1use wide::f64x4;
25
26use crate::domain::cut::Cut;
27use crate::error::{RcfError, RcfResult};
28
29#[derive(Debug, Clone, PartialEq)]
43#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
44pub struct BoundingBox<const D: usize> {
45 #[cfg_attr(feature = "serde", serde(with = "crate::serde_util::fixed_array_f64"))]
49 min: [f64; D],
50 #[cfg_attr(feature = "serde", serde(with = "crate::serde_util::fixed_array_f64"))]
52 max: [f64; D],
53}
54
55impl<const D: usize> BoundingBox<D> {
56 pub fn from_point(point: &[f64]) -> RcfResult<Self> {
63 if D == 0 {
64 return Err(RcfError::EmptyBoundingBox);
65 }
66 if point.len() != D {
67 return Err(RcfError::DimensionMismatch {
68 expected: D,
69 got: point.len(),
70 });
71 }
72 let mut min = [0.0_f64; D];
73 let mut max = [0.0_f64; D];
74 min.copy_from_slice(point);
75 max.copy_from_slice(point);
76 Ok(Self { min, max })
77 }
78
79 #[must_use]
81 #[inline]
82 pub const fn dim(&self) -> usize {
83 D
84 }
85
86 #[must_use]
88 #[inline]
89 pub fn min(&self) -> &[f64; D] {
90 &self.min
91 }
92
93 #[must_use]
95 #[inline]
96 pub fn max(&self) -> &[f64; D] {
97 &self.max
98 }
99
100 #[must_use]
107 #[inline]
108 pub fn range_at(&self, d: usize) -> f64 {
109 self.max[d] - self.min[d]
110 }
111
112 #[must_use]
121 #[inline]
122 pub fn range_sum(&self) -> f64 {
123 let chunks = D / 4;
124 let mut acc_simd = f64x4::splat(0.0);
125 for i in 0..chunks {
126 let off = i * 4;
127 let mn = f64x4::from([
128 self.min[off],
129 self.min[off + 1],
130 self.min[off + 2],
131 self.min[off + 3],
132 ]);
133 let mx = f64x4::from([
134 self.max[off],
135 self.max[off + 1],
136 self.max[off + 2],
137 self.max[off + 3],
138 ]);
139 acc_simd += mx - mn;
140 }
141 let mut s = acc_simd.reduce_add();
142 for d in (chunks * 4)..D {
143 s += self.max[d] - self.min[d];
144 }
145 s
146 }
147
148 pub fn extend(&mut self, point: &[f64]) -> RcfResult<()> {
154 if point.len() != D {
155 return Err(RcfError::DimensionMismatch {
156 expected: D,
157 got: point.len(),
158 });
159 }
160 for (d, &v) in point.iter().enumerate() {
161 if v < self.min[d] {
162 self.min[d] = v;
163 }
164 if v > self.max[d] {
165 self.max[d] = v;
166 }
167 }
168 Ok(())
169 }
170
171 pub fn merge_with(&mut self, other: &Self) {
174 for d in 0..D {
175 if other.min[d] < self.min[d] {
176 self.min[d] = other.min[d];
177 }
178 if other.max[d] > self.max[d] {
179 self.max[d] = other.max[d];
180 }
181 }
182 }
183
184 #[must_use]
186 pub fn merged(&self, other: &Self) -> Self {
187 let mut out = self.clone();
188 out.merge_with(other);
189 out
190 }
191
192 pub fn extension_per_dim(&self, point: &[f64]) -> RcfResult<[f64; D]> {
202 if point.len() != D {
203 return Err(RcfError::DimensionMismatch {
204 expected: D,
205 got: point.len(),
206 });
207 }
208 let mut out = [0.0_f64; D];
209 for d in 0..D {
210 let above = point[d] - self.max[d];
211 let below = self.min[d] - point[d];
212 let mut delta = 0.0;
213 if above > 0.0 {
214 delta += above;
215 }
216 if below > 0.0 {
217 delta += below;
218 }
219 out[d] = delta;
220 }
221 Ok(out)
222 }
223
224 pub fn probability_of_cut(&self, point: &[f64]) -> RcfResult<(f64, [f64; D])> {
234 let extension = self.extension_per_dim(point)?;
235 let extension_sum: f64 = extension.iter().sum();
236 let denom = self.range_sum() + extension_sum;
237 if denom == 0.0 {
238 return Ok((0.0, [0.0; D]));
239 }
240 let mut per_dim = [0.0_f64; D];
241 for d in 0..D {
242 per_dim[d] = extension[d] / denom;
243 }
244 let total: f64 = per_dim.iter().sum();
245 Ok((total, per_dim))
246 }
247
248 pub fn per_dim_cut_probabilities(&self, point: &[f64]) -> RcfResult<[f64; D]> {
255 Ok(self.probability_of_cut(point)?.1)
256 }
257
258 #[inline]
265 #[must_use]
266 pub fn augmented_range_at(&self, d: usize, point: &[f64]) -> f64 {
267 let lo = self.min[d].min(point[d]);
268 let hi = self.max[d].max(point[d]);
269 hi - lo
270 }
271
272 #[inline]
279 #[must_use]
280 pub fn augmented_range_sum(&self, point: &[f64]) -> f64 {
281 let chunks = D / 4;
282 let mut acc_simd = f64x4::splat(0.0);
283 for i in 0..chunks {
284 let off = i * 4;
285 let p = f64x4::from([point[off], point[off + 1], point[off + 2], point[off + 3]]);
286 let mn = f64x4::from([
287 self.min[off],
288 self.min[off + 1],
289 self.min[off + 2],
290 self.min[off + 3],
291 ]);
292 let mx = f64x4::from([
293 self.max[off],
294 self.max[off + 1],
295 self.max[off + 2],
296 self.max[off + 3],
297 ]);
298 let lo = mn.fast_min(p);
299 let hi = mx.fast_max(p);
300 acc_simd += hi - lo;
301 }
302 let mut s = acc_simd.reduce_add();
303 let tail_start = chunks * 4;
304 for ((&p, &mn), &mx) in point[tail_start..D]
305 .iter()
306 .zip(self.min[tail_start..D].iter())
307 .zip(self.max[tail_start..D].iter())
308 {
309 let lo = mn.min(p);
310 let hi = mx.max(p);
311 s += hi - lo;
312 }
313 s
314 }
315
316 #[inline]
324 pub fn augmented_random_cut<R: rand::Rng + ?Sized>(
325 &self,
326 point: &[f64],
327 rng: &mut R,
328 ) -> RcfResult<Cut> {
329 let total = self.augmented_range_sum(point);
330 if total <= 0.0 {
331 return Err(RcfError::EmptyBoundingBox);
332 }
333 let mut target = rand::RngExt::random::<f64>(rng) * total;
334 let mut chosen = 0_usize;
335 for d in 0..D {
336 let r = self.augmented_range_at(d, point);
337 if target < r {
338 chosen = d;
339 break;
340 }
341 target -= r;
342 chosen = d;
343 }
344 let lo = self.min[chosen].min(point[chosen]);
345 let hi = self.max[chosen].max(point[chosen]);
346 let value = if (hi - lo).abs() < f64::EPSILON {
347 lo
348 } else {
349 lo + rand::RngExt::random::<f64>(rng) * (hi - lo)
350 };
351 Ok(Cut::new(chosen, value))
352 }
353
354 pub fn total_probability_of_cut(&self, point: &[f64]) -> RcfResult<f64> {
366 if point.len() != D {
367 return Err(RcfError::DimensionMismatch {
368 expected: D,
369 got: point.len(),
370 });
371 }
372 let chunks = D / 4;
373 let zero = f64x4::splat(0.0);
374 let mut range_acc = f64x4::splat(0.0);
375 let mut ext_acc = f64x4::splat(0.0);
376 for i in 0..chunks {
377 let off = i * 4;
378 let p = f64x4::from([point[off], point[off + 1], point[off + 2], point[off + 3]]);
379 let mn = f64x4::from([
380 self.min[off],
381 self.min[off + 1],
382 self.min[off + 2],
383 self.min[off + 3],
384 ]);
385 let mx = f64x4::from([
386 self.max[off],
387 self.max[off + 1],
388 self.max[off + 2],
389 self.max[off + 3],
390 ]);
391 range_acc += mx - mn;
392 let above = (p - mx).fast_max(zero);
393 let below = (mn - p).fast_max(zero);
394 ext_acc += above + below;
395 }
396 let mut range_sum = range_acc.reduce_add();
397 let mut extension_sum = ext_acc.reduce_add();
398 let tail_start = chunks * 4;
399 for ((&p, &mn), &mx) in point[tail_start..D]
400 .iter()
401 .zip(self.min[tail_start..D].iter())
402 .zip(self.max[tail_start..D].iter())
403 {
404 range_sum += mx - mn;
405 let above = p - mx;
406 let below = mn - p;
407 if above > 0.0 {
408 extension_sum += above;
409 }
410 if below > 0.0 {
411 extension_sum += below;
412 }
413 }
414 let denom = range_sum + extension_sum;
415 if denom == 0.0 {
416 return Ok(0.0);
417 }
418 Ok(extension_sum / denom)
419 }
420}
421
422#[cfg(test)]
423#[allow(clippy::float_cmp)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn from_point_creates_degenerate_box() {
429 let b = BoundingBox::<3>::from_point(&[1.0, 2.0, 3.0]).unwrap();
430 assert_eq!(b.dim(), 3);
431 assert_eq!(b.min(), &[1.0, 2.0, 3.0]);
432 assert_eq!(b.max(), &[1.0, 2.0, 3.0]);
433 assert_eq!(b.range_sum(), 0.0);
434 }
435
436 #[test]
437 fn from_point_rejects_zero_dim() {
438 assert!(matches!(
439 BoundingBox::<0>::from_point(&[]).unwrap_err(),
440 RcfError::EmptyBoundingBox
441 ));
442 }
443
444 #[test]
445 fn from_point_rejects_dim_mismatch() {
446 assert!(matches!(
447 BoundingBox::<3>::from_point(&[1.0, 2.0]).unwrap_err(),
448 RcfError::DimensionMismatch { .. }
449 ));
450 }
451
452 #[test]
453 fn extend_grows_box() {
454 let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
455 b.extend(&[3.0, -2.0]).unwrap();
456 assert_eq!(b.min(), &[0.0, -2.0]);
457 assert_eq!(b.max(), &[3.0, 0.0]);
458 assert!((b.range_sum() - 5.0).abs() < 1e-12);
459 }
460
461 #[test]
462 fn extend_rejects_dim_mismatch() {
463 let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
464 assert!(matches!(
465 b.extend(&[1.0, 2.0, 3.0]).unwrap_err(),
466 RcfError::DimensionMismatch { .. }
467 ));
468 }
469
470 #[test]
471 fn range_at_per_dim() {
472 let mut b = BoundingBox::<3>::from_point(&[0.0, 0.0, 0.0]).unwrap();
473 b.extend(&[2.0, 4.0, 8.0]).unwrap();
474 assert_eq!(b.range_at(0), 2.0);
475 assert_eq!(b.range_at(1), 4.0);
476 assert_eq!(b.range_at(2), 8.0);
477 assert_eq!(b.range_sum(), 14.0);
478 }
479
480 #[test]
481 fn merge_with_unions_corners() {
482 let mut a = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
483 a.extend(&[2.0, 2.0]).unwrap();
484 let mut b = BoundingBox::<2>::from_point(&[-1.0, 1.0]).unwrap();
485 b.extend(&[1.0, 5.0]).unwrap();
486 a.merge_with(&b);
487 assert_eq!(a.min(), &[-1.0, 0.0]);
488 assert_eq!(a.max(), &[2.0, 5.0]);
489 }
490
491 #[test]
492 fn merged_returns_new_box() {
493 let a = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
494 let b = BoundingBox::<2>::from_point(&[5.0, 5.0]).unwrap();
495 let union = a.merged(&b);
496 assert_eq!(union.min(), &[0.0, 0.0]);
497 assert_eq!(union.max(), &[5.0, 5.0]);
498 assert_eq!(a.min(), &[0.0, 0.0]);
499 assert_eq!(b.max(), &[5.0, 5.0]);
500 }
501
502 #[test]
503 fn extension_zero_when_point_inside() {
504 let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
505 b.extend(&[10.0, 10.0]).unwrap();
506 let ext = b.extension_per_dim(&[5.0, 5.0]).unwrap();
507 assert_eq!(ext, [0.0, 0.0]);
508 }
509
510 #[test]
511 fn extension_picks_above_and_below() {
512 let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
513 b.extend(&[10.0, 10.0]).unwrap();
514 let ext = b.extension_per_dim(&[-3.0, 15.0]).unwrap();
515 assert_eq!(ext, [3.0, 5.0]);
516 }
517
518 #[test]
519 fn probability_of_cut_zero_when_inside() {
520 let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
521 b.extend(&[10.0, 10.0]).unwrap();
522 let (p, per_dim) = b.probability_of_cut(&[5.0, 5.0]).unwrap();
523 assert_eq!(p, 0.0);
524 assert_eq!(per_dim, [0.0, 0.0]);
525 }
526
527 #[test]
528 fn probability_of_cut_concentrated_on_extending_dim() {
529 let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
530 b.extend(&[10.0, 10.0]).unwrap();
531 let (total, per_dim) = b.probability_of_cut(&[1000.0, 5.0]).unwrap();
532 assert!(per_dim[0] > per_dim[1]);
533 assert!((per_dim[0] + per_dim[1] - total).abs() < 1e-12);
534 }
535
536 #[test]
537 fn probability_of_cut_handles_degenerate_box() {
538 let b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
539 let (p, per_dim) = b.probability_of_cut(&[0.0, 0.0]).unwrap();
540 assert_eq!(p, 0.0);
541 assert_eq!(per_dim, [0.0, 0.0]);
542 }
543
544 #[test]
545 fn probability_of_cut_per_dim_sums_to_total() {
546 let mut b = BoundingBox::<3>::from_point(&[0.0, 0.0, 0.0]).unwrap();
547 b.extend(&[1.0, 1.0, 1.0]).unwrap();
548 let (total, per_dim) = b.probability_of_cut(&[5.0, -3.0, 0.5]).unwrap();
549 let sum: f64 = per_dim.iter().sum();
550 assert!((sum - total).abs() < 1e-12);
551 assert!(per_dim[0] > 0.0);
552 assert!(per_dim[1] > 0.0);
553 assert_eq!(per_dim[2], 0.0);
554 }
555
556 #[test]
557 fn probability_of_cut_rejects_dim_mismatch() {
558 let b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
559 assert!(matches!(
560 b.probability_of_cut(&[1.0]).unwrap_err(),
561 RcfError::DimensionMismatch { .. }
562 ));
563 }
564
565 #[test]
566 fn per_dim_cut_probabilities_matches_full_call() {
567 let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
568 b.extend(&[1.0, 1.0]).unwrap();
569 let (_, full) = b.probability_of_cut(&[5.0, -3.0]).unwrap();
570 let only_per_dim = b.per_dim_cut_probabilities(&[5.0, -3.0]).unwrap();
571 assert_eq!(full, only_per_dim);
572 }
573}