1use cuneus::compute::{ComputeShader, ComputeShaderBuilder, StorageBufferSpec};
2use cuneus::prelude::*;
3use cuneus::{GaussianCamera, GaussianCloud, GaussianExporter, GaussianRenderer, GaussianSorter};
4use log::{error, info};
5use std::collections::HashSet;
6
7const MAX_GAUSSIANS: u32 = 2_000_000;
8
9
10cuneus::uniform_params! {
11 struct GaussianParams {
12 num_gaussians: u32,
13 gaussian_size: f32,
14 scene_scale: f32,
15 gamma: f32,
16 depth_shift: u32,
17 _pad0: u32,
18 _pad1: u32,
19 _pad2: u32}
20}
21
22impl Default for GaussianParams {
23 fn default() -> Self {
24 Self {
25 num_gaussians: 0,
26 gaussian_size: 1.0,
27 scene_scale: 10.0,
28 gamma: 1.2,
29 depth_shift: 16,
30 _pad0: 0,
31 _pad1: 0,
32 _pad2: 0}
33 }
34}
35
36struct CameraState {
37 yaw: f32,
38 pitch: f32,
39 distance: f32,
40 fov: f32,
41 target: [f32; 3],
42 is_dragging: bool,
43 last_mouse: [f32; 2],
44 keys_held: HashSet<String>}
45
46impl Default for CameraState {
47 fn default() -> Self {
48 Self {
49 yaw: 0.0,
50 pitch: 0.0,
51 distance: 1.0,
52 fov: 51.0,
53 target: [0.0; 3],
54 is_dragging: false,
55 last_mouse: [0.0; 2],
56 keys_held: HashSet::new()}
57 }
58}
59
60impl CameraState {
61 fn new() -> Self {
62 Self {
63 yaw: 6.28,
64 pitch: -0.05,
65 distance: 4.0,
66 fov: 51.0,
67 target: [0.0, 0.0, -6.0],
68 ..Default::default()
69 }
70 }
71
72 fn reset(&mut self) {
73 let keys = std::mem::take(&mut self.keys_held);
74 *self = Self::new();
75 self.keys_held = keys;
76 }
77
78 fn apply_held_keys(&mut self, dt: f32) {
79 if self.keys_held.is_empty() {
80 return;
81 }
82 let speed = 2.0 * self.distance * dt;
83 let (sy, cy) = (self.yaw.sin(), self.yaw.cos());
84 let forward = [sy, 0.0, cy];
85 let right = [-cy, 0.0, sy];
86
87 for key in &self.keys_held {
88 match key.as_str() {
89 "w" => { self.target[0] += forward[0] * speed; self.target[2] += forward[2] * speed; }
90 "s" => { self.target[0] -= forward[0] * speed; self.target[2] -= forward[2] * speed; }
91 "a" => { self.target[0] -= right[0] * speed; self.target[2] -= right[2] * speed; }
92 "d" => { self.target[0] += right[0] * speed; self.target[2] += right[2] * speed; }
93 "q" => { self.target[1] += speed; }
94 "e" => { self.target[1] -= speed; }
95 _ => {}
96 }
97 }
98 }
99}
100
101struct Gaussian3DShader {
102 base: RenderKit,
103 preprocess: ComputeShader,
104 sorter: GaussianSorter,
105 renderer: GaussianRenderer,
106 render_bind_group: Option<wgpu::BindGroup>,
107 camera_buffer: wgpu::Buffer,
108 params_buffer: wgpu::Buffer,
109 params: GaussianParams,
110 camera: CameraState,
111 surface_format: wgpu::TextureFormat}
112
113impl Gaussian3DShader {
114 fn load_ply(&mut self, core: &Core, path: &std::path::Path) {
115 info!("Loading: {:?}", path);
116 match GaussianCloud::from_ply(path) {
117 Ok(cloud) => {
118 let count = cloud.metadata.num_gaussians.min(MAX_GAUSSIANS);
119 info!("Loaded {} Gaussians", count);
120
121 let bytes = cloud.as_bytes();
122 let size = (count as usize * 64).min(bytes.len());
123 core.queue.write_buffer(&self.preprocess.storage_buffers[0], 0, &bytes[..size]);
124
125 self.params.num_gaussians = count;
126 self.sync_params(core);
127
128 self.sorter.prepare_with_buffers(
129 &core.device,
130 &self.preprocess.storage_buffers[2],
131 &self.preprocess.storage_buffers[3],
132 count,
133 );
134
135 self.render_bind_group = Some(self.renderer.create_bind_group(
136 &core.device,
137 &self.params_buffer,
138 &self.camera_buffer,
139 &self.preprocess.storage_buffers[1],
140 &self.preprocess.storage_buffers[3],
141 ));
142
143 self.sorter.force_sort();
144 self.camera.reset();
145 }
146 Err(e) => error!("Load error: {:?}", e)}
147 }
148
149 fn sync_params(&self, core: &Core) {
150 self.preprocess.set_custom_params(self.params, &core.queue);
151 core.queue.write_buffer(&self.params_buffer, 0, bytemuck::bytes_of(&self.params));
152 }
153
154 fn export_frame(&mut self, core: &Core, frame: u32, time: f32) {
155 let settings = self.base.export_manager.settings().clone();
156 let camera = GaussianCamera::from_orbit(
157 self.camera.yaw, self.camera.pitch, self.camera.distance,
158 self.camera.target, self.camera.fov.to_radians(),
159 [settings.width as f32, settings.height as f32],
160 );
161 core.queue.write_buffer(&self.camera_buffer, 0, bytemuck::bytes_of(&camera));
162 core.queue.write_buffer(&self.preprocess.storage_buffers[4], 0, bytemuck::bytes_of(&camera));
163 self.preprocess.set_time(time, 1.0 / settings.fps as f32, &core.queue);
164
165 if let Some(ref bg) = self.render_bind_group {
166 GaussianExporter::export_frame(
167 core, &mut self.preprocess, &self.sorter, &self.renderer,
168 bg, self.params.num_gaussians, frame, &settings, self.surface_format,
169 );
170 }
171 }
172
173 fn update_camera(&self, core: &Core) {
174 let camera = GaussianCamera::from_orbit(
175 self.camera.yaw,
176 self.camera.pitch,
177 self.camera.distance,
178 self.camera.target,
179 self.camera.fov.to_radians(),
180 [core.size.width as f32, core.size.height as f32],
181 );
182 core.queue.write_buffer(&self.camera_buffer, 0, bytemuck::bytes_of(&camera));
183 core.queue.write_buffer(&self.preprocess.storage_buffers[4], 0, bytemuck::bytes_of(&camera));
184 }
185}
186
187impl ShaderManager for Gaussian3DShader {
188 fn init(core: &Core) -> Self {
189 let base = RenderKit::new(core);
190
191 let gaussian_size = (MAX_GAUSSIANS as u64) * 64;
192 let gaussian_2d_size = (MAX_GAUSSIANS as u64) * 48;
193 let keys_size = (MAX_GAUSSIANS as u64) * 4;
194 let indices_size = (MAX_GAUSSIANS as u64) * 4;
195 let camera_size = std::mem::size_of::<GaussianCamera>() as u64;
196
197 let config = ComputeShaderBuilder::new()
198 .with_label("Gaussian Preprocess")
199 .with_entry_point("preprocess")
200 .with_custom_uniforms::<GaussianParams>()
201 .with_workgroup_size([256, 1, 1])
202 .with_storage_buffer(StorageBufferSpec::new("gaussians", gaussian_size))
203 .with_storage_buffer(StorageBufferSpec::new("gaussian_2d", gaussian_2d_size))
204 .with_storage_buffer(StorageBufferSpec::new("depth_keys", keys_size))
205 .with_storage_buffer(StorageBufferSpec::new("sorted_indices", indices_size))
206 .with_storage_buffer(StorageBufferSpec::new("camera", camera_size))
207 .build();
208
209 let preprocess = cuneus::compute_shader!(core, "shaders/gaussian3d.wgsl", config);
210
211 let camera_buffer = core.device.create_buffer(&wgpu::BufferDescriptor {
212 label: Some("Gaussian Camera"),
213 size: std::mem::size_of::<GaussianCamera>() as u64,
214 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
215 mapped_at_creation: false});
216
217 let params_buffer = core.device.create_buffer(&wgpu::BufferDescriptor {
218 label: Some("Gaussian Params"),
219 size: std::mem::size_of::<GaussianParams>() as u64,
220 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
221 mapped_at_creation: false});
222
223 let sorter = GaussianSorter::new_16bit(&core.device);
224 let renderer = GaussianRenderer::new(
225 &core.device,
226 core.config.format,
227 include_str!("shaders/gaussian3d.wgsl"),
228 );
229
230 Self {
231 base,
232 preprocess,
233 sorter,
234 renderer,
235 render_bind_group: None,
236 camera_buffer,
237 params_buffer,
238 params: GaussianParams::default(),
239 camera: CameraState::new(),
240 surface_format: core.config.format}
241 }
242
243 fn update(&mut self, core: &Core) {
244 self.preprocess.check_hot_reload(&core.device);
245
246 if let Some((frame, time)) = self.base.export_manager.try_get_next_frame() {
247 self.export_frame(core, frame, time);
248 } else {
249 self.base.export_manager.complete_export();
250 }
251
252 let dt = self.base.fps_tracker.delta_time();
253 self.camera.apply_held_keys(dt);
254 self.update_camera(core);
255
256 let current_time = self.base.controls.get_time(&self.base.start_time);
257 self.preprocess.set_time(current_time, dt, &core.queue);
258 }
259
260 fn resize(&mut self, core: &Core) {
261 self.base.update_resolution(&core.queue, core.size);
262 }
263
264 fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
265 let output = match core.surface.get_current_texture() {
266 wgpu::CurrentSurfaceTexture::Success(texture)
267 | wgpu::CurrentSurfaceTexture::Suboptimal(texture) => texture,
268 wgpu::CurrentSurfaceTexture::Timeout
269 | wgpu::CurrentSurfaceTexture::Occluded => {
270 return Err(cuneus::SurfaceError::SkipFrame);
271 }
272 wgpu::CurrentSurfaceTexture::Outdated => {
273 return Err(cuneus::SurfaceError::Outdated);
274 }
275 wgpu::CurrentSurfaceTexture::Lost => {
276 return Err(cuneus::SurfaceError::Lost);
277 }
278 wgpu::CurrentSurfaceTexture::Validation => {
279 return Err(cuneus::SurfaceError::Lost);
280 }
281 };
282 let view = output.texture.create_view(&wgpu::TextureViewDescriptor::default());
283
284 let mut params = self.params;
285 let mut changed = false;
286 let mut load_ply_path: Option<std::path::PathBuf> = None;
287 let mut should_start_export = false;
288 let mut export_request = self.base.export_manager.get_ui_request();
289 let mut controls_request = self.base.controls.get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
290
291 let full_output = if self.base.key_handler.show_ui {
292 self.base.render_ui(core, |ctx| {
293 RenderKit::apply_default_style(ctx);
294
295 egui::Window::new("3D Gaussian Splatting")
296 .collapsible(true)
297 .resizable(true)
298 .default_width(300.0)
299 .show(ctx, |ui| {
300 if params.num_gaussians > 0 {
301 ui.label(format!("Gaussians: {}", params.num_gaussians));
302 } else {
303 ui.label("Drag & drop a .ply file");
304 }
305 ui.small("WASD: move | QE: up/down | R: reset | Drag: rotate");
306
307 if ui.button("Load PLY...").clicked() {
308 if let Some(p) = rfd::FileDialog::new().add_filter("PLY", &["ply"]).pick_file() {
309 load_ply_path = Some(p);
310 }
311 }
312
313 ui.separator();
314
315 egui::CollapsingHeader::new("Visual Settings")
316 .default_open(true)
317 .show(ui, |ui| {
318 changed |= ui.add(egui::Slider::new(&mut params.scene_scale, 0.01..=100.0)
319 .logarithmic(true).text("Scene Scale")).changed();
320 changed |= ui.add(egui::Slider::new(&mut params.gaussian_size, 0.1..=2.0)
321 .text("Gaussian Size")).changed();
322 changed |= ui.add(egui::Slider::new(&mut params.gamma, 0.1..=2.2)
323 .text("Gamma")).changed();
324
325 let mut depth_shift_f = params.depth_shift as f32;
326 if ui.add(egui::Slider::new(&mut depth_shift_f, 1.0..=30.0)
327 .step_by(1.0)
328 .text("Depth Blur")).changed() {
329 params.depth_shift = depth_shift_f as u32;
330 changed = true;
331 }
332 });
333
334 egui::CollapsingHeader::new("Camera Settings")
335 .default_open(false)
336 .show(ui, |ui| {
337 changed |= ui.add(egui::Slider::new(&mut self.camera.distance, 0.1..=100.0)
338 .logarithmic(true).text("Distance")).changed();
339 changed |= ui.add(egui::Slider::new(&mut self.camera.fov, 20.0..=120.0)
340 .text("FOV")).changed();
341 changed |= ui.add(egui::DragValue::new(&mut self.camera.yaw)
342 .speed(0.05).prefix("Yaw: ")).changed();
343 changed |= ui.add(egui::Slider::new(&mut self.camera.pitch, -1.5..=1.5)
344 .text("Pitch")).changed();
345
346 if ui.button("Reset Camera").clicked() {
347 self.camera.reset();
348 changed = true;
349 }
350 });
351
352 ui.separator();
353 ShaderControls::render_controls_widget(ui, &mut controls_request);
354
355 ui.separator();
356 should_start_export =
357 ExportManager::render_export_ui_widget(ui, &mut export_request);
358 });
359 })
360 } else {
361 self.base.render_ui(core, |_ctx| {})
362 };
363
364 self.base.export_manager.apply_ui_request(export_request);
365 self.base.apply_control_request(controls_request);
366
367 if should_start_export {
368 self.base.export_manager.start_export();
369 }
370
371 if let Some(path) = load_ply_path {
372 self.load_ply(core, &path);
373 }
374 if changed {
375 self.params = params;
376 self.sync_params(core);
377 }
378
379 let mut encoder = core.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
380 label: Some("Gaussian3D")});
381
382 let count = self.params.num_gaussians;
383 if count > 0 && self.render_bind_group.is_some() {
384 self.update_camera(core);
385
386 let workgroups = (count + 255) / 256;
388 self.preprocess.dispatch_stage_with_workgroups(&mut encoder, 0, [workgroups, 1, 1]);
389
390 self.sorter.sort(&mut encoder, count);
392
393 encoder = core.flush_encoder(encoder);
395
396 {
398 let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
399 label: Some("Gaussian Render"),
400 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
401 view: &view,
402 resolve_target: None,
403 ops: wgpu::Operations {
404 load: wgpu::LoadOp::Clear(wgpu::Color::BLACK),
405 store: wgpu::StoreOp::Store},
406 depth_slice: None})],
407 ..Default::default()
408 });
409 self.renderer.render(&mut pass, self.render_bind_group.as_ref().unwrap(), count);
410 }
411 } else {
412 let _pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
413 label: Some("Clear"),
414 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
415 view: &view,
416 resolve_target: None,
417 ops: wgpu::Operations {
418 load: wgpu::LoadOp::Clear(wgpu::Color::BLACK),
419 store: wgpu::StoreOp::Store},
420 depth_slice: None})],
421 ..Default::default()
422 });
423 }
424
425 self.base.handle_render_output(core, &view, full_output, &mut encoder);
426 core.queue.submit(Some(encoder.finish()));
427 output.present();
428 Ok(())
429 }
430
431 fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
432 if self.base.egui_state.on_window_event(core.window(), event).consumed {
433 return true;
434 }
435
436 if let WindowEvent::KeyboardInput { event, .. } = event {
437 if self.base.key_handler.handle_keyboard_input(core.window(), event) {
438 return true;
439 }
440 if let winit::keyboard::Key::Character(ch) = &event.logical_key {
441 let key = ch.as_str().to_lowercase();
442 match event.state {
443 winit::event::ElementState::Pressed => {
444 if key == "r" {
445 self.camera.reset();
446 self.sorter.force_sort();
447 return true;
448 }
449 if matches!(key.as_str(), "w" | "a" | "s" | "d" | "q" | "e") {
450 self.camera.keys_held.insert(key);
451 return true;
452 }
453 }
454 winit::event::ElementState::Released => {
455 self.camera.keys_held.remove(&key);
456 }
457 }
458 }
459 }
460
461 if let WindowEvent::MouseInput { state, button, .. } = event {
462 if *button == winit::event::MouseButton::Left {
463 self.camera.is_dragging = *state == winit::event::ElementState::Pressed;
464 return true;
465 }
466 }
467
468 if let WindowEvent::CursorMoved { position, .. } = event {
469 let x = position.x as f32;
470 let y = position.y as f32;
471 if self.camera.is_dragging {
472 let dx = x - self.camera.last_mouse[0];
473 let dy = y - self.camera.last_mouse[1];
474 self.camera.yaw += dx * 0.01;
475 self.camera.pitch = (self.camera.pitch + dy * 0.01).clamp(-1.5, 1.5);
476 }
477 self.camera.last_mouse = [x, y];
478 return self.camera.is_dragging;
479 }
480
481 if let WindowEvent::MouseWheel { delta, .. } = event {
482 let d = match delta {
483 winit::event::MouseScrollDelta::LineDelta(_, y) => *y,
484 winit::event::MouseScrollDelta::PixelDelta(p) => (p.y as f32 / 100.0).clamp(-3.0, 3.0)};
485 let factor = (1.0 + d * 0.1).clamp(0.5, 2.0);
486 self.camera.distance = (self.camera.distance * factor).clamp(0.1, 500.0);
487 return true;
488 }
489
490 if let WindowEvent::DroppedFile(path) = event {
491 if path.extension().map(|e| e == "ply").unwrap_or(false) {
492 self.load_ply(core, path);
493 }
494 return true;
495 }
496
497 false
498 }
499}
500
501fn main() -> Result<(), Box<dyn std::error::Error>> {
502 env_logger::init();
503 let (app, event_loop) = ShaderApp::new("3D Gaussian Splatting", 800, 600);
504 app.run(event_loop, Gaussian3DShader::init)
505}