1use burn::prelude::*;
8use dyn_clone::DynClone;
9use std::fmt::Debug;
10
11use crate::{
12 analysis::prelude::*,
13 primitives::nvector::{NVector, length, normalize, to_tensor},
14};
15
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum Classification {
24 Inside(f32),
25 Outside(f32),
26 On,
27}
28
29impl Classification {
30 pub fn is_inside(&self) -> bool {
31 matches!(self, Classification::Inside(_))
32 }
33
34 pub fn is_outside(&self) -> bool {
35 matches!(self, Classification::Outside(_))
36 }
37
38 pub fn is_on(&self) -> bool {
39 matches!(self, Classification::On)
40 }
41}
42
43impl From<f32> for Classification {
44 fn from(value: f32) -> Self {
45 if value < -EPSILON {
46 Classification::Inside(value)
47 } else if value > EPSILON {
48 Classification::Outside(value)
49 } else {
50 Classification::On
51 }
52 }
53}
54
55pub trait Classify<B: Backend, const D: usize> {
56 fn is_inside_mask(&self) -> Tensor<B, D, Bool>;
58
59 fn is_outside_mask(&self) -> Tensor<B, D, Bool>;
61
62 fn is_on_mask(&self) -> Tensor<B, D, Bool>;
64
65 fn classification_of_index(&self, index: usize) -> Classification;
67}
68
69impl<B: Backend, const D: usize> Classify<B, D> for Tensor<B, D, Float> {
70 fn is_inside_mask(&self) -> Tensor<B, D, Bool> {
71 self.clone().lower_elem(-EPSILON)
72 }
73
74 fn is_outside_mask(&self) -> Tensor<B, D, Bool> {
75 self.clone().greater_elem(EPSILON)
76 }
77
78 fn is_on_mask(&self) -> Tensor<B, D, Bool> {
79 self.clone().abs().lower_elem(EPSILON)
80 }
81
82 fn classification_of_index(&self, index: usize) -> Classification {
83 match self.clone().into_data().as_slice().unwrap()[index] {
84 x if x < -EPSILON => Classification::Inside(x),
85 x if x > EPSILON => Classification::Outside(x),
86 _ => Classification::On,
87 }
88 }
89}
90
91#[allow(type_alias_bounds)]
93pub type Origins<B: Backend, const N: usize> = Tensor<B, 2, Float>;
94#[allow(type_alias_bounds)]
96pub type Directions<B: Backend, const N: usize> = Tensor<B, 2, Float>;
97#[allow(type_alias_bounds)]
99pub type Scalars<B: Backend> = Tensor<B, 1, Float>;
100
101pub trait ScalarField<const N: usize, B: Backend>: DynClone + Send + Sync {
103 fn evaluate(&self, origins: Origins<B, N>) -> Scalars<B>;
105
106 fn device(&self) -> &B::Device;
108}
109
110dyn_clone::clone_trait_object!(<const N: usize, B: Backend> ScalarField<N, B>);
111
112pub const EPSILON: f32 = 1e-5;
115
116#[derive(Clone)]
119pub struct Isosurface<const N: usize, B: Backend> {
120 pub field: Box<dyn ScalarField<N, B>>,
122 pub ray_field: Option<Box<dyn RayField<N, B>>>,
124 pub constant: Tensor<B, 1, Float>,
126}
127
128impl<const N: usize, B: Backend> PartialEq for Isosurface<N, B> {
129 fn eq(&self, _other: &Self) -> bool {
130 false
132 }
133}
134
135impl<const N: usize, B: Backend> Debug for Isosurface<N, B> {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 write!(f, "Isosurface<{N}>")
138 }
139}
140
141impl<const N: usize, B: Backend> ScalarField<N, B> for Isosurface<N, B> {
142 fn evaluate(&self, points: Origins<B, N>) -> Scalars<B> {
143 self.field.evaluate(points) - self.constant.clone()
144 }
145
146 fn device(&self) -> &B::Device {
147 self.field.device()
148 }
149}
150
151impl<const N: usize, B: Backend> Isosurface<N, B> {
152 pub fn classify_point(&self, point: &NVector<N>) -> Classification {
154 let points = Tensor::from_data([*point], self.device());
155 self.evaluate(points).classification_of_index(0)
156 }
157}
158
159pub trait IntoIsosurface<const N: usize, B: Backend> {
161 fn into_isosurface(self, constant: f32) -> Isosurface<N, B>;
163}
164
165#[derive(Debug, Clone)]
167pub enum Field2D<B: Backend> {
168 Circle {
170 r: Tensor<B, 2, Float>,
171 device: B::Device,
172 },
173 Ellipse {
175 a: Tensor<B, 2, Float>,
176 b: Tensor<B, 2, Float>,
177 device: B::Device,
178 },
179 Line {
181 normal: Tensor<B, 2, Float>,
182 device: B::Device,
183 },
184 Cylinder {
186 r: Tensor<B, 2, Float>,
187 device: B::Device,
188 },
189}
190
191impl<B: Backend> Field2D<B> {
192 pub fn circle(r: f32, device: B::Device) -> Self {
194 Self::Circle {
195 r: to_tensor([r], &device.clone()),
196 device: device.clone(),
197 }
198 }
199
200 pub fn ellipse(a: f32, b: f32, device: B::Device) -> Self {
202 Self::Ellipse {
203 a: to_tensor([a], &device.clone()),
204 b: to_tensor([b], &device.clone()),
205 device: device.clone(),
206 }
207 }
208
209 pub fn line(normal: NVector<2>, device: B::Device) -> Self {
211 Self::Line {
212 normal: to_tensor(normal, &device.clone()),
213 device: device.clone(),
214 }
215 }
216
217 pub fn cylinder(r: f32, device: B::Device) -> Self {
220 Self::Cylinder {
221 r: to_tensor([r], &device.clone()),
222 device: device.clone(),
223 }
224 }
225}
226
227impl<B: Backend> ScalarField<2, B> for Field2D<B> {
228 fn evaluate(&self, points: Origins<B, 2>) -> Scalars<B> {
229 assert!(
230 points.shape().dims::<2>()[1] == 2,
231 "Points must be of shape (num_points, 2)"
232 );
233
234 match self {
235 Field2D::Circle { r, .. } => FieldND::<2, B>::Hypersphere {
236 r: r.clone(),
237 device: self.device().clone(),
238 }
239 .evaluate(points),
240 Field2D::Ellipse { a, b, .. } => (points.clone().slice([0..1, 0..1])
241 / a.clone().unsqueeze().powi_scalar(2)
242 + points.clone().slice([0..1, 1..2]) / b.clone().unsqueeze().powi_scalar(2)
243 - 1.0)
244 .squeeze(1),
245
246 Field2D::Line { normal, .. } => FieldND::<2, B>::Hyperplane {
247 normal: normal.clone(),
248 device: self.device().clone(),
249 }
250 .evaluate(points),
251 Field2D::Cylinder { r, .. } => FieldND::<2, B>::Hypercylinder {
252 r: r.clone(),
253 device: self.device().clone(),
254 }
255 .evaluate(points),
256 }
257 }
258
259 fn device(&self) -> &B::Device {
260 match self {
261 Field2D::Circle { device, .. } => device,
262 Field2D::Ellipse { device, .. } => device,
263 Field2D::Line { device, .. } => device,
264 Field2D::Cylinder { device, .. } => device,
265 }
266 }
267}
268
269impl<B: Backend> IntoIsosurface<2, B> for Field2D<B> {
270 fn into_isosurface(self, constant: f32) -> Isosurface<2, B> {
271 Isosurface {
272 field: Box::new(self.clone()),
273 ray_field: Some(Box::new(self.clone())),
274 constant: Tensor::from_data([constant], ScalarField::device(&self)),
275 }
276 }
277}
278
279#[derive(Debug, Clone)]
280pub enum Field3D<B: Backend> {
281 Cone {
283 axis: Tensor<B, 2, Float>,
284 theta: Tensor<B, 2, Float>,
285 device: B::Device,
286 },
287 Cylinder {
289 r: Tensor<B, 2, Float>,
290 device: B::Device,
291 },
292 Sphere {
294 r: Tensor<B, 2, Float>,
295 device: B::Device,
296 },
297 Torus {
299 r1: Tensor<B, 2, Float>,
300 r2: Tensor<B, 2, Float>,
301 device: B::Device,
302 },
303 Plane {
305 normal: Tensor<B, 2, Float>,
306 device: B::Device,
307 },
308}
309
310impl<B: Backend> Field3D<B> {
311 pub fn sphere(r: f32, device: B::Device) -> Self {
391 Self::Sphere {
392 r: to_tensor([r], &device.clone()),
393 device,
394 }
395 }
396
397 pub fn cylinder(r: f32, device: B::Device) -> Self {
399 Self::Cylinder {
400 r: to_tensor([r], &device.clone()),
401 device,
402 }
403 }
404
405 pub fn cone(direction: NVector<3>, theta: f32, device: B::Device) -> Self {
407 assert!(theta >= 0.0, "Theta must be non-negative");
408 assert!(length(&direction) > 0.0, "direction must be non-zero");
409 Self::Cone {
410 axis: to_tensor(normalize(&direction), &device.clone()),
411 theta: to_tensor([theta], &device.clone()),
412 device: device.clone(),
413 }
414 }
415
416 pub fn torus(r1: f32, r2: f32, device: B::Device) -> Self {
418 Self::Torus {
419 r1: to_tensor([r1], &device.clone()),
421 r2: to_tensor([r2], &device.clone()),
422 device: device.clone(),
423 }
424 }
425
426 pub fn plane(normal: NVector<3>, device: B::Device) -> Self {
428 Self::Plane {
429 normal: to_tensor(normal, &device.clone()),
431 device: device.clone(),
432 }
433 }
434
435 pub fn hypercylinder(r: f32, device: B::Device) -> Self {
438 Self::Cylinder {
439 r: to_tensor([r], &device.clone()),
440 device: device.clone(),
441 }
442 }
443
444 pub fn hypertorus(r1: f32, r2: f32, device: B::Device) -> Self {
447 Self::Torus {
448 r1: to_tensor([r1], &device.clone()),
449 r2: to_tensor([r2], &device.clone()),
450 device: device.clone(),
451 }
452 }
453}
454
455impl<B: Backend> ScalarField<3, B> for Field3D<B> {
456 fn evaluate(&self, points: Origins<B, 3>) -> Scalars<B> {
457 assert!(
458 points.shape().dims::<2>()[1] == 3,
459 "Points must be of shape (num_points, 3)"
460 );
461
462 match self {
463 Field3D::Sphere { r, .. } => FieldND::<3, B>::Hypersphere {
464 r: r.clone(),
465 device: self.device().clone(),
466 }
467 .evaluate(points),
468 Field3D::Cone { axis, theta, .. } => FieldND::<3, B>::Hypercone {
469 axis: axis.clone(),
470 theta: theta.clone(),
471 device: self.device().clone(),
472 }
473 .evaluate(points),
474 Field3D::Cylinder { r, .. } => FieldND::<3, B>::Hypercylinder {
475 r: r.clone(),
476 device: self.device().clone(),
477 }
478 .evaluate(points),
479 Field3D::Torus { r1, r2, .. } => FieldND::<3, B>::Hypertorus {
480 r1: r1.clone(),
481 r2: r2.clone(),
482 device: self.device().clone(),
483 }
484 .evaluate(points),
485 Field3D::Plane { normal, .. } => FieldND::<3, B>::Hyperplane {
486 normal: normal.clone(),
487 device: self.device().clone(),
488 }
489 .evaluate(points),
490 }
491 }
492
493 fn device(&self) -> &B::Device {
494 match self {
495 Field3D::Sphere { device, .. } => device,
496 Field3D::Cone { device, .. } => device,
497 Field3D::Cylinder { device, .. } => device,
498 Field3D::Torus { device, .. } => device,
499 Field3D::Plane { device, .. } => device,
500 }
501 }
502}
503
504impl<B: Backend> IntoIsosurface<3, B> for Field3D<B> {
505 fn into_isosurface(self, constant: f32) -> Isosurface<3, B> {
506 Isosurface {
507 field: Box::new(self.clone()),
508 ray_field: Some(Box::new(self.clone())),
509 constant: Tensor::from_data([constant], ScalarField::device(&self)),
510 }
511 }
512}
513
514#[derive(Clone)]
517pub enum FieldND<const N: usize, B: Backend> {
518 Hyperplane {
520 normal: Tensor<B, 2, Float>,
521 device: B::Device,
522 },
523
524 Hypersphere {
526 r: Tensor<B, 2, Float>,
527 device: B::Device,
528 },
529 Hypercone {
531 axis: Tensor<B, 2, Float>,
532 theta: Tensor<B, 2, Float>,
533 device: B::Device,
534 },
535 Hypercylinder {
537 r: Tensor<B, 2, Float>,
538 device: B::Device,
539 },
540
541 Hypertorus {
543 r1: Tensor<B, 2, Float>,
544 r2: Tensor<B, 2, Float>,
545 device: B::Device,
546 },
547}
548
549impl<const N: usize, B: Backend> FieldND<N, B> {
550 pub fn hyperplane(normal: NVector<N>, device: B::Device) -> Self {
552 Self::Hyperplane {
553 normal: to_tensor(normalize(&normal), &device.clone()),
554 device: device.clone(),
555 }
556 }
557
558 pub fn hypersphere(r: f32, device: B::Device) -> Self {
560 Self::Hypersphere {
561 r: to_tensor([r], &device.clone()),
562 device: device.clone(),
563 }
564 }
565
566 pub fn hypercylinder(r: f32, device: B::Device) -> Self {
569 Self::Hypercylinder {
570 r: to_tensor([r], &device.clone()),
571 device: device.clone(),
572 }
573 }
574
575 pub fn hypertorus(r1: f32, r2: f32, device: B::Device) -> Self {
579 Self::Hypertorus {
580 r1: to_tensor([r1], &device.clone()),
581 r2: to_tensor([r2], &device.clone()),
582 device: device.clone(),
583 }
584 }
585}
586
587impl<const N: usize, B: Backend> ScalarField<N, B> for FieldND<N, B> {
588 fn evaluate(&self, points: Origins<B, N>) -> Scalars<B> {
589 assert!(
590 points.shape().dims::<2>()[1] == N,
591 "Points must be of shape (num_points, {N})"
592 );
593
594 match self {
595 FieldND::Hyperplane { normal, .. } => points.matmul(normal.clone()).squeeze(1),
596 FieldND::Hypersphere { r, .. } => (points.clone().powf_scalar(2.0).sum_dim(1)
597 - r.clone().powf_scalar(2.0).unsqueeze())
598 .squeeze(1),
599 FieldND::Hypercone { axis, theta, .. } => {
600 let axis_dot_x = points.clone().matmul(axis.clone()).squeeze(1);
601 let cos_theta_sq = theta.clone().cos().powf_scalar(2.0).squeeze(1);
602 let x_squared = points.clone().mul(points.clone()).sum_dim(1).squeeze(1);
603
604 cos_theta_sq.clone() * x_squared - axis_dot_x.clone().powf_scalar(2.0)
607 }
608 FieldND::Hypercylinder { r, .. } => {
609 if N <= 1 {
612 panic!("Hypercylinder requires at least 2 dimensions");
613 }
614 let first_n_minus_1 = points.clone().slice([None, Some((0i64, (N - 1) as i64))]);
615 (first_n_minus_1.powf_scalar(2.0).sum_dim(1)
616 - r.clone().powf_scalar(2.0).unsqueeze())
617 .squeeze(1)
618 }
619 FieldND::Hypertorus { r1, r2, .. } => {
620 if N <= 2 {
623 panic!("HyperTorus requires at least 3 dimensions");
624 }
625
626 let device = self.device();
628 let last_coord = points
629 .clone()
630 .select(1, Tensor::<B, 1, Int>::from_data([N - 1], device));
631
632 let first_n_minus_1 = points.clone().slice([None, Some((0i64, (N - 1) as i64))]);
634 let radial_distance_sq = first_n_minus_1.powf_scalar(2.0).sum_dim(1);
635 let radial_distance = radial_distance_sq.clone().sqrt();
636
637 let major_term = (radial_distance - r1.clone().unsqueeze()).powf_scalar(2.0);
639 let minor_term = last_coord.powf_scalar(2.0);
640 let r2_sq = r2.clone().powf_scalar(2.0).unsqueeze();
641
642 (major_term + minor_term - r2_sq).squeeze(1)
643 }
644 }
645 }
646
647 fn device(&self) -> &B::Device {
648 match self {
649 FieldND::Hyperplane { device, .. } => device,
650 FieldND::Hypersphere { device, .. } => device,
651 FieldND::Hypercone { device, .. } => device,
652 FieldND::Hypercylinder { device, .. } => device,
653 FieldND::Hypertorus { device, .. } => device,
654 }
655 }
656}
657
658impl<const N: usize, B: Backend> IntoIsosurface<N, B> for FieldND<N, B> {
659 fn into_isosurface(self, constant: f32) -> Isosurface<N, B> {
660 Isosurface {
661 field: Box::new(self.clone()),
662 ray_field: Some(Box::new(self.clone())),
663 constant: Tensor::from_data([constant], ScalarField::device(&self)),
664 }
665 }
666}
667
668#[cfg(test)]
669mod tests {
670
671 use crate::csg::prelude::*;
672 use crate::primitives::nvector::to_tensor;
673 use crate::test_utils::{assert_tensor_almost_eq, assert_tensor_eq};
674 use backend_macro::with_backend;
675 use burn::prelude::*;
676 use rstest::rstest;
677
678 #[with_backend]
679 #[rstest]
680 #[case(
682 "hyperplane",
683 [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
684 [0.0, 1.0, 0.0]
685 )]
686 #[case(
687 "hypersphere",
688 [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
689 [-1.0, 0.0, 0.0]
690 )]
691 #[case(
692 "hypercone",
693 [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.5, 0.0, 0.866025]],
694 [0.0, -0.25, 0.0]
695 )]
696 #[case(
697 "hypercylinder",
698 [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 5.0]],
699 [-1.0, 0.0, 0.0]
700 )]
701 fn test_scalar_field_nd<const N: usize>(
702 #[case] field: &str,
703 #[case] points: [[f32; N]; 3],
704 #[case] expected: [f32; 3],
705 ) {
706 let field = match field {
707 "hyperplane" => FieldND::<3, Backend>::hyperplane([1.0, 0.0, 0.0], device()),
708 "hypersphere" => FieldND::<3, Backend>::hypersphere(1.0, device()),
709 "hypercone" => {
710 let axis = to_tensor([0.0, 0.0, 1.0], &device());
711 let theta = to_tensor([std::f32::consts::FRAC_PI_6], &device());
712 FieldND::<3, Backend>::Hypercone {
713 axis,
714 theta,
715 device: device(),
716 }
717 }
718 "hypercylinder" => FieldND::<3, Backend>::hypercylinder(1.0, device()),
719 _ => panic!("Invalid field"),
720 };
721 let points = Tensor::from_data(points, ScalarField::device(&field));
722 let expected = Tensor::from_data(expected, ScalarField::device(&field));
723 let values = field.evaluate(points);
724 assert_tensor_almost_eq(values, expected, Some(EPSILON));
725 }
726
727 #[with_backend]
728 #[rstest]
729 #[case(
730 "sphere",
731 [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
732 [-1.0, 0.0, 0.0]
733 )]
734 #[case(
735 "cone",
736 [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.5, 0.0, 0.866025]],
737 [0.0, -0.25, 0.0]
738 )]
739 fn test_scalar_field_3d(
740 #[case] field: &str,
741 #[case] points: [[f32; 3]; 3],
742 #[case] expected: [f32; 3],
743 ) {
744 let field = match field {
745 "sphere" => Field3D::<Backend>::sphere(1.0, device()),
746 "cone" => {
747 Field3D::<Backend>::cone([0.0, 0.0, 1.0], std::f32::consts::FRAC_PI_6, device())
748 }
749 _ => panic!("Invalid field"),
750 };
751 let points = Tensor::from_data(points, ScalarField::device(&field));
752 let expected = Tensor::from_data(expected, ScalarField::device(&field));
753 let values = field.evaluate(points);
754 println!("values: {}", values);
755 println!("expected: {}", expected);
756 assert_tensor_almost_eq(values, expected, Some(EPSILON));
757 }
758
759 #[with_backend]
760 #[rstest]
761 #[case(
762 "line",
763 [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
764 [0.0, 1.0, 0.0]
765 )]
766 #[case(
767 "circle",
768 [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
769 [-1.0, 0.0, 0.0]
770 )]
771 fn test_scalar_field_2d(
772 #[case] field: &str,
773 #[case] points: [[f32; 2]; 3],
774 #[case] expected: [f32; 3],
775 ) {
776 let field = match field {
777 "line" => Field2D::<Backend>::line([1.0, 0.0], device()),
778 "circle" => Field2D::<Backend>::circle(1.0, device()),
779 _ => panic!("Invalid field"),
780 };
781 let points = Tensor::from_data(points, &device());
782 let expected = Tensor::from_data(expected, &device());
783 let values = field.evaluate(points);
784 assert_tensor_eq(values, expected);
785 }
786
787 #[with_backend]
788 #[test]
789 fn test_hypertorus_4d() {
790 let hypertorus = FieldND::<4, Backend>::hypertorus(2.0, 0.5, device());
792
793 let points = Tensor::from_data(
794 [
795 [2.0, 0.0, 0.0, 0.0], [2.0, 0.0, 0.0, 0.5], [0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], ],
800 &device(),
801 );
802
803 let values = hypertorus.evaluate(points);
804
805 let expected = Tensor::from_data(
807 [
808 -0.25, 0.0, 3.75, 0.75, ],
813 &device(),
814 );
815
816 assert_tensor_almost_eq(values, expected, Some(1e-5));
817 }
818}