Skip to main content

synth/
synth.rs

1use cuneus::audio::PcmStreamManager;
2use cuneus::compute::*;
3use cuneus::prelude::*;
4use log::error;
5
6const MAX_SAMPLES_PER_FRAME: u32 = 1024;
7const SAMPLE_RATE: u32 = 44100;
8
9cuneus::uniform_params! {
10    struct SynthParams {
11        tempo: f32,
12        waveform_type: u32,
13        octave: f32,
14        volume: f32,
15        beat_enabled: u32,
16        reverb_mix: f32,
17        delay_time: f32,
18        delay_feedback: f32,
19        filter_cutoff: f32,
20        filter_resonance: f32,
21        distortion_amount: f32,
22        chorus_rate: f32,
23        chorus_depth: f32,
24        attack_time: f32,
25        decay_time: f32,
26        sustain_level: f32,
27        release_time: f32,
28        sample_offset: u32,
29        samples_to_generate: u32,
30        sample_rate: u32,
31        key_states: [[f32; 4]; 3],
32        key_decay: [[f32; 4]; 3],
33    }
34}
35
36struct SynthManager {
37    base: RenderKit,
38    compute_shader: ComputeShader,
39    current_params: SynthParams,
40    pcm_stream: Option<PcmStreamManager>,
41    keys_held: [bool; 9],
42    audio_start: std::time::Instant,
43    last_samples_generated: u32,
44}
45
46impl SynthManager {
47    /// key_states stores note-on time (>0 = pressed at this time, 0 = never pressed)
48    fn set_key_press_time(&mut self, key_index: usize, time: f32) {
49        if key_index < 9 {
50            self.current_params.key_states[key_index / 4][key_index % 4] = time;
51        }
52    }
53
54    /// key_decay stores release time (>0 = released at this time, 0 = still held or idle)
55    fn set_key_release_time(&mut self, key_index: usize, time: f32) {
56        if key_index < 9 {
57            self.current_params.key_decay[key_index / 4][key_index % 4] = time;
58        }
59    }
60}
61
62impl ShaderManager for SynthManager {
63    fn init(core: &Core) -> Self {
64        let base = RenderKit::new(core);
65
66        let initial_params = SynthParams {
67            tempo: 120.0,
68            waveform_type: 1,
69            octave: 4.0,
70            volume: 0.7,
71            beat_enabled: 0,
72            reverb_mix: 0.15,
73            delay_time: 0.3,
74            delay_feedback: 0.3,
75            filter_cutoff: 0.9,
76            filter_resonance: 0.1,
77            distortion_amount: 0.0,
78            chorus_rate: 2.0,
79            chorus_depth: 0.1,
80            attack_time: 0.02,
81            decay_time: 0.15,
82            sustain_level: 0.7,
83            release_time: 0.4,
84            sample_offset: 0,
85            samples_to_generate: MAX_SAMPLES_PER_FRAME,
86            sample_rate: SAMPLE_RATE,
87            key_states: [[0.0; 4]; 3],
88            key_decay: [[0.0; 4]; 3],
89        };
90
91        let audio_buffer_size = (MAX_SAMPLES_PER_FRAME * 2) as usize;
92
93        let config = ComputeShader::builder()
94            .with_entry_point("main")
95            .with_custom_uniforms::<SynthParams>()
96            .with_audio(audio_buffer_size)
97            .with_workgroup_size([16, 16, 1])
98            .with_texture_format(COMPUTE_TEXTURE_FORMAT_RGBA16)
99            .with_label("Synth")
100            .build();
101
102        let compute_shader = cuneus::compute_shader!(core, "shaders/synth.wgsl", config);
103        compute_shader.set_custom_params(initial_params, &core.queue);
104
105        let pcm_stream = match PcmStreamManager::new(Some(SAMPLE_RATE)) {
106            Ok(mut stream) => {
107                if let Err(e) = stream.start() {
108                    error!("Failed to start PCM stream: {e}");
109                    None
110                } else {
111                    Some(stream)
112                }
113            }
114            Err(e) => {
115                error!("Failed to create PCM stream: {e}");
116                None
117            }
118        };
119
120        Self {
121            base,
122            compute_shader,
123            current_params: initial_params,
124            pcm_stream,
125            keys_held: [false; 9],
126            audio_start: std::time::Instant::now(),
127            last_samples_generated: 0,
128        }
129    }
130
131    fn update(&mut self, core: &Core) {
132        let current_time = self.base.controls.get_time(&self.base.start_time);
133        let delta = 1.0 / 60.0;
134        self.compute_shader
135            .set_time(current_time, delta, &core.queue);
136
137        if let Some(ref mut stream) = self.pcm_stream {
138            stream.set_master_volume(self.current_params.volume as f64);
139
140            // Push previous frame's audio first
141            let prev = self.last_samples_generated;
142            if prev > 0 {
143                if let Ok(audio_data) = pollster::block_on(
144                    self.compute_shader
145                        .read_audio_buffer(&core.device, &core.queue),
146                ) {
147                    let count = (prev * 2) as usize;
148                    if audio_data.len() >= count {
149                        let _ = stream.push_samples(&audio_data[..count]);
150                    }
151                }
152            }
153
154            // Calculate this frame's needs
155            let elapsed = self.audio_start.elapsed().as_secs_f64();
156            let target_samples = (elapsed * SAMPLE_RATE as f64) as u64;
157            let written = stream.samples_written();
158            let needed = (target_samples.saturating_sub(written) as u32).min(MAX_SAMPLES_PER_FRAME);
159            self.current_params.sample_offset = written as u32;
160            self.current_params.samples_to_generate = needed;
161            self.last_samples_generated = needed;
162        }
163        self.compute_shader
164            .set_custom_params(self.current_params, &core.queue);
165    }
166
167    fn resize(&mut self, core: &Core) {
168        self.base.default_resize(core, &mut self.compute_shader);
169    }
170
171    fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
172        let mut frame = self.base.begin_frame(core)?;
173
174        let mut params = self.current_params;
175        let mut changed = false;
176        let mut controls_request = self
177            .base
178            .controls
179            .get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
180
181        let full_output = if self.base.key_handler.show_ui {
182            self.base.render_ui(core, |ctx| {
183                RenderKit::apply_default_style(ctx);
184
185                egui::Window::new("GPU Synth")
186                    .collapsible(true)
187                    .resizable(true)
188                    .default_width(280.0)
189                    .show(ctx, |ui| {
190                        egui::CollapsingHeader::new("Playback")
191                            .default_open(true)
192                            .show(ui, |ui| {
193                                ui.label("Keys 1-9: C D E F G A B C D");
194
195                                let mut beat_enabled = params.beat_enabled > 0;
196                                if ui.checkbox(&mut beat_enabled, "Background Beat").changed() {
197                                    params.beat_enabled = u32::from(beat_enabled);
198                                    changed = true;
199                                }
200
201                                changed |= ui
202                                    .add(egui::Slider::new(&mut params.tempo, 60.0..=180.0).text("Tempo"))
203                                    .changed();
204                                changed |= ui
205                                    .add(egui::Slider::new(&mut params.octave, 2.0..=7.0).text("Octave"))
206                                    .changed();
207                                changed |= ui
208                                    .add(egui::Slider::new(&mut params.volume, 0.0..=1.0).text("Volume"))
209                                    .changed();
210
211                                ui.horizontal(|ui| {
212                                    ui.label("Wave:");
213                                    for (i, name) in ["Sin", "Saw", "Sqr", "Tri", "Nse"].iter().enumerate() {
214                                        if ui.selectable_label(params.waveform_type == i as u32, *name).clicked() {
215                                            params.waveform_type = i as u32;
216                                            changed = true;
217                                        }
218                                    }
219                                });
220                            });
221
222                        egui::CollapsingHeader::new("Envelope (ADSR)")
223                            .default_open(true)
224                            .show(ui, |ui| {
225                                changed |= ui.add(egui::Slider::new(&mut params.attack_time, 0.001..=0.5).logarithmic(true).text("Attack").suffix("s")).changed();
226                                changed |= ui.add(egui::Slider::new(&mut params.decay_time, 0.01..=1.0).logarithmic(true).text("Decay").suffix("s")).changed();
227                                changed |= ui.add(egui::Slider::new(&mut params.sustain_level, 0.0..=1.0).text("Sustain")).changed();
228                                changed |= ui.add(egui::Slider::new(&mut params.release_time, 0.01..=2.0).logarithmic(true).text("Release").suffix("s")).changed();
229
230                                ui.separator();
231                                if ui.small_button("Piano").clicked() {
232                                    params.attack_time = 0.01; params.decay_time = 0.3;
233                                    params.sustain_level = 0.5; params.release_time = 0.8;
234                                    changed = true;
235                                }
236                                ui.horizontal(|ui| {
237                                    if ui.small_button("Pad").clicked() {
238                                        params.attack_time = 0.2; params.decay_time = 0.5;
239                                        params.sustain_level = 0.8; params.release_time = 1.5;
240                                        changed = true;
241                                    }
242                                    if ui.small_button("Pluck").clicked() {
243                                        params.attack_time = 0.005; params.decay_time = 0.1;
244                                        params.sustain_level = 0.3; params.release_time = 0.2;
245                                        changed = true;
246                                    }
247                                });
248                            });
249
250                        egui::CollapsingHeader::new("Filter")
251                            .default_open(false)
252                            .show(ui, |ui| {
253                                changed |= ui.add(egui::Slider::new(&mut params.filter_cutoff, 0.0..=1.0).text("Cutoff")).changed();
254                                changed |= ui.add(egui::Slider::new(&mut params.filter_resonance, 0.0..=0.9).text("Resonance")).changed();
255                            });
256
257                        egui::CollapsingHeader::new("Effects")
258                            .default_open(false)
259                            .show(ui, |ui| {
260                                changed |= ui.add(egui::Slider::new(&mut params.reverb_mix, 0.0..=0.8).text("Reverb")).changed();
261                                changed |= ui.add(egui::Slider::new(&mut params.delay_time, 0.01..=1.0).text("Delay")).changed();
262                                changed |= ui.add(egui::Slider::new(&mut params.delay_feedback, 0.0..=0.8).text("Feedback")).changed();
263                                changed |= ui.add(egui::Slider::new(&mut params.distortion_amount, 0.0..=0.9).text("Distortion")).changed();
264                                ui.separator();
265                                changed |= ui.add(egui::Slider::new(&mut params.chorus_rate, 0.1..=10.0).text("Chorus Rate")).changed();
266                                changed |= ui.add(egui::Slider::new(&mut params.chorus_depth, 0.0..=0.5).text("Chorus Depth")).changed();
267                            });
268
269                        ui.separator();
270                        ShaderControls::render_controls_widget(ui, &mut controls_request);
271                    });
272            })
273        } else {
274            self.base.render_ui(core, |_ctx| {})
275        };
276
277        if changed {
278            // Preserve audio fields that are managed by update()
279            params.sample_offset = self.current_params.sample_offset;
280            params.samples_to_generate = self.current_params.samples_to_generate;
281            params.sample_rate = self.current_params.sample_rate;
282            params.key_states = self.current_params.key_states;
283            params.key_decay = self.current_params.key_decay;
284            self.current_params = params;
285        }
286
287        self.base.apply_control_request(controls_request);
288
289        self.compute_shader.dispatch(&mut frame.encoder, core);
290
291        self.base.renderer.render_to_view(
292            &mut frame.encoder,
293            &frame.view,
294            &self.compute_shader.get_output_texture().bind_group,
295        );
296
297        self.base.end_frame(core, frame, full_output);
298
299        Ok(())
300    }
301
302    fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
303        if self
304            .base
305            .egui_state
306            .on_window_event(core.window(), event)
307            .consumed
308        {
309            return true;
310        }
311
312        if let WindowEvent::KeyboardInput { event, .. } = event {
313            if let winit::keyboard::Key::Character(ref s) = event.logical_key {
314                if let Some(key_index) = s.chars().next().and_then(|c| c.to_digit(10)) {
315                    if (1..=9).contains(&key_index) {
316                        let index = (key_index - 1) as usize;
317
318                        let current_time = self.base.controls.get_time(&self.base.start_time);
319                        if event.state == winit::event::ElementState::Pressed && !self.keys_held[index] {
320                            self.keys_held[index] = true;
321                            let has_previous = self.current_params.key_states[index / 4][index % 4] > 0.0;
322                            let in_release = self.current_params.key_decay[index / 4][index % 4] > 0.0;
323                            if has_previous && in_release {
324                                // Retrigger: just cancel the release, note continues from current level
325                                self.set_key_release_time(index, 0.0);
326                            } else {
327                                // Fresh note
328                                self.set_key_press_time(index, current_time);
329                                self.set_key_release_time(index, 0.0);
330                            }
331                            self.compute_shader
332                                .set_custom_params(self.current_params, &core.queue);
333                        } else if event.state == winit::event::ElementState::Released {
334                            self.keys_held[index] = false;
335                            // Store release time — shader ADSR handles the fade
336                            self.set_key_release_time(index, current_time);
337                            self.compute_shader
338                                .set_custom_params(self.current_params, &core.queue);
339                        }
340                        return true;
341                    }
342                }
343            }
344            return self
345                .base
346                .key_handler
347                .handle_keyboard_input(core.window(), event);
348        }
349
350        false
351    }
352}
353
354fn main() -> Result<(), Box<dyn std::error::Error>> {
355    env_logger::init();
356    cuneus::gst::init()?;
357
358    let (app, event_loop) = ShaderApp::new("Synth", 800, 600);
359    app.run(event_loop, SynthManager::init)
360}