1use crate::primitives::nvector::{NVector, to_tensor};
79use burn::prelude::*;
80use rand::Rng;
81use serde_with::serde_as;
82
83#[serde_as]
85#[derive(Debug, Copy, Clone, PartialEq)]
86pub struct BoundingBox<const N: usize> {
87 min: NVector<N>,
88 max: NVector<N>,
89}
90
91impl<const N: usize> BoundingBox<N> {
92 pub fn new(min: NVector<N>, max: NVector<N>) -> Self {
94 Self { min, max }
95 }
96
97 pub fn min(&self) -> &NVector<N> {
99 &self.min
100 }
101
102 pub fn max(&self) -> &NVector<N> {
104 &self.max
105 }
106
107 pub fn volume(&self) -> f32 {
109 self.min
110 .iter()
111 .zip(self.max.iter())
112 .map(|(min, max)| (max - min).abs())
113 .product()
114 }
115
116 pub fn center(&self) -> NVector<N> {
118 let mut center = [0.0; N];
119 for (i, val) in center.iter_mut().enumerate().take(N) {
120 *val = (self.min[i] + self.max[i]) / 2.0;
121 }
122 center
123 }
124
125 pub fn contains(&self, point: &NVector<N>) -> bool {
127 point
128 .iter()
129 .zip(self.min().iter())
130 .zip(self.max().iter())
131 .all(|((p, min), max)| p >= min && p <= max)
132 }
133
134 pub fn contains_tensor<B: Backend>(&self, points: Tensor<B, 2>) -> Tensor<B, 1, Bool> {
136 let min = Tensor::<B, 1>::from_data(self.min, &points.device());
137 let max = Tensor::<B, 1>::from_data(self.max, &points.device());
138 let gt = points.clone().greater_equal(min.unsqueeze()).int();
139 let lt = points.clone().lower_equal(max.unsqueeze()).int();
140 gt.mul(lt).sum_dim(1).equal_elem(N as i32).squeeze(1)
141 }
142
143 pub fn corners(&self) -> impl Iterator<Item = NVector<N>> {
145 let mut corners = Vec::with_capacity(1 << N);
146 let min = self.min();
147 let max = self.max();
148 for i in 0..1 << N {
149 let mut corner = [0.0; N];
150 for (j, corner_j) in corner.iter_mut().enumerate().take(N) {
151 *corner_j = if i & (1 << j) == 0 { min[j] } else { max[j] };
152 }
153 corners.push(corner);
154 }
155 corners.into_iter()
156 }
157
158 pub fn sample(&self, rng: &mut impl Rng) -> NVector<N> {
160 let mut point = [0.0; N];
161 let min = self.min();
162 let max = self.max();
163 for (i, val) in point.iter_mut().enumerate().take(N) {
164 *val = min[i] + (max[i] - min[i]) * rng.random::<f32>();
165 }
166 point
167 }
168
169 pub fn sample_on_device<B: Backend>(
171 &self,
172 num_samples: usize,
173 rng: &mut impl Rng,
174 device: &B::Device,
175 ) -> Tensor<B, 2, Float> {
176 let scalars: Vec<f32> = rng.random_iter::<f32>().take(num_samples * N).collect();
178 let scalars =
179 Tensor::<B, 1>::from_data(scalars.as_slice(), device).reshape([num_samples, N]);
180
181 let min_tensor = Tensor::<B, 1>::from_data(self.min, device).unsqueeze();
183 let max_tensor = Tensor::<B, 1>::from_data(self.max, device).unsqueeze();
184 let range = max_tensor.clone().sub(min_tensor.clone());
185
186 min_tensor.add(scalars.mul(range))
187 }
188
189 pub fn split(&self) -> (BoundingBox<N>, BoundingBox<N>) {
191 let mut longest_axis_idx = 0;
192 let mut longest_axis = 0.0;
193 let min = self.min();
194 let max = self.max();
195 (0..N).for_each(|idx| {
196 let axis_length = (max[idx] - min[idx]).abs();
197 if axis_length > longest_axis {
198 longest_axis = axis_length;
199 longest_axis_idx = idx;
200 }
201 });
202
203 let mut left_max = *self.max();
205 left_max[longest_axis_idx] -= longest_axis / 2.0;
206
207 let mut right_min = *self.min();
209 right_min[longest_axis_idx] += longest_axis / 2.0;
210 (
211 BoundingBox::new(*self.min(), left_max),
212 BoundingBox::new(right_min, *self.max()),
213 )
214 }
215
216 pub fn multisplit(&self, depth: usize) -> Vec<BoundingBox<N>> {
220 fn multisplit_helper<const N: usize>(
221 bbox: BoundingBox<N>,
222 num_splits: usize,
223 ) -> Vec<BoundingBox<N>> {
224 if num_splits == 0 {
225 vec![bbox]
226 } else {
227 let (left, right) = bbox.split();
228 let mut subdivisions = Vec::new();
229 subdivisions.extend(multisplit_helper(left, num_splits - 1));
230 subdivisions.extend(multisplit_helper(right, num_splits - 1));
231 subdivisions
232 }
233 }
234 multisplit_helper(*self, N * depth)
236 .into_iter()
237 .rev()
238 .take(2i32.pow(N as u32 * depth as u32) as usize)
239 .rev()
240 .collect()
241 }
242
243 pub fn point_in_bounds(&self, point: &NVector<N>) -> bool {
245 for (i, point_i) in point.iter().enumerate().take(N) {
246 if *point_i < self.min[i] || *point_i > self.max[i] {
247 return false;
248 }
249 }
250 true
251 }
252
253 pub fn points_in_bounds<B: Backend>(&self, points: Tensor<B, 2, Float>) -> Tensor<B, 1, Bool> {
255 let min = to_tensor(self.min, &points.device());
256 let max = to_tensor(self.max, &points.device());
257 let gt = points.clone().greater(min.unsqueeze()).int();
258 let lt = points.clone().lower(max.unsqueeze()).int();
259 gt.mul(lt).reshape([points.shape().dims::<2>()[0]]).bool()
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use rand::rng;
267
268 #[test]
269 fn test_volume() {
270 let bbox = BoundingBox::new([0.0], [1.0]);
271 assert_eq!(bbox.volume(), 1.0);
272
273 let bbox = BoundingBox::new([0.0, 0.0], [1.0, 1.0]);
274 assert_eq!(bbox.volume(), 1.0);
275
276 let bbox = BoundingBox::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
277 assert_eq!(bbox.volume(), 1.0);
278
279 let bbox = BoundingBox::new([0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]);
280 assert_eq!(bbox.volume(), 1.0);
281 }
282
283 #[test]
284 fn test_center() {
285 let bbox = BoundingBox::new([0.0], [2.0]);
286 assert_eq!(bbox.center(), [1.0]);
287
288 let bbox = BoundingBox::new([0.0, 0.0], [2.0, 2.0]);
289 assert_eq!(bbox.center(), [1.0, 1.0]);
290
291 let bbox = BoundingBox::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]);
292 assert_eq!(bbox.center(), [1.0, 1.0, 1.0]);
293
294 let bbox = BoundingBox::new([0.0, 0.0, 0.0, 0.0], [2.0, 2.0, 2.0, 2.0]);
295 assert_eq!(bbox.center(), [1.0, 1.0, 1.0, 1.0]);
296 }
297
298 #[test]
299 fn test_contains() {
300 let bbox = BoundingBox::new([0.0], [1.0]);
301 assert!(bbox.contains(&[0.5]));
302 assert!(!bbox.contains(&[1.5]));
303
304 let bbox = BoundingBox::new([0.0, 0.0], [1.0, 1.0]);
305 assert!(bbox.contains(&[0.5, 0.5]));
306 assert!(!bbox.contains(&[1.5, 0.5]));
307
308 let bbox = BoundingBox::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
309 assert!(bbox.contains(&[0.5, 0.5, 0.5]));
310 assert!(!bbox.contains(&[1.5, 0.5, 0.5]));
311
312 let bbox = BoundingBox::new([0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]);
313 assert!(bbox.contains(&[0.5, 0.5, 0.5, 0.5]));
314 assert!(!bbox.contains(&[1.5, 0.5, 0.5, 0.5]));
315 }
316
317 #[test]
318 fn test_corners() {
319 let bbox = BoundingBox::new([0.0], [1.0]);
320 let corners: Vec<_> = bbox.corners().collect();
321 assert_eq!(corners, vec![[0.0], [1.0]]);
322
323 let bbox = BoundingBox::new([0.0, 0.0], [1.0, 1.0]);
324 let corners: Vec<_> = bbox.corners().collect();
325 assert_eq!(
326 corners,
327 vec![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]
328 );
329
330 let bbox = BoundingBox::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
331 let corners: Vec<_> = bbox.corners().collect();
332 assert_eq!(
333 corners,
334 vec![
335 [0.0, 0.0, 0.0],
336 [1.0, 0.0, 0.0],
337 [0.0, 1.0, 0.0],
338 [1.0, 1.0, 0.0],
339 [0.0, 0.0, 1.0],
340 [1.0, 0.0, 1.0],
341 [0.0, 1.0, 1.0],
342 [1.0, 1.0, 1.0]
343 ]
344 );
345
346 let bbox = BoundingBox::new([0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]);
347 let corners: Vec<_> = bbox.corners().collect();
348 assert_eq!(
349 corners,
350 vec![
351 [0.0, 0.0, 0.0, 0.0],
352 [1.0, 0.0, 0.0, 0.0],
353 [0.0, 1.0, 0.0, 0.0],
354 [1.0, 1.0, 0.0, 0.0],
355 [0.0, 0.0, 1.0, 0.0],
356 [1.0, 0.0, 1.0, 0.0],
357 [0.0, 1.0, 1.0, 0.0],
358 [1.0, 1.0, 1.0, 0.0],
359 [0.0, 0.0, 0.0, 1.0],
360 [1.0, 0.0, 0.0, 1.0],
361 [0.0, 1.0, 0.0, 1.0],
362 [1.0, 1.0, 0.0, 1.0],
363 [0.0, 0.0, 1.0, 1.0],
364 [1.0, 0.0, 1.0, 1.0],
365 [0.0, 1.0, 1.0, 1.0],
366 [1.0, 1.0, 1.0, 1.0]
367 ]
368 );
369 }
370
371 #[test]
372 fn test_sample() {
373 let mut rng = rng();
374
375 let bbox = BoundingBox::new([0.0], [1.0]);
376 let sample = bbox.sample(&mut rng);
377 assert!(sample[0] >= 0.0 && sample[0] <= 1.0);
378
379 let bbox = BoundingBox::new([0.0, 0.0], [1.0, 1.0]);
380 let sample = bbox.sample(&mut rng);
381 assert!(sample[0] >= 0.0 && sample[0] <= 1.0);
382 assert!(sample[1] >= 0.0 && sample[1] <= 1.0);
383
384 let bbox = BoundingBox::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
385 let sample = bbox.sample(&mut rng);
386 assert!(sample[0] >= 0.0 && sample[0] <= 1.0);
387 assert!(sample[1] >= 0.0 && sample[1] <= 1.0);
388 assert!(sample[2] >= 0.0 && sample[2] <= 1.0);
389
390 let bbox = BoundingBox::new([0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]);
391 let sample = bbox.sample(&mut rng);
392 assert!(sample[0] >= 0.0 && sample[0] <= 1.0);
393 assert!(sample[1] >= 0.0 && sample[1] <= 1.0);
394 assert!(sample[2] >= 0.0 && sample[2] <= 1.0);
395 assert!(sample[3] >= 0.0 && sample[3] <= 1.0);
396 }
397
398 #[test]
399 fn test_split() {
400 let bbox = BoundingBox::new([0.0], [2.0]);
401 let (left, right) = bbox.split();
402 assert_eq!(left, BoundingBox::new([0.0], [1.0]));
403 assert_eq!(right, BoundingBox::new([1.0], [2.0]));
404
405 let bbox = BoundingBox::new([0.0, 0.0], [2.0, 2.0]);
406 let (left, right) = bbox.split();
407 assert_eq!(left, BoundingBox::new([0.0, 0.0], [1.0, 2.0]));
408 assert_eq!(right, BoundingBox::new([1.0, 0.0], [2.0, 2.0]));
409
410 let bbox = BoundingBox::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]);
411 let (left, right) = bbox.split();
412 assert_eq!(left, BoundingBox::new([0.0, 0.0, 0.0], [1.0, 2.0, 2.0]));
413 assert_eq!(right, BoundingBox::new([1.0, 0.0, 0.0], [2.0, 2.0, 2.0]));
414
415 let bbox = BoundingBox::new([0.0, 0.0, 0.0, 0.0], [2.0, 2.0, 2.0, 2.0]);
416 let (left, right) = bbox.split();
417 assert_eq!(
418 left,
419 BoundingBox::new([0.0, 0.0, 0.0, 0.0], [1.0, 2.0, 2.0, 2.0])
420 );
421 assert_eq!(
422 right,
423 BoundingBox::new([1.0, 0.0, 0.0, 0.0], [2.0, 2.0, 2.0, 2.0])
424 );
425 }
426}