use super::device::GpuDevice;
use crate::synthesis::waveform::Waveform;
use crate::track::NoteEvent;
use anyhow::{Context, Result};
use wgpu::util::DeviceExt;
pub struct GpuSynthesizer {
device: GpuDevice,
compute_pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
}
impl std::fmt::Debug for GpuSynthesizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuSynthesizer")
.field("device", &self.device)
.finish()
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
struct GpuNoteParams {
frequency: f32,
duration: f32,
sample_rate: f32,
waveform: u32,
attack: f32,
decay: f32,
sustain: f32,
release: f32,
fm_enabled: u32,
fm_mod_ratio: f32,
fm_mod_index: f32,
velocity: f32,
_padding: u32,
}
unsafe impl bytemuck::Pod for GpuNoteParams {}
unsafe impl bytemuck::Zeroable for GpuNoteParams {}
impl GpuSynthesizer {
pub fn new(device: GpuDevice) -> Result<Self> {
let shader_source = include_str!("synthesis.wgsl");
let shader = device
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Synthesis Shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let bind_group_layout =
device
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Synthesis Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout =
device
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Synthesis Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let compute_pipeline =
device
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Synthesis Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
Ok(Self {
device,
compute_pipeline,
bind_group_layout,
})
}
pub fn synthesize_note(&self, note: &NoteEvent, sample_rate: f32) -> Result<Vec<f32>> {
let gpu_params = self.note_to_gpu_params(note, sample_rate);
let total_duration = note.envelope.total_duration(note.duration);
let total_samples = (total_duration * sample_rate) as usize;
let params_buffer =
self.device
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Note Params Buffer"),
contents: bytemuck::cast_slice(&[gpu_params]),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let output_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output Buffer"),
size: (total_samples * std::mem::size_of::<f32>()) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: (total_samples * std::mem::size_of::<f32>()) as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = self
.device
.device
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Synthesis Bind Group"),
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: params_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
},
],
});
let mut encoder =
self.device
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Synthesis Encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Synthesis Compute Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&self.compute_pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
let workgroups = (total_samples as u32).div_ceil(256);
compute_pass.dispatch_workgroups(workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(
&output_buffer,
0,
&staging_buffer,
0,
(total_samples * std::mem::size_of::<f32>()) as u64,
);
self.device.queue.submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).ok();
});
self.device.device.poll(wgpu::Maintain::Wait);
pollster::block_on(async { receiver.receive().await })
.context("Failed to map buffer")?
.context("Buffer mapping failed")?;
let data = buffer_slice.get_mapped_range();
let samples: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
staging_buffer.unmap();
Ok(samples)
}
fn note_to_gpu_params(&self, note: &NoteEvent, sample_rate: f32) -> GpuNoteParams {
let waveform_id = match note.waveform {
Waveform::Sine => 0,
Waveform::Sawtooth => 1,
Waveform::Square => 2,
Waveform::Triangle => 3,
};
let fm_enabled = if note.fm_params.mod_index > 0.0 { 1 } else { 0 };
GpuNoteParams {
frequency: note.frequencies[0], duration: note.duration,
sample_rate,
waveform: waveform_id,
attack: note.envelope.attack,
decay: note.envelope.decay,
sustain: note.envelope.sustain,
release: note.envelope.release,
fm_enabled,
fm_mod_ratio: note.fm_params.mod_ratio,
fm_mod_index: note.fm_params.mod_index,
velocity: note.velocity,
_padding: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::synthesis::envelope::Envelope;
use crate::synthesis::fm_synthesis::FMParams;
#[test]
fn test_gpu_synthesis() {
let device = match GpuDevice::new() {
Ok(d) => d,
Err(_) => {
println!("GPU not available, skipping test");
return;
}
};
let synthesizer = GpuSynthesizer::new(device).expect("Failed to create synthesizer");
let note = NoteEvent {
frequencies: [440.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
num_freqs: 1,
start_time: 0.0,
duration: 0.5,
waveform: Waveform::Sine,
envelope: Envelope::default(),
filter_envelope: Default::default(),
fm_params: FMParams::default(),
pitch_bend_semitones: 0.0,
custom_wavetable: None,
velocity: 1.0,
spatial_position: None,
};
let samples = synthesizer
.synthesize_note(¬e, 44100.0)
.expect("GPU synthesis failed");
assert!(!samples.is_empty());
assert!(samples.len() > 1000);
println!("✅ GPU synthesized {} samples", samples.len());
}
}