crater/analysis/rays/
bench_utils.rs

1//! Shared utilities for ray casting benchmarks and examples
2
3use crate::analysis::rays::prelude::*;
4use crate::csg::prelude::*;
5use burn::prelude::*;
6use burn::tensor::cast::ToElement;
7
8use burn::tensor::backend::AutodiffBackend;
9use rand::rngs::StdRng;
10use rand::{Rng, SeedableRng};
11use std::fmt;
12
13/// Results from a ray casting benchmark
14#[derive(Debug, Clone)]
15pub struct BenchmarkResult {
16    pub algorithm: String,
17    pub sphere_count: usize,
18    pub ray_count: usize,
19    pub successful_hits: usize,
20    pub hit_rate: f32,
21    pub average_distance: f32,
22    pub ray_cast_time: std::time::Duration,
23    pub rays_per_second: f64,
24}
25
26impl BenchmarkResult {
27    /// Create a new benchmark result from ray cast data
28    pub fn new<B: Backend>(
29        result: &RayCastResult<B, 3>,
30        elapsed: std::time::Duration,
31        sphere_count: usize,
32        ray_count: usize,
33        algorithm: &str,
34    ) -> Self {
35        // Count successful hits (rays that didn't fail)
36        let successful_mask = result.distances().lower_elem(f32::MAX);
37        let successful_hits: usize = successful_mask.clone().int().sum().into_scalar().to_usize();
38
39        // Calculate average distance for successful hits only
40        let distances = result.distances();
41        let valid_distances = distances.clone().mask_fill(successful_mask.bool_not(), 0.0);
42        let average_distance = if successful_hits > 0 {
43            valid_distances.sum().into_scalar().to_f32() / successful_hits as f32
44        } else {
45            0.0
46        };
47
48        let hit_rate = (successful_hits as f32 / ray_count as f32) * 100.0;
49        let rays_per_second = ray_count as f64 / elapsed.as_secs_f64();
50
51        Self {
52            algorithm: algorithm.to_string(),
53            sphere_count,
54            ray_count,
55            successful_hits,
56            hit_rate,
57            average_distance,
58            ray_cast_time: elapsed,
59            rays_per_second,
60        }
61    }
62}
63
64impl fmt::Display for BenchmarkResult {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        writeln!(
67            f,
68            "Results for {} algorithm with {} spheres and {} rays:",
69            self.algorithm, self.sphere_count, self.ray_count
70        )?;
71        writeln!(f, "  Total rays: {}", self.ray_count)?;
72        writeln!(f, "  Successful hits: {}", self.successful_hits)?;
73        writeln!(f, "  Hit rate: {:.2}%", self.hit_rate)?;
74        writeln!(f, "  Average distance: {:.4}", self.average_distance)?;
75        writeln!(f, "  Ray-cast time: {:.2?}", self.ray_cast_time)?;
76        writeln!(f, "  Rays per second: {:.0}", self.rays_per_second)?;
77        Ok(())
78    }
79}
80
81/// Create a region consisting of multiple random spheres
82pub fn make_random_spheres<B: Backend>(
83    num_spheres: usize,
84    seed: u64,
85    device: &B::Device,
86) -> Region<3, B> {
87    let mut rng = StdRng::seed_from_u64(seed);
88    let mut spheres = Vec::new();
89
90    for _ in 0..num_spheres {
91        // Generate random center in unit cube [-1, 1]^3
92        let center = [
93            rng.random_range(-0.8..=0.8),
94            rng.random_range(-0.8..=0.8),
95            rng.random_range(-0.8..=0.8),
96        ];
97
98        // Generate random radius between 0.1 and 0.3
99        let radius = rng.random_range(0.1..=0.3);
100
101        // Create sphere and translate it to the random center
102        let sphere = Field3D::<B>::sphere(radius, device.clone())
103            .into_isosurface(0.0)
104            .transform(Translate(center))
105            .region();
106
107        spheres.push(sphere);
108    }
109
110    // Create union of all spheres
111    let mut union = spheres.remove(0);
112    for sphere in spheres {
113        union = &union | &sphere;
114    }
115
116    union
117}
118
119/// Generate random rays with the given seed
120pub fn random_rays<B: Backend>(
121    batch_size: usize,
122    seed: u64,
123    device: &B::Device,
124) -> (Tensor<B, 2, Float>, Tensor<B, 2, Float>) {
125    let mut rng = StdRng::seed_from_u64(seed);
126
127    // Generate random directions (unit vectors)
128    let mut directions = Vec::with_capacity(batch_size * 3);
129    for _ in 0..batch_size {
130        let mut v = [0.0f32; 3];
131        loop {
132            for val in &mut v {
133                *val = rng.random_range(-1.0..=1.0);
134            }
135            let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
136            if norm > 1e-6 {
137                for val in &v {
138                    directions.push(val / norm);
139                }
140                break;
141            }
142        }
143    }
144
145    // Generate random origins inside the unit cube
146    let mut origins = Vec::with_capacity(batch_size * 3);
147    for _ in 0..batch_size {
148        let x = rng.random_range(-1.0..=1.0);
149        let y = rng.random_range(-1.0..=1.0);
150        let z = rng.random_range(-1.0..=1.0);
151        origins.push(x);
152        origins.push(y);
153        origins.push(z);
154    }
155
156    let directions =
157        Tensor::<B, 1, Float>::from_data(directions.as_slice(), device).reshape([batch_size, 3]);
158    let origins =
159        Tensor::<B, 1, Float>::from_data(origins.as_slice(), device).reshape([batch_size, 3]);
160    (origins, directions)
161}
162
163/// Parse algorithm string into RayCastAlgorithm
164pub fn parse_algorithm(algorithm: &str) -> RayCastAlgorithm<3> {
165    match algorithm.to_lowercase().as_str() {
166        "analytical" => RayCastAlgorithm::<3>::Analytical,
167        "march_and_bisect" => RayCastAlgorithm::<3>::BracketAndBisect {
168            lambda: 10.0,
169            d_lambda: f32::EPSILON * 1e6,
170            max_bisection_iterations: 30,
171        },
172        "newton" => RayCastAlgorithm::<3>::Newton {
173            max_iterations: 250,
174            nudge_distance: 0.01,
175            step_size: 0.1,
176        },
177        _ => {
178            eprintln!("Unknown algorithm: {}. Using analytical.", algorithm);
179            RayCastAlgorithm::<3>::Analytical
180        }
181    }
182}
183
184/// Run a ray casting benchmark with the given parameters
185pub fn run_raycast_benchmark<B: AutodiffBackend>(
186    sphere_count: usize,
187    ray_count: usize,
188    algorithm: &str,
189    sphere_seed: u64,
190    ray_seed: u64,
191    device: &B::Device,
192) -> BenchmarkResult {
193    let region = make_random_spheres::<B>(sphere_count, sphere_seed, device);
194    let (origins, directions) = random_rays::<B>(ray_count, ray_seed, device);
195    let algebra = Algebra::default();
196    let rays = Rays::new(origins, directions);
197    let parsed_algorithm = parse_algorithm(algorithm);
198
199    let start_time = std::time::Instant::now();
200    let result = region.ray_cast(rays, &algebra, parsed_algorithm);
201    let elapsed = start_time.elapsed();
202
203    BenchmarkResult::new(&result, elapsed, sphere_count, ray_count, algorithm)
204}
205
206/// Run a ray casting benchmark and return both the benchmark result and the ray cast result
207pub fn run_raycast_benchmark_with_result<B: AutodiffBackend>(
208    sphere_count: usize,
209    ray_count: usize,
210    algorithm: &str,
211    sphere_seed: u64,
212    ray_seed: u64,
213    device: &B::Device,
214) -> (BenchmarkResult, RayCastResult<B, 3>) {
215    let region = make_random_spheres::<B>(sphere_count, sphere_seed, device);
216    let (origins, directions) = random_rays::<B>(ray_count, ray_seed, device);
217    let algebra = Algebra::default();
218    let rays = Rays::new(origins, directions);
219    let parsed_algorithm = parse_algorithm(algorithm);
220
221    let start_time = std::time::Instant::now();
222    let result = region.ray_cast(rays, &algebra, parsed_algorithm);
223    let elapsed = start_time.elapsed();
224
225    let benchmark_result =
226        BenchmarkResult::new(&result, elapsed, sphere_count, ray_count, algorithm);
227    (benchmark_result, result)
228}