Skip to main content

fft/
fft.rs

1use cuneus::compute::{
2    ComputeShader, PassDescription, StorageBufferSpec, COMPUTE_TEXTURE_FORMAT_RGBA16};
3use cuneus::{Core, ExportManager, RenderKit, ShaderControls, ShaderManager};
4use log::error;
5use cuneus::WindowEvent;
6
7cuneus::uniform_params! {
8    struct FFTParams {
9    filter_type: i32,
10    filter_strength: f32,
11    filter_direction: f32,
12    filter_radius: f32,
13    show_freqs: i32,
14    resolution: u32,
15    is_bw: i32,
16    _padding: u32}
17}
18
19struct FFTShader {
20    base: RenderKit,
21    compute_shader: ComputeShader,
22    should_initialize: bool,
23    current_params: FFTParams, // Store current parameters
24}
25
26impl ShaderManager for FFTShader {
27    fn init(core: &Core) -> Self {
28        let initial_params = FFTParams {
29            filter_type: 1,
30            filter_strength: 0.3,
31            filter_direction: 0.0,
32            filter_radius: 3.0,
33            show_freqs: 0,
34            resolution: 1024,
35            is_bw: 0,
36            _padding: 0};
37        let base = RenderKit::new(core);
38
39        // Define the FFT multi-pass pipeline
40        let passes = vec![
41            PassDescription::new("initialize_data", &[]), // Stage 0: Initialize from input texture
42            PassDescription::new("fft_horizontal", &["initialize_data"]), // Stage 1: FFT horizontal pass
43            PassDescription::new("fft_vertical", &["fft_horizontal"]), // Stage 2: FFT vertical pass
44            PassDescription::new("modify_frequencies", &["fft_vertical"]), // Stage 3: Apply frequency domain filters
45            PassDescription::new("ifft_horizontal", &["modify_frequencies"]), // Stage 4: Inverse FFT horizontal
46            PassDescription::new("ifft_vertical", &["ifft_horizontal"]), // Stage 5: Inverse FFT vertical
47            PassDescription::new("main_image", &["ifft_vertical"]),      // Stage 6: Final display
48        ];
49
50        let config = ComputeShader::builder()
51            .with_entry_point("initialize_data") // Start with data initialization
52            .with_multi_pass(&passes)
53            .with_input_texture() // Re-enable input texture support
54            .with_custom_uniforms::<FFTParams>()
55            .with_storage_buffer(StorageBufferSpec::new("image_data", 2048 * 2048 * 3 * 8)) // FFT working memory: max res to avoid crash
56            .with_workgroup_size([16, 16, 1])
57            .with_texture_format(COMPUTE_TEXTURE_FORMAT_RGBA16)
58            .with_label("FFT Multi-Pass")
59            .build();
60
61        let compute_shader = cuneus::compute_shader!(core, "shaders/fft.wgsl", config);
62
63        // Initialize custom uniform with initial parameters
64        compute_shader.set_custom_params(initial_params, &core.queue);
65
66        Self {
67            base,
68            compute_shader,
69            should_initialize: true,
70            current_params: initial_params}
71    }
72
73    fn update(&mut self, core: &Core) {
74        // Update time
75        let current_time = self.base.controls.get_time(&self.base.start_time);
76        let delta = 1.0 / 60.0;
77        self.compute_shader
78            .set_time(current_time, delta, &core.queue);
79
80        // Update input textures for image proc.
81        self.base.update_current_texture(core, &core.queue);
82        if let Some(texture_manager) = self.base.get_current_texture_manager() {
83            // Update input texture in unified ComputeShader
84            self.compute_shader.update_input_texture(
85                &texture_manager.view,
86                &texture_manager.sampler,
87                &core.device,
88            );
89        }
90        // Handle export
91        self.compute_shader.handle_export(core, &mut self.base);
92    }
93
94    fn resize(&mut self, core: &Core) {
95        self.compute_shader
96            .resize(core, core.size.width, core.size.height);
97    }
98
99    fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
100        let mut frame = self.base.begin_frame(core)?;
101
102        // Handle UI and controls - using original transparent UI design
103        let mut params = self.current_params;
104        let mut changed = false;
105        let mut should_start_export = false;
106        let mut export_request = self.base.export_manager.get_ui_request();
107        let mut controls_request = self
108            .base
109            .controls
110            .get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
111
112        let using_video_texture = self.base.using_video_texture;
113        let using_hdri_texture = self.base.using_hdri_texture;
114        let using_webcam_texture = self.base.using_webcam_texture;
115        let video_info = self.base.get_video_info();
116        let hdri_info = self.base.get_hdri_info();
117        let webcam_info = self.base.get_webcam_info();
118        let full_output = if self.base.key_handler.show_ui {
119            self.base.render_ui(core, |ctx| {
120                RenderKit::apply_default_style(ctx);
121
122                egui::Window::new("fourier workflow")
123                    .collapsible(true)
124                    .resizable(true)
125                    .default_width(250.0)
126                    .show(ctx, |ui| {
127                        // Media controls
128                        ShaderControls::render_media_panel(
129                            ui,
130                            &mut controls_request,
131                            using_video_texture,
132                            video_info,
133                            using_hdri_texture,
134                            hdri_info,
135                            using_webcam_texture,
136                            webcam_info,
137                        );
138
139                        ui.separator();
140
141                        egui::CollapsingHeader::new("FFT Settings")
142                            .default_open(false)
143                            .show(ui, |ui| {
144                                ui.label("Resolution:");
145
146                                ui.horizontal(|ui| {
147                                    changed |= ui
148                                        .radio_value(&mut params.resolution, 256, "256")
149                                        .changed();
150                                    changed |= ui
151                                        .radio_value(&mut params.resolution, 512, "512")
152                                        .changed();
153                                    changed |= ui
154                                        .radio_value(&mut params.resolution, 1024, "1024")
155                                        .changed();
156                                    changed |= ui
157                                        .radio_value(&mut params.resolution, 2048, "2048")
158                                        .changed();
159                                });
160
161                                if changed {
162                                    self.should_initialize = true;
163                                }
164
165                                ui.separator();
166                                ui.label("View Mode:");
167                                changed |= ui
168                                    .radio_value(&mut params.show_freqs, 0, "Filtered")
169                                    .changed();
170                                changed |= ui
171                                    .radio_value(&mut params.show_freqs, 1, "Frequency Domain")
172                                    .changed();
173                                
174                                let mut is_bw_bool = params.is_bw != 0;
175                                if ui.checkbox(&mut is_bw_bool, "Black & White").changed() {
176                                    params.is_bw = if is_bw_bool { 1 } else { 0 };
177                                    changed = true;
178                                }
179
180                                ui.separator();
181                            });
182
183                        egui::CollapsingHeader::new("Filter Settings")
184                            .default_open(false)
185                            .show(ui, |ui| {
186                                ui.label("Filter Type:");
187                                // Keep the improved ComboBox as requested
188                                changed |= egui::ComboBox::from_label("")
189                                    .selected_text(match params.filter_type {
190                                        0 => "LP",
191                                        1 => "HP",
192                                        2 => "BP",
193                                        3 => "Directional",
194                                        _ => "None"})
195                                    .show_ui(ui, |ui| {
196                                        ui.selectable_value(&mut params.filter_type, 0, "LP")
197                                            .changed()
198                                            || ui
199                                                .selectable_value(&mut params.filter_type, 1, "HP")
200                                                .changed()
201                                            || ui
202                                                .selectable_value(&mut params.filter_type, 2, "BP")
203                                                .changed()
204                                            || ui
205                                                .selectable_value(
206                                                    &mut params.filter_type,
207                                                    3,
208                                                    "Directional",
209                                                )
210                                                .changed()
211                                    })
212                                    .inner
213                                    .unwrap_or(false);
214
215                                ui.separator();
216
217                                changed |= ui
218                                    .add(
219                                        egui::Slider::new(&mut params.filter_strength, 0.0..=1.0)
220                                            .text("Filter Strength"),
221                                    )
222                                    .changed();
223
224                                if params.filter_type == 2 {
225                                    changed |= ui
226                                        .add(
227                                            egui::Slider::new(
228                                                &mut params.filter_radius,
229                                                0.0..=6.28,
230                                            )
231                                            .text("Band Radius"),
232                                        )
233                                        .changed();
234                                }
235
236                                if params.filter_type == 3 {
237                                    changed |= ui
238                                        .add(
239                                            egui::Slider::new(
240                                                &mut params.filter_direction,
241                                                0.0..=6.28,
242                                            )
243                                            .text("Direction"),
244                                        )
245                                        .changed();
246                                }
247                            });
248
249                        ui.separator();
250
251                        ShaderControls::render_controls_widget(ui, &mut controls_request);
252
253                        ui.separator();
254
255                        should_start_export =
256                            ExportManager::render_export_ui_widget(ui, &mut export_request);
257                    });
258            })
259        } else {
260            self.base.render_ui(core, |_ctx| {})
261        };
262
263        // Keep current parameters - don't reset to defaults
264        // The UI will modify 'params' directly, and we'll apply changes at the end
265
266        // Apply controls
267        self.base.apply_media_requests(core, &controls_request);
268
269        // Handle export requests
270        self.base.export_manager.apply_ui_request(export_request);
271        if should_start_export {
272            self.base.export_manager.start_export();
273        }
274
275        if controls_request.load_media_path.is_some() {
276            self.should_initialize = true;
277        }
278        if controls_request.start_webcam {
279            self.should_initialize = true;
280        }
281
282        // Apply parameter changes
283        if changed {
284            self.current_params = params;
285            self.compute_shader.set_custom_params(params, &core.queue);
286            self.should_initialize = true; // Trigger FFT reprocessing
287        }
288
289        // FFT dispatch - only run full pipeline when needed, otherwise just display
290        let mut should_run_full_fft = self.should_initialize
291            || self.base.using_video_texture
292            || self.base.using_webcam_texture
293            || changed; // Also run when parameters change
294
295        // FORCE run FFT if there's any texture to debug the issue
296        let has_any_texture = self.base.get_current_texture_manager().is_some();
297        if has_any_texture && !should_run_full_fft {
298            should_run_full_fft = true;
299        }
300        // Get FFT resolution for proper workgroup calculation
301        let n = params.resolution;
302        if should_run_full_fft {
303            // Stage 0: Initialize data from input texture (16x16 workgroups)
304            self.compute_shader.dispatch_stage_with_workgroups(
305                &mut frame.encoder,
306                0,
307                [n.div_ceil(16), n.div_ceil(16), 1],
308            );
309
310            // Stage 1: FFT horizontal (Nx1 workgroups)
311            self.compute_shader
312                .dispatch_stage_with_workgroups(&mut frame.encoder, 1, [n, 1, 1]);
313
314            // Stage 2: FFT vertical (Nx1 workgroups)
315            self.compute_shader
316                .dispatch_stage_with_workgroups(&mut frame.encoder, 2, [n, 1, 1]);
317
318            // Stage 3: Modify frequencies - apply filter (16x16 workgroups)
319            self.compute_shader.dispatch_stage_with_workgroups(
320                &mut frame.encoder,
321                3,
322                [n.div_ceil(16), n.div_ceil(16), 1],
323            );
324
325            if params.show_freqs == 0 {
326                // Stage 4: Inverse FFT horizontal
327                self.compute_shader
328                    .dispatch_stage_with_workgroups(&mut frame.encoder, 4, [n, 1, 1]);
329
330                // Stage 5: Inverse FFT vertical
331                self.compute_shader
332                    .dispatch_stage_with_workgroups(&mut frame.encoder, 5, [n, 1, 1]);
333            }
334
335            self.should_initialize = false;
336            log::info!("Completed full FFT pipeline");
337        } else {
338            log::debug!("Skipping full FFT pipeline - using cached result");
339        }
340
341        // Stage 6: Main rendering - always run for display (uses screen size)
342        self.compute_shader.dispatch_stage(&mut frame.encoder, core, 6);
343
344        self.base.renderer.render_to_view(&mut frame.encoder, &frame.view, &self.compute_shader.get_output_texture().bind_group);
345
346        self.base.end_frame(core, frame, full_output);
347
348        Ok(())
349    }
350
351    fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
352        if self.base.default_handle_input(core, event) {
353            return true;
354        }
355        if let WindowEvent::DroppedFile(path) = event {
356            if let Err(e) = self.base.load_media(core, path) {
357                error!("Failed to load dropped file: {e:?}");
358            } else {
359                self.should_initialize = true;
360            }
361            return true;
362        }
363        false
364    }
365}
366
367fn main() -> Result<(), Box<dyn std::error::Error>> {
368    cuneus::gst::init()?;
369    env_logger::init();
370    let (app, event_loop) = cuneus::ShaderApp::new("FFT", 800, 600);
371    app.run(event_loop, FFTShader::init)
372}