crater/csg/
fields.rs

1//! Scalar fields and isosurface definitions.
2//!
3//! This module provides the core abstractions for scalar fields and isosurfaces.
4//! For theoretical background on scalar fields and CSG operations, see the
5//! [Scalar Fields](../../book/theory/scalar-fields.md) chapter.
6
7use burn::prelude::*;
8use dyn_clone::DynClone;
9use std::fmt::Debug;
10
11use crate::{
12    analysis::prelude::*,
13    primitives::nvector::{NVector, length, normalize, to_tensor},
14};
15
16#[derive(Debug, Clone, Copy, PartialEq)]
17/// Represents the classification of a point relative to an [`Isosurface`].
18///
19/// [`Classification::Inside`] always owns a negative value, and
20/// [`Classification::Outside`] always owns a positive value.
21/// [`Classification::On`] is a special case that occurs when the field value is within
22/// [`EPSILON`] of the surface.
23pub enum Classification {
24    Inside(f32),
25    Outside(f32),
26    On,
27}
28
29impl Classification {
30    pub fn is_inside(&self) -> bool {
31        matches!(self, Classification::Inside(_))
32    }
33
34    pub fn is_outside(&self) -> bool {
35        matches!(self, Classification::Outside(_))
36    }
37
38    pub fn is_on(&self) -> bool {
39        matches!(self, Classification::On)
40    }
41}
42
43impl From<f32> for Classification {
44    fn from(value: f32) -> Self {
45        if value < -EPSILON {
46            Classification::Inside(value)
47        } else if value > EPSILON {
48            Classification::Outside(value)
49        } else {
50            Classification::On
51        }
52    }
53}
54
55pub trait Classify<B: Backend, const D: usize> {
56    /// Mask for points inside the surface.
57    fn is_inside_mask(&self) -> Tensor<B, D, Bool>;
58
59    /// Mask for points outside the surface.
60    fn is_outside_mask(&self) -> Tensor<B, D, Bool>;
61
62    /// Mask for points on the surface.
63    fn is_on_mask(&self) -> Tensor<B, D, Bool>;
64
65    /// Get the classification of a point at a given index.
66    fn classification_of_index(&self, index: usize) -> Classification;
67}
68
69impl<B: Backend, const D: usize> Classify<B, D> for Tensor<B, D, Float> {
70    fn is_inside_mask(&self) -> Tensor<B, D, Bool> {
71        self.clone().lower_elem(-EPSILON)
72    }
73
74    fn is_outside_mask(&self) -> Tensor<B, D, Bool> {
75        self.clone().greater_elem(EPSILON)
76    }
77
78    fn is_on_mask(&self) -> Tensor<B, D, Bool> {
79        self.clone().abs().lower_elem(EPSILON)
80    }
81
82    fn classification_of_index(&self, index: usize) -> Classification {
83        match self.clone().into_data().as_slice().unwrap()[index] {
84            x if x < -EPSILON => Classification::Inside(x),
85            x if x > EPSILON => Classification::Outside(x),
86            _ => Classification::On,
87        }
88    }
89}
90
91/// A collection of points in N-dimensional space
92#[allow(type_alias_bounds)]
93pub type Origins<B: Backend, const N: usize> = Tensor<B, 2, Float>;
94/// A collection of directions in N-dimensional space
95#[allow(type_alias_bounds)]
96pub type Directions<B: Backend, const N: usize> = Tensor<B, 2, Float>;
97/// A collection of scalars
98#[allow(type_alias_bounds)]
99pub type Scalars<B: Backend> = Tensor<B, 1, Float>;
100
101/// A [`ScalarField`] is a function that maps a collection of points to a collection of scalars.
102pub trait ScalarField<const N: usize, B: Backend>: DynClone + Send + Sync {
103    /// Evaluates the scalar field at a collection of points.
104    fn evaluate(&self, origins: Origins<B, N>) -> Scalars<B>;
105
106    /// Get the device on which the scalar field is allocated.
107    fn device(&self) -> &B::Device;
108}
109
110dyn_clone::clone_trait_object!(<const N: usize, B: Backend> ScalarField<N, B>);
111
112/// Absolute value for a floating point value of a scalar field to be considered zero.
113// pub const EPSILON: f32 = 5e-3;
114pub const EPSILON: f32 = 1e-5;
115
116/// Represents a surface in N-dimensional space. The surface is defined by a mathematical
117/// expression that takes N variables as input.
118#[derive(Clone)]
119pub struct Isosurface<const N: usize, B: Backend> {
120    /// The scalar field
121    pub field: Box<dyn ScalarField<N, B>>,
122    /// The ray field (optional)
123    pub ray_field: Option<Box<dyn RayField<N, B>>>,
124    /// The constant value of the surface.
125    pub constant: Tensor<B, 1, Float>,
126}
127
128impl<const N: usize, B: Backend> PartialEq for Isosurface<N, B> {
129    fn eq(&self, _other: &Self) -> bool {
130        // All surfaces are unequal
131        false
132    }
133}
134
135impl<const N: usize, B: Backend> Debug for Isosurface<N, B> {
136    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        write!(f, "Isosurface<{N}>")
138    }
139}
140
141impl<const N: usize, B: Backend> ScalarField<N, B> for Isosurface<N, B> {
142    fn evaluate(&self, points: Origins<B, N>) -> Scalars<B> {
143        self.field.evaluate(points) - self.constant.clone()
144    }
145
146    fn device(&self) -> &B::Device {
147        self.field.device()
148    }
149}
150
151impl<const N: usize, B: Backend> Isosurface<N, B> {
152    /// Classifies a single point relative to the [`Isosurface`].
153    pub fn classify_point(&self, point: &NVector<N>) -> Classification {
154        let points = Tensor::from_data([*point], self.device());
155        self.evaluate(points).classification_of_index(0)
156    }
157}
158
159/// Conversion of a type into a [`Isosurface`] object.
160pub trait IntoIsosurface<const N: usize, B: Backend> {
161    /// Converts the object into a [`Isosurface<N, B>`].
162    fn into_isosurface(self, constant: f32) -> Isosurface<N, B>;
163}
164
165/// Represents a 2D geometric shape that can be converted into a [`Isosurface`].
166#[derive(Debug, Clone)]
167pub enum Field2D<B: Backend> {
168    /// A circle defined by its radius
169    Circle {
170        r: Tensor<B, 2, Float>,
171        device: B::Device,
172    },
173    /// An ellipse defined by its semi-major and semi-minor axes and center.
174    Ellipse {
175        a: Tensor<B, 2, Float>,
176        b: Tensor<B, 2, Float>,
177        device: B::Device,
178    },
179    /// Line in 2D space
180    Line {
181        normal: Tensor<B, 2, Float>,
182        device: B::Device,
183    },
184    /// A cylinder in 2D space (infinite circle)
185    Cylinder {
186        r: Tensor<B, 2, Float>,
187        device: B::Device,
188    },
189}
190
191impl<B: Backend> Field2D<B> {
192    /// Allocate a circle [`ScalarField`] in the given device's memory.
193    pub fn circle(r: f32, device: B::Device) -> Self {
194        Self::Circle {
195            r: to_tensor([r], &device.clone()),
196            device: device.clone(),
197        }
198    }
199
200    /// Allocate an ellipse [`ScalarField`] in the given device's memory.
201    pub fn ellipse(a: f32, b: f32, device: B::Device) -> Self {
202        Self::Ellipse {
203            a: to_tensor([a], &device.clone()),
204            b: to_tensor([b], &device.clone()),
205            device: device.clone(),
206        }
207    }
208
209    /// Allocate a line [`ScalarField`] in the given device's memory.
210    pub fn line(normal: NVector<2>, device: B::Device) -> Self {
211        Self::Line {
212            normal: to_tensor(normal, &device.clone()),
213            device: device.clone(),
214        }
215    }
216
217    /// Allocate a cylinder [`ScalarField`] in the given device's memory.
218    /// This is equivalent to a hypercylinder in 2D space.
219    pub fn cylinder(r: f32, device: B::Device) -> Self {
220        Self::Cylinder {
221            r: to_tensor([r], &device.clone()),
222            device: device.clone(),
223        }
224    }
225}
226
227impl<B: Backend> ScalarField<2, B> for Field2D<B> {
228    fn evaluate(&self, points: Origins<B, 2>) -> Scalars<B> {
229        assert!(
230            points.shape().dims::<2>()[1] == 2,
231            "Points must be of shape (num_points, 2)"
232        );
233
234        match self {
235            Field2D::Circle { r, .. } => FieldND::<2, B>::Hypersphere {
236                r: r.clone(),
237                device: self.device().clone(),
238            }
239            .evaluate(points),
240            Field2D::Ellipse { a, b, .. } => (points.clone().slice([0..1, 0..1])
241                / a.clone().unsqueeze().powi_scalar(2)
242                + points.clone().slice([0..1, 1..2]) / b.clone().unsqueeze().powi_scalar(2)
243                - 1.0)
244                .squeeze(1),
245
246            Field2D::Line { normal, .. } => FieldND::<2, B>::Hyperplane {
247                normal: normal.clone(),
248                device: self.device().clone(),
249            }
250            .evaluate(points),
251            Field2D::Cylinder { r, .. } => FieldND::<2, B>::Hypercylinder {
252                r: r.clone(),
253                device: self.device().clone(),
254            }
255            .evaluate(points),
256        }
257    }
258
259    fn device(&self) -> &B::Device {
260        match self {
261            Field2D::Circle { device, .. } => device,
262            Field2D::Ellipse { device, .. } => device,
263            Field2D::Line { device, .. } => device,
264            Field2D::Cylinder { device, .. } => device,
265        }
266    }
267}
268
269impl<B: Backend> IntoIsosurface<2, B> for Field2D<B> {
270    fn into_isosurface(self, constant: f32) -> Isosurface<2, B> {
271        Isosurface {
272            field: Box::new(self.clone()),
273            ray_field: Some(Box::new(self.clone())),
274            constant: Tensor::from_data([constant], ScalarField::device(&self)),
275        }
276    }
277}
278
279#[derive(Debug, Clone)]
280pub enum Field3D<B: Backend> {
281    /// A cone defined by its direction and opening angle
282    Cone {
283        axis: Tensor<B, 2, Float>,
284        theta: Tensor<B, 2, Float>,
285        device: B::Device,
286    },
287    /// A cylinder defined by its radius
288    Cylinder {
289        r: Tensor<B, 2, Float>,
290        device: B::Device,
291    },
292    /// A sphere defined by its radius.
293    Sphere {
294        r: Tensor<B, 2, Float>,
295        device: B::Device,
296    },
297    /// A torus defined by its two radii.
298    Torus {
299        r1: Tensor<B, 2, Float>,
300        r2: Tensor<B, 2, Float>,
301        device: B::Device,
302    },
303    /// Planes in 3D space
304    Plane {
305        normal: Tensor<B, 2, Float>,
306        device: B::Device,
307    },
308}
309
310impl<B: Backend> Field3D<B> {
311    /// Creates a 3D sphere scalar field with specified radius.
312    ///
313    /// Generates a sphere centered at the origin with the given radius. The scalar field
314    /// represents the signed distance from any point to the sphere surface, where:
315    /// - Negative values indicate points inside the sphere
316    /// - Zero values indicate points exactly on the sphere surface  
317    /// - Positive values indicate points outside the sphere
318    ///
319    /// # Mathematical Definition
320    ///
321    /// For a point **x** = (x, y, z) and sphere radius r, the field function is:
322    ///
323    /// f(**x**) = ||**x**|| - r = √(x² + y² + z²) - r
324    ///
325    /// This is the signed distance function (SDF) for a sphere, providing exact
326    /// geometric distances that enable precise CSG operations.
327    ///
328    /// # Parameters
329    ///
330    /// * `r` - Sphere radius in world units. Must be positive.
331    /// * `device` - Backend device for tensor allocation and computation
332    ///
333    /// # Returns
334    ///
335    /// A [`Field3D`] representing the sphere scalar field, ready for use in
336    /// CSG operations, isosurface extraction, or ray casting.
337    ///
338    /// # Panics
339    ///
340    /// This function does not panic, but negative radii will produce inverted spheres
341    /// where the interior and exterior are swapped.
342    ///
343    /// # Examples
344    ///
345    /// ```rust
346    /// use crater::csg::prelude::*;
347    /// use crater::primitives::prelude::*;
348    /// use burn::prelude::*;
349    /// use burn::backend::ndarray::NdArrayDevice;
350    ///
351    /// // Create a unit sphere
352    /// let sphere = Field3D::<burn::backend::ndarray::NdArray>::sphere(1.0, NdArrayDevice::Cpu);
353    ///
354    /// // Convert to isosurface for CSG operations
355    /// let isosurface = sphere.into_isosurface(0.0);
356    /// let region = Region::HalfSpace(isosurface, Side::Negative);
357    ///
358    /// // Test point classification
359    /// let algebra = Algebra::default();
360    /// let center = Tensor::from_data([[0.0, 0.0, 0.0]], &NdArrayDevice::Cpu);
361    /// let surface = Tensor::from_data([[1.0, 0.0, 0.0]], &NdArrayDevice::Cpu);
362    /// let exterior = Tensor::from_data([[2.0, 0.0, 0.0]], &NdArrayDevice::Cpu);
363    ///
364    /// let center_val = region.evaluate(center, &algebra);    // ≈ -1.0 (inside)
365    /// let surface_val = region.evaluate(surface, &algebra);  // ≈  0.0 (on surface)  
366    /// let exterior_val = region.evaluate(exterior, &algebra); // ≈  1.0 (outside)
367    /// ```
368    ///
369    /// ```rust
370    /// use crater::csg::prelude::*;
371    /// use crater::primitives::prelude::*;
372    /// use burn::prelude::*;
373    /// use burn::backend::ndarray::NdArrayDevice;
374    ///
375    /// // Create a large sphere for mesh generation
376    /// let large_sphere = Field3D::<burn::backend::ndarray::NdArray>::sphere(5.0, NdArrayDevice::Cpu).into_isosurface(0.0);
377    /// let region = Region::HalfSpace(large_sphere, Side::Negative);
378    ///
379    /// // Generate high-resolution mesh
380    /// let params = MarchingCubesParams {
381    ///     region,
382    ///     bounds: BoundingBox::new([-6.0, -6.0, -6.0], [6.0, 6.0, 6.0]),
383    ///     resolution: (64, 64, 64),
384    ///     algebra: Algebra::default(),
385    /// };
386    ///
387    /// let mesh = marching_cubes(&params, &NdArrayDevice::Cpu);
388    /// println!("Generated sphere mesh with {} triangles", mesh.triangles.len());
389    /// ```
390    pub fn sphere(r: f32, device: B::Device) -> Self {
391        Self::Sphere {
392            r: to_tensor([r], &device.clone()),
393            device,
394        }
395    }
396
397    /// Allocate a cylinder [`ScalarField`] in the given device's memory.
398    pub fn cylinder(r: f32, device: B::Device) -> Self {
399        Self::Cylinder {
400            r: to_tensor([r], &device.clone()),
401            device,
402        }
403    }
404
405    /// Allocate a cone [`ScalarField`] in the given device's memory.
406    pub fn cone(direction: NVector<3>, theta: f32, device: B::Device) -> Self {
407        assert!(theta >= 0.0, "Theta must be non-negative");
408        assert!(length(&direction) > 0.0, "direction must be non-zero");
409        Self::Cone {
410            axis: to_tensor(normalize(&direction), &device.clone()),
411            theta: to_tensor([theta], &device.clone()),
412            device: device.clone(),
413        }
414    }
415
416    /// Allocate a torus [`ScalarField`] in the given device's memory.
417    pub fn torus(r1: f32, r2: f32, device: B::Device) -> Self {
418        Self::Torus {
419            // Radii must be a column vectors
420            r1: to_tensor([r1], &device.clone()),
421            r2: to_tensor([r2], &device.clone()),
422            device: device.clone(),
423        }
424    }
425
426    /// Allocate a plane [`ScalarField`] in the given device's memory.      
427    pub fn plane(normal: NVector<3>, device: B::Device) -> Self {
428        Self::Plane {
429            // Normal vector must be a column vector
430            normal: to_tensor(normal, &device.clone()),
431            device: device.clone(),
432        }
433    }
434
435    /// Allocate a hypercylinder [`ScalarField`] in the given device's memory.
436    /// This is equivalent to a hypercylinder in 3D space.
437    pub fn hypercylinder(r: f32, device: B::Device) -> Self {
438        Self::Cylinder {
439            r: to_tensor([r], &device.clone()),
440            device: device.clone(),
441        }
442    }
443
444    /// Allocate a hypertorus [`ScalarField`] in the given device's memory.
445    /// This is equivalent to a hypertorus in 3D space.
446    pub fn hypertorus(r1: f32, r2: f32, device: B::Device) -> Self {
447        Self::Torus {
448            r1: to_tensor([r1], &device.clone()),
449            r2: to_tensor([r2], &device.clone()),
450            device: device.clone(),
451        }
452    }
453}
454
455impl<B: Backend> ScalarField<3, B> for Field3D<B> {
456    fn evaluate(&self, points: Origins<B, 3>) -> Scalars<B> {
457        assert!(
458            points.shape().dims::<2>()[1] == 3,
459            "Points must be of shape (num_points, 3)"
460        );
461
462        match self {
463            Field3D::Sphere { r, .. } => FieldND::<3, B>::Hypersphere {
464                r: r.clone(),
465                device: self.device().clone(),
466            }
467            .evaluate(points),
468            Field3D::Cone { axis, theta, .. } => FieldND::<3, B>::Hypercone {
469                axis: axis.clone(),
470                theta: theta.clone(),
471                device: self.device().clone(),
472            }
473            .evaluate(points),
474            Field3D::Cylinder { r, .. } => FieldND::<3, B>::Hypercylinder {
475                r: r.clone(),
476                device: self.device().clone(),
477            }
478            .evaluate(points),
479            Field3D::Torus { r1, r2, .. } => FieldND::<3, B>::Hypertorus {
480                r1: r1.clone(),
481                r2: r2.clone(),
482                device: self.device().clone(),
483            }
484            .evaluate(points),
485            Field3D::Plane { normal, .. } => FieldND::<3, B>::Hyperplane {
486                normal: normal.clone(),
487                device: self.device().clone(),
488            }
489            .evaluate(points),
490        }
491    }
492
493    fn device(&self) -> &B::Device {
494        match self {
495            Field3D::Sphere { device, .. } => device,
496            Field3D::Cone { device, .. } => device,
497            Field3D::Cylinder { device, .. } => device,
498            Field3D::Torus { device, .. } => device,
499            Field3D::Plane { device, .. } => device,
500        }
501    }
502}
503
504impl<B: Backend> IntoIsosurface<3, B> for Field3D<B> {
505    fn into_isosurface(self, constant: f32) -> Isosurface<3, B> {
506        Isosurface {
507            field: Box::new(self.clone()),
508            ray_field: Some(Box::new(self.clone())),
509            constant: Tensor::from_data([constant], ScalarField::device(&self)),
510        }
511    }
512}
513
514/// Represents a surface in N-dimensional space. The surface is defined by a mathematical
515/// expression that takes N variables as input.
516#[derive(Clone)]
517pub enum FieldND<const N: usize, B: Backend> {
518    /// A hyperplane defined by its normal vector and distance from the origin.
519    Hyperplane {
520        normal: Tensor<B, 2, Float>,
521        device: B::Device,
522    },
523
524    /// A hypersphere defined by its radius.
525    Hypersphere {
526        r: Tensor<B, 2, Float>,
527        device: B::Device,
528    },
529    /// A hypercone defined by its axis and opening angle.
530    Hypercone {
531        axis: Tensor<B, 2, Float>,
532        theta: Tensor<B, 2, Float>,
533        device: B::Device,
534    },
535    /// A hypercylinder defined by its radius, extending along the last dimension.
536    Hypercylinder {
537        r: Tensor<B, 2, Float>,
538        device: B::Device,
539    },
540
541    /// A hypertorus defined by its two radii.
542    Hypertorus {
543        r1: Tensor<B, 2, Float>,
544        r2: Tensor<B, 2, Float>,
545        device: B::Device,
546    },
547}
548
549impl<const N: usize, B: Backend> FieldND<N, B> {
550    /// Allocate a hyperplane [`ScalarField`] in the given device's memory.
551    pub fn hyperplane(normal: NVector<N>, device: B::Device) -> Self {
552        Self::Hyperplane {
553            normal: to_tensor(normalize(&normal), &device.clone()),
554            device: device.clone(),
555        }
556    }
557
558    /// Allocate a hypersphere [`ScalarField`] in the given device's memory.
559    pub fn hypersphere(r: f32, device: B::Device) -> Self {
560        Self::Hypersphere {
561            r: to_tensor([r], &device.clone()),
562            device: device.clone(),
563        }
564    }
565
566    /// Allocate a hypercylinder [`ScalarField`] in the given device's memory.
567    /// The cylinder extends infinitely along the last dimension.
568    pub fn hypercylinder(r: f32, device: B::Device) -> Self {
569        Self::Hypercylinder {
570            r: to_tensor([r], &device.clone()),
571            device: device.clone(),
572        }
573    }
574
575    /// Allocate a hypertorus [`ScalarField`] in the given device's memory.
576    /// r1 is the major radius, r2 is the minor radius.
577    /// The torus lies in the first N-1 dimensions.
578    pub fn hypertorus(r1: f32, r2: f32, device: B::Device) -> Self {
579        Self::Hypertorus {
580            r1: to_tensor([r1], &device.clone()),
581            r2: to_tensor([r2], &device.clone()),
582            device: device.clone(),
583        }
584    }
585}
586
587impl<const N: usize, B: Backend> ScalarField<N, B> for FieldND<N, B> {
588    fn evaluate(&self, points: Origins<B, N>) -> Scalars<B> {
589        assert!(
590            points.shape().dims::<2>()[1] == N,
591            "Points must be of shape (num_points, {N})"
592        );
593
594        match self {
595            FieldND::Hyperplane { normal, .. } => points.matmul(normal.clone()).squeeze(1),
596            FieldND::Hypersphere { r, .. } => (points.clone().powf_scalar(2.0).sum_dim(1)
597                - r.clone().powf_scalar(2.0).unsqueeze())
598            .squeeze(1),
599            FieldND::Hypercone { axis, theta, .. } => {
600                let axis_dot_x = points.clone().matmul(axis.clone()).squeeze(1);
601                let cos_theta_sq = theta.clone().cos().powf_scalar(2.0).squeeze(1);
602                let x_squared = points.clone().mul(points.clone()).sum_dim(1).squeeze(1);
603
604                // Double cone: cos²(θ) * |x|² - (axis·x)² = 0
605                // This defines a cone that extends in both directions from the origin
606                cos_theta_sq.clone() * x_squared - axis_dot_x.clone().powf_scalar(2.0)
607            }
608            FieldND::Hypercylinder { r, .. } => {
609                // Hypercylinder: sum of squares of first N-1 dimensions minus r²
610                // This creates a cylinder that extends infinitely along the last dimension
611                if N <= 1 {
612                    panic!("Hypercylinder requires at least 2 dimensions");
613                }
614                let first_n_minus_1 = points.clone().slice([None, Some((0i64, (N - 1) as i64))]);
615                (first_n_minus_1.powf_scalar(2.0).sum_dim(1)
616                    - r.clone().powf_scalar(2.0).unsqueeze())
617                .squeeze(1)
618            }
619            FieldND::Hypertorus { r1, r2, .. } => {
620                // Hypertorus: generalization of the 3D torus to N dimensions
621                // The torus lies in the first N-1 dimensions
622                if N <= 2 {
623                    panic!("HyperTorus requires at least 3 dimensions");
624                }
625
626                // Extract the last coordinate (N-1 dimension)
627                let device = self.device();
628                let last_coord = points
629                    .clone()
630                    .select(1, Tensor::<B, 1, Int>::from_data([N - 1], device));
631
632                // Distance in the first N-1 dimensions
633                let first_n_minus_1 = points.clone().slice([None, Some((0i64, (N - 1) as i64))]);
634                let radial_distance_sq = first_n_minus_1.powf_scalar(2.0).sum_dim(1);
635                let radial_distance = radial_distance_sq.clone().sqrt();
636
637                // Hypertorus equation: (sqrt(x₁² + x₂² + ... + x_{N-1}²) - r1)² + x_N² - r2²
638                let major_term = (radial_distance - r1.clone().unsqueeze()).powf_scalar(2.0);
639                let minor_term = last_coord.powf_scalar(2.0);
640                let r2_sq = r2.clone().powf_scalar(2.0).unsqueeze();
641
642                (major_term + minor_term - r2_sq).squeeze(1)
643            }
644        }
645    }
646
647    fn device(&self) -> &B::Device {
648        match self {
649            FieldND::Hyperplane { device, .. } => device,
650            FieldND::Hypersphere { device, .. } => device,
651            FieldND::Hypercone { device, .. } => device,
652            FieldND::Hypercylinder { device, .. } => device,
653            FieldND::Hypertorus { device, .. } => device,
654        }
655    }
656}
657
658impl<const N: usize, B: Backend> IntoIsosurface<N, B> for FieldND<N, B> {
659    fn into_isosurface(self, constant: f32) -> Isosurface<N, B> {
660        Isosurface {
661            field: Box::new(self.clone()),
662            ray_field: Some(Box::new(self.clone())),
663            constant: Tensor::from_data([constant], ScalarField::device(&self)),
664        }
665    }
666}
667
668#[cfg(test)]
669mod tests {
670
671    use crate::csg::prelude::*;
672    use crate::primitives::nvector::to_tensor;
673    use crate::test_utils::{assert_tensor_almost_eq, assert_tensor_eq};
674    use backend_macro::with_backend;
675    use burn::prelude::*;
676    use rstest::rstest;
677
678    #[with_backend]
679    #[rstest]
680    // N-D
681    #[case(
682        "hyperplane",
683        [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
684        [0.0, 1.0, 0.0]
685    )]
686    #[case(
687        "hypersphere",  
688        [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
689        [-1.0, 0.0, 0.0]
690    )]
691    #[case(
692        "hypercone",
693        [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.5, 0.0, 0.866025]],
694        [0.0, -0.25, 0.0]
695    )]
696    #[case(
697        "hypercylinder",
698        [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 5.0]],
699        [-1.0, 0.0, 0.0]
700    )]
701    fn test_scalar_field_nd<const N: usize>(
702        #[case] field: &str,
703        #[case] points: [[f32; N]; 3],
704        #[case] expected: [f32; 3],
705    ) {
706        let field = match field {
707            "hyperplane" => FieldND::<3, Backend>::hyperplane([1.0, 0.0, 0.0], device()),
708            "hypersphere" => FieldND::<3, Backend>::hypersphere(1.0, device()),
709            "hypercone" => {
710                let axis = to_tensor([0.0, 0.0, 1.0], &device());
711                let theta = to_tensor([std::f32::consts::FRAC_PI_6], &device());
712                FieldND::<3, Backend>::Hypercone {
713                    axis,
714                    theta,
715                    device: device(),
716                }
717            }
718            "hypercylinder" => FieldND::<3, Backend>::hypercylinder(1.0, device()),
719            _ => panic!("Invalid field"),
720        };
721        let points = Tensor::from_data(points, ScalarField::device(&field));
722        let expected = Tensor::from_data(expected, ScalarField::device(&field));
723        let values = field.evaluate(points);
724        assert_tensor_almost_eq(values, expected, Some(EPSILON));
725    }
726
727    #[with_backend]
728    #[rstest]
729    #[case(
730        "sphere",
731        [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
732        [-1.0, 0.0, 0.0]
733    )]
734    #[case(
735        "cone",
736        [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.5, 0.0, 0.866025]],
737        [0.0, -0.25, 0.0]
738    )]
739    fn test_scalar_field_3d(
740        #[case] field: &str,
741        #[case] points: [[f32; 3]; 3],
742        #[case] expected: [f32; 3],
743    ) {
744        let field = match field {
745            "sphere" => Field3D::<Backend>::sphere(1.0, device()),
746            "cone" => {
747                Field3D::<Backend>::cone([0.0, 0.0, 1.0], std::f32::consts::FRAC_PI_6, device())
748            }
749            _ => panic!("Invalid field"),
750        };
751        let points = Tensor::from_data(points, ScalarField::device(&field));
752        let expected = Tensor::from_data(expected, ScalarField::device(&field));
753        let values = field.evaluate(points);
754        println!("values: {}", values);
755        println!("expected: {}", expected);
756        assert_tensor_almost_eq(values, expected, Some(EPSILON));
757    }
758
759    #[with_backend]
760    #[rstest]
761    #[case(
762        "line",
763        [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
764        [0.0, 1.0, 0.0]
765    )]
766    #[case(
767        "circle",
768        [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
769        [-1.0, 0.0, 0.0]
770    )]
771    fn test_scalar_field_2d(
772        #[case] field: &str,
773        #[case] points: [[f32; 2]; 3],
774        #[case] expected: [f32; 3],
775    ) {
776        let field = match field {
777            "line" => Field2D::<Backend>::line([1.0, 0.0], device()),
778            "circle" => Field2D::<Backend>::circle(1.0, device()),
779            _ => panic!("Invalid field"),
780        };
781        let points = Tensor::from_data(points, &device());
782        let expected = Tensor::from_data(expected, &device());
783        let values = field.evaluate(points);
784        assert_tensor_eq(values, expected);
785    }
786
787    #[with_backend]
788    #[test]
789    fn test_hypertorus_4d() {
790        // Test 4D hypertorus with major radius 2.0, minor radius 0.5
791        let hypertorus = FieldND::<4, Backend>::hypertorus(2.0, 0.5, device());
792
793        let points = Tensor::from_data(
794            [
795                [2.0, 0.0, 0.0, 0.0], // On major circle in first 3 dims
796                [2.0, 0.0, 0.0, 0.5], // On torus surface
797                [0.0, 0.0, 0.0, 0.0], // At origin
798                [1.0, 0.0, 0.0, 0.0], // Inside major radius
799            ],
800            &device(),
801        );
802
803        let values = hypertorus.evaluate(points);
804
805        // Expected values based on hypertorus equation
806        let expected = Tensor::from_data(
807            [
808                -0.25, // (2-2)² + 0² - 0.25 = -0.25
809                0.0,   // (2-2)² + 0.25 - 0.25 = 0.0 (on surface)
810                3.75,  // (0-2)² + 0² - 0.25 = 4 - 0.25 = 3.75
811                0.75,  // (1-2)² + 0² - 0.25 = 1 - 0.25 = 0.75
812            ],
813            &device(),
814        );
815
816        assert_tensor_almost_eq(values, expected, Some(1e-5));
817    }
818}