Skip to main content

proof_engine/geometry/
csg.rs

1//! Constructive Solid Geometry — union, intersection, difference of mathematical volumes.
2
3use glam::Vec3;
4use super::implicit::ScalarField;
5
6/// CSG operation type.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum CsgOp {
9    Union,
10    Intersection,
11    Difference,
12    SmoothUnion { k: u32 },        // k/100 = smoothing radius
13    SmoothIntersection { k: u32 },
14    SmoothDifference { k: u32 },
15}
16
17/// A node in a CSG tree — either a leaf (scalar field) or a binary operation.
18pub enum CsgNode {
19    Leaf(Box<dyn ScalarField>),
20    Binary {
21        op: CsgOp,
22        left: Box<CsgNode>,
23        right: Box<CsgNode>,
24    },
25}
26
27impl CsgNode {
28    pub fn leaf(field: impl ScalarField + 'static) -> Self {
29        Self::Leaf(Box::new(field))
30    }
31
32    pub fn union(a: CsgNode, b: CsgNode) -> Self {
33        Self::Binary { op: CsgOp::Union, left: Box::new(a), right: Box::new(b) }
34    }
35
36    pub fn intersection(a: CsgNode, b: CsgNode) -> Self {
37        Self::Binary { op: CsgOp::Intersection, left: Box::new(a), right: Box::new(b) }
38    }
39
40    pub fn difference(a: CsgNode, b: CsgNode) -> Self {
41        Self::Binary { op: CsgOp::Difference, left: Box::new(a), right: Box::new(b) }
42    }
43
44    pub fn smooth_union(a: CsgNode, b: CsgNode, k: f32) -> Self {
45        Self::Binary { op: CsgOp::SmoothUnion { k: (k * 100.0) as u32 }, left: Box::new(a), right: Box::new(b) }
46    }
47
48    /// Evaluate the CSG tree at a point.
49    pub fn evaluate(&self, p: Vec3) -> f32 {
50        match self {
51            Self::Leaf(f) => f.evaluate(p),
52            Self::Binary { op, left, right } => {
53                let a = left.evaluate(p);
54                let b = right.evaluate(p);
55                match op {
56                    CsgOp::Union => a.min(b),
57                    CsgOp::Intersection => a.max(b),
58                    CsgOp::Difference => a.max(-b),
59                    CsgOp::SmoothUnion { k } => {
60                        let k = *k as f32 / 100.0;
61                        smooth_min(a, b, k)
62                    }
63                    CsgOp::SmoothIntersection { k } => {
64                        let k = *k as f32 / 100.0;
65                        -smooth_min(-a, -b, k)
66                    }
67                    CsgOp::SmoothDifference { k } => {
68                        let k = *k as f32 / 100.0;
69                        -smooth_min(-a, b, k)
70                    }
71                }
72            }
73        }
74    }
75}
76
77impl ScalarField for CsgNode {
78    fn evaluate(&self, p: Vec3) -> f32 { self.evaluate(p) }
79}
80
81/// A complete CSG tree that can be evaluated as a scalar field.
82pub struct CsgTree {
83    pub root: CsgNode,
84}
85
86impl CsgTree {
87    pub fn new(root: CsgNode) -> Self { Self { root } }
88    pub fn evaluate(&self, p: Vec3) -> f32 { self.root.evaluate(p) }
89}
90
91impl ScalarField for CsgTree {
92    fn evaluate(&self, p: Vec3) -> f32 { self.root.evaluate(p) }
93}
94
95/// Smooth minimum (polynomial).
96fn smooth_min(a: f32, b: f32, k: f32) -> f32 {
97    if k < 1e-6 { return a.min(b); }
98    let h = (0.5 + 0.5 * (b - a) / k).clamp(0.0, 1.0);
99    b + (a - b) * h - k * h * (1.0 - h)
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use super::super::implicit::{SdfSphere, SdfBox};
106
107    #[test]
108    fn union_takes_minimum() {
109        let a = CsgNode::leaf(SdfSphere { center: Vec3::ZERO, radius: 1.0 });
110        let b = CsgNode::leaf(SdfSphere { center: Vec3::new(1.0, 0.0, 0.0), radius: 1.0 });
111        let u = CsgNode::union(a, b);
112        // Point at origin: inside sphere A (neg), should be negative
113        assert!(u.evaluate(Vec3::ZERO) < 0.0);
114    }
115
116    #[test]
117    fn difference_subtracts() {
118        let a = CsgNode::leaf(SdfSphere { center: Vec3::ZERO, radius: 2.0 });
119        let b = CsgNode::leaf(SdfSphere { center: Vec3::ZERO, radius: 1.0 });
120        let d = CsgNode::difference(a, b);
121        // Point at origin is inside B, so difference should be positive (carved out)
122        assert!(d.evaluate(Vec3::ZERO) > 0.0);
123    }
124}