Skip to main content

fluidsim/
fluidsim.rs

1// Navier-Stokes Fluid Simulation
2// Ported from Pavel Dobryakov's WebGL Fluid Simulation
3// https://github.com/PavelDoGreat/WebGL-Fluid-Simulation MIT License
4//
5// With this example I m trying to demonstrate Cuneus's dispatch_stage() capability for running
6// multiple simulation steps per frame, including 20+ pressure iterations. 
7// Note that, this shader is storage-buffer based rather than texture based. Because we can easily loop dispatch stage with workgrpus 20 times
8// for pressure solving without not 20 separate pressure passes like prs 01 2,3, ... 20.
9
10use cuneus::compute::*;
11use cuneus::prelude::*;
12
13const SIM_SCALE: u32 = 2;
14const PRESSURE_ITERATIONS: u32 = 20;
15const INTERNAL_WIDTH: u32 = 2048;
16const INTERNAL_HEIGHT: u32 = 1152;
17
18cuneus::uniform_params! {
19    struct FluidParams {
20    sim_width: u32,
21    sim_height: u32,
22    display_width: u32,
23    display_height: u32,
24    dt: f32,
25    time: f32,
26    velocity_dissipation: f32,
27    density_dissipation: f32,
28    pressure: f32,
29    curl_strength: f32,
30    splat_radius: f32,
31    splat_x: f32,
32    splat_y: f32,
33    splat_dx: f32,
34    splat_dy: f32,
35    splat_force: f32,
36    splat_color_r: f32,
37    splat_color_g: f32,
38    splat_color_b: f32,
39    // ping-pong tracking for each field
40    vel_ping: u32, // which velocity buffer to READ from
41    prs_ping: u32, // which pressure buffer to READ from
42    dye_ping: u32, // which dye buffer to READ from
43    do_splat: u32,
44    _pad: u32}
45}
46
47struct FluidSim {
48    base: RenderKit,
49    compute_shader: ComputeShader,
50    params: FluidParams,
51    prev_mouse_pos: [f32; 2],
52    mouse_initialized: bool,
53    current_color: [f32; 3],
54    color_timer: f32,
55    last_time: std::time::Instant,
56    first_frame: bool,
57    needs_clear: bool}
58
59impl ShaderManager for FluidSim {
60    fn init(core: &Core) -> Self {
61        let base = RenderKit::new(core);
62
63
64        let sim_width = INTERNAL_WIDTH / SIM_SCALE;
65        let sim_height = INTERNAL_HEIGHT / SIM_SCALE;
66        let sim_cells = (sim_width * sim_height) as u64;
67        let dye_cells = (INTERNAL_WIDTH * INTERNAL_HEIGHT) as u64;
68
69        // all at fixed internal resolution
70        let velocity_size = sim_cells * 2 * 4;
71        let pressure_size = sim_cells * 4;
72        let divergence_size = sim_cells * 4;
73        let curl_size = sim_cells * 4;
74        let dye_size = dye_cells * 4 * 4;
75
76        let total_size = velocity_size * 2 + pressure_size * 2 + divergence_size + curl_size + dye_size * 2;
77
78        let passes = vec![
79            PassDescription::new("clear_buffers", &[]), // 0
80            PassDescription::new("splat_velocity", &[]), // 1
81            PassDescription::new("splat_dye", &[]), // 2
82            PassDescription::new("curl_compute", &[]), // 3
83            PassDescription::new("vorticity_apply", &[]), // 4
84            PassDescription::new("divergence_compute", &[]), // 5
85            PassDescription::new("pressure_clear", &[]), // 6
86            PassDescription::new("pressure_iterate", &[]), // 7
87            PassDescription::new("gradient_subtract", &[]), // 8
88            PassDescription::new("advect_velocity", &[]), // 9
89            PassDescription::new("advect_dye", &[]), // 10
90            PassDescription::new("main_image", &[]), // 11
91        ];
92
93        // Note: We don't use .with_mouse() because cuneus MouseTracker doesn't provide
94        // velocity (dx/dy). Fluid simulation needs velocity for force injection, so we
95        // track mouse manually and pass data through FluidParams.
96        let config = ComputeShader::builder()
97            .with_multi_pass(&passes)
98            .with_custom_uniforms::<FluidParams>()
99            .with_storage_buffer(StorageBufferSpec::new("fluid_data", total_size))
100            .with_label("Fluid Simulation")
101            .build();
102
103        let compute_shader = cuneus::compute_shader!(core, "shaders/fluidsim.wgsl", config);
104
105        let params = FluidParams {
106            sim_width,
107            sim_height,
108            display_width: INTERNAL_WIDTH,
109            display_height: INTERNAL_HEIGHT,
110            dt: 1.0 / 60.0,
111            time: 0.0,
112            velocity_dissipation: 0.2,
113            density_dissipation: 1.0,
114            pressure: 0.8,
115            curl_strength: 30.0,
116            splat_radius: 0.25,
117            splat_x: 0.0,
118            splat_y: 0.0,
119            splat_dx: 0.0,
120            splat_dy: 0.0,
121            splat_force: 6000.0,
122            splat_color_r: 0.0,
123            splat_color_g: 0.0,
124            splat_color_b: 0.0,
125            vel_ping: 0,
126            prs_ping: 0,
127            dye_ping: 0,
128            do_splat: 0,
129            _pad: 0};
130
131        compute_shader.set_custom_params(params, &core.queue);
132
133        Self {
134            base,
135            compute_shader,
136            params,
137            prev_mouse_pos: [0.5, 0.5],
138            mouse_initialized: false,
139            current_color: Self::generate_color(),
140            color_timer: 0.0,
141            last_time: std::time::Instant::now(),
142            first_frame: true,
143            needs_clear: true}
144    }
145
146    fn update(&mut self, core: &Core) {
147        let now = std::time::Instant::now();
148        let dt = now.duration_since(self.last_time).as_secs_f32();
149        self.last_time = now;
150        let dt = dt.min(1.0 / 30.0);
151
152        self.params.time += dt;
153        self.params.dt = dt;
154
155        self.color_timer += dt * 10.0;
156        if self.color_timer >= 1.0 {
157            self.color_timer = 0.0;
158            self.current_color = Self::generate_color();
159        }
160
161
162        let current_mouse_pos = self.base.mouse_tracker.uniform.position;
163        let mouse_down = (self.base.mouse_tracker.uniform.buttons[0] & 1) != 0; // Left button
164
165        if mouse_down {
166            if !self.mouse_initialized {
167                self.prev_mouse_pos = current_mouse_pos;
168                self.mouse_initialized = true;
169            }
170
171            let mut dx = current_mouse_pos[0] - self.prev_mouse_pos[0];
172            let mut dy = current_mouse_pos[1] - self.prev_mouse_pos[1];
173
174            let aspect = core.size.width as f32 / core.size.height as f32;
175            if aspect < 1.0 {
176                dx *= aspect;
177            } else {
178                dy /= aspect;
179            }
180
181            self.params.splat_x = current_mouse_pos[0];
182            self.params.splat_y = current_mouse_pos[1];
183            self.params.splat_dx = dx * self.params.splat_force;
184            self.params.splat_dy = dy * self.params.splat_force;
185            self.params.splat_color_r = self.current_color[0];
186            self.params.splat_color_g = self.current_color[1];
187            self.params.splat_color_b = self.current_color[2];
188            self.params.do_splat = if dx.abs() > 0.0001 || dy.abs() > 0.0001 { 1 } else { 0 };
189
190            self.prev_mouse_pos = current_mouse_pos;
191        } else if self.first_frame {
192            self.first_frame = false;
193            let color = Self::generate_color();
194            self.params.splat_x = 0.5;
195            self.params.splat_y = 0.5;
196            self.params.splat_dx = 500.0;
197            self.params.splat_dy = 300.0;
198            self.params.splat_color_r = color[0] * 10.0;
199            self.params.splat_color_g = color[1] * 10.0;
200            self.params.splat_color_b = color[2] * 10.0;
201            self.params.do_splat = 1;
202        } else {
203            self.params.do_splat = 0;
204            self.mouse_initialized = false;
205        }
206        self.compute_shader.handle_export(core, &mut self.base);
207    }
208
209    fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
210        let mut frame = self.base.begin_frame(core)?;
211
212        // Update params
213        self.compute_shader.set_custom_params(self.params, &core.queue);
214
215        let sim_workgroups = [
216            self.params.sim_width.div_ceil(16),
217            self.params.sim_height.div_ceil(16),
218            1,
219        ];
220        let display_workgroups = [
221            self.params.display_width.div_ceil(16),
222            self.params.display_height.div_ceil(16),
223            1,
224        ];
225        let output_workgroups = [
226            core.size.width.div_ceil(16),
227            core.size.height.div_ceil(16),
228            1,
229        ];
230
231        // Stage indices
232        const CLEAR_BUFFERS: usize = 0;
233        const SPLAT_VELOCITY: usize = 1;
234        const SPLAT_DYE: usize = 2;
235        const CURL_COMPUTE: usize = 3;
236        const VORTICITY_APPLY: usize = 4;
237        const DIVERGENCE_COMPUTE: usize = 5;
238        const PRESSURE_CLEAR: usize = 6;
239        const PRESSURE_ITERATE: usize = 7;
240        const GRADIENT_SUBTRACT: usize = 8;
241        const ADVECT_VELOCITY: usize = 9;
242        const ADVECT_DYE: usize = 10;
243        const MAIN_IMAGE: usize = 11;
244
245
246        if self.needs_clear {
247            self.needs_clear = false;
248            let max_workgroups = [
249                self.params.display_width.div_ceil(16),
250                self.params.display_height.div_ceil(16),
251                1,
252            ];
253            self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, CLEAR_BUFFERS, max_workgroups);
254            frame.encoder = core.flush_encoder(frame.encoder);
255            self.compute_shader.set_custom_params(self.params, &core.queue);
256        }
257
258        // Apply splat (additive, in-place on current read buffer)
259        if self.params.do_splat == 1 {
260            self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, SPLAT_VELOCITY, sim_workgroups);
261            self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, SPLAT_DYE, display_workgroups);
262        }
263
264        // Curl: reads vel[vel_ping]
265        self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, CURL_COMPUTE, sim_workgroups);
266
267        // Vorticity: reads vel[vel_ping], writes vel[1-vel_ping]
268        self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, VORTICITY_APPLY, sim_workgroups);
269
270        // Submit before changing ping
271        frame.encoder = core.flush_encoder(frame.encoder);
272        self.params.vel_ping = 1 - self.params.vel_ping;
273        self.compute_shader.set_custom_params(self.params, &core.queue);
274
275        // Divergence: reads vel[vel_ping] (where vorticity wrote)
276        self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, DIVERGENCE_COMPUTE, sim_workgroups);
277
278        // Pressure clear: reads prs[prs_ping], writes prs[1-prs_ping]
279        self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, PRESSURE_CLEAR, sim_workgroups);
280
281        frame.encoder = core.flush_encoder(frame.encoder);
282        self.params.prs_ping = 1 - self.params.prs_ping;
283        self.compute_shader.set_custom_params(self.params, &core.queue);
284
285        // Jacobi solver
286        for _ in 0..PRESSURE_ITERATIONS {
287            self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, PRESSURE_ITERATE, sim_workgroups);
288            frame.encoder = core.flush_encoder(frame.encoder);
289            self.params.prs_ping = 1 - self.params.prs_ping;
290            self.compute_shader.set_custom_params(self.params, &core.queue);
291        }
292
293        // Gradient subtract: reads vel[vel_ping], prs[prs_ping], writes vel[1-vel_ping]
294        self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, GRADIENT_SUBTRACT, sim_workgroups);
295
296        frame.encoder = core.flush_encoder(frame.encoder);
297        self.params.vel_ping = 1 - self.params.vel_ping;
298        self.compute_shader.set_custom_params(self.params, &core.queue);
299
300        // Advect velocity: reads vel[vel_ping], writes vel[1-vel_ping]
301        self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, ADVECT_VELOCITY, sim_workgroups);
302
303        frame.encoder = core.flush_encoder(frame.encoder);
304        self.params.vel_ping = 1 - self.params.vel_ping;
305        self.compute_shader.set_custom_params(self.params, &core.queue);
306
307        // Advect dye: reads vel[vel_ping], dye[dye_ping], writes dye[1-dye_ping]
308        self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, ADVECT_DYE, display_workgroups);
309
310        frame.encoder = core.flush_encoder(frame.encoder);
311        self.params.dye_ping = 1 - self.params.dye_ping;
312        self.compute_shader.set_custom_params(self.params, &core.queue);
313
314        // Display: reads dye[dye_ping]
315        self.compute_shader.dispatch_stage_with_workgroups(&mut frame.encoder, MAIN_IMAGE, output_workgroups);
316
317        self.base.renderer.render_to_view(&mut frame.encoder, &frame.view, &self.compute_shader.get_output_texture().bind_group);
318
319        let mut params = self.params;
320        let mut should_clear = false;
321        let mut should_start_export = false;
322        let mut export_request = self.base.export_manager.get_ui_request();
323        let mut controls_request = self.base.controls.get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
324
325        let full_output = if self.base.key_handler.show_ui {
326            self.base.render_ui(core, |ctx| {
327                ctx.global_style_mut(|style| {
328                    style.visuals.window_fill = egui::Color32::from_rgba_premultiplied(0, 0, 0, 180);
329                    style.text_styles.get_mut(&egui::TextStyle::Body).unwrap().size = 11.0;
330                    style.text_styles.get_mut(&egui::TextStyle::Button).unwrap().size = 10.0;
331                });
332
333                egui::Window::new("Fluid Simulation")
334                    .collapsible(true)
335                    .resizable(true)
336                    .default_width(280.0)
337                    .show(ctx, |ui| {
338                        egui::CollapsingHeader::new("Fluid Parameters")
339                            .default_open(true)
340                            .show(ui, |ui| {
341                                ui.add(egui::Slider::new(&mut params.curl_strength, 0.0..=50.0).text("Vorticity"));
342                                ui.add(egui::Slider::new(&mut params.velocity_dissipation, 0.0..=4.0).text("Vel Dissipation"));
343                                ui.add(egui::Slider::new(&mut params.density_dissipation, 0.0..=4.0).text("Dye Dissipation"));
344                                ui.add(egui::Slider::new(&mut params.pressure, 0.0..=1.0).text("Pressure"));
345                            });
346
347                        egui::CollapsingHeader::new("Splat Settings")
348                            .default_open(false)
349                            .show(ui, |ui| {
350                                ui.add(egui::Slider::new(&mut params.splat_radius, 0.01..=1.0).text("Radius"));
351                                ui.add(egui::Slider::new(&mut params.splat_force, 1000.0..=20000.0).text("Force"));
352                            });
353
354                        ui.separator();
355                        ShaderControls::render_controls_widget(ui, &mut controls_request);
356
357                        ui.separator();
358                        should_start_export = ExportManager::render_export_ui_widget(ui, &mut export_request);
359
360                        ui.separator();
361                        if ui.button("Clear Fluid").clicked() {
362                            should_clear = true;
363                        }
364                        ui.label(format!("Internal: {}x{}", INTERNAL_WIDTH, INTERNAL_HEIGHT));
365                        ui.label("Drag mouse to add fluid");
366                    });
367            })
368        } else {
369            self.base.render_ui(core, |_| {})
370        };
371
372        if controls_request.should_clear_buffers || should_clear {
373            self.params.vel_ping = 0;
374            self.params.prs_ping = 0;
375            self.params.dye_ping = 0;
376            self.first_frame = true;
377            self.needs_clear = true;
378        }
379
380        // Apply UI changes
381        self.base.apply_control_request(controls_request);
382        self.base.export_manager.apply_ui_request(export_request);
383        self.params = params;
384
385        if should_start_export {
386            self.base.export_manager.start_export();
387        }
388
389        self.base.end_frame(core, frame, full_output);
390
391        Ok(())
392    }
393
394    fn resize(&mut self, core: &Core) {
395        self.base.default_resize(core, &mut self.compute_shader);
396    }
397
398    fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
399        if self.base.egui_state.on_window_event(core.window(), event).consumed {
400            return true;
401        }
402        if self.base.handle_mouse_input(core, event, false) {
403            if let WindowEvent::MouseInput { state, button, .. } = event {
404                if *button == winit::event::MouseButton::Left && state.is_pressed() {
405                    self.current_color = Self::generate_color();
406                }
407            }
408            return true;
409        }
410
411        if let WindowEvent::KeyboardInput { event, .. } = event {
412            return self.base.key_handler.handle_keyboard_input(core.window(), event);
413        }
414
415        false
416    }
417}
418
419impl FluidSim {
420    fn generate_color() -> [f32; 3] {
421        use std::collections::hash_map::RandomState;
422        use std::hash::{BuildHasher, Hasher};
423        let state = RandomState::new();
424        let mut hasher = state.build_hasher();
425        hasher.write_u64(std::time::SystemTime::now()
426            .duration_since(std::time::UNIX_EPOCH)
427            .unwrap()
428            .as_nanos() as u64);
429        let h = (hasher.finish() as f32) / (u64::MAX as f32);
430
431        let i = (h * 6.0).floor() as i32;
432        let f = h * 6.0 - i as f32;
433        let q = 1.0 - f;
434        let t = f;
435
436        let (r, g, b) = match i % 6 {
437            0 => (1.0, t, 0.0),
438            1 => (q, 1.0, 0.0),
439            2 => (0.0, 1.0, t),
440            3 => (0.0, q, 1.0),
441            4 => (t, 0.0, 1.0),
442            _ => (1.0, 0.0, q)};
443
444        [r * 0.2, g * 0.2, b * 0.2]
445    }
446}
447
448fn main() -> Result<(), Box<dyn std::error::Error>> {
449    env_logger::init();
450    let (app, event_loop) = ShaderApp::new("Fluid Sim", 1024, 768);
451    app.run(event_loop, FluidSim::init)
452}