crater/analysis/rays/
algo.rs

1//! Ray-surface intersection algorithms.
2//!
3//! This module implements various algorithms for finding intersections between rays and
4//! implicit surfaces. For the mathematical theory and algorithm descriptions, see the
5//! [Ray Casting](../../book/theory/ray-casting.md) chapter.
6
7use super::prelude::*;
8use crate::csg::prelude::*;
9use burn::prelude::*;
10use burn::tensor::{
11    Tensor,
12    backend::{AutodiffBackend, Backend},
13};
14use log::info;
15
16pub const RAY_FAIL_VALUE: f32 = f32::MAX;
17
18/// Ray-surface intersection algorithms for implicit surface ray casting.
19///
20/// This enum provides multiple algorithmic approaches for solving ray-surface intersection
21/// problems, each optimized for different scenarios in safety-critical applications where
22/// precision and reliability are paramount.
23///
24/// # Algorithm Selection Guidelines
25///
26/// * **Use [`Analytical`](RayCastAlgorithm::Analytical)** for:
27///   - Simple primitive shapes (spheres, planes, cylinders)
28///   - Maximum accuracy and performance requirements
29///   - Real-time applications with tight computational budgets
30///
31/// * **Use [`Newton`](RayCastAlgorithm::Newton)** for:
32///   - Smooth, well-conditioned implicit surfaces
33///   - When autodifferentiation is available
34///   - Applications requiring fast convergence with few iterations
35///
36/// * **Use [`BracketAndBisect`](RayCastAlgorithm::BracketAndBisect)** for:
37///   - Complex CSG constructions with sharp edges
38///   - Robustness over speed requirements
39///   - Surfaces with potential numerical instabilities
40///
41/// For theoretical background, see the [Ray Casting](../../book/theory/ray-casting.md) chapter.
42#[derive(Debug, Clone, Copy)]
43pub enum RayCastAlgorithm<const N: usize> {
44    /// Robust numerical method using bracketing followed by bisection.
45    ///
46    /// This algorithm first advances along each ray to bracket sign changes in the
47    /// scalar field, then uses bisection to refine intersection points to high precision.
48    /// It's the most reliable method for complex CSG constructions.
49    ///
50    /// # Parameters
51    ///
52    /// * `lambda` - Maximum ray extension distance for bracketing phase
53    /// * `d_lambda` - Step size during bracketing search
54    /// * `max_bisection_iterations` - Maximum refinement iterations (typically 20-50)
55    ///
56    /// # Convergence
57    ///
58    /// Guaranteed to converge for continuous scalar fields with precision ε ≈ λ/2ⁿ
59    /// where n is the number of bisection iterations.
60    BracketAndBisect {
61        /// Maximum distance to extend rays during bracketing
62        lambda: f32,
63        /// Step size for bracketing phase
64        d_lambda: f32,
65        /// Maximum iterations for bisection refinement
66        max_bisection_iterations: usize,
67    },
68
69    /// Direct analytical solutions for primitive geometric shapes.
70    ///
71    /// Computes exact ray-surface intersections using closed-form mathematical
72    /// formulas. This is the fastest and most accurate method when applicable,
73    /// but is limited to simple primitives with known intersection equations.
74    ///
75    /// # Supported Primitives
76    ///
77    /// - Spheres: Quadratic equation solutions
78    /// - Planes: Linear ray-plane intersection
79    /// - Cylinders: Analytical cylinder intersection
80    /// - Cones: Quadratic cone intersection formulas
81    ///
82    /// # Limitations
83    ///
84    /// Cannot be used with:
85    /// - Arbitrary scalar field expressions
86    /// - CSG operations beyond simple unions
87    /// - Transformed or deformed primitives
88    Analytical,
89
90    /// Iterative Newton-Raphson root finding using automatic differentiation.
91    ///
92    /// Uses gradient information to rapidly converge to surface intersections.
93    /// Exhibits quadratic convergence near solutions but requires smooth,
94    /// differentiable scalar fields for optimal performance.
95    ///
96    /// # Parameters
97    ///
98    /// * `max_iterations` - Maximum Newton iterations (typically 5-20)
99    /// * `nudge_distance` - Perturbation for degenerate gradient cases
100    /// * `step_size` - Relaxation factor for Newton steps (usually 0.5-1.0)
101    ///
102    /// # Convergence Properties
103    ///
104    /// - **Quadratic**: Error ∝ ε² per iteration near solutions
105    /// - **Sensitive**: May fail near sharp edges or gradient discontinuities
106    /// - **Fast**: Typically converges in 3-10 iterations for smooth surfaces
107    Newton {
108        /// Maximum number of Newton iterations
109        max_iterations: usize,
110        /// Distance to perturb points with zero gradients
111        nudge_distance: f32,
112        /// Step size relaxation factor (0.0 to 1.0)
113        step_size: f32,
114    },
115}
116
117impl<const N: usize> RayCastAlgorithm<N> {
118    /// Computes ray-surface intersections using the specified algorithm.
119    ///
120    /// This method finds the points where rays intersect with the boundary of an implicit
121    /// surface region. It supports multiple algorithmic approaches, each with different
122    /// trade-offs between accuracy, performance, and applicability.
123    ///
124    /// # Mathematical Problem
125    ///
126    /// Given rays **r**(t) = **o** + t**d** and a region R with characteristic function f(x),
127    /// find parameter values t where f(**r**(t)) = 0, indicating surface intersection.
128    ///
129    /// # Algorithm Variants
130    ///
131    /// * **Analytical**: Direct closed-form solutions for primitive shapes (spheres, planes, etc.)
132    ///   - Fastest and most accurate when applicable
133    ///   - Limited to simple geometric primitives with known ray intersection formulas
134    ///
135    /// * **Newton**: Iterative root-finding using gradient-based optimization
136    ///   - Quadratic convergence near solutions
137    ///   - Requires autodifferentiable scalar fields
138    ///   - Best for smooth, well-conditioned surfaces
139    ///
140    /// * **BracketAndBisect**: Robust bracketing followed by bisection refinement
141    ///   - Guaranteed convergence for continuous functions
142    ///   - Works with any scalar field implementation
143    ///   - Slower but most reliable for complex CSG constructions
144    ///
145    /// # Parameters
146    ///
147    /// * `rays` - Collection of rays to cast against the surface
148    /// * `region` - Implicit surface region to intersect with
149    /// * `algebra` - CSG algebra for evaluating boolean operations in composite regions
150    ///
151    /// # Returns
152    ///
153    /// [`RayCastResult`] containing intersection points, distances, field values, and metadata
154    /// for each input ray. Rays that miss the surface are marked with sentinel values.
155    ///
156    /// # Performance Notes
157    ///
158    /// - **Analytical**: O(1) per ray for primitive surfaces
159    /// - **Newton**: O(k) where k is iteration count (typically 3-10)
160    /// - **BracketAndBisect**: O(log ε⁻¹) where ε is desired precision
161    ///
162    /// For complex CSG regions, cost scales with the number of constituent half-spaces.
163    ///
164    /// # Examples
165    ///
166    /// ```rust
167    /// use crater::analysis::prelude::*;
168    /// use crater::csg::prelude::*;
169    /// use burn::prelude::*;
170    /// use burn::backend::ndarray::{NdArrayDevice, NdArray};
171    /// use burn::backend::Autodiff;
172    ///
173    /// // Create a sphere region
174    /// let sphere = Field3D::<Autodiff<NdArray>>::sphere(1.0, NdArrayDevice::Cpu).into_isosurface(0.0);
175    /// let region = Region::HalfSpace(sphere, Side::Negative);
176    /// let algebra = Algebra::<Autodiff<NdArray>>::default();
177    ///
178    /// // Create rays pointing at the sphere
179    /// let origins = Tensor::from_data([[-2.0, 0.0, 0.0], [0.0, -2.0, 0.0]], &NdArrayDevice::Cpu);
180    /// let directions = Tensor::from_data([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], &NdArrayDevice::Cpu);
181    /// let rays = Rays::<Autodiff<NdArray>, 3>::new(origins, directions);
182    ///
183    /// // Use analytical algorithm for primitive shapes
184    /// let algorithm = RayCastAlgorithm::<3>::Analytical;
185    /// // let result = algorithm.solve(rays, &region, &algebra);
186    ///
187    /// // Check intersection distances (should be ≈ 1.0 for unit sphere)
188    /// // let distances = result.distances();
189    /// ```
190    ///
191    /// ```rust
192    /// use crater::analysis::prelude::*;
193    /// use crater::csg::prelude::*;
194    /// use burn::prelude::*;
195    /// use burn::backend::wgpu::WgpuDevice;
196    ///
197    /// // Complex CSG region requiring numerical methods
198    /// let sphere = Field3D::<burn::backend::wgpu::Wgpu>::sphere(1.0, WgpuDevice::default()).into_isosurface(0.0);
199    /// let cube = Field3D::<burn::backend::wgpu::Wgpu>::sphere(1.5, WgpuDevice::default()).into_isosurface(0.0);
200    ///
201    /// let sphere_region = Region::HalfSpace(sphere, Side::Negative);
202    /// let cube_region = Region::HalfSpace(cube, Side::Negative);
203    /// let intersection = Region::Intersection(Box::new(sphere_region), Box::new(cube_region));
204    ///
205    /// // Use bracket-and-bisect for robust intersection with CSG
206    /// let algorithm = RayCastAlgorithm::<3>::BracketAndBisect {
207    ///     lambda: 10.0,
208    ///     d_lambda: 0.01,
209    ///     max_bisection_iterations: 50,
210    /// };
211    ///
212    /// // let result = algorithm.solve(rays, &intersection, &algebra);
213    /// ```
214    pub fn solve<B: AutodiffBackend>(
215        &self,
216        rays: Rays<B, N>,
217        region: &Region<N, B>,
218        algebra: &impl CSGAlgebra<B>,
219    ) -> RayCastResult<B, N> {
220        match self {
221            Self::BracketAndBisect { .. } => self.bracket_and_bisect(rays, region, algebra),
222            Self::Analytical => self.analytical(rays, region, algebra),
223            Self::Newton { .. } => self.newton(rays, region, algebra),
224        }
225    }
226
227    fn analytical<B: Backend>(
228        &self,
229        rays: Rays<B, N>,
230        region: &Region<N, B>,
231        algebra: &impl CSGAlgebra<B>,
232    ) -> RayCastResult<B, N> {
233        info!("Analytical ray casting: {}", rays);
234        // First, iterate over all halfspaces, and evaluate the ray field
235        let halfspace_evaluations = region
236            .halfspaces()
237            .enumerate()
238            .map(|(i, halfspace)| {
239                info!("Evaluating halfspace {}: {:?}", i, halfspace);
240                let isosurface = match halfspace {
241                    Region::HalfSpace(surf, _) => surf,
242                    _ => panic!(),
243                };
244
245                assert!(
246                    isosurface.ray_field.is_some(),
247                    "Cannot use analytical ray casting with non-analytic isosurfaces"
248                );
249
250                // Evaluate ray field at the isosurface
251                let result = isosurface
252                    .ray_field
253                    .as_ref()
254                    .unwrap()
255                    .solve(rays.clone(), isosurface.constant.clone());
256
257                result
258            })
259            .collect::<Vec<_>>();
260
261        // Next, at least one of the halfspaces must be on the surface of the REGION as well
262        let mut intersections =
263            Tensor::<B, 2, Float>::ones([rays.ray_count(), N], &rays.device()) * RAY_FAIL_VALUE;
264        let mut distance_to_hit =
265            Tensor::<B, 2, Float>::ones([rays.ray_count(), 1], &rays.device()) * RAY_FAIL_VALUE;
266
267        halfspace_evaluations.iter().for_each(|result| {
268            // True if the ray hits the surface of the REGION at this point
269
270            let is_on_mask = region.evaluate(result.extensions(), algebra).is_on_mask();
271
272            // If no intersection with this halfspace, skip
273            if !is_on_mask.clone().any().into_scalar() {
274                return;
275            }
276
277            let is_closer = distance_to_hit
278                .clone()
279                .greater_equal(result.distances().clone().unsqueeze::<2>().transpose());
280
281            // Create a mask for rays that hit the surface AND are closer than current hit
282            let should_replace_mask = is_on_mask
283                .clone()
284                .int()
285                .mul(is_closer.clone().squeeze(1).int())
286                .bool();
287
288            // If no rays should be replaced, skip
289            if !should_replace_mask.clone().any().into_scalar() {
290                return;
291            }
292
293            // Use scatter_replace with boolean masks to replace entire rows
294            intersections = intersections
295                .scatter_replace(should_replace_mask.clone(), result.extensions().clone());
296            distance_to_hit = distance_to_hit.scatter_replace(
297                should_replace_mask.clone(),
298                result.distances().clone().unsqueeze::<2>().transpose(),
299            );
300        });
301        RayCastResult::new(
302            rays.clone(),
303            intersections.clone(),
304            region.evaluate(intersections.clone(), algebra),
305            Tensor::<B, 1, Float>::zeros([rays.ray_count()], &rays.device()),
306        )
307    }
308
309    fn newton<B: AutodiffBackend>(
310        &self,
311        rays: Rays<B, N>,
312        region: &Region<N, B>,
313        algebra: &impl CSGAlgebra<B>,
314    ) -> RayCastResult<B, N> {
315        if let Self::Newton {
316            max_iterations,
317            nudge_distance,
318            step_size,
319        } = self
320        {
321            let device = rays.device();
322            let origins = rays.origins();
323            let batch_size = origins.dims()[0];
324
325            let zeros = Tensor::<B, 1>::zeros([batch_size], &device);
326            // Start with a mask of all rays active
327            let mut m_i = Tensor::<B, 1>::ones([batch_size], &device);
328            // Start from ray origins
329            let mut points = origins.clone();
330
331            for _ in 0..*max_iterations {
332                // Compute field values at the current points
333                let field_values = region.evaluate(points.clone(), algebra);
334
335                // For points that are determined to be ON the surface, we set
336                // their mask positions to 0.
337                let is_on_mask = field_values.is_on_mask();
338                m_i = m_i.clone().mul(is_on_mask.clone().bool_not().float());
339
340                // Then, if m_i == 1, it means we should take a gradient step
341                // If m_i is all 0, then we can stop
342                if m_i.clone().any().into_scalar() {
343                    let mut directional_derivatives = region.directional_derivative(
344                        Rays::new(points.clone(), rays.directions()),
345                        algebra,
346                    ); // [batch_size, 1]
347
348                    // Degenerate indices are those where the directional derivative is 0
349                    let degenerate_mask = directional_derivatives.clone().equal_elem(0.0);
350
351                    // Set the directional derivative to 1 for degenerate indices
352                    if degenerate_mask.clone().any().into_scalar() {
353                        directional_derivatives = directional_derivatives
354                            .clone()
355                            .unsqueeze::<2>()
356                            .transpose()
357                            .scatter_replace(
358                                degenerate_mask.clone(),
359                                Tensor::ones([degenerate_mask.clone().dims()[0], 1], &device)
360                                    .mul_scalar(*nudge_distance),
361                            )
362                            .squeeze(1);
363                    }
364
365                    let step_size = -field_values
366                        .clone()
367                        .inner()
368                        .div(directional_derivatives)
369                        .mul_scalar(*step_size)
370                        .unsqueeze::<2>()
371                        .transpose();
372
373                    // Take a gradient step using directional derivative
374                    let step = rays
375                        .directions()
376                        .clone()
377                        .inner()
378                        .mul(step_size.clone())
379                        .mul(m_i.clone().inner().unsqueeze::<2>().transpose());
380
381                    points = points.add(Tensor::<B, 2>::from_inner(step));
382                } else {
383                    println!("No rays are still active, stopping");
384                    break;
385                }
386            }
387
388            // Ensure that rays are not moving backwards
389            let displacement = points.clone() - rays.origins().clone();
390            let backwards_mask = displacement
391                .clone()
392                .mul(rays.directions().clone())
393                .sum_dim(1)
394                .squeeze::<1>(1)
395                .lower_elem(0.0);
396
397            if backwards_mask.clone().any().into_scalar() {
398                points = points.scatter_replace(
399                    backwards_mask.clone(),
400                    Tensor::<B, 2>::ones([backwards_mask.dims()[0], N], &device) * RAY_FAIL_VALUE,
401                );
402            }
403
404            // For all rays that never had m_i == 1, set their points to RAY_FAIL_VALUE
405            let failed_ray_mask = m_i.clone().bool();
406            if failed_ray_mask.clone().any().into_scalar() {
407                points = points.scatter_replace(
408                    failed_ray_mask.clone(),
409                    Tensor::<B, 2>::ones([failed_ray_mask.dims()[0], N], &device) * RAY_FAIL_VALUE,
410                );
411            }
412
413            RayCastResult::new(
414                rays.clone(),
415                points.clone(),
416                region.evaluate(points, algebra),
417                zeros,
418            )
419        } else {
420            panic!("Expected GradientDescent variant")
421        }
422    }
423
424    fn bracket_and_bisect<B: Backend>(
425        &self,
426        rays: Rays<B, N>,
427        region: &Region<N, B>,
428        algebra: &impl CSGAlgebra<B>,
429    ) -> RayCastResult<B, N> {
430        if let Self::BracketAndBisect {
431            lambda,
432            d_lambda,
433            max_bisection_iterations,
434        } = self
435        {
436            let device = rays.device();
437            let origins = rays.origins();
438            let directions = rays.directions();
439            let batch_size = origins.dims()[0];
440
441            let zeros = Tensor::<B, 1>::zeros([batch_size], &device);
442            let ones = Tensor::<B, 1>::ones([batch_size], &device);
443            let lambda_0 = LAMBDA;
444
445            let mut m_i = ones.clone();
446            let mut l_i = zeros.clone();
447            let l_0 = l_i.clone();
448            let mut r_i = l_i.clone() + lambda_0;
449            let mut f_l = region.evaluate(
450                extend_rays::<B, N>(origins.clone(), directions.clone(), l_i.clone()),
451                algebra,
452            );
453
454            // Bracketing phase:
455            // Find intervals [l_i, r_i] along each ray that contain the isosurface
456            // by advancing until the field value changes sign or we exceed max_distance
457            loop {
458                // Compute the right side of the bracket
459                let f_r = region.evaluate(
460                    extend_rays::<B, N>(origins.clone(), directions.clone(), r_i.clone()),
461                    algebra,
462                );
463
464                // Update the mask to reflect those brackets that now bound a root
465                m_i = m_i.clone().mul(
466                    ones.clone()
467                        .sub(contains_root_mask(f_l.clone().mul(f_r.clone()))),
468                );
469
470                // Check if we've exceeded max_distance for any active rays
471                let exceeded_max = (r_i.clone() - l_0.clone()).greater_elem(*lambda);
472
473                if exceeded_max.clone().any().into_scalar() {
474                    // For rays that exceeded max_distance, mark them as not intersecting
475                    m_i = m_i.clone().mul(ones.clone().sub(exceeded_max.float()));
476                }
477
478                // If no rays are still active, we can stop
479                if !m_i.clone().any().into_scalar() {
480                    break;
481                }
482
483                // Update the left and right sides of the bracket
484                // Step fixed distance
485                let step_size = m_i.clone().mul_scalar(*d_lambda);
486                l_i = l_i.clone().add(step_size.clone());
487                r_i = r_i.clone().add(step_size.clone());
488                f_l = f_r;
489            }
490
491            // Bisection phase:
492            // Refine the brackets to precisely locate the intersection points
493            // using the bisection method for root finding
494            (0..*max_bisection_iterations).for_each(|_| {
495                // Compute the midpoint of the bracket
496                let c_i = l_i.clone().add(r_i.clone()).div_scalar(2.0);
497                // Evaluate the field at left, center, and right points
498                let f_c = region.evaluate(
499                    extend_rays::<B, N>(origins.clone(), directions.clone(), c_i.clone()),
500                    algebra,
501                );
502                let f_l = region.evaluate(
503                    extend_rays::<B, N>(origins.clone(), directions.clone(), l_i.clone()),
504                    algebra,
505                );
506                // Create masks for brackets:
507                // - is_bracketed_left: true if the root is in [l_i, c_i]
508                // - is_bracketed_right: true if the root is in [c_i, r_i]
509                // Note: there is a chance that the root is in both brackets (i.e., c_i = 0), but we
510                // just break ties by assuming anything that is bracketed left must also not be bracketed right
511                let is_bracketed_left = contains_root_mask(f_c.clone().mul(f_l.clone()));
512                let is_bracketed_right = ones.clone() - is_bracketed_left.clone();
513
514                // Update brackets based on which half contains the root
515                l_i = c_i.clone().mul(ones.clone().sub(is_bracketed_left.clone()))
516                    + l_i.clone().mul(is_bracketed_left.clone());
517                r_i = c_i
518                    .clone()
519                    .mul(ones.clone().sub(is_bracketed_right.clone()))
520                    + r_i.clone().mul(is_bracketed_right.clone());
521            });
522            let mut intersections =
523                extend_rays::<B, N>(origins.clone(), directions.clone(), r_i.clone());
524
525            // If the ray's origin is on the surface of the region, overwrite
526            let is_on_mask = region.evaluate(origins.clone(), algebra).is_on_mask();
527
528            if is_on_mask.clone().any().into_scalar() {
529                intersections = intersections.scatter_replace(is_on_mask.clone(), origins.clone());
530            }
531
532            // If we did not find an extension that is on the surface of the region,
533            // provide sentinel values
534            let is_not_on_mask = region
535                .evaluate(intersections.clone(), algebra)
536                .is_on_mask()
537                .bool_not();
538            if is_not_on_mask.clone().any().into_scalar() {
539                intersections = intersections.scatter_replace(
540                    is_not_on_mask.clone(),
541                    Tensor::<B, 2, Float>::ones([is_not_on_mask.clone().dims()[0], N], &device)
542                        * RAY_FAIL_VALUE,
543                );
544            }
545
546            // Return the final ray intersection points and their field values
547            let final_field_values = region.evaluate(intersections.clone(), algebra);
548            RayCastResult::new(rays, intersections, final_field_values, zeros)
549        } else {
550            panic!("Expected MarchAndBisect variant")
551        }
552    }
553}
554
555/// Min step distance in bracketing phase
556/// Use non-rounded epsilon to avoid precision issues
557const LAMBDA: f32 = f32::EPSILON * 1.0e6;
558
559/// Extend a collection of rays by a vector of scalars
560pub fn extend_rays<B: Backend, const N: usize>(
561    positions: Origins<B, N>,
562    directions: Directions<B, N>,
563    lambda_i: Tensor<B, 1>,
564) -> Origins<B, N> {
565    // Compute the extended points
566    positions.clone().add(
567        lambda_i
568            .clone()
569            .unsqueeze::<2>()
570            .transpose()
571            .mul(directions),
572    )
573}
574
575/// Creates a mask tensor where negative values become 1.0 and non-negative values become 0.0
576fn contains_root_mask<B: Backend>(tensor: Tensor<B, 1>) -> Tensor<B, 1> {
577    tensor.clone().lower_equal_elem(0.0).float()
578}
579
580impl<const N: usize, B: AutodiffBackend> Region<N, B> {
581    /// Solve the ray casting problem for a given region and algebra
582    pub fn ray_cast(
583        &self,
584        rays: Rays<B, N>,
585        algebra: &impl CSGAlgebra<B>,
586        algorithm: RayCastAlgorithm<N>,
587    ) -> RayCastResult<B, N> {
588        algorithm.solve(rays, self, algebra)
589    }
590    /// Returns the directional derivative along a collection of rays
591    ///
592    /// This is the scalar quantity:
593    ///
594    /// $$ \frac{\partial f}{\partial \mathbf{d}} = \nabla f \cdot \mathbf{d} $$
595    pub fn directional_derivative(
596        &self,
597        rays: Rays<B, N>,
598        algebra: &impl CSGAlgebra<B>,
599    ) -> Scalars<B::InnerBackend> {
600        let points = rays.origins();
601        let grads = self.gradient(points, algebra);
602        // Project onto the ray direction
603        grads
604            .mul(rays.directions().require_grad().inner())
605            .sum_dim(1)
606            .squeeze(1)
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613    use crate::primitives::prelude::*;
614    use crate::test_utils::assert_tensor_almost_eq;
615    use backend_macro::with_backend;
616    use burn::tensor::Tensor;
617
618    #[with_backend]
619    #[test]
620    fn test_directional_derivative_3d_sphere() {
621        // Create a simple sphere field
622        let sphere = Field3D::<Backend>::sphere(1.0, device()).into_isosurface(0.0);
623        let region = sphere.region();
624        let algebra = Algebra::default();
625
626        // Create rays pointing in different directions
627        let origins = Tensor::<Backend, 2>::from_data(
628            [
629                [0.0, 0.0, 0.0], // Origin at center
630                [0.5, 0.0, 0.0], // Origin inside sphere
631                [2.0, 0.0, 1.0], // Origin outside sphere
632                [1.0, 0.0, 0.0], // Origin on surface
633                [0.1, 0.1, 0.1], // Origin inside sphere
634            ],
635            &device(),
636        );
637
638        let directions = Tensor::<Backend, 2>::from_data(
639            [
640                [1.0, 0.0, 0.0],             // Pointing right
641                [1.0, 0.0, 0.0],             // Pointing right
642                [0.0, 0.0, 1.0],             // Pointing up
643                [1.0, 0.0, 0.0],             // Pointing right
644                normalize(&[1.0, 1.0, 1.0]), // Pointing along diagonal
645            ],
646            &device(),
647        );
648        let expected_derivatives =
649            Tensor::<Backend, 1>::from_data([0.0, 1.0, 2.0, 2.0, 0.2 * 3.0_f32.sqrt()], &device());
650
651        let rays = Rays::new(origins.clone(), directions.clone());
652
653        // Compute directional derivatives
654        let derivatives = region.directional_derivative(rays.clone(), &algebra);
655        println!("derivatives: {}", derivatives);
656        // Check that we get the expected shape
657        assert_eq!(derivatives.dims(), [5]);
658
659        // Check that the derivatives are correct
660        assert_tensor_almost_eq(derivatives, expected_derivatives.inner(), Some(EPSILON));
661    }
662}