Skip to main content

gaussian/
gaussian.rs

1use cuneus::compute::{ComputeShader, ComputeShaderBuilder, PassDescription, StorageBufferSpec, COMPUTE_TEXTURE_FORMAT_RGBA16};
2use cuneus::{Core, RenderKit, ShaderApp, ShaderControls, ShaderManager};
3use cuneus::{ExportManager, UniformProvider};
4use log::error;
5use cuneus::WindowEvent;
6
7#[repr(C)]
8#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
9struct GaussianParams {
10    num_gaussians: u32,
11    learning_rate: f32,
12    color_learning_rate: f32,
13    reset_training: u32,
14    show_target: u32,
15    show_error: u32,
16    temperature: f32,
17    error_scale: f32,
18    min_sigma: f32,
19    max_sigma: f32,
20    position_noise: f32,
21    random_seed: u32,
22    iteration: u32,
23    sigma_learning_rate: f32,
24    _padding0: u32,
25    _padding1: u32,
26}
27
28impl Default for GaussianParams {
29    fn default() -> Self {
30        Self {
31            num_gaussians: 20000,
32
33            learning_rate: 0.01,
34
35
36            color_learning_rate: 0.008,
37
38            reset_training: 0,
39            show_target: 0,
40            show_error: 0,
41
42
43            temperature: 1.0,
44
45            error_scale: 2.0,
46
47            min_sigma: 0.001,
48
49            max_sigma: 0.05,
50
51            position_noise: 0.5,
52
53            random_seed: 42,
54            iteration: 0,
55
56            sigma_learning_rate: 0.001,
57
58            _padding0: 0,
59            _padding1: 0,
60        }
61    }
62}
63
64impl UniformProvider for GaussianParams {
65    fn as_bytes(&self) -> &[u8] {
66        bytemuck::bytes_of(self)
67    }
68}
69
70struct GaussianShader {
71    base: RenderKit,
72    compute_shader: ComputeShader,
73    current_params: GaussianParams,
74}
75
76impl ShaderManager for GaussianShader {
77    fn init(core: &Core) -> Self {
78        let base = RenderKit::new(core);
79
80        // 1. init_gaussians: Initialize/reset Gaussian parameters
81        // 2. clear_gradients: Clear gradient buffer before each iteration
82        // 3. render_display: Render Gaussians + compute gradients via backprop
83        // 4. update_gaussians: Adam to update parameters
84        let passes = vec![
85            PassDescription::new("init_gaussians", &[]),
86            PassDescription::new("clear_gradients", &[]),
87            PassDescription::new("render_display", &[]),
88            PassDescription::new("update_gaussians", &[]),
89        ];
90
91        // Storage buffers for training
92        // Each Gaussian: position(2f32) + sigma(3f32) + color(3f32) + opacity(1f32) = 9 f32 (gradient data)
93        // GaussianData struct: 10 f32 (includes padding)
94        let max_gaussians = 20000u32;
95        let gaussian_buffer_size = (max_gaussians * 40) as u64;
96        let gradient_buffer_size = (max_gaussians * 36) as u64;
97        let adam_buffer_size = (max_gaussians * 36) as u64;
98
99        let config = ComputeShaderBuilder::new()
100            .with_label("Gaussian Splatting Training")
101            .with_multi_pass(&passes)
102            .with_channels(1)
103            .with_custom_uniforms::<GaussianParams>()
104            .with_storage_buffer(StorageBufferSpec::new("gaussian_params", gaussian_buffer_size))
105            .with_storage_buffer(StorageBufferSpec::new("gradient_buffer", gradient_buffer_size))
106            .with_storage_buffer(StorageBufferSpec::new("adam_first_moment", adam_buffer_size))
107            .with_storage_buffer(StorageBufferSpec::new("adam_second_moment", adam_buffer_size))
108            .with_texture_format(COMPUTE_TEXTURE_FORMAT_RGBA16)
109            .build();
110
111        let compute_shader = cuneus::compute_shader!(core, "shaders/gaussian.wgsl", config);
112
113        let initial_params = GaussianParams::default();
114        let shader = Self {
115            base,
116            compute_shader,
117            current_params: initial_params,
118        };
119
120        shader
121            .compute_shader
122            .set_custom_params(initial_params, &core.queue);
123
124        shader
125    }
126
127    fn update(&mut self, core: &Core) {
128        let current_time = self.base.controls.get_time(&self.base.start_time);
129        let delta = 1.0 / 60.0;
130        self.compute_shader
131            .set_time(current_time, delta, &core.queue);
132
133        // Update target texture from media
134        self.base.update_current_texture(core, &core.queue);
135        if let Some(texture_manager) = self.base.get_current_texture_manager() {
136            self.compute_shader.update_channel_texture(
137                0,
138                &texture_manager.view,
139                &texture_manager.sampler,
140                &core.device,
141                &core.queue,
142            );
143        }
144
145        // Auto-increment iteration counter
146        if self.current_params.reset_training == 0 {
147            self.current_params.iteration = self.current_params.iteration.wrapping_add(1);
148            self.compute_shader.set_custom_params(self.current_params, &core.queue);
149        }
150        self.compute_shader.handle_export(core, &mut self.base);
151    }
152
153    fn resize(&mut self, core: &Core) {
154        self.base.default_resize(core, &mut self.compute_shader);
155    }
156
157    fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
158        let mut frame = self.base.begin_frame(core)?;
159
160        let mut controls_request = self
161            .base
162            .controls
163            .get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
164
165        let mut params = self.current_params;
166        let mut changed = false;
167        let mut should_start_export = false;
168        let mut export_request = self.base.export_manager.get_ui_request();
169
170        let using_video_texture = self.base.using_video_texture;
171        let using_hdri_texture = self.base.using_hdri_texture;
172        let using_webcam_texture = self.base.using_webcam_texture;
173        let video_info = self.base.get_video_info();
174        let hdri_info = self.base.get_hdri_info();
175        let webcam_info = self.base.get_webcam_info();
176
177        let full_output = if self.base.key_handler.show_ui {
178            self.base.render_ui(core, |ctx| {
179                RenderKit::apply_default_style(ctx);
180
181                egui::Window::new("gaussian splatting")
182                    .collapsible(true)
183                    .resizable(true)
184                    .default_width(280.0)
185                    .show(ctx, |ui| {
186                        ui.label(format!("Iteration: {}", params.iteration));
187
188                        egui::CollapsingHeader::new("Training")
189                            .default_open(false)
190                            .show(ui, |ui| {
191                                changed |= ui
192                                    .add(
193                                        egui::Slider::new(&mut params.num_gaussians, 100..=20000)
194                                            .text("N Gauss")
195                                            .logarithmic(true),
196                                    )
197                                    .changed();
198
199                                changed |= ui
200                                    .add(
201                                        egui::Slider::new(&mut params.learning_rate, 0.0001..=0.1)
202                                            .text("pos LR")
203                                            .logarithmic(true),
204                                    )
205                                    .changed();
206
207                                changed |= ui
208                                    .add(
209                                        egui::Slider::new(&mut params.color_learning_rate, 0.001..=0.2)
210                                            .text("col LR")
211                                            .logarithmic(true),
212                                    )
213                                    .changed();
214
215                                changed |= ui
216                                    .add(
217                                        egui::Slider::new(&mut params.temperature, 0.1..=5.0)
218                                            .text("temp"),
219                                    )
220                                    .changed();
221
222                                ui.separator();
223
224                                ui.horizontal(|ui| {
225                                    changed |= ui
226                                        .add(
227                                            egui::Slider::new(&mut params.random_seed, 1..=10000)
228                                                .text("seed"),
229                                        )
230                                        .changed();
231                                    if ui.button("🎲").on_hover_text("Randomize seed").clicked() {
232                                        params.random_seed = (std::time::SystemTime::now()
233                                            .duration_since(std::time::UNIX_EPOCH)
234                                            .unwrap()
235                                            .as_millis() % 10000) as u32;
236                                        params.reset_training = 1;
237                                        changed = true;
238                                    }
239                                });
240
241                                if ui.button("res training").clicked() {
242                                    params.reset_training = 1;
243                                    params.iteration = 0;
244                                    changed = true;
245                                }
246                            });
247
248                        egui::CollapsingHeader::new("vis")
249                            .default_open(false)
250                            .show(ui, |ui| {
251                                let mut show_target = params.show_target != 0;
252                                if ui.checkbox(&mut show_target, "Show Target").changed() {
253                                    params.show_target = if show_target { 1 } else { 0 };
254                                    changed = true;
255                                }
256
257                                let mut show_error = params.show_error != 0;
258                                if ui.checkbox(&mut show_error, "Show Error").changed() {
259                                    params.show_error = if show_error { 1 } else { 0 };
260                                    changed = true;
261                                }
262
263                                if params.show_error != 0 {
264                                    changed |= ui
265                                        .add(
266                                            egui::Slider::new(&mut params.error_scale, 0.5..=10.0)
267                                                .text("Error Scale"),
268                                        )
269                                        .changed();
270                                }
271                            });
272
273                        egui::CollapsingHeader::new("Advanced")
274                            .default_open(false)
275                            .show(ui, |ui| {
276                                changed |= ui
277                                    .add(
278                                        egui::Slider::new(&mut params.sigma_learning_rate, 0.001..=0.1)
279                                            .text("Sigma LR")
280                                            .logarithmic(true),
281                                    )
282                                    .changed();
283
284                                changed |= ui
285                                    .add(
286                                        egui::Slider::new(&mut params.min_sigma, 0.001..=0.05)
287                                            .text("Min Sigma")
288                                            .logarithmic(true),
289                                    )
290                                    .changed();
291
292                                changed |= ui
293                                    .add(
294                                        egui::Slider::new(&mut params.max_sigma, 0.02..=0.3)
295                                            .text("Max Sigma")
296                                            .logarithmic(true),
297                                    )
298                                    .changed();
299
300                                changed |= ui
301                                    .add(
302                                        egui::Slider::new(&mut params.position_noise, 0.0..=1.0)
303                                            .text("Position"),
304                                    )
305                                    .changed();
306                            });
307
308                        ui.separator();
309
310                        ui.separator();
311
312                        ShaderControls::render_media_panel(
313                            ui,
314                            &mut controls_request,
315                            using_video_texture,
316                            video_info,
317                            using_hdri_texture,
318                            hdri_info,
319                            using_webcam_texture,
320                            webcam_info,
321                        );
322
323                        ui.separator();
324                        ShaderControls::render_controls_widget(ui, &mut controls_request);
325
326                        ui.separator();
327                        should_start_export =
328                            ExportManager::render_export_ui_widget(ui, &mut export_request);
329                    });
330            })
331        } else {
332            self.base.render_ui(core, |_ctx| {})
333        };
334
335        self.base.export_manager.apply_ui_request(export_request);
336        self.base.apply_media_requests(core, &controls_request);
337
338        if controls_request.should_clear_buffers || params.reset_training != 0 {
339            self.compute_shader.current_frame = 0;
340            self.compute_shader.time_uniform.data.frame = 0;
341            self.compute_shader.time_uniform.update(&core.queue);
342
343            params.iteration = 0;
344            params.reset_training = 0;
345            changed = true;
346        }
347
348        if changed {
349            self.current_params = params;
350            self.compute_shader.set_custom_params(params, &core.queue);
351        }
352
353        if should_start_export {
354            self.base.export_manager.start_export();
355        }
356
357
358        self.compute_shader.dispatch(&mut frame.encoder, core);
359
360        self.base.renderer.render_to_view(&mut frame.encoder, &frame.view, &self.compute_shader.get_output_texture().bind_group);
361
362        self.base.end_frame(core, frame, full_output);
363
364        Ok(())
365    }
366
367    fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
368        if self.base.default_handle_input(core, event) {
369            return true;
370        }
371        if let WindowEvent::DroppedFile(path) = event {
372            if let Err(e) = self.base.load_media(core, path) {
373                error!("Failed to load dropped file: {e:?}");
374            }
375            self.current_params.reset_training = 1;
376            self.current_params.iteration = 0;
377            self.compute_shader.set_custom_params(self.current_params, &core.queue);
378            return true;
379        }
380        false
381    }
382}
383
384fn main() -> Result<(), Box<dyn std::error::Error>> {
385    cuneus::gst::init()?;
386    env_logger::init();
387    let (app, event_loop) = ShaderApp::new("2D Gaussian Splatting", 450, 350);
388    app.run(event_loop, GaussianShader::init)
389}