1use crate::infer::GraphExt;
19use crate::op::Op;
20use crate::{DType, Graph, NodeId, Shape};
21
22#[derive(Clone, Debug)]
24pub struct GaussianSplatInputs {
25 pub positions: NodeId,
26 pub scales: NodeId,
27 pub rotations: NodeId,
28 pub opacities: NodeId,
29 pub colors: NodeId,
30 pub sh_coeffs: NodeId,
31 pub meta: NodeId,
32}
33
34#[derive(Clone, Copy, Debug)]
36pub struct GaussianSplatRenderParams {
37 pub width: u32,
38 pub height: u32,
39 pub tile_size: u32,
40 pub radius_scale: f32,
41 pub alpha_cutoff: f32,
42 pub max_splat_steps: u32,
43 pub transmittance_threshold: f32,
44 pub max_list_entries: u32,
45}
46
47#[derive(Clone, Copy, Debug)]
49pub struct GaussianSplatBackwardParams {
50 pub render: GaussianSplatRenderParams,
51 pub loss_grad_clip: f32,
52 pub sh_band: u32,
53 pub max_anisotropy: f32,
54}
55
56impl Default for GaussianSplatBackwardParams {
57 fn default() -> Self {
58 Self {
59 render: GaussianSplatRenderParams::default(),
60 loss_grad_clip: 1.0,
61 sh_band: 0,
62 max_anisotropy: 10.0,
63 }
64 }
65}
66
67pub const GAUSSIAN_SPLAT_PREP_RASTER_PARAMS_FLOATS: usize = 11;
69
70pub fn gaussian_splat_tile_count(width: u32, height: u32, tile_size: u32) -> u32 {
72 let tw = width.div_ceil(tile_size);
73 let th = height.div_ceil(tile_size);
74 tw * th
75}
76
77pub fn gaussian_splat_prep_packed_len(
79 count: usize,
80 max_list_entries: u32,
81 width: u32,
82 height: u32,
83 tile_size: u32,
84) -> usize {
85 let n = count.max(1);
86 let max_list = max_list_entries as usize;
87 let tiles = gaussian_splat_tile_count(width, height, tile_size) as usize;
88 let pixels = (width as usize).saturating_mul(height as usize).max(1);
89 n * 4
90 + n
91 + n * 3
92 + n * 3
93 + n * 4
94 + max_list
95 + tiles * 2
96 + pixels * 3
97 + GAUSSIAN_SPLAT_PREP_RASTER_PARAMS_FLOATS
98}
99
100pub fn gaussian_splat_packed_grad_len(count: usize, sh_coeff_count: usize) -> usize {
102 count * (3 + 3 + 4 + 1 + 3) + count * sh_coeff_count.max(1) * 3
103}
104
105pub fn unpack_gaussian_splat_packed_grads(
107 g: &mut Graph,
108 packed: NodeId,
109 count: usize,
110 sh_coeff_count: usize,
111) -> GaussianSplatInputs {
112 let mut off = 0usize;
113 let mut take = |len: usize| -> NodeId {
114 let id = g.narrow_(packed, 0, off, len);
115 off += len;
116 id
117 };
118 let positions = take(count * 3);
119 let scales = take(count * 3);
120 let rotations = take(count * 4);
121 let opacities = take(count);
122 let colors = take(count * 3);
123 let sh_coeffs = take(count * sh_coeff_count.max(1) * 3);
124 let _ = off;
125 GaussianSplatInputs {
126 positions,
127 scales,
128 rotations,
129 opacities,
130 colors,
131 sh_coeffs,
132 meta: packed,
133 }
134}
135
136impl Default for GaussianSplatRenderParams {
137 fn default() -> Self {
138 Self {
139 width: 64,
140 height: 64,
141 tile_size: 16,
142 radius_scale: 1.6,
143 alpha_cutoff: 1.0 / 255.0,
144 max_splat_steps: 32,
145 transmittance_threshold: 0.01,
146 max_list_entries: 18 * 32,
147 }
148 }
149}
150
151impl Graph {
152 pub fn gaussian_splat_render(
157 &mut self,
158 inputs: GaussianSplatInputs,
159 params: GaussianSplatRenderParams,
160 ) -> NodeId {
161 let out_elems = (params.width as usize) * (params.height as usize) * 4;
162 let dtype = self.shape(inputs.positions).dtype();
163 let out_shape = Shape::new(&[out_elems], dtype);
164 self.push(
165 Op::GaussianSplatRender {
166 width: params.width,
167 height: params.height,
168 tile_size: params.tile_size,
169 radius_scale: params.radius_scale,
170 alpha_cutoff: params.alpha_cutoff,
171 max_splat_steps: params.max_splat_steps,
172 transmittance_threshold: params.transmittance_threshold,
173 max_list_entries: params.max_list_entries,
174 },
175 vec![
176 inputs.positions,
177 inputs.scales,
178 inputs.rotations,
179 inputs.opacities,
180 inputs.colors,
181 inputs.sh_coeffs,
182 inputs.meta,
183 ],
184 out_shape,
185 None,
186 )
187 }
188
189 pub fn gaussian_splat_render_meta(
191 &mut self,
192 camera_position: [f32; 3],
193 camera_target: [f32; 3],
194 camera_up: [f32; 3],
195 fov_y_degrees: f32,
196 near: f32,
197 far: f32,
198 background: [f32; 3],
199 params: GaussianSplatRenderParams,
200 ) -> NodeId {
201 let values = vec![
202 camera_position[0],
203 camera_position[1],
204 camera_position[2],
205 camera_target[0],
206 camera_target[1],
207 camera_target[2],
208 camera_up[0],
209 camera_up[1],
210 camera_up[2],
211 fov_y_degrees,
212 near,
213 far,
214 background[0],
215 background[1],
216 background[2],
217 params.width as f32,
218 params.height as f32,
219 params.tile_size as f32,
220 params.radius_scale,
221 params.alpha_cutoff,
222 params.max_splat_steps as f32,
223 params.transmittance_threshold,
224 params.max_list_entries as f32,
225 ];
226 let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
227 self.add_node(
228 Op::Constant { data: bytes },
229 vec![],
230 Shape::new(&[23], DType::F32),
231 )
232 }
233
234 pub fn gaussian_splat_prepare(
236 &mut self,
237 inputs: GaussianSplatInputs,
238 params: GaussianSplatRenderParams,
239 ) -> NodeId {
240 let count = self.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
241 let packed_len = gaussian_splat_prep_packed_len(
242 count,
243 params.max_list_entries,
244 params.width,
245 params.height,
246 params.tile_size,
247 );
248 let dtype = self.shape(inputs.positions).dtype();
249 self.push(
250 Op::GaussianSplatPrepare {
251 width: params.width,
252 height: params.height,
253 tile_size: params.tile_size,
254 radius_scale: params.radius_scale,
255 alpha_cutoff: params.alpha_cutoff,
256 max_splat_steps: params.max_splat_steps,
257 transmittance_threshold: params.transmittance_threshold,
258 max_list_entries: params.max_list_entries,
259 },
260 vec![
261 inputs.positions,
262 inputs.scales,
263 inputs.rotations,
264 inputs.opacities,
265 inputs.colors,
266 inputs.sh_coeffs,
267 inputs.meta,
268 ],
269 Shape::new(&[packed_len], dtype),
270 None,
271 )
272 }
273
274 pub fn gaussian_splat_rasterize(
276 &mut self,
277 prep: NodeId,
278 meta: NodeId,
279 params: GaussianSplatRenderParams,
280 ) -> NodeId {
281 let out_elems = (params.width as usize) * (params.height as usize) * 4;
282 let dtype = self.shape(prep).dtype();
283 self.push(
284 Op::GaussianSplatRasterize {
285 width: params.width,
286 height: params.height,
287 tile_size: params.tile_size,
288 alpha_cutoff: params.alpha_cutoff,
289 max_splat_steps: params.max_splat_steps,
290 transmittance_threshold: params.transmittance_threshold,
291 max_list_entries: params.max_list_entries,
292 },
293 vec![prep, meta],
294 Shape::new(&[out_elems], dtype),
295 None,
296 )
297 }
298
299 pub fn gaussian_splat_render_decomposed(
301 &mut self,
302 inputs: GaussianSplatInputs,
303 params: GaussianSplatRenderParams,
304 ) -> NodeId {
305 let meta = inputs.meta;
306 let prep = self.gaussian_splat_prepare(inputs, params);
307 self.gaussian_splat_rasterize(prep, meta, params)
308 }
309
310 pub fn gaussian_splat_render_backward(
312 &mut self,
313 inputs: GaussianSplatInputs,
314 d_loss_rgba: NodeId,
315 params: GaussianSplatBackwardParams,
316 ) -> NodeId {
317 let count = self.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
318 let sh_len = self.shape(inputs.sh_coeffs).num_elements().unwrap_or(0);
319 let sh_coeff_count = if count == 0 {
320 1
321 } else {
322 (sh_len / (count * 3)).max(1)
323 };
324 let packed_len = gaussian_splat_packed_grad_len(count, sh_coeff_count);
325 let dtype = self.shape(inputs.positions).dtype();
326 let r = params.render;
327 self.push(
328 Op::GaussianSplatRenderBackward {
329 width: r.width,
330 height: r.height,
331 tile_size: r.tile_size,
332 radius_scale: r.radius_scale,
333 alpha_cutoff: r.alpha_cutoff,
334 max_splat_steps: r.max_splat_steps,
335 transmittance_threshold: r.transmittance_threshold,
336 max_list_entries: r.max_list_entries,
337 loss_grad_clip: params.loss_grad_clip,
338 sh_band: params.sh_band,
339 max_anisotropy: params.max_anisotropy,
340 },
341 vec![
342 inputs.positions,
343 inputs.scales,
344 inputs.rotations,
345 inputs.opacities,
346 inputs.colors,
347 inputs.sh_coeffs,
348 inputs.meta,
349 d_loss_rgba,
350 ],
351 Shape::new(&[packed_len], dtype),
352 None,
353 )
354 }
355}