use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::autograd::{
context, BackwardOp, CustomBackwardOp, TapeEntry, VersionSnapshot,
};
use crate::tensor::{AutogradState, DType, Layout, StorageHandle, Tensor, TensorMeta};
#[cfg(feature = "gpu")]
use crate::backend::gpu::context::{CustomOpKey, GpuContext, STORAGE_USAGE};
#[cfg(feature = "gpu")]
use wgpu::util::DeviceExt;
pub use crate::autograd::CustomBackward;
pub trait CustomOp: Send + Sync {
fn op_name(&self) -> &str;
fn wgsl_source(&self) -> &str;
fn entry_point(&self) -> &str;
fn num_inputs(&self) -> usize;
fn output_shape(&self, input_shapes: &[&[usize]]) -> Vec<usize>;
fn workgroup_size(&self) -> (u32, u32, u32) {
(256, 1, 1)
}
fn dispatch(&self, output_numel: usize) -> (u32, u32, u32) {
let wg = self.workgroup_size().0;
((output_numel as u32 + wg - 1) / wg, 1, 1)
}
fn uniform_data(&self, inputs: &[&Tensor]) -> Vec<u8> {
let numel = inputs.iter().map(|t| t.numel()).max().unwrap_or(0) as u32;
let mut buf = vec![0u8; 16];
buf[0..4].copy_from_slice(&numel.to_le_bytes());
buf
}
fn backward_handler(&self) -> Option<Arc<dyn CustomBackward>> {
None
}
fn save_for_backward<'a>(
&self,
inputs: &[&'a Tensor],
_output: &'a Tensor,
) -> Vec<&'a Tensor> {
inputs.iter().copied().collect()
}
}
#[cfg(feature = "gpu")]
pub fn custom_forward(op: &dyn CustomOp, inputs: &[&Tensor]) -> Tensor {
assert_eq!(
inputs.len(),
op.num_inputs(),
"custom_forward: expected {} inputs, got {}",
op.num_inputs(),
inputs.len(),
);
assert!(!inputs.is_empty(), "custom_forward: need at least one input");
let dtype = inputs[0].dtype();
let input_shapes: Vec<&[usize]> = inputs.iter().map(|t| t.shape()).collect();
let output_shape = op.output_shape(&input_shapes);
let out_numel: usize = output_shape.iter().product();
let ctx = GpuContext::get().expect("GPU required for custom_forward");
let key = CustomOpKey {
op_name: op.op_name().to_string(),
dtype_tag: match dtype {
DType::F32 => 0,
DType::F16 => 1,
DType::Q8 { .. } => panic!("custom_forward: Q8 inputs not supported"),
},
num_inputs: op.num_inputs(),
};
let wgsl = crate::backend::gpu::context::preprocess_shader(op.wgsl_source(), dtype);
let cached = ctx.custom_ops.get_or_compile(
&key,
&ctx.device,
&wgsl,
op.entry_point(),
op.num_inputs(),
);
let out_buf = ctx.pool.acquire(&ctx.device, dtype.gpu_buf_size(out_numel), STORAGE_USAGE);
let uniform_bytes = op.uniform_data(inputs);
let uniform_buf = ctx.device.create_buffer_init(
&wgpu::util::BufferInitDescriptor {
label: Some("custom_op_uniform"),
contents: &uniform_bytes,
usage: wgpu::BufferUsages::UNIFORM,
},
);
let n = op.num_inputs();
let input_guards: Vec<_> = inputs.iter().map(|t| {
t.storage.ensure_gpu();
t.storage.gpu_buffer()
}).collect();
let mut bg_entries: Vec<wgpu::BindGroupEntry<'_>> = Vec::with_capacity(n + 2);
for (i, guard) in input_guards.iter().enumerate() {
bg_entries.push(wgpu::BindGroupEntry {
binding: i as u32,
resource: guard.as_entire_binding(),
});
}
bg_entries.push(wgpu::BindGroupEntry {
binding: n as u32,
resource: out_buf.as_entire_binding(),
});
bg_entries.push(wgpu::BindGroupEntry {
binding: (n + 1) as u32,
resource: uniform_buf.as_entire_binding(),
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &cached.layout,
entries: &bg_entries,
label: Some("custom_op_bg"),
});
let (dx, dy, dz) = op.dispatch(out_numel);
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&cached.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(dx, dy, dz);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
drop(input_guards);
let out_storage = StorageHandle::new_gpu(out_buf, out_numel);
let out_layout = Layout::contiguous(output_shape.clone());
let output = Tensor::from_storage_and_layout(out_storage.clone(), out_layout.clone());
let has_bw = op.backward_handler();
let any_grad = inputs.iter().any(|t| t.requires_grad()) && !context::is_no_grad();
if let (Some(handler), true) = (has_bw, any_grad) {
let out_grad_id = context::next_grad_id();
let mut input_gids = Vec::with_capacity(n);
let mut input_versions = Vec::with_capacity(n);
for &inp in inputs {
let gid = inp.grad_id().unwrap_or_else(context::next_grad_id);
input_versions.push(VersionSnapshot::new(gid, &inp.storage));
if let Some(m) = inp.meta() {
m.total_grads.fetch_add(1, Ordering::Relaxed);
}
input_gids.push(gid);
}
let to_save = op.save_for_backward(inputs, &output);
let saved_storages: Vec<StorageHandle> = to_save.iter().map(|t| t.storage.clone()).collect();
let saved_layouts: Vec<Layout> = to_save.iter().map(|t| t.layout.clone()).collect();
let saved_shapes: Vec<Vec<usize>> = to_save.iter().map(|t| t.shape().to_vec()).collect();
let op_id = context::with_tape(|tape| {
tape.push(TapeEntry {
op: BackwardOp::Custom(CustomBackwardOp {
handler,
input_versions,
saved_storages,
saved_layouts,
saved_shapes,
}),
inputs: input_gids,
outputs: vec![out_grad_id],
})
});
Tensor {
storage: out_storage,
layout: Layout::contiguous(output_shape),
state: AutogradState::Tracked(Arc::new(TensorMeta {
requires_grad: true,
grad_id: Some(out_grad_id),
creator: op_id,
is_leaf: false,
retains_grad: false,
total_grads: AtomicUsize::new(0),
})),
}
} else {
output
}
}