1use cuneus::compute::{ComputeShader, ComputeShaderBuilder, PassDescription, StorageBufferSpec, COMPUTE_TEXTURE_FORMAT_RGBA16};
2use cuneus::{Core, RenderKit, ShaderApp, ShaderControls, ShaderManager};
3use cuneus::{ExportManager, UniformProvider};
4use log::error;
5use cuneus::WindowEvent;
6
7#[repr(C)]
8#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
9struct GaussianParams {
10 num_gaussians: u32,
11 learning_rate: f32,
12 color_learning_rate: f32,
13 reset_training: u32,
14 show_target: u32,
15 show_error: u32,
16 temperature: f32,
17 error_scale: f32,
18 min_sigma: f32,
19 max_sigma: f32,
20 position_noise: f32,
21 random_seed: u32,
22 iteration: u32,
23 sigma_learning_rate: f32,
24 _padding0: u32,
25 _padding1: u32,
26}
27
28impl Default for GaussianParams {
29 fn default() -> Self {
30 Self {
31 num_gaussians: 20000,
32
33 learning_rate: 0.01,
34
35
36 color_learning_rate: 0.008,
37
38 reset_training: 0,
39 show_target: 0,
40 show_error: 0,
41
42
43 temperature: 1.0,
44
45 error_scale: 2.0,
46
47 min_sigma: 0.001,
48
49 max_sigma: 0.05,
50
51 position_noise: 0.5,
52
53 random_seed: 42,
54 iteration: 0,
55
56 sigma_learning_rate: 0.001,
57
58 _padding0: 0,
59 _padding1: 0,
60 }
61 }
62}
63
64impl UniformProvider for GaussianParams {
65 fn as_bytes(&self) -> &[u8] {
66 bytemuck::bytes_of(self)
67 }
68}
69
70struct GaussianShader {
71 base: RenderKit,
72 compute_shader: ComputeShader,
73 current_params: GaussianParams,
74}
75
76impl ShaderManager for GaussianShader {
77 fn init(core: &Core) -> Self {
78 let base = RenderKit::new(core);
79
80 let passes = vec![
85 PassDescription::new("init_gaussians", &[]),
86 PassDescription::new("clear_gradients", &[]),
87 PassDescription::new("render_display", &[]),
88 PassDescription::new("update_gaussians", &[]),
89 ];
90
91 let max_gaussians = 20000u32;
95 let gaussian_buffer_size = (max_gaussians * 40) as u64;
96 let gradient_buffer_size = (max_gaussians * 36) as u64;
97 let adam_buffer_size = (max_gaussians * 36) as u64;
98
99 let config = ComputeShaderBuilder::new()
100 .with_label("Gaussian Splatting Training")
101 .with_multi_pass(&passes)
102 .with_channels(1)
103 .with_custom_uniforms::<GaussianParams>()
104 .with_storage_buffer(StorageBufferSpec::new("gaussian_params", gaussian_buffer_size))
105 .with_storage_buffer(StorageBufferSpec::new("gradient_buffer", gradient_buffer_size))
106 .with_storage_buffer(StorageBufferSpec::new("adam_first_moment", adam_buffer_size))
107 .with_storage_buffer(StorageBufferSpec::new("adam_second_moment", adam_buffer_size))
108 .with_texture_format(COMPUTE_TEXTURE_FORMAT_RGBA16)
109 .build();
110
111 let compute_shader = cuneus::compute_shader!(core, "shaders/gaussian.wgsl", config);
112
113 let initial_params = GaussianParams::default();
114 let shader = Self {
115 base,
116 compute_shader,
117 current_params: initial_params,
118 };
119
120 shader
121 .compute_shader
122 .set_custom_params(initial_params, &core.queue);
123
124 shader
125 }
126
127 fn update(&mut self, core: &Core) {
128 let current_time = self.base.controls.get_time(&self.base.start_time);
129 let delta = 1.0 / 60.0;
130 self.compute_shader
131 .set_time(current_time, delta, &core.queue);
132
133 self.base.update_current_texture(core, &core.queue);
135 if let Some(texture_manager) = self.base.get_current_texture_manager() {
136 self.compute_shader.update_channel_texture(
137 0,
138 &texture_manager.view,
139 &texture_manager.sampler,
140 &core.device,
141 &core.queue,
142 );
143 }
144
145 if self.current_params.reset_training == 0 {
147 self.current_params.iteration = self.current_params.iteration.wrapping_add(1);
148 self.compute_shader.set_custom_params(self.current_params, &core.queue);
149 }
150 self.compute_shader.handle_export(core, &mut self.base);
151 }
152
153 fn resize(&mut self, core: &Core) {
154 self.base.default_resize(core, &mut self.compute_shader);
155 }
156
157 fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
158 let mut frame = self.base.begin_frame(core)?;
159
160 let mut controls_request = self
161 .base
162 .controls
163 .get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
164
165 let mut params = self.current_params;
166 let mut changed = false;
167 let mut should_start_export = false;
168 let mut export_request = self.base.export_manager.get_ui_request();
169
170 let using_video_texture = self.base.using_video_texture;
171 let using_hdri_texture = self.base.using_hdri_texture;
172 let using_webcam_texture = self.base.using_webcam_texture;
173 let video_info = self.base.get_video_info();
174 let hdri_info = self.base.get_hdri_info();
175 let webcam_info = self.base.get_webcam_info();
176
177 let full_output = if self.base.key_handler.show_ui {
178 self.base.render_ui(core, |ctx| {
179 RenderKit::apply_default_style(ctx);
180
181 egui::Window::new("gaussian splatting")
182 .collapsible(true)
183 .resizable(true)
184 .default_width(280.0)
185 .show(ctx, |ui| {
186 ui.label(format!("Iteration: {}", params.iteration));
187
188 egui::CollapsingHeader::new("Training")
189 .default_open(false)
190 .show(ui, |ui| {
191 changed |= ui
192 .add(
193 egui::Slider::new(&mut params.num_gaussians, 100..=20000)
194 .text("N Gauss")
195 .logarithmic(true),
196 )
197 .changed();
198
199 changed |= ui
200 .add(
201 egui::Slider::new(&mut params.learning_rate, 0.0001..=0.1)
202 .text("pos LR")
203 .logarithmic(true),
204 )
205 .changed();
206
207 changed |= ui
208 .add(
209 egui::Slider::new(&mut params.color_learning_rate, 0.001..=0.2)
210 .text("col LR")
211 .logarithmic(true),
212 )
213 .changed();
214
215 changed |= ui
216 .add(
217 egui::Slider::new(&mut params.temperature, 0.1..=5.0)
218 .text("temp"),
219 )
220 .changed();
221
222 ui.separator();
223
224 ui.horizontal(|ui| {
225 changed |= ui
226 .add(
227 egui::Slider::new(&mut params.random_seed, 1..=10000)
228 .text("seed"),
229 )
230 .changed();
231 if ui.button("🎲").on_hover_text("Randomize seed").clicked() {
232 params.random_seed = (std::time::SystemTime::now()
233 .duration_since(std::time::UNIX_EPOCH)
234 .unwrap()
235 .as_millis() % 10000) as u32;
236 params.reset_training = 1;
237 changed = true;
238 }
239 });
240
241 if ui.button("res training").clicked() {
242 params.reset_training = 1;
243 params.iteration = 0;
244 changed = true;
245 }
246 });
247
248 egui::CollapsingHeader::new("vis")
249 .default_open(false)
250 .show(ui, |ui| {
251 let mut show_target = params.show_target != 0;
252 if ui.checkbox(&mut show_target, "Show Target").changed() {
253 params.show_target = if show_target { 1 } else { 0 };
254 changed = true;
255 }
256
257 let mut show_error = params.show_error != 0;
258 if ui.checkbox(&mut show_error, "Show Error").changed() {
259 params.show_error = if show_error { 1 } else { 0 };
260 changed = true;
261 }
262
263 if params.show_error != 0 {
264 changed |= ui
265 .add(
266 egui::Slider::new(&mut params.error_scale, 0.5..=10.0)
267 .text("Error Scale"),
268 )
269 .changed();
270 }
271 });
272
273 egui::CollapsingHeader::new("Advanced")
274 .default_open(false)
275 .show(ui, |ui| {
276 changed |= ui
277 .add(
278 egui::Slider::new(&mut params.sigma_learning_rate, 0.001..=0.1)
279 .text("Sigma LR")
280 .logarithmic(true),
281 )
282 .changed();
283
284 changed |= ui
285 .add(
286 egui::Slider::new(&mut params.min_sigma, 0.001..=0.05)
287 .text("Min Sigma")
288 .logarithmic(true),
289 )
290 .changed();
291
292 changed |= ui
293 .add(
294 egui::Slider::new(&mut params.max_sigma, 0.02..=0.3)
295 .text("Max Sigma")
296 .logarithmic(true),
297 )
298 .changed();
299
300 changed |= ui
301 .add(
302 egui::Slider::new(&mut params.position_noise, 0.0..=1.0)
303 .text("Position"),
304 )
305 .changed();
306 });
307
308 ui.separator();
309
310 ui.separator();
311
312 ShaderControls::render_media_panel(
313 ui,
314 &mut controls_request,
315 using_video_texture,
316 video_info,
317 using_hdri_texture,
318 hdri_info,
319 using_webcam_texture,
320 webcam_info,
321 );
322
323 ui.separator();
324 ShaderControls::render_controls_widget(ui, &mut controls_request);
325
326 ui.separator();
327 should_start_export =
328 ExportManager::render_export_ui_widget(ui, &mut export_request);
329 });
330 })
331 } else {
332 self.base.render_ui(core, |_ctx| {})
333 };
334
335 self.base.export_manager.apply_ui_request(export_request);
336 self.base.apply_media_requests(core, &controls_request);
337
338 if controls_request.should_clear_buffers || params.reset_training != 0 {
339 self.compute_shader.current_frame = 0;
340 self.compute_shader.time_uniform.data.frame = 0;
341 self.compute_shader.time_uniform.update(&core.queue);
342
343 params.iteration = 0;
344 params.reset_training = 0;
345 changed = true;
346 }
347
348 if changed {
349 self.current_params = params;
350 self.compute_shader.set_custom_params(params, &core.queue);
351 }
352
353 if should_start_export {
354 self.base.export_manager.start_export();
355 }
356
357
358 self.compute_shader.dispatch(&mut frame.encoder, core);
359
360 self.base.renderer.render_to_view(&mut frame.encoder, &frame.view, &self.compute_shader.get_output_texture().bind_group);
361
362 self.base.end_frame(core, frame, full_output);
363
364 Ok(())
365 }
366
367 fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
368 if self.base.default_handle_input(core, event) {
369 return true;
370 }
371 if let WindowEvent::DroppedFile(path) = event {
372 if let Err(e) = self.base.load_media(core, path) {
373 error!("Failed to load dropped file: {e:?}");
374 }
375 self.current_params.reset_training = 1;
376 self.current_params.iteration = 0;
377 self.compute_shader.set_custom_params(self.current_params, &core.queue);
378 return true;
379 }
380 false
381 }
382}
383
384fn main() -> Result<(), Box<dyn std::error::Error>> {
385 cuneus::gst::init()?;
386 env_logger::init();
387 let (app, event_loop) = ShaderApp::new("2D Gaussian Splatting", 450, 350);
388 app.run(event_loop, GaussianShader::init)
389}