mod execute;
pub use execute::dispatch::PipelineCache;
#[cfg(test)]
mod tests;
use super::GpuDevice;
use std::collections::HashMap;
use std::sync::Arc;
use wgpu;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BufferId(pub(crate) usize);
#[derive(Debug)]
pub(crate) enum GpuOp {
Relu { input: BufferId, output: BufferId },
Scale { input: BufferId, output: BufferId, scalar: f32 },
Add { a: BufferId, b: BufferId, output: BufferId },
Mul { a: BufferId, b: BufferId, output: BufferId },
Dot {
a: BufferId,
b: BufferId,
output: BufferId, },
Sigmoid { input: BufferId, output: BufferId },
Tanh { input: BufferId, output: BufferId },
Swish { input: BufferId, output: BufferId },
Gelu { input: BufferId, output: BufferId },
Sub { a: BufferId, b: BufferId, output: BufferId },
Matmul { a: BufferId, b: BufferId, output: BufferId, m: u32, k: u32, n: u32 },
}
pub struct GpuCommandBatch {
pub(crate) device: Arc<GpuDevice>,
pub(crate) operations: Vec<GpuOp>,
pub(crate) buffers: HashMap<BufferId, BufferInfo>,
pub(crate) next_buffer_id: usize,
}
#[derive(Debug)]
pub(crate) struct BufferInfo {
pub(crate) size: usize,
pub(crate) data: Option<Vec<f32>>,
pub(crate) gpu_buffer: Option<Arc<wgpu::Buffer>>,
}
impl GpuCommandBatch {
pub fn new(device: GpuDevice) -> Self {
Self {
device: Arc::new(device),
operations: Vec::new(),
buffers: HashMap::new(),
next_buffer_id: 0,
}
}
fn alloc_buffer(&mut self, size: usize, data: Option<Vec<f32>>) -> BufferId {
let id = BufferId(self.next_buffer_id);
self.next_buffer_id += 1;
self.buffers.insert(id, BufferInfo { size, data, gpu_buffer: None });
id
}
pub fn upload(&mut self, data: &[f32]) -> BufferId {
self.alloc_buffer(data.len(), Some(data.to_vec()))
}
fn alloc_output(&mut self, size: usize) -> BufferId {
self.alloc_buffer(size, None)
}
pub fn relu(&mut self, input: BufferId) -> BufferId {
let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
let output = self.alloc_output(size);
self.operations.push(GpuOp::Relu { input, output });
output
}
pub fn scale(&mut self, input: BufferId, scalar: f32) -> BufferId {
let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
let output = self.alloc_output(size);
self.operations.push(GpuOp::Scale { input, output, scalar });
output
}
pub fn add(&mut self, a: BufferId, b: BufferId) -> BufferId {
let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
let output = self.alloc_output(size_a);
self.operations.push(GpuOp::Add { a, b, output });
output
}
pub fn mul(&mut self, a: BufferId, b: BufferId) -> BufferId {
let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
let output = self.alloc_output(size_a);
self.operations.push(GpuOp::Mul { a, b, output });
output
}
pub fn dot(&mut self, a: BufferId, b: BufferId) -> BufferId {
let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
let output = self.alloc_output(1);
self.operations.push(GpuOp::Dot { a, b, output });
output
}
pub fn sigmoid(&mut self, input: BufferId) -> BufferId {
let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
let output = self.alloc_output(size);
self.operations.push(GpuOp::Sigmoid { input, output });
output
}
pub fn tanh(&mut self, input: BufferId) -> BufferId {
let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
let output = self.alloc_output(size);
self.operations.push(GpuOp::Tanh { input, output });
output
}
pub fn swish(&mut self, input: BufferId) -> BufferId {
let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
let output = self.alloc_output(size);
self.operations.push(GpuOp::Swish { input, output });
output
}
pub fn gelu(&mut self, input: BufferId) -> BufferId {
let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
let output = self.alloc_output(size);
self.operations.push(GpuOp::Gelu { input, output });
output
}
pub fn sub(&mut self, a: BufferId, b: BufferId) -> BufferId {
let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
let output = self.alloc_output(size_a);
self.operations.push(GpuOp::Sub { a, b, output });
output
}
pub fn matmul(&mut self, a: BufferId, b: BufferId, m: u32, k: u32, n: u32) -> BufferId {
let size_a = self.buffers.get(&a).expect("Invalid buffer A ID").size;
let size_b = self.buffers.get(&b).expect("Invalid buffer B ID").size;
assert_eq!(
size_a,
(m * k) as usize,
"Buffer A size {} doesn't match M×K = {}",
size_a,
m * k
);
assert_eq!(
size_b,
(k * n) as usize,
"Buffer B size {} doesn't match K×N = {}",
size_b,
k * n
);
let output = self.alloc_output((m * n) as usize);
self.operations.push(GpuOp::Matmul { a, b, output, m, k, n });
output
}
pub fn import_buffer(&mut self, buffer: Arc<wgpu::Buffer>, size: usize) -> BufferId {
let id = BufferId(self.next_buffer_id);
self.next_buffer_id += 1;
self.buffers.insert(id, BufferInfo { size, data: None, gpu_buffer: Some(buffer) });
id
}
pub fn wgpu_device(&self) -> &wgpu::Device {
&self.device.device
}
pub fn wgpu_queue(&self) -> &wgpu::Queue {
&self.device.queue
}
pub fn num_operations(&self) -> usize {
self.operations.len()
}
pub fn num_buffers(&self) -> usize {
self.buffers.len()
}
}