Skip to main content

rlx_ir/ops/
splat.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! 3D Gaussian splatting graph builders.
17
18use crate::infer::GraphExt;
19use crate::op::Op;
20use crate::{DType, Graph, NodeId, Shape};
21
22/// Packed scene + camera tensors for [`Op::GaussianSplatRender`].
23#[derive(Clone, Debug)]
24pub struct GaussianSplatInputs {
25    pub positions: NodeId,
26    pub scales: NodeId,
27    pub rotations: NodeId,
28    pub opacities: NodeId,
29    pub colors: NodeId,
30    pub sh_coeffs: NodeId,
31    pub meta: NodeId,
32}
33
34/// Render-parameter bundle (embedded in the op, not separate tensors).
35#[derive(Clone, Copy, Debug)]
36pub struct GaussianSplatRenderParams {
37    pub width: u32,
38    pub height: u32,
39    pub tile_size: u32,
40    pub radius_scale: f32,
41    pub alpha_cutoff: f32,
42    pub max_splat_steps: u32,
43    pub transmittance_threshold: f32,
44    pub max_list_entries: u32,
45}
46
47/// Training backward parameters for [`Op::GaussianSplatRenderBackward`].
48#[derive(Clone, Copy, Debug)]
49pub struct GaussianSplatBackwardParams {
50    pub render: GaussianSplatRenderParams,
51    pub loss_grad_clip: f32,
52    pub sh_band: u32,
53    pub max_anisotropy: f32,
54}
55
56impl Default for GaussianSplatBackwardParams {
57    fn default() -> Self {
58        Self {
59            render: GaussianSplatRenderParams::default(),
60            loss_grad_clip: 1.0,
61            sh_band: 0,
62            max_anisotropy: 10.0,
63        }
64    }
65}
66
67/// Trailing raster params in a packed prepare buffer (`width`, `height`, …).
68pub const GAUSSIAN_SPLAT_PREP_RASTER_PARAMS_FLOATS: usize = 11;
69
70/// Tile count for a framebuffer (matches `rlx_splat::prep_layout::tile_count`).
71pub fn gaussian_splat_tile_count(width: u32, height: u32, tile_size: u32) -> u32 {
72    let tw = width.div_ceil(tile_size);
73    let th = height.div_ceil(tile_size);
74    tw * th
75}
76
77/// Packed prepare-buffer length for `N` splats (must match `rlx_splat::prep_layout::pack_prepared`).
78pub fn gaussian_splat_prep_packed_len(
79    count: usize,
80    max_list_entries: u32,
81    width: u32,
82    height: u32,
83    tile_size: u32,
84) -> usize {
85    let n = count.max(1);
86    let max_list = max_list_entries as usize;
87    let tiles = gaussian_splat_tile_count(width, height, tile_size) as usize;
88    let pixels = (width as usize).saturating_mul(height as usize).max(1);
89    n * 4
90        + n
91        + n * 3
92        + n * 3
93        + n * 4
94        + max_list
95        + tiles * 2
96        + pixels * 3
97        + GAUSSIAN_SPLAT_PREP_RASTER_PARAMS_FLOATS
98}
99
100/// Packed scene gradient layout lengths for `N` splats and `sh_coeff_count` SH bands.
101pub fn gaussian_splat_packed_grad_len(count: usize, sh_coeff_count: usize) -> usize {
102    count * (3 + 3 + 4 + 1 + 3) + count * sh_coeff_count.max(1) * 3
103}
104
105/// Unpack [`Op::GaussianSplatRenderBackward`] output into per-parameter gradients.
106pub fn unpack_gaussian_splat_packed_grads(
107    g: &mut Graph,
108    packed: NodeId,
109    count: usize,
110    sh_coeff_count: usize,
111) -> GaussianSplatInputs {
112    let mut off = 0usize;
113    let mut take = |len: usize| -> NodeId {
114        let id = g.narrow_(packed, 0, off, len);
115        off += len;
116        id
117    };
118    let positions = take(count * 3);
119    let scales = take(count * 3);
120    let rotations = take(count * 4);
121    let opacities = take(count);
122    let colors = take(count * 3);
123    let sh_coeffs = take(count * sh_coeff_count.max(1) * 3);
124    let _ = off;
125    GaussianSplatInputs {
126        positions,
127        scales,
128        rotations,
129        opacities,
130        colors,
131        sh_coeffs,
132        meta: packed,
133    }
134}
135
136impl Default for GaussianSplatRenderParams {
137    fn default() -> Self {
138        Self {
139            width: 64,
140            height: 64,
141            tile_size: 16,
142            radius_scale: 1.6,
143            alpha_cutoff: 1.0 / 255.0,
144            max_splat_steps: 32,
145            transmittance_threshold: 0.01,
146            max_list_entries: 18 * 32,
147        }
148    }
149}
150
151impl Graph {
152    /// First-class CPU reference Gaussian splat forward render.
153    ///
154    /// See [`Op::GaussianSplatRender`] for the seven-input contract and
155    /// [`GaussianSplatRenderParams`] for framebuffer settings.
156    pub fn gaussian_splat_render(
157        &mut self,
158        inputs: GaussianSplatInputs,
159        params: GaussianSplatRenderParams,
160    ) -> NodeId {
161        let out_elems = (params.width as usize) * (params.height as usize) * 4;
162        let dtype = self.shape(inputs.positions).dtype();
163        let out_shape = Shape::new(&[out_elems], dtype);
164        self.push(
165            Op::GaussianSplatRender {
166                width: params.width,
167                height: params.height,
168                tile_size: params.tile_size,
169                radius_scale: params.radius_scale,
170                alpha_cutoff: params.alpha_cutoff,
171                max_splat_steps: params.max_splat_steps,
172                transmittance_threshold: params.transmittance_threshold,
173                max_list_entries: params.max_list_entries,
174            },
175            vec![
176                inputs.positions,
177                inputs.scales,
178                inputs.rotations,
179                inputs.opacities,
180                inputs.colors,
181                inputs.sh_coeffs,
182                inputs.meta,
183            ],
184            out_shape,
185            None,
186        )
187    }
188
189    /// Build the 23-float `meta` vector expected by [`Op::GaussianSplatRender`].
190    pub fn gaussian_splat_render_meta(
191        &mut self,
192        camera_position: [f32; 3],
193        camera_target: [f32; 3],
194        camera_up: [f32; 3],
195        fov_y_degrees: f32,
196        near: f32,
197        far: f32,
198        background: [f32; 3],
199        params: GaussianSplatRenderParams,
200    ) -> NodeId {
201        let values = vec![
202            camera_position[0],
203            camera_position[1],
204            camera_position[2],
205            camera_target[0],
206            camera_target[1],
207            camera_target[2],
208            camera_up[0],
209            camera_up[1],
210            camera_up[2],
211            fov_y_degrees,
212            near,
213            far,
214            background[0],
215            background[1],
216            background[2],
217            params.width as f32,
218            params.height as f32,
219            params.tile_size as f32,
220            params.radius_scale,
221            params.alpha_cutoff,
222            params.max_splat_steps as f32,
223            params.transmittance_threshold,
224            params.max_list_entries as f32,
225        ];
226        let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
227        self.add_node(
228            Op::Constant { data: bytes },
229            vec![],
230            Shape::new(&[23], DType::F32),
231        )
232    }
233
234    /// Strict IR stage 1: project + bin + sort + rays → packed prepare buffer.
235    pub fn gaussian_splat_prepare(
236        &mut self,
237        inputs: GaussianSplatInputs,
238        params: GaussianSplatRenderParams,
239    ) -> NodeId {
240        let count = self.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
241        let packed_len = gaussian_splat_prep_packed_len(
242            count,
243            params.max_list_entries,
244            params.width,
245            params.height,
246            params.tile_size,
247        );
248        let dtype = self.shape(inputs.positions).dtype();
249        self.push(
250            Op::GaussianSplatPrepare {
251                width: params.width,
252                height: params.height,
253                tile_size: params.tile_size,
254                radius_scale: params.radius_scale,
255                alpha_cutoff: params.alpha_cutoff,
256                max_splat_steps: params.max_splat_steps,
257                transmittance_threshold: params.transmittance_threshold,
258                max_list_entries: params.max_list_entries,
259            },
260            vec![
261                inputs.positions,
262                inputs.scales,
263                inputs.rotations,
264                inputs.opacities,
265                inputs.colors,
266                inputs.sh_coeffs,
267                inputs.meta,
268            ],
269            Shape::new(&[packed_len], dtype),
270            None,
271        )
272    }
273
274    /// Strict IR stage 2: rasterize from prepare buffer + meta.
275    pub fn gaussian_splat_rasterize(
276        &mut self,
277        prep: NodeId,
278        meta: NodeId,
279        params: GaussianSplatRenderParams,
280    ) -> NodeId {
281        let out_elems = (params.width as usize) * (params.height as usize) * 4;
282        let dtype = self.shape(prep).dtype();
283        self.push(
284            Op::GaussianSplatRasterize {
285                width: params.width,
286                height: params.height,
287                tile_size: params.tile_size,
288                alpha_cutoff: params.alpha_cutoff,
289                max_splat_steps: params.max_splat_steps,
290                transmittance_threshold: params.transmittance_threshold,
291                max_list_entries: params.max_list_entries,
292            },
293            vec![prep, meta],
294            Shape::new(&[out_elems], dtype),
295            None,
296        )
297    }
298
299    /// Decomposed strict-IR forward: prepare → rasterize.
300    pub fn gaussian_splat_render_decomposed(
301        &mut self,
302        inputs: GaussianSplatInputs,
303        params: GaussianSplatRenderParams,
304    ) -> NodeId {
305        let meta = inputs.meta;
306        let prep = self.gaussian_splat_prepare(inputs, params);
307        self.gaussian_splat_rasterize(prep, meta, params)
308    }
309
310    /// Backward pass for [`Op::GaussianSplatRender`] (packed scene gradients).
311    pub fn gaussian_splat_render_backward(
312        &mut self,
313        inputs: GaussianSplatInputs,
314        d_loss_rgba: NodeId,
315        params: GaussianSplatBackwardParams,
316    ) -> NodeId {
317        let count = self.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
318        let sh_len = self.shape(inputs.sh_coeffs).num_elements().unwrap_or(0);
319        let sh_coeff_count = if count == 0 {
320            1
321        } else {
322            (sh_len / (count * 3)).max(1)
323        };
324        let packed_len = gaussian_splat_packed_grad_len(count, sh_coeff_count);
325        let dtype = self.shape(inputs.positions).dtype();
326        let r = params.render;
327        self.push(
328            Op::GaussianSplatRenderBackward {
329                width: r.width,
330                height: r.height,
331                tile_size: r.tile_size,
332                radius_scale: r.radius_scale,
333                alpha_cutoff: r.alpha_cutoff,
334                max_splat_steps: r.max_splat_steps,
335                transmittance_threshold: r.transmittance_threshold,
336                max_list_entries: r.max_list_entries,
337                loss_grad_clip: params.loss_grad_clip,
338                sh_band: params.sh_band,
339                max_anisotropy: params.max_anisotropy,
340            },
341            vec![
342                inputs.positions,
343                inputs.scales,
344                inputs.rotations,
345                inputs.opacities,
346                inputs.colors,
347                inputs.sh_coeffs,
348                inputs.meta,
349                d_loss_rgba,
350            ],
351            Shape::new(&[packed_len], dtype),
352            None,
353        )
354    }
355}