use std::sync::Arc;
use bytemuck::{Pod, Zeroable};
use wgpu;
use super::buffers;
use super::context::GpuContext;
use crate::layer::DenseLayer;
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable)]
struct EncodeParams {
n_inputs: u32,
words_per_input: u32,
seed_lo: u32,
seed_hi: u32,
n_samples: u32,
length: u32,
_pad0: u32,
_pad1: u32,
}
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable)]
struct AccumParams {
n_inputs: u32,
n_neurons: u32,
words_per_input: u32,
inv_length: f32,
n_samples: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
pub struct GpuDenseLayer {
pub cpu: DenseLayer,
ctx: Arc<GpuContext>,
weight_buf: wgpu::Buffer,
input_prob_buf: wgpu::Buffer,
packed_input_buf: wgpu::Buffer,
output_buf: wgpu::Buffer,
output_staging: wgpu::Buffer,
encode_uniform_buf: wgpu::Buffer,
accum_uniform_buf: wgpu::Buffer,
max_batch_size: usize,
}
impl GpuDenseLayer {
pub fn try_new(
n_inputs: usize,
n_neurons: usize,
length: usize,
seed: u64,
max_batch: usize,
) -> Option<Self> {
let ctx = super::context::get_context()?;
let cpu = DenseLayer::new(n_inputs, n_neurons, length, seed);
let words = length.div_ceil(64);
let dev = &ctx.device;
let weight_bytes: &[u8] = bytemuck::cast_slice(cpu.packed_weights_flat());
let weight_buf = dev.create_buffer(&wgpu::BufferDescriptor {
label: Some("weights"),
size: weight_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&weight_buf, 0, weight_bytes);
let input_prob_size = (max_batch * n_inputs * 4) as u64; let packed_size = (max_batch * n_inputs * words * 8) as u64; let output_size = (max_batch * n_neurons * 4) as u64;
let input_prob_buf = buffers::storage_buffer(dev, "input_probs", input_prob_size, true);
let packed_input_buf = buffers::storage_buffer(dev, "packed_inputs", packed_size, false);
let output_buf = buffers::storage_buffer(dev, "output", output_size, false);
let output_staging = buffers::staging_buffer(dev, "output_staging", output_size);
let encode_uniform_buf = buffers::uniform_buffer(dev, "encode_params", 32);
let accum_uniform_buf = buffers::uniform_buffer(dev, "accum_params", 32);
Some(GpuDenseLayer {
cpu,
ctx,
weight_buf,
input_prob_buf,
packed_input_buf,
output_buf,
output_staging,
encode_uniform_buf,
accum_uniform_buf,
max_batch_size: max_batch,
})
}
pub fn forward_gpu(&self, inputs: &[f64], seed: u64) -> Vec<f64> {
self.forward_batch_gpu(inputs, 1, seed)
}
pub fn forward_batch_gpu(&self, inputs_flat: &[f64], n_samples: usize, seed: u64) -> Vec<f64> {
let n_inputs = self.cpu.n_inputs;
let n_neurons = self.cpu.n_neurons;
let words = self.cpu.words_per_input;
let length = self.cpu.length;
assert_eq!(inputs_flat.len(), n_samples * n_inputs);
assert!(
n_samples <= self.max_batch_size,
"Batch size {} exceeds max {}",
n_samples,
self.max_batch_size
);
let dev = &self.ctx.device;
let queue = &self.ctx.queue;
let inputs_f32: Vec<f32> = inputs_flat.iter().map(|&x| x as f32).collect();
queue.write_buffer(&self.input_prob_buf, 0, bytemuck::cast_slice(&inputs_f32));
let encode_params = EncodeParams {
n_inputs: n_inputs as u32,
words_per_input: words as u32,
seed_lo: seed as u32,
seed_hi: (seed >> 32) as u32,
n_samples: n_samples as u32,
length: length as u32,
_pad0: 0,
_pad1: 0,
};
queue.write_buffer(
&self.encode_uniform_buf,
0,
bytemuck::bytes_of(&encode_params),
);
let accum_params = AccumParams {
n_inputs: n_inputs as u32,
n_neurons: n_neurons as u32,
words_per_input: words as u32,
inv_length: self.cpu.inv_length as f32,
n_samples: n_samples as u32,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
queue.write_buffer(
&self.accum_uniform_buf,
0,
bytemuck::bytes_of(&accum_params),
);
let encode_bg = dev.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("encode_bg"),
layout: &self.ctx.encode_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.input_prob_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.packed_input_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: self.encode_uniform_buf.as_entire_binding(),
},
],
});
let accum_bg = dev.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("accum_bg"),
layout: &self.ctx.accumulate_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.weight_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.packed_input_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: self.output_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: self.accum_uniform_buf.as_entire_binding(),
},
],
});
let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("dense_forward"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("encode"),
timestamp_writes: None,
});
pass.set_pipeline(&self.ctx.encode_pipeline);
pass.set_bind_group(0, &encode_bg, &[]);
let x_groups = ((n_inputs * words) as u32).div_ceil(256);
pass.dispatch_workgroups(x_groups, n_samples as u32, 1);
}
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("accumulate"),
timestamp_writes: None,
});
pass.set_pipeline(&self.ctx.accumulate_pipeline);
pass.set_bind_group(0, &accum_bg, &[]);
pass.dispatch_workgroups(n_neurons as u32, n_samples as u32, 1);
}
let out_bytes = (n_samples * n_neurons * 4) as u64;
encoder.copy_buffer_to_buffer(&self.output_buf, 0, &self.output_staging, 0, out_bytes);
queue.submit(std::iter::once(encoder.finish()));
let slice = self.output_staging.slice(..out_bytes);
slice.map_async(wgpu::MapMode::Read, |_| {});
dev.poll(wgpu::Maintain::Wait);
let data = slice.get_mapped_range();
let output_f32: &[f32] = bytemuck::cast_slice(&data);
let result: Vec<f64> = output_f32[..n_samples * n_neurons]
.iter()
.map(|&x| x as f64)
.collect();
drop(data);
self.output_staging.unmap();
result
}
pub fn gpu_name(&self) -> &str {
&self.ctx.adapter_name
}
}