use super::{GpuBackend, GpuError, GpuTapeData, TapeMeta};
pub struct WgpuTapeBuffers {
pub(crate) opcodes_buf: wgpu::Buffer,
pub(crate) arg0_buf: wgpu::Buffer,
pub(crate) arg1_buf: wgpu::Buffer,
pub(crate) constants_buf: wgpu::Buffer,
pub(crate) output_indices_buf: wgpu::Buffer,
pub(crate) num_ops: u32,
pub(crate) num_inputs: u32,
pub(crate) num_variables: u32,
pub(crate) num_outputs: u32,
}
pub struct WgpuContext {
device: wgpu::Device,
queue: wgpu::Queue,
forward_pipeline: wgpu::ComputePipeline,
reverse_pipeline: wgpu::ComputePipeline,
tangent_fwd_pipeline: wgpu::ComputePipeline,
tangent_rev_pipeline: wgpu::ComputePipeline,
#[cfg(feature = "stde")]
taylor_fwd_kth_pipelines: [wgpu::ComputePipeline; 5],
tape_bind_group_layout: wgpu::BindGroupLayout,
forward_io_bind_group_layout: wgpu::BindGroupLayout,
reverse_io_bind_group_layout: wgpu::BindGroupLayout,
tangent_fwd_io_bind_group_layout: wgpu::BindGroupLayout,
tangent_rev_io_bind_group_layout: wgpu::BindGroupLayout,
taylor_fwd_2nd_io_bind_group_layout: wgpu::BindGroupLayout,
}
impl WgpuContext {
#[must_use]
pub fn new() -> Option<Self> {
pollster::block_on(Self::new_async())
}
async fn new_async() -> Option<Self> {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok()?;
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor::default())
.await
.ok()?;
let tape_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("echidna_tape_bgl"),
entries: &[
bgl_storage_ro(0),
bgl_storage_ro(1),
bgl_storage_ro(2),
bgl_storage_ro(3),
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
bgl_storage_ro(5),
],
});
let forward_io_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("echidna_fwd_io_bgl"),
entries: &[
bgl_storage_ro(0),
bgl_storage_rw(1),
bgl_storage_rw(2),
],
});
let reverse_io_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("echidna_rev_io_bgl"),
entries: &[
bgl_storage_ro(0),
bgl_storage_rw(1),
bgl_storage_rw(2),
],
});
let fwd_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("echidna_forward_pl"),
bind_group_layouts: &[&tape_bind_group_layout, &forward_io_bind_group_layout],
immediate_size: 0,
});
let fwd_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("echidna_forward_shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/forward.wgsl").into()),
});
let forward_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("echidna_forward_pipeline"),
layout: Some(&fwd_layout),
module: &fwd_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let rev_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("echidna_reverse_pl"),
bind_group_layouts: &[&tape_bind_group_layout, &reverse_io_bind_group_layout],
immediate_size: 0,
});
let rev_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("echidna_reverse_shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/reverse.wgsl").into()),
});
let reverse_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("echidna_reverse_pipeline"),
layout: Some(&rev_layout),
module: &rev_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let tangent_fwd_io_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("echidna_tfwd_io_bgl"),
entries: &[
bgl_storage_ro(0), bgl_storage_ro(1), bgl_storage_rw(2), bgl_storage_rw(3), bgl_storage_rw(4), ],
});
let tfwd_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("echidna_tangent_fwd_pl"),
bind_group_layouts: &[&tape_bind_group_layout, &tangent_fwd_io_bind_group_layout],
immediate_size: 0,
});
let tfwd_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("echidna_tangent_fwd_shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/tangent_forward.wgsl").into()),
});
let tangent_fwd_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("echidna_tangent_fwd_pipeline"),
layout: Some(&tfwd_layout),
module: &tfwd_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let tangent_rev_io_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("echidna_trev_io_bgl"),
entries: &[
bgl_storage_ro(0), bgl_storage_ro(1), bgl_storage_rw(2), bgl_storage_rw(3), bgl_storage_rw(4), bgl_storage_rw(5), bgl_storage_rw(6), bgl_storage_rw(7), ],
});
let trev_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("echidna_tangent_rev_pl"),
bind_group_layouts: &[&tape_bind_group_layout, &tangent_rev_io_bind_group_layout],
immediate_size: 0,
});
let trev_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("echidna_tangent_rev_shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/tangent_reverse.wgsl").into()),
});
let tangent_rev_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("echidna_tangent_rev_pipeline"),
layout: Some(&trev_layout),
module: &trev_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let taylor_fwd_2nd_io_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("echidna_taylor2_io_bgl"),
entries: &[
bgl_storage_ro(0), bgl_storage_ro(1), bgl_storage_rw(2), bgl_storage_rw(3), ],
});
let taylor2_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("echidna_taylor_fwd_2nd_pl"),
bind_group_layouts: &[
&tape_bind_group_layout,
&taylor_fwd_2nd_io_bind_group_layout,
],
immediate_size: 0,
});
#[cfg(feature = "stde")]
let taylor_fwd_kth_pipelines = {
use super::taylor_codegen::generate_taylor_wgsl;
std::array::from_fn(|idx| {
let k = idx + 1;
let wgsl_src = generate_taylor_wgsl(k);
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(&format!("echidna_taylor_fwd_k{k}_shader")),
source: wgpu::ShaderSource::Wgsl(wgsl_src.into()),
});
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(&format!("echidna_taylor_fwd_k{k}_pipeline")),
layout: Some(&taylor2_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
})
})
};
Some(WgpuContext {
device,
queue,
forward_pipeline,
reverse_pipeline,
tangent_fwd_pipeline,
tangent_rev_pipeline,
#[cfg(feature = "stde")]
taylor_fwd_kth_pipelines,
tape_bind_group_layout,
forward_io_bind_group_layout,
reverse_io_bind_group_layout,
tangent_fwd_io_bind_group_layout,
tangent_rev_io_bind_group_layout,
taylor_fwd_2nd_io_bind_group_layout,
})
}
fn create_tape_bind_group(
&self,
tape: &WgpuTapeBuffers,
meta_buf: &wgpu::Buffer,
) -> wgpu::BindGroup {
self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("tape_bg"),
layout: &self.tape_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: tape.opcodes_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: tape.arg0_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: tape.arg1_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: tape.constants_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: meta_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: tape.output_indices_buf.as_entire_binding(),
},
],
})
}
#[cfg(feature = "stde")]
pub fn taylor_forward_kth_batch(
&self,
tape: &WgpuTapeBuffers,
primal_inputs: &[f32],
direction_seeds: &[f32],
batch_size: u32,
order: usize,
) -> Result<super::TaylorKthBatchResult<f32>, GpuError> {
use wgpu::util::DeviceExt;
if !(1..=5).contains(&order) {
return Err(GpuError::Other(format!(
"unsupported Taylor order {order}, must be 1..=5"
)));
}
let k = order as u32;
let ni = tape.num_inputs;
let nv = tape.num_variables;
let no = tape.num_outputs;
assert!(
(batch_size as u64) * (nv as u64) * (k as u64) <= u32::MAX as u64,
"batch_size * num_variables * order overflows u32 in WGSL shader index arithmetic"
);
let total_inputs = (batch_size as usize) * (ni as usize);
assert_eq!(
primal_inputs.len(),
total_inputs,
"primal_inputs length mismatch"
);
assert_eq!(
direction_seeds.len(),
total_inputs,
"direction_seeds length mismatch"
);
let meta = TapeMeta {
num_ops: tape.num_ops,
num_inputs: ni,
num_variables: nv,
num_outputs: no,
batch_size,
_pad: [0; 3],
};
let meta_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("taylor_kth_meta"),
contents: bytemuck::bytes_of(&meta),
usage: wgpu::BufferUsages::UNIFORM,
});
let primal_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("taylor_kth_primals"),
contents: bytemuck::cast_slice(primal_inputs),
usage: wgpu::BufferUsages::STORAGE,
});
let seed_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("taylor_kth_seeds"),
contents: bytemuck::cast_slice(direction_seeds),
usage: wgpu::BufferUsages::STORAGE,
});
let jets_size = (batch_size as u64) * (nv as u64) * (k as u64) * 4;
let jets_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("taylor_kth_jets"),
size: jets_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let out_count = (batch_size as u64) * (no as u64) * (k as u64);
let out_size = out_count * 4;
let jet_out_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("taylor_kth_jet_out"),
size: out_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("taylor_kth_staging"),
size: out_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let tape_bg = self.create_tape_bind_group(tape, &meta_buf);
let io_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("taylor_kth_io_bg"),
layout: &self.taylor_fwd_2nd_io_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: primal_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: seed_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: jets_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: jet_out_buf.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("taylor_kth_enc"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("taylor_kth_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.taylor_fwd_kth_pipelines[order - 1]);
pass.set_bind_group(0, &tape_bg, &[]);
pass.set_bind_group(1, &io_bg, &[]);
pass.dispatch_workgroups(batch_size.div_ceil(256), 1, 1);
}
encoder.copy_buffer_to_buffer(&jet_out_buf, 0, &staging_buf, 0, out_size);
let sub_idx = self.queue.submit(std::iter::once(encoder.finish()));
let slice = staging_buf.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
self.device
.poll(wgpu::PollType::Wait {
submission_index: Some(sub_idx),
timeout: None,
})
.map_err(|e| GpuError::Other(format!("device poll failed: {e}")))?;
rx.recv()
.map_err(|e| GpuError::Other(format!("channel recv failed: {e}")))?
.map_err(|e| GpuError::Other(format!("buffer map failed: {e}")))?;
let data = slice.get_mapped_range();
let raw: &[f32] = bytemuck::cast_slice(&data);
let total_out = (batch_size as usize) * (no as usize);
let mut coefficients: Vec<Vec<f32>> =
(0..order).map(|_| Vec::with_capacity(total_out)).collect();
for i in 0..total_out {
for c in 0..order {
coefficients[c].push(raw[i * order + c]);
}
}
drop(data);
staging_buf.unmap();
Ok(super::TaylorKthBatchResult {
coefficients,
order,
})
}
}
impl GpuBackend for WgpuContext {
type TapeBuffers = WgpuTapeBuffers;
fn num_outputs(&self, tape: &WgpuTapeBuffers) -> u32 {
tape.num_outputs
}
fn upload_tape(&self, data: &GpuTapeData) -> WgpuTapeBuffers {
use wgpu::util::DeviceExt;
let opcodes_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("opcodes"),
contents: bytemuck::cast_slice(&data.opcodes),
usage: wgpu::BufferUsages::STORAGE,
});
let arg0_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("arg0"),
contents: bytemuck::cast_slice(&data.arg0),
usage: wgpu::BufferUsages::STORAGE,
});
let arg1_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("arg1"),
contents: bytemuck::cast_slice(&data.arg1),
usage: wgpu::BufferUsages::STORAGE,
});
let constants_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("constants"),
contents: bytemuck::cast_slice(&data.constants),
usage: wgpu::BufferUsages::STORAGE,
});
let num_outputs = if data.output_indices.is_empty() {
1u32
} else {
data.output_indices.len() as u32
};
let output_indices = if data.output_indices.is_empty() {
vec![data.output_index]
} else {
data.output_indices.clone()
};
let output_indices_buf =
self.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("output_indices"),
contents: bytemuck::cast_slice(&output_indices),
usage: wgpu::BufferUsages::STORAGE,
});
WgpuTapeBuffers {
opcodes_buf,
arg0_buf,
arg1_buf,
constants_buf,
output_indices_buf,
num_ops: data.num_ops,
num_inputs: data.num_inputs,
num_variables: data.num_variables,
num_outputs,
}
}
fn forward_batch(
&self,
tape: &WgpuTapeBuffers,
inputs: &[f32],
batch_size: u32,
) -> Result<Vec<f32>, GpuError> {
use wgpu::util::DeviceExt;
let num_inputs = tape.num_inputs;
let num_variables = tape.num_variables;
let num_outputs = tape.num_outputs;
assert_eq!(
inputs.len(),
(batch_size as usize) * (num_inputs as usize),
"inputs length must be batch_size * num_inputs"
);
assert!(
(batch_size as u64) * (num_variables as u64) <= u32::MAX as u64,
"batch_size * num_variables overflows u32 in WGSL shader index arithmetic"
);
let meta = TapeMeta {
num_ops: tape.num_ops,
num_inputs,
num_variables,
num_outputs,
batch_size,
_pad: [0; 3],
};
let meta_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("tape_meta"),
contents: bytemuck::bytes_of(&meta),
usage: wgpu::BufferUsages::UNIFORM,
});
let input_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("inputs"),
contents: bytemuck::cast_slice(inputs),
usage: wgpu::BufferUsages::STORAGE,
});
let values_size = (batch_size as u64) * (num_variables as u64) * 4;
let values_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("values"),
size: values_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let output_size = (batch_size as u64) * (num_outputs as u64) * 4;
let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("outputs"),
size: output_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging"),
size: output_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let tape_bg = self.create_tape_bind_group(tape, &meta_buf);
let io_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("io_bg"),
layout: &self.forward_io_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: values_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: output_buf.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("forward_enc"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("forward_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.forward_pipeline);
pass.set_bind_group(0, &tape_bg, &[]);
pass.set_bind_group(1, &io_bg, &[]);
pass.dispatch_workgroups(batch_size.div_ceil(256), 1, 1);
}
encoder.copy_buffer_to_buffer(&output_buf, 0, &staging_buf, 0, output_size);
let sub_idx = self.queue.submit(std::iter::once(encoder.finish()));
let slice = staging_buf.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
self.device
.poll(wgpu::PollType::Wait {
submission_index: Some(sub_idx),
timeout: None,
})
.map_err(|e| GpuError::Other(format!("device poll failed: {e}")))?;
rx.recv()
.map_err(|e| GpuError::Other(format!("channel recv failed: {e}")))?
.map_err(|e| GpuError::Other(format!("buffer map failed: {e}")))?;
let data = slice.get_mapped_range();
let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
staging_buf.unmap();
Ok(result)
}
fn gradient_batch(
&self,
tape: &WgpuTapeBuffers,
inputs: &[f32],
batch_size: u32,
) -> Result<(Vec<f32>, Vec<f32>), GpuError> {
use wgpu::util::DeviceExt;
let num_inputs = tape.num_inputs;
let num_variables = tape.num_variables;
let num_outputs = tape.num_outputs;
assert_eq!(
inputs.len(),
(batch_size as usize) * (num_inputs as usize),
"inputs length must be batch_size * num_inputs"
);
assert!(
(batch_size as u64) * (num_variables as u64) <= u32::MAX as u64,
"batch_size * num_variables overflows u32 in WGSL shader index arithmetic"
);
let meta = TapeMeta {
num_ops: tape.num_ops,
num_inputs,
num_variables,
num_outputs,
batch_size,
_pad: [0; 3],
};
let meta_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("tape_meta"),
contents: bytemuck::bytes_of(&meta),
usage: wgpu::BufferUsages::UNIFORM,
});
let input_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("inputs"),
contents: bytemuck::cast_slice(inputs),
usage: wgpu::BufferUsages::STORAGE,
});
let values_size = (batch_size as u64) * (num_variables as u64) * 4;
let values_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("values"),
size: values_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let output_count = (batch_size as u64) * (num_outputs as u64);
let output_size = output_count * 4;
let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("outputs"),
size: output_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let adjoints_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("adjoints"),
size: values_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let grad_count = (batch_size as u64) * (num_inputs as u64);
let grad_size = grad_count * 4;
let grad_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("grad_out"),
size: grad_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let output_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("output_staging"),
size: output_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let grad_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("grad_staging"),
size: grad_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let tape_bg = self.create_tape_bind_group(tape, &meta_buf);
let fwd_io_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("fwd_io_bg"),
layout: &self.forward_io_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: values_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: output_buf.as_entire_binding(),
},
],
});
let rev_io_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rev_io_bg"),
layout: &self.reverse_io_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: values_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: adjoints_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: grad_buf.as_entire_binding(),
},
],
});
let workgroups = batch_size.div_ceil(256);
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gradient_enc"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("forward_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.forward_pipeline);
pass.set_bind_group(0, &tape_bg, &[]);
pass.set_bind_group(1, &fwd_io_bg, &[]);
pass.dispatch_workgroups(workgroups, 1, 1);
}
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("reverse_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.reverse_pipeline);
pass.set_bind_group(0, &tape_bg, &[]);
pass.set_bind_group(1, &rev_io_bg, &[]);
pass.dispatch_workgroups(workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(&output_buf, 0, &output_staging, 0, output_size);
encoder.copy_buffer_to_buffer(&grad_buf, 0, &grad_staging, 0, grad_size);
let sub_idx = self.queue.submit(std::iter::once(encoder.finish()));
let out_slice = output_staging.slice(..);
let grad_slice = grad_staging.slice(..);
let (tx1, rx1) = std::sync::mpsc::channel();
let (tx2, rx2) = std::sync::mpsc::channel();
out_slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = tx1.send(r);
});
grad_slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = tx2.send(r);
});
self.device
.poll(wgpu::PollType::Wait {
submission_index: Some(sub_idx),
timeout: None,
})
.map_err(|e| GpuError::Other(format!("device poll failed: {e}")))?;
rx1.recv()
.map_err(|e| GpuError::Other(format!("channel recv failed: {e}")))?
.map_err(|e| GpuError::Other(format!("output map failed: {e}")))?;
rx2.recv()
.map_err(|e| GpuError::Other(format!("channel recv failed: {e}")))?
.map_err(|e| GpuError::Other(format!("grad map failed: {e}")))?;
let out_data = out_slice.get_mapped_range();
let outputs: Vec<f32> = bytemuck::cast_slice(&out_data).to_vec();
drop(out_data);
output_staging.unmap();
let grad_data = grad_slice.get_mapped_range();
let grads: Vec<f32> = bytemuck::cast_slice(&grad_data).to_vec();
drop(grad_data);
grad_staging.unmap();
Ok((outputs, grads))
}
fn sparse_jacobian(
&self,
tape: &WgpuTapeBuffers,
tape_cpu: &mut crate::BytecodeTape<f32>,
x: &[f32],
) -> Result<(Vec<f32>, crate::sparse::JacobianSparsityPattern, Vec<f32>), GpuError> {
use wgpu::util::DeviceExt;
let num_inputs = tape.num_inputs as usize;
let num_outputs = tape.num_outputs as usize;
let num_variables = tape.num_variables;
let pattern = tape_cpu.detect_jacobian_sparsity();
let (colors, num_colors) = crate::sparse::column_coloring(&pattern);
if num_colors == 0 {
tape_cpu.forward(x);
let vals = tape_cpu.output_values();
let vals_f32: Vec<f32> = vals.to_vec();
return Ok((vals_f32, pattern, vec![]));
}
assert!(
(num_colors as u64) * (num_variables as u64) <= u32::MAX as u64,
"num_colors * num_variables overflows u32 in WGSL shader index arithmetic"
);
let batch = num_colors;
let mut primal_inputs = Vec::with_capacity(batch as usize * num_inputs);
let mut tangent_seeds = Vec::with_capacity(batch as usize * num_inputs);
for c in 0..num_colors {
for i in 0..num_inputs {
primal_inputs.push(x[i]);
tangent_seeds.push(if colors[i] == c { 1.0f32 } else { 0.0f32 });
}
}
let meta = TapeMeta {
num_ops: tape.num_ops,
num_inputs: tape.num_inputs,
num_variables,
num_outputs: tape.num_outputs,
batch_size: batch,
_pad: [0; 3],
};
let meta_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("tape_meta"),
contents: bytemuck::bytes_of(&meta),
usage: wgpu::BufferUsages::UNIFORM,
});
let primal_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("primal_inputs"),
contents: bytemuck::cast_slice(&primal_inputs),
usage: wgpu::BufferUsages::STORAGE,
});
let seed_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("tangent_seeds"),
contents: bytemuck::cast_slice(&tangent_seeds),
usage: wgpu::BufferUsages::STORAGE,
});
let buf_size = (batch as u64) * (num_variables as u64) * 4;
let primals_work = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("primals_work"),
size: buf_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let tangents_work = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("tangents_work"),
size: buf_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let out_size = (batch as u64) * (tape.num_outputs as u64) * 4;
let tangent_out_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("tangent_outputs"),
size: out_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging"),
size: out_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let tape_bg = self.create_tape_bind_group(tape, &meta_buf);
let io_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("tfwd_io_bg"),
layout: &self.tangent_fwd_io_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: primal_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: seed_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: primals_work.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: tangents_work.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: tangent_out_buf.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sparse_jac_enc"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("tangent_fwd_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.tangent_fwd_pipeline);
pass.set_bind_group(0, &tape_bg, &[]);
pass.set_bind_group(1, &io_bg, &[]);
pass.dispatch_workgroups(batch.div_ceil(256), 1, 1);
}
encoder.copy_buffer_to_buffer(&tangent_out_buf, 0, &staging, 0, out_size);
let sub_idx = self.queue.submit(std::iter::once(encoder.finish()));
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = tx.send(r);
});
self.device
.poll(wgpu::PollType::Wait {
submission_index: Some(sub_idx),
timeout: None,
})
.map_err(|e| GpuError::Other(format!("device poll failed: {e}")))?;
rx.recv()
.map_err(|e| GpuError::Other(format!("recv: {e}")))?
.map_err(|e| GpuError::Other(format!("map: {e}")))?;
let data = slice.get_mapped_range();
let tangent_results: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
staging.unmap();
let nnz = pattern.nnz();
let mut jac_values = vec![0.0f32; nnz];
for (k, (&row, &col)) in pattern.rows.iter().zip(pattern.cols.iter()).enumerate() {
let o = row as usize; let i = col as usize; let c = colors[i] as usize;
jac_values[k] = tangent_results[c * num_outputs + o];
}
tape_cpu.forward(x);
let output_values: Vec<f32> = tape_cpu.output_values();
Ok((output_values, pattern, jac_values))
}
fn hvp_batch(
&self,
tape: &WgpuTapeBuffers,
x: &[f32],
tangent_dirs: &[f32],
batch_size: u32,
) -> Result<(Vec<f32>, Vec<f32>), GpuError> {
use wgpu::util::DeviceExt;
let ni = tape.num_inputs;
let nv = tape.num_variables;
assert_eq!(x.len(), ni as usize);
assert_eq!(tangent_dirs.len(), (batch_size as usize) * (ni as usize));
assert!(
(batch_size as u64) * (nv as u64) <= u32::MAX as u64,
"batch_size * num_variables overflows u32 in WGSL shader index arithmetic"
);
let mut primal_inputs = Vec::with_capacity((batch_size as usize) * (ni as usize));
for _ in 0..batch_size {
primal_inputs.extend_from_slice(x);
}
let meta = TapeMeta {
num_ops: tape.num_ops,
num_inputs: ni,
num_variables: nv,
num_outputs: tape.num_outputs,
batch_size,
_pad: [0; 3],
};
let meta_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("meta"),
contents: bytemuck::bytes_of(&meta),
usage: wgpu::BufferUsages::UNIFORM,
});
let primal_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("primals_in"),
contents: bytemuck::cast_slice(&primal_inputs),
usage: wgpu::BufferUsages::STORAGE,
});
let seed_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("seeds"),
contents: bytemuck::cast_slice(tangent_dirs),
usage: wgpu::BufferUsages::STORAGE,
});
let buf_size = (batch_size as u64) * (nv as u64) * 4;
let grad_size = (batch_size as u64) * (ni as u64) * 4;
let primals_work = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("pw"),
size: buf_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let tangents_work = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("tw"),
size: buf_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let adj_re_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ar"),
size: buf_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let adj_eps_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ae"),
size: buf_size,
usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
});
let grad_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("go"),
size: grad_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let hvp_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ho"),
size: grad_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let grad_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("gs"),
size: grad_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let hvp_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("hs"),
size: grad_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let tape_bg = self.create_tape_bind_group(tape, &meta_buf);
let io_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("trev_io"),
layout: &self.tangent_rev_io_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: primal_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: seed_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: primals_work.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: tangents_work.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: adj_re_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: adj_eps_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: grad_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 7,
resource: hvp_buf.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("hvp_enc"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("trev_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.tangent_rev_pipeline);
pass.set_bind_group(0, &tape_bg, &[]);
pass.set_bind_group(1, &io_bg, &[]);
pass.dispatch_workgroups(batch_size.div_ceil(256), 1, 1);
}
encoder.copy_buffer_to_buffer(&grad_buf, 0, &grad_staging, 0, grad_size);
encoder.copy_buffer_to_buffer(&hvp_buf, 0, &hvp_staging, 0, grad_size);
let sub_idx = self.queue.submit(std::iter::once(encoder.finish()));
let gs = grad_staging.slice(..);
let hs = hvp_staging.slice(..);
let (tx1, rx1) = std::sync::mpsc::channel();
let (tx2, rx2) = std::sync::mpsc::channel();
gs.map_async(wgpu::MapMode::Read, move |r| {
let _ = tx1.send(r);
});
hs.map_async(wgpu::MapMode::Read, move |r| {
let _ = tx2.send(r);
});
self.device
.poll(wgpu::PollType::Wait {
submission_index: Some(sub_idx),
timeout: None,
})
.map_err(|e| GpuError::Other(format!("device poll failed: {e}")))?;
rx1.recv()
.map_err(|e| GpuError::Other(format!("{e}")))?
.map_err(|e| GpuError::Other(format!("{e}")))?;
rx2.recv()
.map_err(|e| GpuError::Other(format!("{e}")))?
.map_err(|e| GpuError::Other(format!("{e}")))?;
let gd = gs.get_mapped_range();
let grads: Vec<f32> = bytemuck::cast_slice(&gd).to_vec();
drop(gd);
grad_staging.unmap();
let hd = hs.get_mapped_range();
let hvps: Vec<f32> = bytemuck::cast_slice(&hd).to_vec();
drop(hd);
hvp_staging.unmap();
Ok((grads, hvps))
}
fn sparse_hessian(
&self,
tape: &WgpuTapeBuffers,
tape_cpu: &mut crate::BytecodeTape<f32>,
x: &[f32],
) -> Result<(f32, Vec<f32>, crate::sparse::SparsityPattern, Vec<f32>), GpuError> {
let ni = tape.num_inputs as usize;
let pattern = tape_cpu.detect_sparsity();
let (colors, num_colors) = crate::sparse::greedy_coloring(&pattern);
if num_colors == 0 {
tape_cpu.forward(x);
let val = tape_cpu.output_value();
let grad = tape_cpu.gradient(x);
return Ok((val, grad, pattern, vec![]));
}
let batch = num_colors;
let mut tangent_dirs = Vec::with_capacity(batch as usize * ni);
for c in 0..num_colors {
for &color in &colors[..ni] {
tangent_dirs.push(if color == c { 1.0f32 } else { 0.0f32 });
}
}
let (grads, hvps) = self.hvp_batch(tape, x, &tangent_dirs, batch)?;
let gradient: Vec<f32> = grads[..ni].to_vec();
let nnz = pattern.nnz();
let mut hess_values = vec![0.0f32; nnz];
for (k, (&row, &col)) in pattern.rows.iter().zip(pattern.cols.iter()).enumerate() {
let i = row as usize;
let j = col as usize;
let c = colors[j] as usize;
hess_values[k] = hvps[c * ni + i];
}
tape_cpu.forward(x);
let value = tape_cpu.output_value();
Ok((value, gradient, pattern, hess_values))
}
#[cfg(feature = "stde")]
fn taylor_forward_kth_batch(
&self,
tape: &WgpuTapeBuffers,
primal_inputs: &[f32],
direction_seeds: &[f32],
batch_size: u32,
order: usize,
) -> Result<super::TaylorKthBatchResult<f32>, GpuError> {
self.taylor_forward_kth_batch(tape, primal_inputs, direction_seeds, batch_size, order)
}
}
fn bgl_storage_ro(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
fn bgl_storage_rw(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}