Skip to main content

pathtracing/
pathtracing.rs

1use cuneus::compute::*;
2use cuneus::prelude::*;
3use log::error;
4
5struct CameraMovement {
6    forward: bool,
7    backward: bool,
8    left: bool,
9    right: bool,
10    up: bool,
11    down: bool,
12    speed: f32,
13    last_update: std::time::Instant,
14
15    yaw: f32,
16    pitch: f32,
17    mouse_sensitivity: f32,
18
19    last_mouse_x: f32,
20    last_mouse_y: f32,
21    mouse_initialized: bool,
22    mouse_look_enabled: bool,
23    look_changed: bool}
24
25impl Default for CameraMovement {
26    fn default() -> Self {
27        Self {
28            forward: false,
29            backward: false,
30            left: false,
31            right: false,
32            up: false,
33            down: false,
34            speed: 2.0,
35            last_update: std::time::Instant::now(),
36
37            yaw: 0.0,
38            pitch: 0.0,
39            mouse_sensitivity: 0.005,
40
41            last_mouse_x: 0.0,
42            last_mouse_y: 0.0,
43            mouse_initialized: false,
44            mouse_look_enabled: true,
45            look_changed: false}
46    }
47}
48
49impl CameraMovement {
50    fn update_camera(&mut self, params: &mut PathTracingParams) -> bool {
51        let now = std::time::Instant::now();
52        let dt = now.duration_since(self.last_update).as_secs_f32();
53        self.last_update = now;
54
55        let mut changed = false;
56
57        if self.look_changed {
58            changed = true;
59            self.look_changed = false;
60        }
61
62        let forward = [
63            self.pitch.cos() * self.yaw.cos(),
64            self.pitch.sin(),
65            self.pitch.cos() * self.yaw.sin(),
66        ];
67
68        let world_up = [0.0, 1.0, 0.0];
69        let right = [
70            forward[1] * world_up[2] - forward[2] * world_up[1],
71            forward[2] * world_up[0] - forward[0] * world_up[2],
72            forward[0] * world_up[1] - forward[1] * world_up[0],
73        ];
74
75        let right_len = (right[0] * right[0] + right[1] * right[1] + right[2] * right[2]).sqrt();
76        let right = [
77            right[0] / right_len,
78            right[1] / right_len,
79            right[2] / right_len,
80        ];
81
82        let delta = self.speed * dt;
83        let mut move_vec = [0.0, 0.0, 0.0];
84
85        if self.forward {
86            move_vec[0] += forward[0] * delta;
87            move_vec[1] += forward[1] * delta;
88            move_vec[2] += forward[2] * delta;
89            changed = true;
90        }
91        if self.backward {
92            move_vec[0] -= forward[0] * delta;
93            move_vec[1] -= forward[1] * delta;
94            move_vec[2] -= forward[2] * delta;
95            changed = true;
96        }
97        if self.right {
98            move_vec[0] += right[0] * delta;
99            move_vec[1] += right[1] * delta;
100            move_vec[2] += right[2] * delta;
101            changed = true;
102        }
103        if self.left {
104            move_vec[0] -= right[0] * delta;
105            move_vec[1] -= right[1] * delta;
106            move_vec[2] -= right[2] * delta;
107            changed = true;
108        }
109        if self.up {
110            move_vec[1] += delta;
111            changed = true;
112        }
113        if self.down {
114            move_vec[1] -= delta;
115            changed = true;
116        }
117
118        params.camera_pos_x += move_vec[0];
119        params.camera_pos_y += move_vec[1];
120        params.camera_pos_z += move_vec[2];
121
122        let look_distance = 1.0;
123        params.camera_target_x = params.camera_pos_x + forward[0] * look_distance;
124        params.camera_target_y = params.camera_pos_y + forward[1] * look_distance;
125        params.camera_target_z = params.camera_pos_z + forward[2] * look_distance;
126
127        changed
128    }
129
130    fn handle_mouse_movement(&mut self, x: f32, y: f32) -> bool {
131        if !self.mouse_look_enabled {
132            return false;
133        }
134
135        if !self.mouse_initialized {
136            self.last_mouse_x = x;
137            self.last_mouse_y = y;
138            self.mouse_initialized = true;
139            return false;
140        }
141
142        let dx = x - self.last_mouse_x;
143        let dy = y - self.last_mouse_y;
144
145        self.last_mouse_x = x;
146        self.last_mouse_y = y;
147
148        self.yaw += dx * self.mouse_sensitivity;
149        self.pitch -= dy * self.mouse_sensitivity;
150
151        self.pitch = self
152            .pitch
153            .clamp(-std::f32::consts::PI * 0.49, std::f32::consts::PI * 0.49);
154
155        self.look_changed = true;
156
157        true
158    }
159
160    fn toggle_mouse_look(&mut self) {
161        self.mouse_look_enabled = !self.mouse_look_enabled;
162        self.mouse_initialized = false;
163    }
164}
165
166cuneus::uniform_params! {
167    struct PathTracingParams {
168    camera_pos_x: f32,
169    camera_pos_y: f32,
170    camera_pos_z: f32,
171    camera_target_x: f32,
172    camera_target_y: f32,
173    camera_target_z: f32,
174    fov: f32,
175    aperture: f32,
176
177    max_bounces: u32,
178    samples_per_pixel: u32,
179    accumulate: u32,
180
181    num_spheres: u32,
182    _padding1: f32,
183    _padding2: f32,
184
185    rotation_speed: f32,
186
187    exposure: f32}
188}
189
190struct PathTracingShader {
191    base: RenderKit,
192    compute_shader: ComputeShader,
193    current_params: PathTracingParams,
194    camera_movement: CameraMovement,
195    frame_count: u32,
196    should_reset_accumulation: bool}
197
198impl PathTracingShader {
199    fn clear_buffers(&mut self, core: &Core) {
200        self.compute_shader.clear_all_buffers(core);
201        self.frame_count = 0;
202        self.should_reset_accumulation = false;
203    }
204}
205
206impl ShaderManager for PathTracingShader {
207    fn init(core: &Core) -> Self {
208        let base = RenderKit::new(core);
209
210        let initial_params = PathTracingParams {
211            camera_pos_x: 0.0,
212            camera_pos_y: 1.0,
213            camera_pos_z: 6.0,
214            camera_target_x: 0.0,
215            camera_target_y: 0.0,
216            camera_target_z: -1.0,
217            fov: 40.0,
218            aperture: 0.00,
219            max_bounces: 4,
220            samples_per_pixel: 2,
221            accumulate: 1,
222            num_spheres: 15,
223            _padding1: 0.0,
224            _padding2: 0.0,
225            rotation_speed: 0.2,
226            exposure: 1.5};
227
228        let config = ComputeShader::builder()
229            .with_entry_point("main")
230            .with_input_texture() // Enable input texture support for background
231            .with_custom_uniforms::<PathTracingParams>()
232            .with_mouse()
233            .with_storage_buffer(StorageBufferSpec::new(
234                "atomic_buffer",
235                (core.size.width * core.size.height * 3 * 4) as u64,
236            )) // 3 channels * u32 per pixel
237            .with_workgroup_size([16, 16, 1])
238            .with_texture_format(COMPUTE_TEXTURE_FORMAT_RGBA16)
239            .with_label("Path Tracing Unified")
240            .build();
241
242        let compute_shader = cuneus::compute_shader!(core, "shaders/pathtracing.wgsl", config);
243
244        compute_shader.set_custom_params(initial_params, &core.queue);
245
246        Self {
247            base,
248            compute_shader,
249            current_params: initial_params,
250            camera_movement: CameraMovement::default(),
251            frame_count: 0,
252            should_reset_accumulation: true}
253    }
254
255    fn update(&mut self, core: &Core) {
256        // Update time
257        let current_time = self.base.controls.get_time(&self.base.start_time);
258        let delta = 1.0 / 60.0;
259        self.compute_shader
260            .set_time(current_time, delta, &core.queue);
261
262        // Update input textures for background
263        self.base.update_current_texture(core, &core.queue);
264        if let Some(texture_manager) = self.base.get_current_texture_manager() {
265            self.compute_shader.update_input_texture(
266                &texture_manager.view,
267                &texture_manager.sampler,
268                &core.device,
269            );
270        }
271
272        if self.camera_movement.update_camera(&mut self.current_params) {
273            self.compute_shader
274                .set_custom_params(self.current_params, &core.queue);
275            self.should_reset_accumulation = true;
276        }
277        // Handle export
278        self.compute_shader.handle_export(core, &mut self.base);
279    }
280
281    fn resize(&mut self, core: &Core) {
282        self.base.default_resize(core, &mut self.compute_shader);
283        self.should_reset_accumulation = true;
284    }
285
286    fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
287        let mut frame = self.base.begin_frame(core)?;
288
289        // Handle UI and parameter updates
290        let mut params = self.current_params;
291        let mut changed = false;
292        let mut should_start_export = false;
293        let mut export_request = self.base.export_manager.get_ui_request();
294        let mut controls_request = self
295            .base
296            .controls
297            .get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
298
299        let current_fps = self.base.fps_tracker.fps();
300        let using_video_texture = self.base.using_video_texture;
301        let using_hdri_texture = self.base.using_hdri_texture;
302        let using_webcam_texture = self.base.using_webcam_texture;
303        let video_info = self.base.get_video_info();
304        let hdri_info = self.base.get_hdri_info();
305        let webcam_info = self.base.get_webcam_info();
306
307        let full_output = if self.base.key_handler.show_ui {
308            self.base.render_ui(core, |ctx| {
309                RenderKit::apply_default_style(ctx);
310
311                egui::Window::new("Path Tracer")
312                    .collapsible(true)
313                    .resizable(true)
314                    .default_width(300.0)
315                    .show(ctx, |ui| {
316                        ui.label("Camera Controls:");
317                        ui.label("W/A/S/D - Movements");
318                        ui.label("Q/E - down/up");
319                        ui.label("Mouse - Look around");
320                        ui.label("Right Click - Toggle mouse look");
321                        ui.label("Space - Toggle progressive rendering");
322                        ui.separator();
323                        ShaderControls::render_media_panel(
324                            ui,
325                            &mut controls_request,
326                            using_video_texture,
327                            video_info,
328                            using_hdri_texture,
329                            hdri_info,
330                            using_webcam_texture,
331                            webcam_info,
332                        );
333                        ui.separator();
334
335                        egui::CollapsingHeader::new("Render Settings")
336                            .default_open(false)
337                            .show(ui, |ui| {
338                                let old_samples = params.samples_per_pixel;
339                                changed |= ui
340                                    .add(
341                                        egui::Slider::new(&mut params.samples_per_pixel, 1..=16)
342                                            .text("Samples/pixel"),
343                                    )
344                                    .changed();
345                                if params.samples_per_pixel != old_samples {
346                                    self.should_reset_accumulation = true;
347                                }
348
349                                let old_bounces = params.max_bounces;
350                                changed |= ui
351                                    .add(
352                                        egui::Slider::new(&mut params.max_bounces, 1..=16)
353                                            .text("Max Bounces"),
354                                    )
355                                    .changed();
356                                if params.max_bounces != old_bounces {
357                                    self.should_reset_accumulation = true;
358                                }
359
360                                let old_accumulate = params.accumulate;
361                                let mut accumulate_bool = params.accumulate > 0;
362                                changed |= ui
363                                    .checkbox(&mut accumulate_bool, "Progressive Rendering")
364                                    .changed();
365                                params.accumulate = if accumulate_bool { 1 } else { 0 };
366                                if params.accumulate != old_accumulate {
367                                    self.should_reset_accumulation = true;
368                                }
369
370                                changed |= ui
371                                    .add(
372                                        egui::Slider::new(&mut params.exposure, 0.1..=5.0)
373                                            .text("Exposure"),
374                                    )
375                                    .changed();
376                                changed |= ui
377                                    .add(
378                                        egui::Slider::new(&mut params.aperture, 0.0..=0.5)
379                                            .text("Depth of Field"),
380                                    )
381                                    .changed();
382                                changed |= ui
383                                    .add(
384                                        egui::Slider::new(&mut params.rotation_speed, 0.0..=2.0)
385                                            .text("Animation Speed"),
386                                    )
387                                    .changed();
388
389                                if ui.button("Reset Accumulation").clicked() {
390                                    self.should_reset_accumulation = true;
391                                    changed = true;
392                                }
393                            });
394
395                        ui.separator();
396                        ShaderControls::render_controls_widget(ui, &mut controls_request);
397                        ui.separator();
398                        should_start_export =
399                            ExportManager::render_export_ui_widget(ui, &mut export_request);
400                        ui.separator();
401                        ui.label(format!("Accumulated Samples: {}", self.frame_count));
402                        ui.label(format!(
403                            "Resolution: {}x{}",
404                            core.size.width, core.size.height
405                        ));
406                        ui.label(format!("FPS: {current_fps:.1}"));
407                    });
408            })
409        } else {
410            self.base.render_ui(core, |_ctx| {})
411        };
412
413        // Apply controls
414        self.base.export_manager.apply_ui_request(export_request);
415        if controls_request.should_clear_buffers || self.should_reset_accumulation {
416            self.clear_buffers(core);
417        }
418        self.base.apply_media_requests(core, &controls_request);
419
420        if should_start_export {
421            self.base.export_manager.start_export();
422        }
423
424        // Update mouse 
425        self.compute_shader
426            .update_mouse_uniform(&self.base.mouse_tracker.uniform, &core.queue);
427
428        if changed {
429            self.current_params = params;
430            self.compute_shader.set_custom_params(params, &core.queue);
431        }
432
433        // Set frame count for random number generation
434        self.compute_shader.time_uniform.data.frame = self.frame_count;
435        self.compute_shader.time_uniform.update(&core.queue);
436
437        // Single stage dispatch
438        self.compute_shader.dispatch(&mut frame.encoder, core);
439
440        self.base.renderer.render_to_view(&mut frame.encoder, &frame.view, &self.compute_shader.get_output_texture().bind_group);
441
442        self.base.end_frame(core, frame, full_output);
443
444        // Increment frame count for progressive rendering and noise generation
445        if self.current_params.accumulate > 0 {
446            self.frame_count += 1;
447        } else {
448            self.frame_count = (self.frame_count + 1) % 1000;
449        }
450
451        Ok(())
452    }
453
454    fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
455        if self
456            .base
457            .egui_state
458            .on_window_event(core.window(), event)
459            .consumed
460        {
461            return true;
462        }
463
464        if let WindowEvent::KeyboardInput { event, .. } = event {
465            if let winit::keyboard::Key::Character(ch) = &event.logical_key {
466                match ch.as_str() {
467                    "w" | "W" => {
468                        self.camera_movement.forward =
469                            event.state == winit::event::ElementState::Pressed;
470                        self.should_reset_accumulation = true;
471                        return true;
472                    }
473                    "s" | "S" => {
474                        self.camera_movement.backward =
475                            event.state == winit::event::ElementState::Pressed;
476                        self.should_reset_accumulation = true;
477                        return true;
478                    }
479                    "a" | "A" => {
480                        self.camera_movement.left =
481                            event.state == winit::event::ElementState::Pressed;
482                        self.should_reset_accumulation = true;
483                        return true;
484                    }
485                    "d" | "D" => {
486                        self.camera_movement.right =
487                            event.state == winit::event::ElementState::Pressed;
488                        self.should_reset_accumulation = true;
489                        return true;
490                    }
491                    "q" | "Q" => {
492                        self.camera_movement.down =
493                            event.state == winit::event::ElementState::Pressed;
494                        self.should_reset_accumulation = true;
495                        return true;
496                    }
497                    "e" | "E" => {
498                        self.camera_movement.up =
499                            event.state == winit::event::ElementState::Pressed;
500                        self.should_reset_accumulation = true;
501                        return true;
502                    }
503                    " " => {
504                        if event.state == winit::event::ElementState::Released {
505                            self.current_params.accumulate = 1 - self.current_params.accumulate;
506                            self.should_reset_accumulation = true;
507                            self.compute_shader
508                                .set_custom_params(self.current_params, &core.queue);
509                            return true;
510                        }
511                    }
512                    _ => {}
513                }
514            }
515        }
516
517        if let WindowEvent::CursorMoved { position, .. } = event {
518            let x = position.x as f32;
519            let y = position.y as f32;
520
521            self.base.handle_mouse_input(core, event, false);
522
523            if self.camera_movement.handle_mouse_movement(x, y) {
524                self.should_reset_accumulation = true;
525                return true;
526            }
527        }
528
529        if let WindowEvent::MouseInput { state, button, .. } = event {
530            if *button == winit::event::MouseButton::Right
531                && *state == winit::event::ElementState::Released
532            {
533                self.camera_movement.toggle_mouse_look();
534                return true;
535            }
536        }
537
538        if let WindowEvent::DroppedFile(path) = event {
539            if let Err(e) = self.base.load_media(core, path) {
540                error!("Failed to load dropped file: {e:?}");
541            }
542            return true;
543        }
544
545        if let WindowEvent::KeyboardInput { event, .. } = event {
546            if self
547                .base
548                .key_handler
549                .handle_keyboard_input(core.window(), event)
550            {
551                return true;
552            }
553        }
554
555        false
556    }
557}
558
559fn main() -> Result<(), Box<dyn std::error::Error>> {
560    env_logger::init();
561    cuneus::gst::init()?;
562    let (app, event_loop) = ShaderApp::new("Path Tracer", 800, 600);
563
564    app.run(event_loop, PathTracingShader::init)
565}