crater/analysis/rays/algo.rs
1//! Ray-surface intersection algorithms.
2//!
3//! This module implements various algorithms for finding intersections between rays and
4//! implicit surfaces. For the mathematical theory and algorithm descriptions, see the
5//! [Ray Casting](../../book/theory/ray-casting.md) chapter.
6
7use super::prelude::*;
8use crate::csg::prelude::*;
9use burn::prelude::*;
10use burn::tensor::{
11 Tensor,
12 backend::{AutodiffBackend, Backend},
13};
14use log::info;
15
16pub const RAY_FAIL_VALUE: f32 = f32::MAX;
17
18/// Ray-surface intersection algorithms for implicit surface ray casting.
19///
20/// This enum provides multiple algorithmic approaches for solving ray-surface intersection
21/// problems, each optimized for different scenarios in safety-critical applications where
22/// precision and reliability are paramount.
23///
24/// # Algorithm Selection Guidelines
25///
26/// * **Use [`Analytical`](RayCastAlgorithm::Analytical)** for:
27/// - Simple primitive shapes (spheres, planes, cylinders)
28/// - Maximum accuracy and performance requirements
29/// - Real-time applications with tight computational budgets
30///
31/// * **Use [`Newton`](RayCastAlgorithm::Newton)** for:
32/// - Smooth, well-conditioned implicit surfaces
33/// - When autodifferentiation is available
34/// - Applications requiring fast convergence with few iterations
35///
36/// * **Use [`BracketAndBisect`](RayCastAlgorithm::BracketAndBisect)** for:
37/// - Complex CSG constructions with sharp edges
38/// - Robustness over speed requirements
39/// - Surfaces with potential numerical instabilities
40///
41/// For theoretical background, see the [Ray Casting](../../book/theory/ray-casting.md) chapter.
42#[derive(Debug, Clone, Copy)]
43pub enum RayCastAlgorithm<const N: usize> {
44 /// Robust numerical method using bracketing followed by bisection.
45 ///
46 /// This algorithm first advances along each ray to bracket sign changes in the
47 /// scalar field, then uses bisection to refine intersection points to high precision.
48 /// It's the most reliable method for complex CSG constructions.
49 ///
50 /// # Parameters
51 ///
52 /// * `lambda` - Maximum ray extension distance for bracketing phase
53 /// * `d_lambda` - Step size during bracketing search
54 /// * `max_bisection_iterations` - Maximum refinement iterations (typically 20-50)
55 ///
56 /// # Convergence
57 ///
58 /// Guaranteed to converge for continuous scalar fields with precision ε ≈ λ/2ⁿ
59 /// where n is the number of bisection iterations.
60 BracketAndBisect {
61 /// Maximum distance to extend rays during bracketing
62 lambda: f32,
63 /// Step size for bracketing phase
64 d_lambda: f32,
65 /// Maximum iterations for bisection refinement
66 max_bisection_iterations: usize,
67 },
68
69 /// Direct analytical solutions for primitive geometric shapes.
70 ///
71 /// Computes exact ray-surface intersections using closed-form mathematical
72 /// formulas. This is the fastest and most accurate method when applicable,
73 /// but is limited to simple primitives with known intersection equations.
74 ///
75 /// # Supported Primitives
76 ///
77 /// - Spheres: Quadratic equation solutions
78 /// - Planes: Linear ray-plane intersection
79 /// - Cylinders: Analytical cylinder intersection
80 /// - Cones: Quadratic cone intersection formulas
81 ///
82 /// # Limitations
83 ///
84 /// Cannot be used with:
85 /// - Arbitrary scalar field expressions
86 /// - CSG operations beyond simple unions
87 /// - Transformed or deformed primitives
88 Analytical,
89
90 /// Iterative Newton-Raphson root finding using automatic differentiation.
91 ///
92 /// Uses gradient information to rapidly converge to surface intersections.
93 /// Exhibits quadratic convergence near solutions but requires smooth,
94 /// differentiable scalar fields for optimal performance.
95 ///
96 /// # Parameters
97 ///
98 /// * `max_iterations` - Maximum Newton iterations (typically 5-20)
99 /// * `nudge_distance` - Perturbation for degenerate gradient cases
100 /// * `step_size` - Relaxation factor for Newton steps (usually 0.5-1.0)
101 ///
102 /// # Convergence Properties
103 ///
104 /// - **Quadratic**: Error ∝ ε² per iteration near solutions
105 /// - **Sensitive**: May fail near sharp edges or gradient discontinuities
106 /// - **Fast**: Typically converges in 3-10 iterations for smooth surfaces
107 Newton {
108 /// Maximum number of Newton iterations
109 max_iterations: usize,
110 /// Distance to perturb points with zero gradients
111 nudge_distance: f32,
112 /// Step size relaxation factor (0.0 to 1.0)
113 step_size: f32,
114 },
115}
116
117impl<const N: usize> RayCastAlgorithm<N> {
118 /// Computes ray-surface intersections using the specified algorithm.
119 ///
120 /// This method finds the points where rays intersect with the boundary of an implicit
121 /// surface region. It supports multiple algorithmic approaches, each with different
122 /// trade-offs between accuracy, performance, and applicability.
123 ///
124 /// # Mathematical Problem
125 ///
126 /// Given rays **r**(t) = **o** + t**d** and a region R with characteristic function f(x),
127 /// find parameter values t where f(**r**(t)) = 0, indicating surface intersection.
128 ///
129 /// # Algorithm Variants
130 ///
131 /// * **Analytical**: Direct closed-form solutions for primitive shapes (spheres, planes, etc.)
132 /// - Fastest and most accurate when applicable
133 /// - Limited to simple geometric primitives with known ray intersection formulas
134 ///
135 /// * **Newton**: Iterative root-finding using gradient-based optimization
136 /// - Quadratic convergence near solutions
137 /// - Requires autodifferentiable scalar fields
138 /// - Best for smooth, well-conditioned surfaces
139 ///
140 /// * **BracketAndBisect**: Robust bracketing followed by bisection refinement
141 /// - Guaranteed convergence for continuous functions
142 /// - Works with any scalar field implementation
143 /// - Slower but most reliable for complex CSG constructions
144 ///
145 /// # Parameters
146 ///
147 /// * `rays` - Collection of rays to cast against the surface
148 /// * `region` - Implicit surface region to intersect with
149 /// * `algebra` - CSG algebra for evaluating boolean operations in composite regions
150 ///
151 /// # Returns
152 ///
153 /// [`RayCastResult`] containing intersection points, distances, field values, and metadata
154 /// for each input ray. Rays that miss the surface are marked with sentinel values.
155 ///
156 /// # Performance Notes
157 ///
158 /// - **Analytical**: O(1) per ray for primitive surfaces
159 /// - **Newton**: O(k) where k is iteration count (typically 3-10)
160 /// - **BracketAndBisect**: O(log ε⁻¹) where ε is desired precision
161 ///
162 /// For complex CSG regions, cost scales with the number of constituent half-spaces.
163 ///
164 /// # Examples
165 ///
166 /// ```rust
167 /// use crater::analysis::prelude::*;
168 /// use crater::csg::prelude::*;
169 /// use burn::prelude::*;
170 /// use burn::backend::ndarray::{NdArrayDevice, NdArray};
171 /// use burn::backend::Autodiff;
172 ///
173 /// // Create a sphere region
174 /// let sphere = Field3D::<Autodiff<NdArray>>::sphere(1.0, NdArrayDevice::Cpu).into_isosurface(0.0);
175 /// let region = Region::HalfSpace(sphere, Side::Negative);
176 /// let algebra = Algebra::<Autodiff<NdArray>>::default();
177 ///
178 /// // Create rays pointing at the sphere
179 /// let origins = Tensor::from_data([[-2.0, 0.0, 0.0], [0.0, -2.0, 0.0]], &NdArrayDevice::Cpu);
180 /// let directions = Tensor::from_data([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], &NdArrayDevice::Cpu);
181 /// let rays = Rays::<Autodiff<NdArray>, 3>::new(origins, directions);
182 ///
183 /// // Use analytical algorithm for primitive shapes
184 /// let algorithm = RayCastAlgorithm::<3>::Analytical;
185 /// // let result = algorithm.solve(rays, ®ion, &algebra);
186 ///
187 /// // Check intersection distances (should be ≈ 1.0 for unit sphere)
188 /// // let distances = result.distances();
189 /// ```
190 ///
191 /// ```rust
192 /// use crater::analysis::prelude::*;
193 /// use crater::csg::prelude::*;
194 /// use burn::prelude::*;
195 /// use burn::backend::wgpu::WgpuDevice;
196 ///
197 /// // Complex CSG region requiring numerical methods
198 /// let sphere = Field3D::<burn::backend::wgpu::Wgpu>::sphere(1.0, WgpuDevice::default()).into_isosurface(0.0);
199 /// let cube = Field3D::<burn::backend::wgpu::Wgpu>::sphere(1.5, WgpuDevice::default()).into_isosurface(0.0);
200 ///
201 /// let sphere_region = Region::HalfSpace(sphere, Side::Negative);
202 /// let cube_region = Region::HalfSpace(cube, Side::Negative);
203 /// let intersection = Region::Intersection(Box::new(sphere_region), Box::new(cube_region));
204 ///
205 /// // Use bracket-and-bisect for robust intersection with CSG
206 /// let algorithm = RayCastAlgorithm::<3>::BracketAndBisect {
207 /// lambda: 10.0,
208 /// d_lambda: 0.01,
209 /// max_bisection_iterations: 50,
210 /// };
211 ///
212 /// // let result = algorithm.solve(rays, &intersection, &algebra);
213 /// ```
214 pub fn solve<B: AutodiffBackend>(
215 &self,
216 rays: Rays<B, N>,
217 region: &Region<N, B>,
218 algebra: &impl CSGAlgebra<B>,
219 ) -> RayCastResult<B, N> {
220 match self {
221 Self::BracketAndBisect { .. } => self.bracket_and_bisect(rays, region, algebra),
222 Self::Analytical => self.analytical(rays, region, algebra),
223 Self::Newton { .. } => self.newton(rays, region, algebra),
224 }
225 }
226
227 fn analytical<B: Backend>(
228 &self,
229 rays: Rays<B, N>,
230 region: &Region<N, B>,
231 algebra: &impl CSGAlgebra<B>,
232 ) -> RayCastResult<B, N> {
233 info!("Analytical ray casting: {}", rays);
234 // First, iterate over all halfspaces, and evaluate the ray field
235 let halfspace_evaluations = region
236 .halfspaces()
237 .enumerate()
238 .map(|(i, halfspace)| {
239 info!("Evaluating halfspace {}: {:?}", i, halfspace);
240 let isosurface = match halfspace {
241 Region::HalfSpace(surf, _) => surf,
242 _ => panic!(),
243 };
244
245 assert!(
246 isosurface.ray_field.is_some(),
247 "Cannot use analytical ray casting with non-analytic isosurfaces"
248 );
249
250 // Evaluate ray field at the isosurface
251 let result = isosurface
252 .ray_field
253 .as_ref()
254 .unwrap()
255 .solve(rays.clone(), isosurface.constant.clone());
256
257 result
258 })
259 .collect::<Vec<_>>();
260
261 // Next, at least one of the halfspaces must be on the surface of the REGION as well
262 let mut intersections =
263 Tensor::<B, 2, Float>::ones([rays.ray_count(), N], &rays.device()) * RAY_FAIL_VALUE;
264 let mut distance_to_hit =
265 Tensor::<B, 2, Float>::ones([rays.ray_count(), 1], &rays.device()) * RAY_FAIL_VALUE;
266
267 halfspace_evaluations.iter().for_each(|result| {
268 // True if the ray hits the surface of the REGION at this point
269
270 let is_on_mask = region.evaluate(result.extensions(), algebra).is_on_mask();
271
272 // If no intersection with this halfspace, skip
273 if !is_on_mask.clone().any().into_scalar() {
274 return;
275 }
276
277 let is_closer = distance_to_hit
278 .clone()
279 .greater_equal(result.distances().clone().unsqueeze::<2>().transpose());
280
281 // Create a mask for rays that hit the surface AND are closer than current hit
282 let should_replace_mask = is_on_mask
283 .clone()
284 .int()
285 .mul(is_closer.clone().squeeze(1).int())
286 .bool();
287
288 // If no rays should be replaced, skip
289 if !should_replace_mask.clone().any().into_scalar() {
290 return;
291 }
292
293 // Use scatter_replace with boolean masks to replace entire rows
294 intersections = intersections
295 .scatter_replace(should_replace_mask.clone(), result.extensions().clone());
296 distance_to_hit = distance_to_hit.scatter_replace(
297 should_replace_mask.clone(),
298 result.distances().clone().unsqueeze::<2>().transpose(),
299 );
300 });
301 RayCastResult::new(
302 rays.clone(),
303 intersections.clone(),
304 region.evaluate(intersections.clone(), algebra),
305 Tensor::<B, 1, Float>::zeros([rays.ray_count()], &rays.device()),
306 )
307 }
308
309 fn newton<B: AutodiffBackend>(
310 &self,
311 rays: Rays<B, N>,
312 region: &Region<N, B>,
313 algebra: &impl CSGAlgebra<B>,
314 ) -> RayCastResult<B, N> {
315 if let Self::Newton {
316 max_iterations,
317 nudge_distance,
318 step_size,
319 } = self
320 {
321 let device = rays.device();
322 let origins = rays.origins();
323 let batch_size = origins.dims()[0];
324
325 let zeros = Tensor::<B, 1>::zeros([batch_size], &device);
326 // Start with a mask of all rays active
327 let mut m_i = Tensor::<B, 1>::ones([batch_size], &device);
328 // Start from ray origins
329 let mut points = origins.clone();
330
331 for _ in 0..*max_iterations {
332 // Compute field values at the current points
333 let field_values = region.evaluate(points.clone(), algebra);
334
335 // For points that are determined to be ON the surface, we set
336 // their mask positions to 0.
337 let is_on_mask = field_values.is_on_mask();
338 m_i = m_i.clone().mul(is_on_mask.clone().bool_not().float());
339
340 // Then, if m_i == 1, it means we should take a gradient step
341 // If m_i is all 0, then we can stop
342 if m_i.clone().any().into_scalar() {
343 let mut directional_derivatives = region.directional_derivative(
344 Rays::new(points.clone(), rays.directions()),
345 algebra,
346 ); // [batch_size, 1]
347
348 // Degenerate indices are those where the directional derivative is 0
349 let degenerate_mask = directional_derivatives.clone().equal_elem(0.0);
350
351 // Set the directional derivative to 1 for degenerate indices
352 if degenerate_mask.clone().any().into_scalar() {
353 directional_derivatives = directional_derivatives
354 .clone()
355 .unsqueeze::<2>()
356 .transpose()
357 .scatter_replace(
358 degenerate_mask.clone(),
359 Tensor::ones([degenerate_mask.clone().dims()[0], 1], &device)
360 .mul_scalar(*nudge_distance),
361 )
362 .squeeze(1);
363 }
364
365 let step_size = -field_values
366 .clone()
367 .inner()
368 .div(directional_derivatives)
369 .mul_scalar(*step_size)
370 .unsqueeze::<2>()
371 .transpose();
372
373 // Take a gradient step using directional derivative
374 let step = rays
375 .directions()
376 .clone()
377 .inner()
378 .mul(step_size.clone())
379 .mul(m_i.clone().inner().unsqueeze::<2>().transpose());
380
381 points = points.add(Tensor::<B, 2>::from_inner(step));
382 } else {
383 println!("No rays are still active, stopping");
384 break;
385 }
386 }
387
388 // Ensure that rays are not moving backwards
389 let displacement = points.clone() - rays.origins().clone();
390 let backwards_mask = displacement
391 .clone()
392 .mul(rays.directions().clone())
393 .sum_dim(1)
394 .squeeze::<1>(1)
395 .lower_elem(0.0);
396
397 if backwards_mask.clone().any().into_scalar() {
398 points = points.scatter_replace(
399 backwards_mask.clone(),
400 Tensor::<B, 2>::ones([backwards_mask.dims()[0], N], &device) * RAY_FAIL_VALUE,
401 );
402 }
403
404 // For all rays that never had m_i == 1, set their points to RAY_FAIL_VALUE
405 let failed_ray_mask = m_i.clone().bool();
406 if failed_ray_mask.clone().any().into_scalar() {
407 points = points.scatter_replace(
408 failed_ray_mask.clone(),
409 Tensor::<B, 2>::ones([failed_ray_mask.dims()[0], N], &device) * RAY_FAIL_VALUE,
410 );
411 }
412
413 RayCastResult::new(
414 rays.clone(),
415 points.clone(),
416 region.evaluate(points, algebra),
417 zeros,
418 )
419 } else {
420 panic!("Expected GradientDescent variant")
421 }
422 }
423
424 fn bracket_and_bisect<B: Backend>(
425 &self,
426 rays: Rays<B, N>,
427 region: &Region<N, B>,
428 algebra: &impl CSGAlgebra<B>,
429 ) -> RayCastResult<B, N> {
430 if let Self::BracketAndBisect {
431 lambda,
432 d_lambda,
433 max_bisection_iterations,
434 } = self
435 {
436 let device = rays.device();
437 let origins = rays.origins();
438 let directions = rays.directions();
439 let batch_size = origins.dims()[0];
440
441 let zeros = Tensor::<B, 1>::zeros([batch_size], &device);
442 let ones = Tensor::<B, 1>::ones([batch_size], &device);
443 let lambda_0 = LAMBDA;
444
445 let mut m_i = ones.clone();
446 let mut l_i = zeros.clone();
447 let l_0 = l_i.clone();
448 let mut r_i = l_i.clone() + lambda_0;
449 let mut f_l = region.evaluate(
450 extend_rays::<B, N>(origins.clone(), directions.clone(), l_i.clone()),
451 algebra,
452 );
453
454 // Bracketing phase:
455 // Find intervals [l_i, r_i] along each ray that contain the isosurface
456 // by advancing until the field value changes sign or we exceed max_distance
457 loop {
458 // Compute the right side of the bracket
459 let f_r = region.evaluate(
460 extend_rays::<B, N>(origins.clone(), directions.clone(), r_i.clone()),
461 algebra,
462 );
463
464 // Update the mask to reflect those brackets that now bound a root
465 m_i = m_i.clone().mul(
466 ones.clone()
467 .sub(contains_root_mask(f_l.clone().mul(f_r.clone()))),
468 );
469
470 // Check if we've exceeded max_distance for any active rays
471 let exceeded_max = (r_i.clone() - l_0.clone()).greater_elem(*lambda);
472
473 if exceeded_max.clone().any().into_scalar() {
474 // For rays that exceeded max_distance, mark them as not intersecting
475 m_i = m_i.clone().mul(ones.clone().sub(exceeded_max.float()));
476 }
477
478 // If no rays are still active, we can stop
479 if !m_i.clone().any().into_scalar() {
480 break;
481 }
482
483 // Update the left and right sides of the bracket
484 // Step fixed distance
485 let step_size = m_i.clone().mul_scalar(*d_lambda);
486 l_i = l_i.clone().add(step_size.clone());
487 r_i = r_i.clone().add(step_size.clone());
488 f_l = f_r;
489 }
490
491 // Bisection phase:
492 // Refine the brackets to precisely locate the intersection points
493 // using the bisection method for root finding
494 (0..*max_bisection_iterations).for_each(|_| {
495 // Compute the midpoint of the bracket
496 let c_i = l_i.clone().add(r_i.clone()).div_scalar(2.0);
497 // Evaluate the field at left, center, and right points
498 let f_c = region.evaluate(
499 extend_rays::<B, N>(origins.clone(), directions.clone(), c_i.clone()),
500 algebra,
501 );
502 let f_l = region.evaluate(
503 extend_rays::<B, N>(origins.clone(), directions.clone(), l_i.clone()),
504 algebra,
505 );
506 // Create masks for brackets:
507 // - is_bracketed_left: true if the root is in [l_i, c_i]
508 // - is_bracketed_right: true if the root is in [c_i, r_i]
509 // Note: there is a chance that the root is in both brackets (i.e., c_i = 0), but we
510 // just break ties by assuming anything that is bracketed left must also not be bracketed right
511 let is_bracketed_left = contains_root_mask(f_c.clone().mul(f_l.clone()));
512 let is_bracketed_right = ones.clone() - is_bracketed_left.clone();
513
514 // Update brackets based on which half contains the root
515 l_i = c_i.clone().mul(ones.clone().sub(is_bracketed_left.clone()))
516 + l_i.clone().mul(is_bracketed_left.clone());
517 r_i = c_i
518 .clone()
519 .mul(ones.clone().sub(is_bracketed_right.clone()))
520 + r_i.clone().mul(is_bracketed_right.clone());
521 });
522 let mut intersections =
523 extend_rays::<B, N>(origins.clone(), directions.clone(), r_i.clone());
524
525 // If the ray's origin is on the surface of the region, overwrite
526 let is_on_mask = region.evaluate(origins.clone(), algebra).is_on_mask();
527
528 if is_on_mask.clone().any().into_scalar() {
529 intersections = intersections.scatter_replace(is_on_mask.clone(), origins.clone());
530 }
531
532 // If we did not find an extension that is on the surface of the region,
533 // provide sentinel values
534 let is_not_on_mask = region
535 .evaluate(intersections.clone(), algebra)
536 .is_on_mask()
537 .bool_not();
538 if is_not_on_mask.clone().any().into_scalar() {
539 intersections = intersections.scatter_replace(
540 is_not_on_mask.clone(),
541 Tensor::<B, 2, Float>::ones([is_not_on_mask.clone().dims()[0], N], &device)
542 * RAY_FAIL_VALUE,
543 );
544 }
545
546 // Return the final ray intersection points and their field values
547 let final_field_values = region.evaluate(intersections.clone(), algebra);
548 RayCastResult::new(rays, intersections, final_field_values, zeros)
549 } else {
550 panic!("Expected MarchAndBisect variant")
551 }
552 }
553}
554
555/// Min step distance in bracketing phase
556/// Use non-rounded epsilon to avoid precision issues
557const LAMBDA: f32 = f32::EPSILON * 1.0e6;
558
559/// Extend a collection of rays by a vector of scalars
560pub fn extend_rays<B: Backend, const N: usize>(
561 positions: Origins<B, N>,
562 directions: Directions<B, N>,
563 lambda_i: Tensor<B, 1>,
564) -> Origins<B, N> {
565 // Compute the extended points
566 positions.clone().add(
567 lambda_i
568 .clone()
569 .unsqueeze::<2>()
570 .transpose()
571 .mul(directions),
572 )
573}
574
575/// Creates a mask tensor where negative values become 1.0 and non-negative values become 0.0
576fn contains_root_mask<B: Backend>(tensor: Tensor<B, 1>) -> Tensor<B, 1> {
577 tensor.clone().lower_equal_elem(0.0).float()
578}
579
580impl<const N: usize, B: AutodiffBackend> Region<N, B> {
581 /// Solve the ray casting problem for a given region and algebra
582 pub fn ray_cast(
583 &self,
584 rays: Rays<B, N>,
585 algebra: &impl CSGAlgebra<B>,
586 algorithm: RayCastAlgorithm<N>,
587 ) -> RayCastResult<B, N> {
588 algorithm.solve(rays, self, algebra)
589 }
590 /// Returns the directional derivative along a collection of rays
591 ///
592 /// This is the scalar quantity:
593 ///
594 /// $$ \frac{\partial f}{\partial \mathbf{d}} = \nabla f \cdot \mathbf{d} $$
595 pub fn directional_derivative(
596 &self,
597 rays: Rays<B, N>,
598 algebra: &impl CSGAlgebra<B>,
599 ) -> Scalars<B::InnerBackend> {
600 let points = rays.origins();
601 let grads = self.gradient(points, algebra);
602 // Project onto the ray direction
603 grads
604 .mul(rays.directions().require_grad().inner())
605 .sum_dim(1)
606 .squeeze(1)
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use crate::primitives::prelude::*;
614 use crate::test_utils::assert_tensor_almost_eq;
615 use backend_macro::with_backend;
616 use burn::tensor::Tensor;
617
618 #[with_backend]
619 #[test]
620 fn test_directional_derivative_3d_sphere() {
621 // Create a simple sphere field
622 let sphere = Field3D::<Backend>::sphere(1.0, device()).into_isosurface(0.0);
623 let region = sphere.region();
624 let algebra = Algebra::default();
625
626 // Create rays pointing in different directions
627 let origins = Tensor::<Backend, 2>::from_data(
628 [
629 [0.0, 0.0, 0.0], // Origin at center
630 [0.5, 0.0, 0.0], // Origin inside sphere
631 [2.0, 0.0, 1.0], // Origin outside sphere
632 [1.0, 0.0, 0.0], // Origin on surface
633 [0.1, 0.1, 0.1], // Origin inside sphere
634 ],
635 &device(),
636 );
637
638 let directions = Tensor::<Backend, 2>::from_data(
639 [
640 [1.0, 0.0, 0.0], // Pointing right
641 [1.0, 0.0, 0.0], // Pointing right
642 [0.0, 0.0, 1.0], // Pointing up
643 [1.0, 0.0, 0.0], // Pointing right
644 normalize(&[1.0, 1.0, 1.0]), // Pointing along diagonal
645 ],
646 &device(),
647 );
648 let expected_derivatives =
649 Tensor::<Backend, 1>::from_data([0.0, 1.0, 2.0, 2.0, 0.2 * 3.0_f32.sqrt()], &device());
650
651 let rays = Rays::new(origins.clone(), directions.clone());
652
653 // Compute directional derivatives
654 let derivatives = region.directional_derivative(rays.clone(), &algebra);
655 println!("derivatives: {}", derivatives);
656 // Check that we get the expected shape
657 assert_eq!(derivatives.dims(), [5]);
658
659 // Check that the derivatives are correct
660 assert_tensor_almost_eq(derivatives, expected_derivatives.inner(), Some(EPSILON));
661 }
662}