use crate::expr::{trace, ExprId};
use super::buffer::GpuBuffer;
use super::device::GpuDevice;
use super::kernel::KernelCache;
pub struct FusedKernel {
wgsl: String,
n_inputs: usize,
n_outputs: usize, }
impl FusedKernel {
pub fn run(
&self,
device: &GpuDevice,
cache: &mut KernelCache,
inputs: &[f32],
) -> (f32, Vec<f32>) {
assert_eq!(
inputs.len(),
self.n_inputs,
"expected {} inputs, got {}",
self.n_inputs,
inputs.len()
);
let in_buf = GpuBuffer::from_slice(device, inputs);
let out_buf = GpuBuffer::uninit(device, self.n_outputs);
cache.dispatch(device, &self.wgsl, &in_buf, &out_buf, 1);
let result = out_buf.to_vec_sync(device);
let loss = result[0];
let grads = result[1..].to_vec();
(loss, grads)
}
pub fn n_inputs(&self) -> usize {
self.n_inputs
}
pub fn wgsl(&self) -> &str {
&self.wgsl
}
}
pub fn fused_forward_backward(n_inputs: usize, f: impl FnOnce(&[ExprId]) -> ExprId) -> FusedKernel {
let (mut graph, loss) = trace(|| {
let vars: Vec<ExprId> = (0..n_inputs as u16).map(ExprId::var).collect();
f(&vars)
});
let mut all_outputs = vec![loss];
for i in 0..n_inputs as u16 {
let grad = graph.diff(loss, i);
let grad = graph.simplify(grad);
all_outputs.push(grad);
}
all_outputs[0] = graph.simplify(loss);
let kernel = graph.to_wgsl(&all_outputs, n_inputs);
FusedKernel {
wgsl: kernel.source,
n_inputs,
n_outputs: all_outputs.len(),
}
}
pub fn forward_backward_gpu(
device: &GpuDevice,
cache: &mut KernelCache,
n_inputs: usize,
inputs: &[f32],
f: impl FnOnce(&[ExprId]) -> ExprId,
) -> (f32, Vec<f32>) {
let kernel = fused_forward_backward(n_inputs, f);
kernel.run(device, cache, inputs)
}