1use crate::{csg::prelude::*, primitives::nvector::NVector};
4use burn::prelude::*;
5
6use super::{fields::IntoIsosurface, transformations::Translate};
7
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum Side {
10 Negative,
11 Positive,
12}
13
14#[derive(Debug, Clone)]
70pub enum Region<const N: usize, B: Backend> {
71 HalfSpace(Isosurface<N, B>, Side),
76
77 Union(Box<Region<N, B>>, Box<Region<N, B>>),
82
83 Intersection(Box<Region<N, B>>, Box<Region<N, B>>),
88}
89
90impl<const N: usize, B: Backend> Region<N, B> {
91 pub fn evaluate(
125 &self,
126 points: Tensor<B, 2, Float>,
127 algebra: &impl CSGAlgebra<B>,
128 ) -> Tensor<B, 1, Float> {
129 match self {
130 Region::HalfSpace(surf, side) => {
131 match side {
138 Side::Negative => surf.evaluate(points),
139 Side::Positive => -surf.evaluate(points),
140 }
141 }
142 Region::Union(l, r) => {
145 let a = l.evaluate(points.clone(), algebra);
146 let b = r.evaluate(points.clone(), algebra);
147 algebra.union(a, b)
148 }
149 Region::Intersection(l, r) => {
150 let a = l.evaluate(points.clone(), algebra);
151 let b = r.evaluate(points.clone(), algebra);
152 algebra.intersection(a, b)
153 }
154 }
155 }
156
157 pub fn classify_point(
159 &self,
160 point: &NVector<N>,
161 algebra: &impl CSGAlgebra<B>,
162 ) -> Classification {
163 let device = self.device();
164 self.evaluate(Tensor::from_data([*point], device), algebra)
165 .classification_of_index(0)
166 }
167
168 pub fn device(&self) -> &B::Device {
170 match self {
172 Region::HalfSpace(surf, _) => surf.device(),
173 Region::Union(l, _) => l.device(),
174 Region::Intersection(l, _) => l.device(),
175 }
176 }
177
178 pub fn halfspaces(&self) -> HalfSpaceIter<'_, N, B> {
179 HalfSpaceIter::new(self)
180 }
181}
182
183pub struct HalfSpaceIter<'a, const N: usize, B: Backend> {
184 stack: Vec<&'a Region<N, B>>,
185}
186
187impl<'a, const N: usize, B: Backend> HalfSpaceIter<'a, N, B> {
188 fn new(region: &'a Region<N, B>) -> Self {
189 Self {
190 stack: vec![region],
191 }
192 }
193}
194
195impl<'a, const N: usize, B: Backend> Iterator for HalfSpaceIter<'a, N, B> {
196 type Item = &'a Region<N, B>;
197
198 fn next(&mut self) -> Option<Self::Item> {
199 while let Some(node) = self.stack.pop() {
200 match node {
201 Region::HalfSpace(_, _) => return Some(node),
202 Region::Union(l, r) | Region::Intersection(l, r) => {
203 self.stack.push(r);
204 self.stack.push(l);
205 }
206 }
207 }
208 None
209 }
210}
211
212impl<const N: usize, B: Backend> IntoRegion<N, B> for Isosurface<N, B> {
213 fn into_region(self, _device: B::Device) -> Region<N, B> {
214 Region::HalfSpace(self, Side::Negative)
215 }
216}
217
218impl<const N: usize, B: Backend> Isosurface<N, B> {
219 pub fn region(self) -> Region<N, B> {
220 let device = self.device().clone();
221 self.into_region(device)
222 }
223}
224
225impl std::ops::Neg for Side {
226 type Output = Side;
227 fn neg(self) -> Self::Output {
228 match self {
229 Side::Negative => Side::Positive,
230 Side::Positive => Side::Negative,
231 }
232 }
233}
234
235impl<const N: usize, B: Backend> std::ops::Neg for &Region<N, B> {
236 type Output = Region<N, B>;
237 fn neg(self) -> Self::Output {
238 match self {
239 Region::HalfSpace(surf, side) => Region::HalfSpace(surf.clone(), -*side),
240 Region::Union(a, b) => {
241 Region::Intersection(Box::new(-*a.clone()), Box::new(-*b.clone()))
242 }
243 Region::Intersection(a, b) => {
244 Region::Union(Box::new(-*a.clone()), Box::new(-*b.clone()))
245 }
246 }
247 }
248}
249impl<const N: usize, B: Backend> std::ops::Neg for Region<N, B> {
250 type Output = Region<N, B>;
251 fn neg(self) -> Self::Output {
252 match self {
253 Region::HalfSpace(surf, side) => Region::HalfSpace(surf, -side),
254 Region::Union(a, b) => Region::Intersection(Box::new(-*a), Box::new(-*b)),
255 Region::Intersection(a, b) => Region::Union(Box::new(-*a), Box::new(-*b)),
256 }
257 }
258}
259impl<const N: usize, B: Backend> std::ops::BitAnd for &Region<N, B> {
260 type Output = Region<N, B>;
261 fn bitand(self, rhs: Self) -> Self::Output {
262 Region::Intersection(Box::new(self.clone()), Box::new(rhs.clone()))
263 }
264}
265impl<const N: usize, B: Backend> std::ops::BitAnd for Region<N, B> {
266 type Output = Region<N, B>;
267 fn bitand(self, rhs: Self) -> Self::Output {
268 Region::Intersection(Box::new(self), Box::new(rhs))
269 }
270}
271impl<const N: usize, B: Backend> std::ops::BitOr for &Region<N, B> {
272 type Output = Region<N, B>;
273 fn bitor(self, rhs: Self) -> Self::Output {
274 Region::Union(Box::new(self.clone()), Box::new(rhs.clone()))
275 }
276}
277impl<const N: usize, B: Backend> std::ops::BitOr for Region<N, B> {
278 type Output = Region<N, B>;
279 fn bitor(self, rhs: Self) -> Self::Output {
280 Region::Union(Box::new(self), Box::new(rhs))
281 }
282}
283
284pub trait IntoRegion<const N: usize, B: Backend> {
297 fn into_region(self, device: B::Device) -> Region<N, B>;
298}
299
300impl<const N: usize, B: Backend> IntoRegion<N, B> for crate::primitives::bounding::BoundingBox<N> {
301 fn into_region(self, device: B::Device) -> Region<N, B> {
302 let faces = (0..N)
305 .flat_map(|i| {
306 let mut normal = [0.0; N];
307 let mut offset_min = [0.0; N];
308 let mut offset_max = [0.0; N];
309
310 normal[i] = 1.0;
311 offset_min[i] = self.min()[i];
312 offset_max[i] = self.max()[i];
313 [
314 -FieldND::hyperplane(normal, device.clone())
315 .into_isosurface(0.0)
316 .transform(Translate(offset_min))
317 .region(),
318 FieldND::hyperplane(normal, device.clone())
319 .into_isosurface(0.0)
320 .transform(Translate(offset_max))
321 .region(),
322 ]
323 })
324 .collect::<Vec<_>>();
325
326 faces.into_iter().reduce(|a, b| a & b).unwrap()
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::primitives::bounding::BoundingBox;
334 use backend_macro::with_backend;
335
336 use rstest::rstest;
337
338 #[with_backend]
339 #[rstest]
340 fn test_halfspace() {
341 let surface = Field3D::<Backend>::plane([0.0, 0.0, 1.0], device());
344 let halfspace = Region::HalfSpace(surface.into_isosurface(0.0), Side::Positive);
345 assert_eq!(
346 halfspace.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
347 Classification::On
348 );
349 assert_eq!(
350 halfspace.classify_point(&[0.0, 0.0, 1.0], &Algebra::default()),
351 Classification::Inside(-1.0)
352 );
353 assert_eq!(
354 halfspace.classify_point(&[0.0, 0.0, -2.0], &Algebra::default()),
355 Classification::Outside(2.0)
356 );
357
358 let surface = Field3D::<Backend>::sphere(1.0, device());
360 let halfspace = Region::HalfSpace(surface.into_isosurface(0.0), Side::Negative);
361
362 assert_eq!(
363 halfspace.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
364 Classification::Inside(-1.0)
365 );
366 assert_eq!(
367 halfspace.classify_point(&[0.0, 0.0, 1.0], &Algebra::default()),
368 Classification::On
369 );
370 assert_eq!(
371 halfspace.classify_point(&[0.0, 0.0, -2.0], &Algebra::default()),
372 Classification::Outside(3.0)
373 );
374 }
375
376 #[with_backend]
377 #[rstest]
378 fn test_regions() {
379 let surface = Field3D::<Backend>::plane([0.0, 0.0, 1.0], device());
382 let halfspace = Region::HalfSpace(surface.into_isosurface(0.0), Side::Positive);
383
384 let surface = Field3D::<Backend>::sphere(1.0, device());
386 let halfspace2 = Region::HalfSpace(surface.into_isosurface(0.0), Side::Negative);
387
388 let union = Region::Union(Box::new(halfspace), Box::new(halfspace2));
390
391 assert_eq!(
392 union.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
393 Classification::Inside(-1.0) );
395 assert_eq!(
396 union.classify_point(&[0.0, 0.0, 1.0], &Algebra::default()),
397 Classification::Inside(-1.0) );
399 assert_eq!(
400 union.classify_point(&[-1.0, 0.0, 0.0], &Algebra::default()),
401 Classification::On
402 );
403 assert_eq!(
404 union.classify_point(&[0.0, 0.0, 0.5], &Algebra::default()),
405 Classification::Inside(-0.75)
406 );
407 assert_eq!(
408 union.classify_point(&[0.0, 0.0, -2.0], &Algebra::default()),
409 Classification::Outside(2.0)
410 );
411 assert_eq!(
412 union.classify_point(&[0.0, 0.0, 3.0], &Algebra::default()),
413 Classification::Inside(-3.0)
414 );
415
416 let surface = Field3D::<Backend>::plane([0.0, 0.0, 1.0], device());
418 let halfspace = Region::HalfSpace(surface.into_isosurface(0.0), Side::Positive);
419
420 let surface = Field3D::<Backend>::sphere(1.0, device());
422 let halfspace2 = Region::HalfSpace(surface.into_isosurface(0.0), Side::Negative);
423
424 let intersection = Region::Intersection(Box::new(halfspace), Box::new(halfspace2));
426
427 assert_eq!(
428 intersection.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
429 Classification::On
430 );
431 assert_eq!(
432 intersection.classify_point(&[0.0, 0.0, 2.0], &Algebra::default()),
433 Classification::Outside(3.0)
434 );
435 }
436
437 #[with_backend]
438 #[rstest]
439 fn test_bounding_box() {
440 let bbox = BoundingBox::new([0.0], [1.0]);
442 let region: Region<1, Backend> = bbox.into_region(device());
443 assert_eq!(
444 region.classify_point(&[0.0], &Algebra::default()),
445 Classification::On
446 );
447 assert_eq!(
448 region.classify_point(&[1.0], &Algebra::default()),
449 Classification::On
450 );
451 assert_eq!(
452 region.classify_point(&[0.5], &Algebra::default()),
453 Classification::Inside(-0.5)
454 );
455
456 let bbox = BoundingBox::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
457 let region: Region<3, Backend> = bbox.into_region(device());
458 assert_eq!(
459 region.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
460 Classification::On
461 );
462 assert_eq!(
463 region.classify_point(&[1.0, 1.0, 1.0], &Algebra::default()),
464 Classification::On
465 );
466 assert_eq!(
467 region.classify_point(&[0.5, 0.5, 0.5], &Algebra::default()),
468 Classification::Inside(-0.5)
469 );
470 }
471}