1use crate::analysis::rays::prelude::*;
4use crate::csg::prelude::*;
5use burn::prelude::*;
6use burn::tensor::cast::ToElement;
7
8use burn::tensor::backend::AutodiffBackend;
9use rand::rngs::StdRng;
10use rand::{Rng, SeedableRng};
11use std::fmt;
12
13#[derive(Debug, Clone)]
15pub struct BenchmarkResult {
16 pub algorithm: String,
17 pub sphere_count: usize,
18 pub ray_count: usize,
19 pub successful_hits: usize,
20 pub hit_rate: f32,
21 pub average_distance: f32,
22 pub ray_cast_time: std::time::Duration,
23 pub rays_per_second: f64,
24}
25
26impl BenchmarkResult {
27 pub fn new<B: Backend>(
29 result: &RayCastResult<B, 3>,
30 elapsed: std::time::Duration,
31 sphere_count: usize,
32 ray_count: usize,
33 algorithm: &str,
34 ) -> Self {
35 let successful_mask = result.distances().lower_elem(f32::MAX);
37 let successful_hits: usize = successful_mask.clone().int().sum().into_scalar().to_usize();
38
39 let distances = result.distances();
41 let valid_distances = distances.clone().mask_fill(successful_mask.bool_not(), 0.0);
42 let average_distance = if successful_hits > 0 {
43 valid_distances.sum().into_scalar().to_f32() / successful_hits as f32
44 } else {
45 0.0
46 };
47
48 let hit_rate = (successful_hits as f32 / ray_count as f32) * 100.0;
49 let rays_per_second = ray_count as f64 / elapsed.as_secs_f64();
50
51 Self {
52 algorithm: algorithm.to_string(),
53 sphere_count,
54 ray_count,
55 successful_hits,
56 hit_rate,
57 average_distance,
58 ray_cast_time: elapsed,
59 rays_per_second,
60 }
61 }
62}
63
64impl fmt::Display for BenchmarkResult {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 writeln!(
67 f,
68 "Results for {} algorithm with {} spheres and {} rays:",
69 self.algorithm, self.sphere_count, self.ray_count
70 )?;
71 writeln!(f, " Total rays: {}", self.ray_count)?;
72 writeln!(f, " Successful hits: {}", self.successful_hits)?;
73 writeln!(f, " Hit rate: {:.2}%", self.hit_rate)?;
74 writeln!(f, " Average distance: {:.4}", self.average_distance)?;
75 writeln!(f, " Ray-cast time: {:.2?}", self.ray_cast_time)?;
76 writeln!(f, " Rays per second: {:.0}", self.rays_per_second)?;
77 Ok(())
78 }
79}
80
81pub fn make_random_spheres<B: Backend>(
83 num_spheres: usize,
84 seed: u64,
85 device: &B::Device,
86) -> Region<3, B> {
87 let mut rng = StdRng::seed_from_u64(seed);
88 let mut spheres = Vec::new();
89
90 for _ in 0..num_spheres {
91 let center = [
93 rng.random_range(-0.8..=0.8),
94 rng.random_range(-0.8..=0.8),
95 rng.random_range(-0.8..=0.8),
96 ];
97
98 let radius = rng.random_range(0.1..=0.3);
100
101 let sphere = Field3D::<B>::sphere(radius, device.clone())
103 .into_isosurface(0.0)
104 .transform(Translate(center))
105 .region();
106
107 spheres.push(sphere);
108 }
109
110 let mut union = spheres.remove(0);
112 for sphere in spheres {
113 union = &union | &sphere;
114 }
115
116 union
117}
118
119pub fn random_rays<B: Backend>(
121 batch_size: usize,
122 seed: u64,
123 device: &B::Device,
124) -> (Tensor<B, 2, Float>, Tensor<B, 2, Float>) {
125 let mut rng = StdRng::seed_from_u64(seed);
126
127 let mut directions = Vec::with_capacity(batch_size * 3);
129 for _ in 0..batch_size {
130 let mut v = [0.0f32; 3];
131 loop {
132 for val in &mut v {
133 *val = rng.random_range(-1.0..=1.0);
134 }
135 let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
136 if norm > 1e-6 {
137 for val in &v {
138 directions.push(val / norm);
139 }
140 break;
141 }
142 }
143 }
144
145 let mut origins = Vec::with_capacity(batch_size * 3);
147 for _ in 0..batch_size {
148 let x = rng.random_range(-1.0..=1.0);
149 let y = rng.random_range(-1.0..=1.0);
150 let z = rng.random_range(-1.0..=1.0);
151 origins.push(x);
152 origins.push(y);
153 origins.push(z);
154 }
155
156 let directions =
157 Tensor::<B, 1, Float>::from_data(directions.as_slice(), device).reshape([batch_size, 3]);
158 let origins =
159 Tensor::<B, 1, Float>::from_data(origins.as_slice(), device).reshape([batch_size, 3]);
160 (origins, directions)
161}
162
163pub fn parse_algorithm(algorithm: &str) -> RayCastAlgorithm<3> {
165 match algorithm.to_lowercase().as_str() {
166 "analytical" => RayCastAlgorithm::<3>::Analytical,
167 "march_and_bisect" => RayCastAlgorithm::<3>::BracketAndBisect {
168 lambda: 10.0,
169 d_lambda: f32::EPSILON * 1e6,
170 max_bisection_iterations: 30,
171 },
172 "newton" => RayCastAlgorithm::<3>::Newton {
173 max_iterations: 250,
174 nudge_distance: 0.01,
175 step_size: 0.1,
176 },
177 _ => {
178 eprintln!("Unknown algorithm: {}. Using analytical.", algorithm);
179 RayCastAlgorithm::<3>::Analytical
180 }
181 }
182}
183
184pub fn run_raycast_benchmark<B: AutodiffBackend>(
186 sphere_count: usize,
187 ray_count: usize,
188 algorithm: &str,
189 sphere_seed: u64,
190 ray_seed: u64,
191 device: &B::Device,
192) -> BenchmarkResult {
193 let region = make_random_spheres::<B>(sphere_count, sphere_seed, device);
194 let (origins, directions) = random_rays::<B>(ray_count, ray_seed, device);
195 let algebra = Algebra::default();
196 let rays = Rays::new(origins, directions);
197 let parsed_algorithm = parse_algorithm(algorithm);
198
199 let start_time = std::time::Instant::now();
200 let result = region.ray_cast(rays, &algebra, parsed_algorithm);
201 let elapsed = start_time.elapsed();
202
203 BenchmarkResult::new(&result, elapsed, sphere_count, ray_count, algorithm)
204}
205
206pub fn run_raycast_benchmark_with_result<B: AutodiffBackend>(
208 sphere_count: usize,
209 ray_count: usize,
210 algorithm: &str,
211 sphere_seed: u64,
212 ray_seed: u64,
213 device: &B::Device,
214) -> (BenchmarkResult, RayCastResult<B, 3>) {
215 let region = make_random_spheres::<B>(sphere_count, sphere_seed, device);
216 let (origins, directions) = random_rays::<B>(ray_count, ray_seed, device);
217 let algebra = Algebra::default();
218 let rays = Rays::new(origins, directions);
219 let parsed_algorithm = parse_algorithm(algorithm);
220
221 let start_time = std::time::Instant::now();
222 let result = region.ray_cast(rays, &algebra, parsed_algorithm);
223 let elapsed = start_time.elapsed();
224
225 let benchmark_result =
226 BenchmarkResult::new(&result, elapsed, sphere_count, ray_count, algorithm);
227 (benchmark_result, result)
228}