1use std::sync::OnceLock;
18
19type RenderExec = Box<dyn Fn(ArenaRenderArgs) + Send + Sync>;
20type RenderBwdExec = Box<dyn Fn(ArenaRenderBwdArgs) + Send + Sync>;
21type PrepareExec = Box<dyn Fn(ArenaPrepareArgs) + Send + Sync>;
22type RasterizeExec = Box<dyn Fn(ArenaRasterizeArgs) + Send + Sync>;
23type HostRenderExec = Box<dyn Fn(HostRenderArgs) -> Vec<f32> + Send + Sync>;
24type HostBackwardExec = Box<dyn Fn(HostBackwardArgs) -> Vec<f32> + Send + Sync>;
25
26static RENDER: OnceLock<RenderExec> = OnceLock::new();
27static RENDER_BWD: OnceLock<RenderBwdExec> = OnceLock::new();
28static PREPARE: OnceLock<PrepareExec> = OnceLock::new();
29static RASTERIZE: OnceLock<RasterizeExec> = OnceLock::new();
30static HOST_RENDER: OnceLock<HostRenderExec> = OnceLock::new();
31static HOST_BACKWARD: OnceLock<HostBackwardExec> = OnceLock::new();
32
33#[allow(clippy::struct_excessive_bools)]
35pub struct ArenaRenderArgs {
36 pub positions_off: usize,
37 pub positions_len: usize,
38 pub scales_off: usize,
39 pub scales_len: usize,
40 pub rotations_off: usize,
41 pub rotations_len: usize,
42 pub opacities_off: usize,
43 pub opacities_len: usize,
44 pub colors_off: usize,
45 pub colors_len: usize,
46 pub sh_coeffs_off: usize,
47 pub sh_coeffs_len: usize,
48 pub meta_off: usize,
49 pub dst_off: usize,
50 pub dst_len: usize,
51 pub width: u32,
52 pub height: u32,
53 pub tile_size: u32,
54 pub radius_scale: f32,
55 pub alpha_cutoff: f32,
56 pub max_splat_steps: u32,
57 pub transmittance_threshold: f32,
58 pub max_list_entries: u32,
59 pub base: *mut u8,
60}
61
62pub struct ArenaPrepareArgs {
64 pub positions_off: usize,
65 pub positions_len: usize,
66 pub scales_off: usize,
67 pub scales_len: usize,
68 pub rotations_off: usize,
69 pub rotations_len: usize,
70 pub opacities_off: usize,
71 pub opacities_len: usize,
72 pub colors_off: usize,
73 pub colors_len: usize,
74 pub sh_coeffs_off: usize,
75 pub sh_coeffs_len: usize,
76 pub meta_off: usize,
77 pub meta_len: usize,
78 pub prep_off: usize,
79 pub prep_len: usize,
80 pub width: u32,
81 pub height: u32,
82 pub tile_size: u32,
83 pub radius_scale: f32,
84 pub alpha_cutoff: f32,
85 pub max_splat_steps: u32,
86 pub transmittance_threshold: f32,
87 pub max_list_entries: u32,
88 pub base: *mut u8,
89}
90
91pub struct ArenaRasterizeArgs {
93 pub prep_off: usize,
94 pub prep_len: usize,
95 pub meta_off: usize,
96 pub meta_len: usize,
97 pub dst_off: usize,
98 pub dst_len: usize,
99 pub count: usize,
100 pub width: u32,
101 pub height: u32,
102 pub tile_size: u32,
103 pub alpha_cutoff: f32,
104 pub max_splat_steps: u32,
105 pub transmittance_threshold: f32,
106 pub max_list_entries: u32,
107 pub base: *mut u8,
108}
109
110pub struct ArenaRenderBwdArgs {
112 pub positions_off: usize,
113 pub positions_len: usize,
114 pub scales_off: usize,
115 pub scales_len: usize,
116 pub rotations_off: usize,
117 pub rotations_len: usize,
118 pub opacities_off: usize,
119 pub opacities_len: usize,
120 pub colors_off: usize,
121 pub colors_len: usize,
122 pub sh_coeffs_off: usize,
123 pub sh_coeffs_len: usize,
124 pub meta_off: usize,
125 pub d_loss_off: usize,
126 pub d_loss_len: usize,
127 pub packed_off: usize,
128 pub packed_len: usize,
129 pub width: u32,
130 pub height: u32,
131 pub tile_size: u32,
132 pub radius_scale: f32,
133 pub alpha_cutoff: f32,
134 pub max_splat_steps: u32,
135 pub transmittance_threshold: f32,
136 pub max_list_entries: u32,
137 pub loss_grad_clip: f32,
138 pub sh_band: u32,
139 pub max_anisotropy: f32,
140 pub base: *mut u8,
141}
142
143pub struct HostRenderArgs {
145 pub positions: Vec<f32>,
146 pub scales: Vec<f32>,
147 pub rotations: Vec<f32>,
148 pub opacities: Vec<f32>,
149 pub colors: Vec<f32>,
150 pub sh_coeffs: Vec<f32>,
151 pub meta: Vec<f32>,
152 pub width: u32,
153 pub height: u32,
154 pub tile_size: u32,
155 pub radius_scale: f32,
156 pub alpha_cutoff: f32,
157 pub max_splat_steps: u32,
158 pub transmittance_threshold: f32,
159 pub max_list_entries: u32,
160}
161
162pub struct HostBackwardArgs {
164 pub positions: Vec<f32>,
165 pub scales: Vec<f32>,
166 pub rotations: Vec<f32>,
167 pub opacities: Vec<f32>,
168 pub colors: Vec<f32>,
169 pub sh_coeffs: Vec<f32>,
170 pub meta: Vec<f32>,
171 pub d_loss_rgba: Vec<f32>,
172 pub width: u32,
173 pub height: u32,
174 pub tile_size: u32,
175 pub radius_scale: f32,
176 pub alpha_cutoff: f32,
177 pub max_splat_steps: u32,
178 pub transmittance_threshold: f32,
179 pub max_list_entries: u32,
180 pub loss_grad_clip: f32,
181 pub sh_band: u32,
182 pub max_anisotropy: f32,
183}
184
185pub fn register_splat_executors(
187 render: RenderExec,
188 backward: RenderBwdExec,
189 prepare: PrepareExec,
190 rasterize: RasterizeExec,
191 host_render: HostRenderExec,
192 host_backward: HostBackwardExec,
193) {
194 let _ = RENDER.set(render);
195 let _ = RENDER_BWD.set(backward);
196 let _ = PREPARE.set(prepare);
197 let _ = RASTERIZE.set(rasterize);
198 let _ = HOST_RENDER.set(host_render);
199 let _ = HOST_BACKWARD.set(host_backward);
200}
201
202#[allow(clippy::too_many_arguments)]
203pub fn render_host_slices(
204 positions: &[f32],
205 scales: &[f32],
206 rotations: &[f32],
207 opacities: &[f32],
208 colors: &[f32],
209 sh_coeffs: &[f32],
210 meta: &[f32],
211 width: u32,
212 height: u32,
213 tile_size: u32,
214 radius_scale: f32,
215 alpha_cutoff: f32,
216 max_splat_steps: u32,
217 transmittance_threshold: f32,
218 max_list_entries: u32,
219) -> Vec<f32> {
220 HOST_RENDER
221 .get()
222 .expect("call `rlx_splat::register()` before host splat render")(HostRenderArgs {
223 positions: positions.to_vec(),
224 scales: scales.to_vec(),
225 rotations: rotations.to_vec(),
226 opacities: opacities.to_vec(),
227 colors: colors.to_vec(),
228 sh_coeffs: sh_coeffs.to_vec(),
229 meta: meta.to_vec(),
230 width,
231 height,
232 tile_size,
233 radius_scale,
234 alpha_cutoff,
235 max_splat_steps,
236 transmittance_threshold,
237 max_list_entries,
238 })
239}
240
241#[allow(clippy::too_many_arguments)]
242pub fn backward_host_slices(
243 positions: &[f32],
244 scales: &[f32],
245 rotations: &[f32],
246 opacities: &[f32],
247 colors: &[f32],
248 sh_coeffs: &[f32],
249 meta: &[f32],
250 d_loss_rgba: &[f32],
251 width: u32,
252 height: u32,
253 tile_size: u32,
254 radius_scale: f32,
255 alpha_cutoff: f32,
256 max_splat_steps: u32,
257 transmittance_threshold: f32,
258 max_list_entries: u32,
259 loss_grad_clip: f32,
260 sh_band: u32,
261 max_anisotropy: f32,
262) -> Vec<f32> {
263 HOST_BACKWARD
264 .get()
265 .expect("call `rlx_splat::register()` before host splat backward")(HostBackwardArgs {
266 positions: positions.to_vec(),
267 scales: scales.to_vec(),
268 rotations: rotations.to_vec(),
269 opacities: opacities.to_vec(),
270 colors: colors.to_vec(),
271 sh_coeffs: sh_coeffs.to_vec(),
272 meta: meta.to_vec(),
273 d_loss_rgba: d_loss_rgba.to_vec(),
274 width,
275 height,
276 tile_size,
277 radius_scale,
278 alpha_cutoff,
279 max_splat_steps,
280 transmittance_threshold,
281 max_list_entries,
282 loss_grad_clip,
283 sh_band,
284 max_anisotropy,
285 })
286}
287
288#[allow(unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
290pub unsafe fn execute_gaussian_splat_prepare(
291 positions_off: usize,
292 positions_len: usize,
293 scales_off: usize,
294 scales_len: usize,
295 rotations_off: usize,
296 rotations_len: usize,
297 opacities_off: usize,
298 opacities_len: usize,
299 colors_off: usize,
300 colors_len: usize,
301 sh_coeffs_off: usize,
302 sh_coeffs_len: usize,
303 meta_off: usize,
304 meta_len: usize,
305 prep_off: usize,
306 prep_len: usize,
307 width: u32,
308 height: u32,
309 tile_size: u32,
310 radius_scale: f32,
311 alpha_cutoff: f32,
312 max_splat_steps: u32,
313 transmittance_threshold: f32,
314 max_list_entries: u32,
315 base: *mut u8,
316) {
317 PREPARE
318 .get()
319 .expect("call `rlx_splat::register()` before GaussianSplatPrepare")(ArenaPrepareArgs {
320 positions_off,
321 positions_len,
322 scales_off,
323 scales_len,
324 rotations_off,
325 rotations_len,
326 opacities_off,
327 opacities_len,
328 colors_off,
329 colors_len,
330 sh_coeffs_off,
331 sh_coeffs_len,
332 meta_off,
333 meta_len,
334 prep_off,
335 prep_len,
336 width,
337 height,
338 tile_size,
339 radius_scale,
340 alpha_cutoff,
341 max_splat_steps,
342 transmittance_threshold,
343 max_list_entries,
344 base,
345 });
346}
347
348#[allow(unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
350pub unsafe fn execute_gaussian_splat_rasterize(
351 prep_off: usize,
352 prep_len: usize,
353 meta_off: usize,
354 meta_len: usize,
355 dst_off: usize,
356 dst_len: usize,
357 count: usize,
358 width: u32,
359 height: u32,
360 tile_size: u32,
361 alpha_cutoff: f32,
362 max_splat_steps: u32,
363 transmittance_threshold: f32,
364 max_list_entries: u32,
365 base: *mut u8,
366) {
367 RASTERIZE
368 .get()
369 .expect("call `rlx_splat::register()` before GaussianSplatRasterize")(
370 ArenaRasterizeArgs {
371 prep_off,
372 prep_len,
373 meta_off,
374 meta_len,
375 dst_off,
376 dst_len,
377 count,
378 width,
379 height,
380 tile_size,
381 alpha_cutoff,
382 max_splat_steps,
383 transmittance_threshold,
384 max_list_entries,
385 base,
386 },
387 );
388}
389
390#[allow(unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
392pub unsafe fn execute_gaussian_splat_render(
393 positions_off: usize,
394 positions_len: usize,
395 scales_off: usize,
396 scales_len: usize,
397 rotations_off: usize,
398 rotations_len: usize,
399 opacities_off: usize,
400 opacities_len: usize,
401 colors_off: usize,
402 colors_len: usize,
403 sh_coeffs_off: usize,
404 sh_coeffs_len: usize,
405 meta_off: usize,
406 dst_off: usize,
407 dst_len: usize,
408 width: u32,
409 height: u32,
410 tile_size: u32,
411 radius_scale: f32,
412 alpha_cutoff: f32,
413 max_splat_steps: u32,
414 transmittance_threshold: f32,
415 max_list_entries: u32,
416 base: *mut u8,
417) {
418 RENDER
419 .get()
420 .expect("call `rlx_splat::register()` before GaussianSplatRender")(ArenaRenderArgs {
421 positions_off,
422 positions_len,
423 scales_off,
424 scales_len,
425 rotations_off,
426 rotations_len,
427 opacities_off,
428 opacities_len,
429 colors_off,
430 colors_len,
431 sh_coeffs_off,
432 sh_coeffs_len,
433 meta_off,
434 dst_off,
435 dst_len,
436 width,
437 height,
438 tile_size,
439 radius_scale,
440 alpha_cutoff,
441 max_splat_steps,
442 transmittance_threshold,
443 max_list_entries,
444 base,
445 });
446}
447
448#[allow(unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
450pub unsafe fn execute_gaussian_splat_render_backward(
451 positions_off: usize,
452 positions_len: usize,
453 scales_off: usize,
454 scales_len: usize,
455 rotations_off: usize,
456 rotations_len: usize,
457 opacities_off: usize,
458 opacities_len: usize,
459 colors_off: usize,
460 colors_len: usize,
461 sh_coeffs_off: usize,
462 sh_coeffs_len: usize,
463 meta_off: usize,
464 d_loss_off: usize,
465 d_loss_len: usize,
466 packed_off: usize,
467 packed_len: usize,
468 width: u32,
469 height: u32,
470 tile_size: u32,
471 radius_scale: f32,
472 alpha_cutoff: f32,
473 max_splat_steps: u32,
474 transmittance_threshold: f32,
475 max_list_entries: u32,
476 loss_grad_clip: f32,
477 sh_band: u32,
478 max_anisotropy: f32,
479 base: *mut u8,
480) {
481 RENDER_BWD
482 .get()
483 .expect("call `rlx_splat::register()` before GaussianSplatRenderBackward")(
484 ArenaRenderBwdArgs {
485 positions_off,
486 positions_len,
487 scales_off,
488 scales_len,
489 rotations_off,
490 rotations_len,
491 opacities_off,
492 opacities_len,
493 colors_off,
494 colors_len,
495 sh_coeffs_off,
496 sh_coeffs_len,
497 meta_off,
498 d_loss_off,
499 d_loss_len,
500 packed_off,
501 packed_len,
502 width,
503 height,
504 tile_size,
505 radius_scale,
506 alpha_cutoff,
507 max_splat_steps,
508 transmittance_threshold,
509 max_list_entries,
510 loss_grad_clip,
511 sh_band,
512 max_anisotropy,
513 base,
514 },
515 );
516}