crater/csg/
regions.rs

1//! [`Region`]s are collections of points inside an [`Isosurface`]
2
3use crate::{csg::prelude::*, primitives::nvector::NVector};
4use burn::prelude::*;
5
6use super::{fields::IntoIsosurface, transformations::Translate};
7
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum Side {
10    Negative,
11    Positive,
12}
13
14/// A region in N-dimensional space defined by constructive solid geometry (CSG).
15///
16/// [`Region`] represents a collection of points satisfying boolean combinations of implicit
17/// surface constraints. It forms the foundation of the CSG system, enabling construction
18/// of complex 3D models from simpler primitive shapes through union, intersection, and
19/// complement operations.
20///
21/// # Mathematical Foundation
22///
23/// A region R ⊂ ℝⁿ is defined recursively:
24/// - **Half-space**: R = {x : f(x) ⋈ 0} where ⋈ ∈ {<, >} and f is a scalar field
25/// - **Union**: R = R₁ ∪ R₂ using algebraic max operations  
26/// - **Intersection**: R = R₁ ∩ R₂ using algebraic min operations
27///
28/// The choice of [`CSGAlgebra`] determines the specific boolean operators used,
29/// affecting surface smoothness and computational properties.
30///
31/// # Variants
32///
33/// * [`HalfSpace`](Region::HalfSpace) - Points on one side of an implicit surface
34/// * [`Union`](Region::Union) - Points belonging to either of two regions
35/// * [`Intersection`](Region::Intersection) - Points belonging to both regions
36///
37/// # Usage Patterns
38///
39/// Regions are typically constructed by starting with primitive shapes and combining
40/// them using boolean operations:
41///
42/// ```rust
43/// use crater::csg::prelude::*;
44/// use crater::primitives::prelude::*;
45/// use burn::prelude::*;
46/// use burn::backend::ndarray::NdArrayDevice;
47///
48/// // Create primitive shapes
49/// let sphere = Field3D::<burn::backend::ndarray::NdArray>::sphere(1.0, NdArrayDevice::Cpu).into_isosurface(0.0);
50/// let cube = Field3D::<burn::backend::ndarray::NdArray>::sphere(1.5, NdArrayDevice::Cpu).into_isosurface(0.0);
51///
52/// // Build CSG tree: sphere inside cube
53/// let sphere_region = Region::HalfSpace(sphere, Side::Negative);
54/// let cube_region = Region::HalfSpace(cube, Side::Negative);
55/// let rounded_cube = Region::Union(
56///     Box::new(sphere_region),
57///     Box::new(cube_region)
58/// );
59/// ```
60///
61/// # Performance Considerations
62///
63/// - **Tree depth**: Deep CSG trees increase evaluation cost linearly
64/// - **Algebraic choice**: R-functions provide smoothness but higher computational cost
65/// - **Surface complexity**: Complex implicit surfaces require careful numerical handling
66/// - **Memory layout**: Boxed recursive structure has pointer indirection overhead
67///
68/// For detailed theory, see the [Constructive Solid Geometry](../../book/theory/csg.md) chapter.
69#[derive(Debug, Clone)]
70pub enum Region<const N: usize, B: Backend> {
71    /// Points on one side of an implicit surface.
72    ///
73    /// Represents the set {x : side(f(x)) ⋈ 0} where f is the isosurface scalar field.
74    /// The [`Side`] determines whether we include points where f < 0 or f > 0.
75    HalfSpace(Isosurface<N, B>, Side),
76
77    /// Union of two regions: R₁ ∪ R₂.
78    ///
79    /// Points belong to the union if they belong to either the left or right region.
80    /// The specific implementation depends on the chosen [`CSGAlgebra`].
81    Union(Box<Region<N, B>>, Box<Region<N, B>>),
82
83    /// Intersection of two regions: R₁ ∩ R₂.
84    ///
85    /// Points belong to the intersection only if they belong to both regions.
86    /// The specific implementation depends on the chosen [`CSGAlgebra`].
87    Intersection(Box<Region<N, B>>, Box<Region<N, B>>),
88}
89
90impl<const N: usize, B: Backend> Region<N, B> {
91    /// Evaluates the region's isosurface function at multiple origins.
92    ///
93    /// # Parameters
94    ///
95    /// * `points` - Tensor of shape `[M, N]` where `M` is the number of origins and `N` is the spatial dimension
96    /// * `algebra` - CSG algebra defining how boolean operations are computed numerically
97    ///
98    /// # Returns
99    ///
100    /// Tensor of shape `[M]` containing the isosurface function values for each input origin.
101    ///
102    /// # Examples
103    ///
104    /// ```rust
105    /// use crater::csg::prelude::*;
106    /// use burn::backend::ndarray::NdArrayDevice;
107    /// use burn::prelude::*;
108    ///
109    /// // Create a unit sphere region
110    /// let sphere = Field3D::<burn::backend::ndarray::NdArray>::sphere(1.0, NdArrayDevice::Cpu).into_isosurface(0.0);
111    /// let region = Region::HalfSpace(sphere, Side::Negative);
112    /// let algebra = Algebra::default();
113    ///
114    /// // Test points: center, surface, and exterior
115    /// let points = Tensor::from_data([
116    ///     [0.0, 0.0, 0.0],  // Center (inside)
117    ///     [1.0, 0.0, 0.0],  // Surface
118    ///     [2.0, 0.0, 0.0],  // Outside
119    /// ], &NdArrayDevice::Cpu);
120    ///
121    /// let values = region.evaluate(points, &algebra);
122    /// // values ≈ [-1.0, 0.0, 1.0] (inside, on surface, outside)
123    /// ```
124    pub fn evaluate(
125        &self,
126        points: Tensor<B, 2, Float>,
127        algebra: &impl CSGAlgebra<B>,
128    ) -> Tensor<B, 1, Float> {
129        match self {
130            Region::HalfSpace(surf, side) => {
131                // When we're at a raw halfspace, we just let the
132                // surface do the classification, and change the sign based
133                // on whether we're inside or outside.
134                // NOTE: Negative means we want the f < 0 points to be considered "inside"
135                //       Positive means we want the f > 0 points to be considered "inside"
136                //       So we negate positive halfspaces.
137                match side {
138                    Side::Negative => surf.evaluate(points),
139                    Side::Positive => -surf.evaluate(points),
140                }
141            }
142            // For Unions and Intersections, we evaluate the two regions, and
143            // use the algebra to define the combination
144            Region::Union(l, r) => {
145                let a = l.evaluate(points.clone(), algebra);
146                let b = r.evaluate(points.clone(), algebra);
147                algebra.union(a, b)
148            }
149            Region::Intersection(l, r) => {
150                let a = l.evaluate(points.clone(), algebra);
151                let b = r.evaluate(points.clone(), algebra);
152                algebra.intersection(a, b)
153            }
154        }
155    }
156
157    /// Classify a point as being inside, outside, or on the region.
158    pub fn classify_point(
159        &self,
160        point: &NVector<N>,
161        algebra: &impl CSGAlgebra<B>,
162    ) -> Classification {
163        let device = self.device();
164        self.evaluate(Tensor::from_data([*point], device), algebra)
165            .classification_of_index(0)
166    }
167
168    /// Get the device on which the [`Region`] is allocated.
169    pub fn device(&self) -> &B::Device {
170        // Always the same device for all sub-regions
171        match self {
172            Region::HalfSpace(surf, _) => surf.device(),
173            Region::Union(l, _) => l.device(),
174            Region::Intersection(l, _) => l.device(),
175        }
176    }
177
178    pub fn halfspaces(&self) -> HalfSpaceIter<'_, N, B> {
179        HalfSpaceIter::new(self)
180    }
181}
182
183pub struct HalfSpaceIter<'a, const N: usize, B: Backend> {
184    stack: Vec<&'a Region<N, B>>,
185}
186
187impl<'a, const N: usize, B: Backend> HalfSpaceIter<'a, N, B> {
188    fn new(region: &'a Region<N, B>) -> Self {
189        Self {
190            stack: vec![region],
191        }
192    }
193}
194
195impl<'a, const N: usize, B: Backend> Iterator for HalfSpaceIter<'a, N, B> {
196    type Item = &'a Region<N, B>;
197
198    fn next(&mut self) -> Option<Self::Item> {
199        while let Some(node) = self.stack.pop() {
200            match node {
201                Region::HalfSpace(_, _) => return Some(node),
202                Region::Union(l, r) | Region::Intersection(l, r) => {
203                    self.stack.push(r);
204                    self.stack.push(l);
205                }
206            }
207        }
208        None
209    }
210}
211
212impl<const N: usize, B: Backend> IntoRegion<N, B> for Isosurface<N, B> {
213    fn into_region(self, _device: B::Device) -> Region<N, B> {
214        Region::HalfSpace(self, Side::Negative)
215    }
216}
217
218impl<const N: usize, B: Backend> Isosurface<N, B> {
219    pub fn region(self) -> Region<N, B> {
220        let device = self.device().clone();
221        self.into_region(device)
222    }
223}
224
225impl std::ops::Neg for Side {
226    type Output = Side;
227    fn neg(self) -> Self::Output {
228        match self {
229            Side::Negative => Side::Positive,
230            Side::Positive => Side::Negative,
231        }
232    }
233}
234
235impl<const N: usize, B: Backend> std::ops::Neg for &Region<N, B> {
236    type Output = Region<N, B>;
237    fn neg(self) -> Self::Output {
238        match self {
239            Region::HalfSpace(surf, side) => Region::HalfSpace(surf.clone(), -*side),
240            Region::Union(a, b) => {
241                Region::Intersection(Box::new(-*a.clone()), Box::new(-*b.clone()))
242            }
243            Region::Intersection(a, b) => {
244                Region::Union(Box::new(-*a.clone()), Box::new(-*b.clone()))
245            }
246        }
247    }
248}
249impl<const N: usize, B: Backend> std::ops::Neg for Region<N, B> {
250    type Output = Region<N, B>;
251    fn neg(self) -> Self::Output {
252        match self {
253            Region::HalfSpace(surf, side) => Region::HalfSpace(surf, -side),
254            Region::Union(a, b) => Region::Intersection(Box::new(-*a), Box::new(-*b)),
255            Region::Intersection(a, b) => Region::Union(Box::new(-*a), Box::new(-*b)),
256        }
257    }
258}
259impl<const N: usize, B: Backend> std::ops::BitAnd for &Region<N, B> {
260    type Output = Region<N, B>;
261    fn bitand(self, rhs: Self) -> Self::Output {
262        Region::Intersection(Box::new(self.clone()), Box::new(rhs.clone()))
263    }
264}
265impl<const N: usize, B: Backend> std::ops::BitAnd for Region<N, B> {
266    type Output = Region<N, B>;
267    fn bitand(self, rhs: Self) -> Self::Output {
268        Region::Intersection(Box::new(self), Box::new(rhs))
269    }
270}
271impl<const N: usize, B: Backend> std::ops::BitOr for &Region<N, B> {
272    type Output = Region<N, B>;
273    fn bitor(self, rhs: Self) -> Self::Output {
274        Region::Union(Box::new(self.clone()), Box::new(rhs.clone()))
275    }
276}
277impl<const N: usize, B: Backend> std::ops::BitOr for Region<N, B> {
278    type Output = Region<N, B>;
279    fn bitor(self, rhs: Self) -> Self::Output {
280        Region::Union(Box::new(self), Box::new(rhs))
281    }
282}
283
284// impl<'de, const N: usize, B: Backend> Deserialize<'de> for Region<N, B> {
285//     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
286//     where
287//         D: serde::Deserializer<'de>,
288//     {
289//         // Deserialize and parse into region
290//         let r_string = String::deserialize(deserializer)?.to_owned();
291//         str_to_region(&r_string).map_err(serde::de::Error::custom)
292//     }
293// }
294
295/// Convert a type into a [`Region`]
296pub trait IntoRegion<const N: usize, B: Backend> {
297    fn into_region(self, device: B::Device) -> Region<N, B>;
298}
299
300impl<const N: usize, B: Backend> IntoRegion<N, B> for crate::primitives::bounding::BoundingBox<N> {
301    fn into_region(self, device: B::Device) -> Region<N, B> {
302        // A box in dimension N has 2N faces. Generate each as a hyperplane,
303        // then take the intersection of all of them.
304        let faces = (0..N)
305            .flat_map(|i| {
306                let mut normal = [0.0; N];
307                let mut offset_min = [0.0; N];
308                let mut offset_max = [0.0; N];
309
310                normal[i] = 1.0;
311                offset_min[i] = self.min()[i];
312                offset_max[i] = self.max()[i];
313                [
314                    -FieldND::hyperplane(normal, device.clone())
315                        .into_isosurface(0.0)
316                        .transform(Translate(offset_min))
317                        .region(),
318                    FieldND::hyperplane(normal, device.clone())
319                        .into_isosurface(0.0)
320                        .transform(Translate(offset_max))
321                        .region(),
322                ]
323            })
324            .collect::<Vec<_>>();
325
326        faces.into_iter().reduce(|a, b| a & b).unwrap()
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::primitives::bounding::BoundingBox;
334    use backend_macro::with_backend;
335
336    use rstest::rstest;
337
338    #[with_backend]
339    #[rstest]
340    fn test_halfspace() {
341        // Try out a ZPlane halfspace
342        // z > 0 halfspace (using Side::Positive)
343        let surface = Field3D::<Backend>::plane([0.0, 0.0, 1.0], device());
344        let halfspace = Region::HalfSpace(surface.into_isosurface(0.0), Side::Positive);
345        assert_eq!(
346            halfspace.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
347            Classification::On
348        );
349        assert_eq!(
350            halfspace.classify_point(&[0.0, 0.0, 1.0], &Algebra::default()),
351            Classification::Inside(-1.0)
352        );
353        assert_eq!(
354            halfspace.classify_point(&[0.0, 0.0, -2.0], &Algebra::default()),
355            Classification::Outside(2.0)
356        );
357
358        // Try a Sphere halfspace (everything inside the sphere)
359        let surface = Field3D::<Backend>::sphere(1.0, device());
360        let halfspace = Region::HalfSpace(surface.into_isosurface(0.0), Side::Negative);
361
362        assert_eq!(
363            halfspace.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
364            Classification::Inside(-1.0)
365        );
366        assert_eq!(
367            halfspace.classify_point(&[0.0, 0.0, 1.0], &Algebra::default()),
368            Classification::On
369        );
370        assert_eq!(
371            halfspace.classify_point(&[0.0, 0.0, -2.0], &Algebra::default()),
372            Classification::Outside(3.0)
373        );
374    }
375
376    #[with_backend]
377    #[rstest]
378    fn test_regions() {
379        // Try out a ZPlane halfspace
380        // z > 0 halfspace
381        let surface = Field3D::<Backend>::plane([0.0, 0.0, 1.0], device());
382        let halfspace = Region::HalfSpace(surface.into_isosurface(0.0), Side::Positive);
383
384        // Try a Sphere halfspace (everything inside the sphere)
385        let surface = Field3D::<Backend>::sphere(1.0, device());
386        let halfspace2 = Region::HalfSpace(surface.into_isosurface(0.0), Side::Negative);
387
388        // Everything inside the sphere OR above the z=0 plane
389        let union = Region::Union(Box::new(halfspace), Box::new(halfspace2));
390
391        assert_eq!(
392            union.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
393            Classification::Inside(-1.0) // We're "on" the plane, but inside the sphere, so the sphere wins
394        );
395        assert_eq!(
396            union.classify_point(&[0.0, 0.0, 1.0], &Algebra::default()),
397            Classification::Inside(-1.0) // We're "on" the sphere, but inside the plane halfspace, so the plane wins
398        );
399        assert_eq!(
400            union.classify_point(&[-1.0, 0.0, 0.0], &Algebra::default()),
401            Classification::On
402        );
403        assert_eq!(
404            union.classify_point(&[0.0, 0.0, 0.5], &Algebra::default()),
405            Classification::Inside(-0.75)
406        );
407        assert_eq!(
408            union.classify_point(&[0.0, 0.0, -2.0], &Algebra::default()),
409            Classification::Outside(2.0)
410        );
411        assert_eq!(
412            union.classify_point(&[0.0, 0.0, 3.0], &Algebra::default()),
413            Classification::Inside(-3.0)
414        );
415
416        // Try out some Intersections
417        let surface = Field3D::<Backend>::plane([0.0, 0.0, 1.0], device());
418        let halfspace = Region::HalfSpace(surface.into_isosurface(0.0), Side::Positive);
419
420        // Try a Sphere halfspace (everything inside the sphere)
421        let surface = Field3D::<Backend>::sphere(1.0, device());
422        let halfspace2 = Region::HalfSpace(surface.into_isosurface(0.0), Side::Negative);
423
424        // Everything inside the sphere AND above the z=0 plane
425        let intersection = Region::Intersection(Box::new(halfspace), Box::new(halfspace2));
426
427        assert_eq!(
428            intersection.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
429            Classification::On
430        );
431        assert_eq!(
432            intersection.classify_point(&[0.0, 0.0, 2.0], &Algebra::default()),
433            Classification::Outside(3.0)
434        );
435    }
436
437    #[with_backend]
438    #[rstest]
439    fn test_bounding_box() {
440        // 1D
441        let bbox = BoundingBox::new([0.0], [1.0]);
442        let region: Region<1, Backend> = bbox.into_region(device());
443        assert_eq!(
444            region.classify_point(&[0.0], &Algebra::default()),
445            Classification::On
446        );
447        assert_eq!(
448            region.classify_point(&[1.0], &Algebra::default()),
449            Classification::On
450        );
451        assert_eq!(
452            region.classify_point(&[0.5], &Algebra::default()),
453            Classification::Inside(-0.5)
454        );
455
456        let bbox = BoundingBox::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
457        let region: Region<3, Backend> = bbox.into_region(device());
458        assert_eq!(
459            region.classify_point(&[0.0, 0.0, 0.0], &Algebra::default()),
460            Classification::On
461        );
462        assert_eq!(
463            region.classify_point(&[1.0, 1.0, 1.0], &Algebra::default()),
464            Classification::On
465        );
466        assert_eq!(
467            region.classify_point(&[0.5, 0.5, 0.5], &Algebra::default()),
468            Classification::Inside(-0.5)
469        );
470    }
471}