Skip to main content

rlx_cpu/
splat.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15//! CPU dispatch hooks for [`rlx_ir::Op::GaussianSplatRender`] — bodies registered from `rlx-splat`.
16
17use std::sync::OnceLock;
18
19type RenderExec = Box<dyn Fn(ArenaRenderArgs) + Send + Sync>;
20type RenderBwdExec = Box<dyn Fn(ArenaRenderBwdArgs) + Send + Sync>;
21type PrepareExec = Box<dyn Fn(ArenaPrepareArgs) + Send + Sync>;
22type RasterizeExec = Box<dyn Fn(ArenaRasterizeArgs) + Send + Sync>;
23type HostRenderExec = Box<dyn Fn(HostRenderArgs) -> Vec<f32> + Send + Sync>;
24type HostBackwardExec = Box<dyn Fn(HostBackwardArgs) -> Vec<f32> + Send + Sync>;
25
26static RENDER: OnceLock<RenderExec> = OnceLock::new();
27static RENDER_BWD: OnceLock<RenderBwdExec> = OnceLock::new();
28static PREPARE: OnceLock<PrepareExec> = OnceLock::new();
29static RASTERIZE: OnceLock<RasterizeExec> = OnceLock::new();
30static HOST_RENDER: OnceLock<HostRenderExec> = OnceLock::new();
31static HOST_BACKWARD: OnceLock<HostBackwardExec> = OnceLock::new();
32
33/// Arena arguments for forward splat.
34#[allow(clippy::struct_excessive_bools)]
35pub struct ArenaRenderArgs {
36    pub positions_off: usize,
37    pub positions_len: usize,
38    pub scales_off: usize,
39    pub scales_len: usize,
40    pub rotations_off: usize,
41    pub rotations_len: usize,
42    pub opacities_off: usize,
43    pub opacities_len: usize,
44    pub colors_off: usize,
45    pub colors_len: usize,
46    pub sh_coeffs_off: usize,
47    pub sh_coeffs_len: usize,
48    pub meta_off: usize,
49    pub dst_off: usize,
50    pub dst_len: usize,
51    pub width: u32,
52    pub height: u32,
53    pub tile_size: u32,
54    pub radius_scale: f32,
55    pub alpha_cutoff: f32,
56    pub max_splat_steps: u32,
57    pub transmittance_threshold: f32,
58    pub max_list_entries: u32,
59    pub base: *mut u8,
60}
61
62/// Arena arguments for [`Op::GaussianSplatPrepare`].
63pub struct ArenaPrepareArgs {
64    pub positions_off: usize,
65    pub positions_len: usize,
66    pub scales_off: usize,
67    pub scales_len: usize,
68    pub rotations_off: usize,
69    pub rotations_len: usize,
70    pub opacities_off: usize,
71    pub opacities_len: usize,
72    pub colors_off: usize,
73    pub colors_len: usize,
74    pub sh_coeffs_off: usize,
75    pub sh_coeffs_len: usize,
76    pub meta_off: usize,
77    pub meta_len: usize,
78    pub prep_off: usize,
79    pub prep_len: usize,
80    pub width: u32,
81    pub height: u32,
82    pub tile_size: u32,
83    pub radius_scale: f32,
84    pub alpha_cutoff: f32,
85    pub max_splat_steps: u32,
86    pub transmittance_threshold: f32,
87    pub max_list_entries: u32,
88    pub base: *mut u8,
89}
90
91/// Arena arguments for [`Op::GaussianSplatRasterize`].
92pub struct ArenaRasterizeArgs {
93    pub prep_off: usize,
94    pub prep_len: usize,
95    pub meta_off: usize,
96    pub meta_len: usize,
97    pub dst_off: usize,
98    pub dst_len: usize,
99    pub count: usize,
100    pub width: u32,
101    pub height: u32,
102    pub tile_size: u32,
103    pub alpha_cutoff: f32,
104    pub max_splat_steps: u32,
105    pub transmittance_threshold: f32,
106    pub max_list_entries: u32,
107    pub base: *mut u8,
108}
109
110/// Arena arguments for backward splat.
111pub struct ArenaRenderBwdArgs {
112    pub positions_off: usize,
113    pub positions_len: usize,
114    pub scales_off: usize,
115    pub scales_len: usize,
116    pub rotations_off: usize,
117    pub rotations_len: usize,
118    pub opacities_off: usize,
119    pub opacities_len: usize,
120    pub colors_off: usize,
121    pub colors_len: usize,
122    pub sh_coeffs_off: usize,
123    pub sh_coeffs_len: usize,
124    pub meta_off: usize,
125    pub d_loss_off: usize,
126    pub d_loss_len: usize,
127    pub packed_off: usize,
128    pub packed_len: usize,
129    pub width: u32,
130    pub height: u32,
131    pub tile_size: u32,
132    pub radius_scale: f32,
133    pub alpha_cutoff: f32,
134    pub max_splat_steps: u32,
135    pub transmittance_threshold: f32,
136    pub max_list_entries: u32,
137    pub loss_grad_clip: f32,
138    pub sh_band: u32,
139    pub max_anisotropy: f32,
140    pub base: *mut u8,
141}
142
143/// Host-buffer forward splat.
144pub struct HostRenderArgs {
145    pub positions: Vec<f32>,
146    pub scales: Vec<f32>,
147    pub rotations: Vec<f32>,
148    pub opacities: Vec<f32>,
149    pub colors: Vec<f32>,
150    pub sh_coeffs: Vec<f32>,
151    pub meta: Vec<f32>,
152    pub width: u32,
153    pub height: u32,
154    pub tile_size: u32,
155    pub radius_scale: f32,
156    pub alpha_cutoff: f32,
157    pub max_splat_steps: u32,
158    pub transmittance_threshold: f32,
159    pub max_list_entries: u32,
160}
161
162/// Host-buffer backward splat.
163pub struct HostBackwardArgs {
164    pub positions: Vec<f32>,
165    pub scales: Vec<f32>,
166    pub rotations: Vec<f32>,
167    pub opacities: Vec<f32>,
168    pub colors: Vec<f32>,
169    pub sh_coeffs: Vec<f32>,
170    pub meta: Vec<f32>,
171    pub d_loss_rgba: Vec<f32>,
172    pub width: u32,
173    pub height: u32,
174    pub tile_size: u32,
175    pub radius_scale: f32,
176    pub alpha_cutoff: f32,
177    pub max_splat_steps: u32,
178    pub transmittance_threshold: f32,
179    pub max_list_entries: u32,
180    pub loss_grad_clip: f32,
181    pub sh_band: u32,
182    pub max_anisotropy: f32,
183}
184
185/// Register arena + host splat executors (`rlx_splat::register()`).
186pub fn register_splat_executors(
187    render: RenderExec,
188    backward: RenderBwdExec,
189    prepare: PrepareExec,
190    rasterize: RasterizeExec,
191    host_render: HostRenderExec,
192    host_backward: HostBackwardExec,
193) {
194    let _ = RENDER.set(render);
195    let _ = RENDER_BWD.set(backward);
196    let _ = PREPARE.set(prepare);
197    let _ = RASTERIZE.set(rasterize);
198    let _ = HOST_RENDER.set(host_render);
199    let _ = HOST_BACKWARD.set(host_backward);
200}
201
202#[allow(clippy::too_many_arguments)]
203pub fn render_host_slices(
204    positions: &[f32],
205    scales: &[f32],
206    rotations: &[f32],
207    opacities: &[f32],
208    colors: &[f32],
209    sh_coeffs: &[f32],
210    meta: &[f32],
211    width: u32,
212    height: u32,
213    tile_size: u32,
214    radius_scale: f32,
215    alpha_cutoff: f32,
216    max_splat_steps: u32,
217    transmittance_threshold: f32,
218    max_list_entries: u32,
219) -> Vec<f32> {
220    HOST_RENDER
221        .get()
222        .expect("call `rlx_splat::register()` before host splat render")(HostRenderArgs {
223        positions: positions.to_vec(),
224        scales: scales.to_vec(),
225        rotations: rotations.to_vec(),
226        opacities: opacities.to_vec(),
227        colors: colors.to_vec(),
228        sh_coeffs: sh_coeffs.to_vec(),
229        meta: meta.to_vec(),
230        width,
231        height,
232        tile_size,
233        radius_scale,
234        alpha_cutoff,
235        max_splat_steps,
236        transmittance_threshold,
237        max_list_entries,
238    })
239}
240
241#[allow(clippy::too_many_arguments)]
242pub fn backward_host_slices(
243    positions: &[f32],
244    scales: &[f32],
245    rotations: &[f32],
246    opacities: &[f32],
247    colors: &[f32],
248    sh_coeffs: &[f32],
249    meta: &[f32],
250    d_loss_rgba: &[f32],
251    width: u32,
252    height: u32,
253    tile_size: u32,
254    radius_scale: f32,
255    alpha_cutoff: f32,
256    max_splat_steps: u32,
257    transmittance_threshold: f32,
258    max_list_entries: u32,
259    loss_grad_clip: f32,
260    sh_band: u32,
261    max_anisotropy: f32,
262) -> Vec<f32> {
263    HOST_BACKWARD
264        .get()
265        .expect("call `rlx_splat::register()` before host splat backward")(HostBackwardArgs {
266        positions: positions.to_vec(),
267        scales: scales.to_vec(),
268        rotations: rotations.to_vec(),
269        opacities: opacities.to_vec(),
270        colors: colors.to_vec(),
271        sh_coeffs: sh_coeffs.to_vec(),
272        meta: meta.to_vec(),
273        d_loss_rgba: d_loss_rgba.to_vec(),
274        width,
275        height,
276        tile_size,
277        radius_scale,
278        alpha_cutoff,
279        max_splat_steps,
280        transmittance_threshold,
281        max_list_entries,
282        loss_grad_clip,
283        sh_band,
284        max_anisotropy,
285    })
286}
287
288/// Execute [`Op::GaussianSplatPrepare`].
289#[allow(unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
290pub unsafe fn execute_gaussian_splat_prepare(
291    positions_off: usize,
292    positions_len: usize,
293    scales_off: usize,
294    scales_len: usize,
295    rotations_off: usize,
296    rotations_len: usize,
297    opacities_off: usize,
298    opacities_len: usize,
299    colors_off: usize,
300    colors_len: usize,
301    sh_coeffs_off: usize,
302    sh_coeffs_len: usize,
303    meta_off: usize,
304    meta_len: usize,
305    prep_off: usize,
306    prep_len: usize,
307    width: u32,
308    height: u32,
309    tile_size: u32,
310    radius_scale: f32,
311    alpha_cutoff: f32,
312    max_splat_steps: u32,
313    transmittance_threshold: f32,
314    max_list_entries: u32,
315    base: *mut u8,
316) {
317    PREPARE
318        .get()
319        .expect("call `rlx_splat::register()` before GaussianSplatPrepare")(ArenaPrepareArgs {
320        positions_off,
321        positions_len,
322        scales_off,
323        scales_len,
324        rotations_off,
325        rotations_len,
326        opacities_off,
327        opacities_len,
328        colors_off,
329        colors_len,
330        sh_coeffs_off,
331        sh_coeffs_len,
332        meta_off,
333        meta_len,
334        prep_off,
335        prep_len,
336        width,
337        height,
338        tile_size,
339        radius_scale,
340        alpha_cutoff,
341        max_splat_steps,
342        transmittance_threshold,
343        max_list_entries,
344        base,
345    });
346}
347
348/// Execute [`Op::GaussianSplatRasterize`].
349#[allow(unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
350pub unsafe fn execute_gaussian_splat_rasterize(
351    prep_off: usize,
352    prep_len: usize,
353    meta_off: usize,
354    meta_len: usize,
355    dst_off: usize,
356    dst_len: usize,
357    count: usize,
358    width: u32,
359    height: u32,
360    tile_size: u32,
361    alpha_cutoff: f32,
362    max_splat_steps: u32,
363    transmittance_threshold: f32,
364    max_list_entries: u32,
365    base: *mut u8,
366) {
367    RASTERIZE
368        .get()
369        .expect("call `rlx_splat::register()` before GaussianSplatRasterize")(
370        ArenaRasterizeArgs {
371            prep_off,
372            prep_len,
373            meta_off,
374            meta_len,
375            dst_off,
376            dst_len,
377            count,
378            width,
379            height,
380            tile_size,
381            alpha_cutoff,
382            max_splat_steps,
383            transmittance_threshold,
384            max_list_entries,
385            base,
386        },
387    );
388}
389
390/// Execute [`Op::GaussianSplatRender`] against the arena `base` pointer.
391#[allow(unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
392pub unsafe fn execute_gaussian_splat_render(
393    positions_off: usize,
394    positions_len: usize,
395    scales_off: usize,
396    scales_len: usize,
397    rotations_off: usize,
398    rotations_len: usize,
399    opacities_off: usize,
400    opacities_len: usize,
401    colors_off: usize,
402    colors_len: usize,
403    sh_coeffs_off: usize,
404    sh_coeffs_len: usize,
405    meta_off: usize,
406    dst_off: usize,
407    dst_len: usize,
408    width: u32,
409    height: u32,
410    tile_size: u32,
411    radius_scale: f32,
412    alpha_cutoff: f32,
413    max_splat_steps: u32,
414    transmittance_threshold: f32,
415    max_list_entries: u32,
416    base: *mut u8,
417) {
418    RENDER
419        .get()
420        .expect("call `rlx_splat::register()` before GaussianSplatRender")(ArenaRenderArgs {
421        positions_off,
422        positions_len,
423        scales_off,
424        scales_len,
425        rotations_off,
426        rotations_len,
427        opacities_off,
428        opacities_len,
429        colors_off,
430        colors_len,
431        sh_coeffs_off,
432        sh_coeffs_len,
433        meta_off,
434        dst_off,
435        dst_len,
436        width,
437        height,
438        tile_size,
439        radius_scale,
440        alpha_cutoff,
441        max_splat_steps,
442        transmittance_threshold,
443        max_list_entries,
444        base,
445    });
446}
447
448/// Execute [`Op::GaussianSplatRenderBackward`].
449#[allow(unsafe_op_in_unsafe_fn, clippy::too_many_arguments)]
450pub unsafe fn execute_gaussian_splat_render_backward(
451    positions_off: usize,
452    positions_len: usize,
453    scales_off: usize,
454    scales_len: usize,
455    rotations_off: usize,
456    rotations_len: usize,
457    opacities_off: usize,
458    opacities_len: usize,
459    colors_off: usize,
460    colors_len: usize,
461    sh_coeffs_off: usize,
462    sh_coeffs_len: usize,
463    meta_off: usize,
464    d_loss_off: usize,
465    d_loss_len: usize,
466    packed_off: usize,
467    packed_len: usize,
468    width: u32,
469    height: u32,
470    tile_size: u32,
471    radius_scale: f32,
472    alpha_cutoff: f32,
473    max_splat_steps: u32,
474    transmittance_threshold: f32,
475    max_list_entries: u32,
476    loss_grad_clip: f32,
477    sh_band: u32,
478    max_anisotropy: f32,
479    base: *mut u8,
480) {
481    RENDER_BWD
482        .get()
483        .expect("call `rlx_splat::register()` before GaussianSplatRenderBackward")(
484        ArenaRenderBwdArgs {
485            positions_off,
486            positions_len,
487            scales_off,
488            scales_len,
489            rotations_off,
490            rotations_len,
491            opacities_off,
492            opacities_len,
493            colors_off,
494            colors_len,
495            sh_coeffs_off,
496            sh_coeffs_len,
497            meta_off,
498            d_loss_off,
499            d_loss_len,
500            packed_off,
501            packed_len,
502            width,
503            height,
504            tile_size,
505            radius_scale,
506            alpha_cutoff,
507            max_splat_steps,
508            transmittance_threshold,
509            max_list_entries,
510            loss_grad_clip,
511            sh_band,
512            max_anisotropy,
513            base,
514        },
515    );
516}