Skip to main content

gaussian3d/
gaussian3d.rs

1use cuneus::compute::{ComputeShader, ComputeShaderBuilder, StorageBufferSpec};
2use cuneus::prelude::*;
3use cuneus::{GaussianCamera, GaussianCloud, GaussianExporter, GaussianRenderer, GaussianSorter};
4use log::{error, info};
5use std::collections::HashSet;
6
7const MAX_GAUSSIANS: u32 = 2_000_000;
8
9
10cuneus::uniform_params! {
11    struct GaussianParams {
12    num_gaussians: u32,
13    gaussian_size: f32,
14    scene_scale: f32,
15    gamma: f32,
16    depth_shift: u32,
17    _pad0: u32,
18    _pad1: u32,
19    _pad2: u32}
20}
21
22impl Default for GaussianParams {
23    fn default() -> Self {
24        Self {
25            num_gaussians: 0,
26            gaussian_size: 1.0,
27            scene_scale: 10.0,
28            gamma: 1.2,
29            depth_shift: 16,
30            _pad0: 0,
31            _pad1: 0,
32            _pad2: 0}
33    }
34}
35
36struct CameraState {
37    yaw: f32,
38    pitch: f32,
39    distance: f32,
40    fov: f32,
41    target: [f32; 3],
42    is_dragging: bool,
43    last_mouse: [f32; 2],
44    keys_held: HashSet<String>}
45
46impl Default for CameraState {
47    fn default() -> Self {
48        Self {
49            yaw: 0.0,
50            pitch: 0.0,
51            distance: 1.0,
52            fov: 51.0,
53            target: [0.0; 3],
54            is_dragging: false,
55            last_mouse: [0.0; 2],
56            keys_held: HashSet::new()}
57    }
58}
59
60impl CameraState {
61    fn new() -> Self {
62        Self {
63            yaw: 6.28,
64            pitch: -0.05,
65            distance: 4.0,
66            fov: 51.0,
67            target: [0.0, 0.0, -6.0],
68            ..Default::default()
69        }
70    }
71
72    fn reset(&mut self) {
73        let keys = std::mem::take(&mut self.keys_held);
74        *self = Self::new();
75        self.keys_held = keys;
76    }
77
78    fn apply_held_keys(&mut self, dt: f32) {
79        if self.keys_held.is_empty() {
80            return;
81        }
82        let speed = 2.0 * self.distance * dt;
83        let (sy, cy) = (self.yaw.sin(), self.yaw.cos());
84        let forward = [sy, 0.0, cy];
85        let right = [-cy, 0.0, sy];
86
87        for key in &self.keys_held {
88            match key.as_str() {
89                "w" => { self.target[0] += forward[0] * speed; self.target[2] += forward[2] * speed; }
90                "s" => { self.target[0] -= forward[0] * speed; self.target[2] -= forward[2] * speed; }
91                "a" => { self.target[0] -= right[0] * speed; self.target[2] -= right[2] * speed; }
92                "d" => { self.target[0] += right[0] * speed; self.target[2] += right[2] * speed; }
93                "q" => { self.target[1] += speed; }
94                "e" => { self.target[1] -= speed; }
95                _ => {}
96            }
97        }
98    }
99}
100
101struct Gaussian3DShader {
102    base: RenderKit,
103    preprocess: ComputeShader,
104    sorter: GaussianSorter,
105    renderer: GaussianRenderer,
106    render_bind_group: Option<wgpu::BindGroup>,
107    camera_buffer: wgpu::Buffer,
108    params_buffer: wgpu::Buffer,
109    params: GaussianParams,
110    camera: CameraState,
111    surface_format: wgpu::TextureFormat}
112
113impl Gaussian3DShader {
114    fn load_ply(&mut self, core: &Core, path: &std::path::Path) {
115        info!("Loading: {:?}", path);
116        match GaussianCloud::from_ply(path) {
117            Ok(cloud) => {
118                let count = cloud.metadata.num_gaussians.min(MAX_GAUSSIANS);
119                info!("Loaded {} Gaussians", count);
120
121                let bytes = cloud.as_bytes();
122                let size = (count as usize * 64).min(bytes.len());
123                core.queue.write_buffer(&self.preprocess.storage_buffers[0], 0, &bytes[..size]);
124
125                self.params.num_gaussians = count;
126                self.sync_params(core);
127
128                self.sorter.prepare_with_buffers(
129                    &core.device,
130                    &self.preprocess.storage_buffers[2],
131                    &self.preprocess.storage_buffers[3],
132                    count,
133                );
134
135                self.render_bind_group = Some(self.renderer.create_bind_group(
136                    &core.device,
137                    &self.params_buffer,
138                    &self.camera_buffer,
139                    &self.preprocess.storage_buffers[1],
140                    &self.preprocess.storage_buffers[3],
141                ));
142
143                self.sorter.force_sort();
144                self.camera.reset();
145            }
146            Err(e) => error!("Load error: {:?}", e)}
147    }
148
149    fn sync_params(&self, core: &Core) {
150        self.preprocess.set_custom_params(self.params, &core.queue);
151        core.queue.write_buffer(&self.params_buffer, 0, bytemuck::bytes_of(&self.params));
152    }
153
154    fn export_frame(&mut self, core: &Core, frame: u32, time: f32) {
155        let settings = self.base.export_manager.settings().clone();
156        let camera = GaussianCamera::from_orbit(
157            self.camera.yaw, self.camera.pitch, self.camera.distance,
158            self.camera.target, self.camera.fov.to_radians(),
159            [settings.width as f32, settings.height as f32],
160        );
161        core.queue.write_buffer(&self.camera_buffer, 0, bytemuck::bytes_of(&camera));
162        core.queue.write_buffer(&self.preprocess.storage_buffers[4], 0, bytemuck::bytes_of(&camera));
163        self.preprocess.set_time(time, 1.0 / settings.fps as f32, &core.queue);
164
165        if let Some(ref bg) = self.render_bind_group {
166            GaussianExporter::export_frame(
167                core, &mut self.preprocess, &self.sorter, &self.renderer,
168                bg, self.params.num_gaussians, frame, &settings, self.surface_format,
169            );
170        }
171    }
172
173    fn update_camera(&self, core: &Core) {
174        let camera = GaussianCamera::from_orbit(
175            self.camera.yaw,
176            self.camera.pitch,
177            self.camera.distance,
178            self.camera.target,
179            self.camera.fov.to_radians(),
180            [core.size.width as f32, core.size.height as f32],
181        );
182        core.queue.write_buffer(&self.camera_buffer, 0, bytemuck::bytes_of(&camera));
183        core.queue.write_buffer(&self.preprocess.storage_buffers[4], 0, bytemuck::bytes_of(&camera));
184    }
185}
186
187impl ShaderManager for Gaussian3DShader {
188    fn init(core: &Core) -> Self {
189        let base = RenderKit::new(core);
190
191        let gaussian_size = (MAX_GAUSSIANS as u64) * 64;
192        let gaussian_2d_size = (MAX_GAUSSIANS as u64) * 48;
193        let keys_size = (MAX_GAUSSIANS as u64) * 4;
194        let indices_size = (MAX_GAUSSIANS as u64) * 4;
195        let camera_size = std::mem::size_of::<GaussianCamera>() as u64;
196
197        let config = ComputeShaderBuilder::new()
198            .with_label("Gaussian Preprocess")
199            .with_entry_point("preprocess")
200            .with_custom_uniforms::<GaussianParams>()
201            .with_workgroup_size([256, 1, 1])
202            .with_storage_buffer(StorageBufferSpec::new("gaussians", gaussian_size))
203            .with_storage_buffer(StorageBufferSpec::new("gaussian_2d", gaussian_2d_size))
204            .with_storage_buffer(StorageBufferSpec::new("depth_keys", keys_size))
205            .with_storage_buffer(StorageBufferSpec::new("sorted_indices", indices_size))
206            .with_storage_buffer(StorageBufferSpec::new("camera", camera_size))
207            .build();
208
209        let preprocess = cuneus::compute_shader!(core, "shaders/gaussian3d.wgsl", config);
210
211        let camera_buffer = core.device.create_buffer(&wgpu::BufferDescriptor {
212            label: Some("Gaussian Camera"),
213            size: std::mem::size_of::<GaussianCamera>() as u64,
214            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
215            mapped_at_creation: false});
216
217        let params_buffer = core.device.create_buffer(&wgpu::BufferDescriptor {
218            label: Some("Gaussian Params"),
219            size: std::mem::size_of::<GaussianParams>() as u64,
220            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
221            mapped_at_creation: false});
222
223        let sorter = GaussianSorter::new_16bit(&core.device);
224        let renderer = GaussianRenderer::new(
225            &core.device,
226            core.config.format,
227            include_str!("shaders/gaussian3d.wgsl"),
228        );
229
230        Self {
231            base,
232            preprocess,
233            sorter,
234            renderer,
235            render_bind_group: None,
236            camera_buffer,
237            params_buffer,
238            params: GaussianParams::default(),
239            camera: CameraState::new(),
240            surface_format: core.config.format}
241    }
242
243    fn update(&mut self, core: &Core) {
244        self.preprocess.check_hot_reload(&core.device);
245
246        if let Some((frame, time)) = self.base.export_manager.try_get_next_frame() {
247            self.export_frame(core, frame, time);
248        } else {
249            self.base.export_manager.complete_export();
250        }
251
252        let dt = self.base.fps_tracker.delta_time();
253        self.camera.apply_held_keys(dt);
254        self.update_camera(core);
255
256        let current_time = self.base.controls.get_time(&self.base.start_time);
257        self.preprocess.set_time(current_time, dt, &core.queue);
258    }
259
260    fn resize(&mut self, core: &Core) {
261        self.base.update_resolution(&core.queue, core.size);
262    }
263
264    fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
265        let output = match core.surface.get_current_texture() {
266            wgpu::CurrentSurfaceTexture::Success(texture)
267            | wgpu::CurrentSurfaceTexture::Suboptimal(texture) => texture,
268            wgpu::CurrentSurfaceTexture::Timeout
269            | wgpu::CurrentSurfaceTexture::Occluded => {
270                return Err(cuneus::SurfaceError::SkipFrame);
271            }
272            wgpu::CurrentSurfaceTexture::Outdated => {
273                return Err(cuneus::SurfaceError::Outdated);
274            }
275            wgpu::CurrentSurfaceTexture::Lost => {
276                return Err(cuneus::SurfaceError::Lost);
277            }
278            wgpu::CurrentSurfaceTexture::Validation => {
279                return Err(cuneus::SurfaceError::Lost);
280            }
281        };
282        let view = output.texture.create_view(&wgpu::TextureViewDescriptor::default());
283
284        let mut params = self.params;
285        let mut changed = false;
286        let mut load_ply_path: Option<std::path::PathBuf> = None;
287        let mut should_start_export = false;
288        let mut export_request = self.base.export_manager.get_ui_request();
289        let mut controls_request = self.base.controls.get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
290
291        let full_output = if self.base.key_handler.show_ui {
292            self.base.render_ui(core, |ctx| {
293                RenderKit::apply_default_style(ctx);
294
295                egui::Window::new("3D Gaussian Splatting")
296                    .collapsible(true)
297                    .resizable(true)
298                    .default_width(300.0)
299                    .show(ctx, |ui| {
300                        if params.num_gaussians > 0 {
301                            ui.label(format!("Gaussians: {}", params.num_gaussians));
302                        } else {
303                            ui.label("Drag & drop a .ply file");
304                        }
305                        ui.small("WASD: move | QE: up/down | R: reset | Drag: rotate");
306
307                        if ui.button("Load PLY...").clicked() {
308                            if let Some(p) = rfd::FileDialog::new().add_filter("PLY", &["ply"]).pick_file() {
309                                load_ply_path = Some(p);
310                            }
311                        }
312
313                        ui.separator();
314
315                        egui::CollapsingHeader::new("Visual Settings")
316                            .default_open(true)
317                            .show(ui, |ui| {
318                                changed |= ui.add(egui::Slider::new(&mut params.scene_scale, 0.01..=100.0)
319                                    .logarithmic(true).text("Scene Scale")).changed();
320                                changed |= ui.add(egui::Slider::new(&mut params.gaussian_size, 0.1..=2.0)
321                                    .text("Gaussian Size")).changed();
322                                changed |= ui.add(egui::Slider::new(&mut params.gamma, 0.1..=2.2)
323                                    .text("Gamma")).changed();
324
325                                let mut depth_shift_f = params.depth_shift as f32;
326                                if ui.add(egui::Slider::new(&mut depth_shift_f, 1.0..=30.0)
327                                    .step_by(1.0)
328                                    .text("Depth Blur")).changed() {
329                                    params.depth_shift = depth_shift_f as u32;
330                                    changed = true;
331                                }
332                            });
333
334                        egui::CollapsingHeader::new("Camera Settings")
335                            .default_open(false)
336                            .show(ui, |ui| {
337                                changed |= ui.add(egui::Slider::new(&mut self.camera.distance, 0.1..=100.0)
338                                    .logarithmic(true).text("Distance")).changed();
339                                changed |= ui.add(egui::Slider::new(&mut self.camera.fov, 20.0..=120.0)
340                                    .text("FOV")).changed();
341                                changed |= ui.add(egui::DragValue::new(&mut self.camera.yaw)
342                                    .speed(0.05).prefix("Yaw: ")).changed();
343                                changed |= ui.add(egui::Slider::new(&mut self.camera.pitch, -1.5..=1.5)
344                                    .text("Pitch")).changed();
345
346                                if ui.button("Reset Camera").clicked() {
347                                    self.camera.reset();
348                                    changed = true;
349                                }
350                            });
351
352                        ui.separator();
353                        ShaderControls::render_controls_widget(ui, &mut controls_request);
354
355                        ui.separator();
356                        should_start_export =
357                            ExportManager::render_export_ui_widget(ui, &mut export_request);
358                    });
359            })
360        } else {
361            self.base.render_ui(core, |_ctx| {})
362        };
363
364        self.base.export_manager.apply_ui_request(export_request);
365        self.base.apply_control_request(controls_request);
366
367        if should_start_export {
368            self.base.export_manager.start_export();
369        }
370
371        if let Some(path) = load_ply_path {
372            self.load_ply(core, &path);
373        }
374        if changed {
375            self.params = params;
376            self.sync_params(core);
377        }
378
379        let mut encoder = core.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
380            label: Some("Gaussian3D")});
381
382        let count = self.params.num_gaussians;
383        if count > 0 && self.render_bind_group.is_some() {
384            self.update_camera(core);
385
386            // Compute preprocess
387            let workgroups = (count + 255) / 256;
388            self.preprocess.dispatch_stage_with_workgroups(&mut encoder, 0, [workgroups, 1, 1]);
389
390            // GPU Radix Sort
391            self.sorter.sort(&mut encoder, count);
392
393            // Split submission: submit preprocess+sort, start new encoder for render
394            encoder = core.flush_encoder(encoder);
395
396            // Fragment render
397            {
398                let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
399                    label: Some("Gaussian Render"),
400                    color_attachments: &[Some(wgpu::RenderPassColorAttachment {
401                        view: &view,
402                        resolve_target: None,
403                        ops: wgpu::Operations {
404                            load: wgpu::LoadOp::Clear(wgpu::Color::BLACK),
405                            store: wgpu::StoreOp::Store},
406                        depth_slice: None})],
407                    ..Default::default()
408                });
409                self.renderer.render(&mut pass, self.render_bind_group.as_ref().unwrap(), count);
410            }
411        } else {
412            let _pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
413                label: Some("Clear"),
414                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
415                    view: &view,
416                    resolve_target: None,
417                    ops: wgpu::Operations {
418                        load: wgpu::LoadOp::Clear(wgpu::Color::BLACK),
419                        store: wgpu::StoreOp::Store},
420                    depth_slice: None})],
421                ..Default::default()
422            });
423        }
424
425        self.base.handle_render_output(core, &view, full_output, &mut encoder);
426        core.queue.submit(Some(encoder.finish()));
427        output.present();
428        Ok(())
429    }
430
431    fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
432        if self.base.egui_state.on_window_event(core.window(), event).consumed {
433            return true;
434        }
435
436        if let WindowEvent::KeyboardInput { event, .. } = event {
437            if self.base.key_handler.handle_keyboard_input(core.window(), event) {
438                return true;
439            }
440            if let winit::keyboard::Key::Character(ch) = &event.logical_key {
441                let key = ch.as_str().to_lowercase();
442                match event.state {
443                    winit::event::ElementState::Pressed => {
444                        if key == "r" {
445                            self.camera.reset();
446                            self.sorter.force_sort();
447                            return true;
448                        }
449                        if matches!(key.as_str(), "w" | "a" | "s" | "d" | "q" | "e") {
450                            self.camera.keys_held.insert(key);
451                            return true;
452                        }
453                    }
454                    winit::event::ElementState::Released => {
455                        self.camera.keys_held.remove(&key);
456                    }
457                }
458            }
459        }
460
461        if let WindowEvent::MouseInput { state, button, .. } = event {
462            if *button == winit::event::MouseButton::Left {
463                self.camera.is_dragging = *state == winit::event::ElementState::Pressed;
464                return true;
465            }
466        }
467
468        if let WindowEvent::CursorMoved { position, .. } = event {
469            let x = position.x as f32;
470            let y = position.y as f32;
471            if self.camera.is_dragging {
472                let dx = x - self.camera.last_mouse[0];
473                let dy = y - self.camera.last_mouse[1];
474                self.camera.yaw += dx * 0.01;
475                self.camera.pitch = (self.camera.pitch + dy * 0.01).clamp(-1.5, 1.5);
476            }
477            self.camera.last_mouse = [x, y];
478            return self.camera.is_dragging;
479        }
480
481        if let WindowEvent::MouseWheel { delta, .. } = event {
482            let d = match delta {
483                winit::event::MouseScrollDelta::LineDelta(_, y) => *y,
484                winit::event::MouseScrollDelta::PixelDelta(p) => (p.y as f32 / 100.0).clamp(-3.0, 3.0)};
485            let factor = (1.0 + d * 0.1).clamp(0.5, 2.0);
486            self.camera.distance = (self.camera.distance * factor).clamp(0.1, 500.0);
487            return true;
488        }
489
490        if let WindowEvent::DroppedFile(path) = event {
491            if path.extension().map(|e| e == "ply").unwrap_or(false) {
492                self.load_ply(core, path);
493            }
494            return true;
495        }
496
497        false
498    }
499}
500
501fn main() -> Result<(), Box<dyn std::error::Error>> {
502    env_logger::init();
503    let (app, event_loop) = ShaderApp::new("3D Gaussian Splatting", 800, 600);
504    app.run(event_loop, Gaussian3DShader::init)
505}