use crate::bytecode_tape::BytecodeTape;
use crate::opcode::OpCode;
#[cfg(feature = "gpu-wgpu")]
pub mod wgpu_backend;
#[cfg(feature = "gpu-cuda")]
pub mod cuda_backend;
#[cfg(feature = "stde")]
pub mod stde_gpu;
#[cfg(feature = "stde")]
pub mod taylor_codegen;
#[cfg(feature = "gpu-wgpu")]
pub use wgpu_backend::{WgpuContext, WgpuTapeBuffers};
#[cfg(feature = "gpu-cuda")]
pub use cuda_backend::{CudaContext, CudaTapeBuffers};
pub trait GpuBackend {
type TapeBuffers;
fn upload_tape(&self, data: &GpuTapeData) -> Self::TapeBuffers;
fn num_outputs(&self, tape: &Self::TapeBuffers) -> u32;
fn forward_batch(
&self,
tape: &Self::TapeBuffers,
inputs: &[f32],
batch_size: u32,
) -> Result<Vec<f32>, GpuError>;
fn gradient_batch(
&self,
tape: &Self::TapeBuffers,
inputs: &[f32],
batch_size: u32,
) -> Result<(Vec<f32>, Vec<f32>), GpuError>;
fn sparse_jacobian(
&self,
tape: &Self::TapeBuffers,
tape_cpu: &mut BytecodeTape<f32>,
x: &[f32],
) -> Result<(Vec<f32>, crate::sparse::JacobianSparsityPattern, Vec<f32>), GpuError>;
fn hvp_batch(
&self,
tape: &Self::TapeBuffers,
x: &[f32],
tangent_dirs: &[f32],
batch_size: u32,
) -> Result<(Vec<f32>, Vec<f32>), GpuError>;
fn sparse_hessian(
&self,
tape: &Self::TapeBuffers,
tape_cpu: &mut BytecodeTape<f32>,
x: &[f32],
) -> Result<(f32, Vec<f32>, crate::sparse::SparsityPattern, Vec<f32>), GpuError>;
#[cfg(feature = "stde")]
fn taylor_forward_2nd_batch(
&self,
tape: &Self::TapeBuffers,
primal_inputs: &[f32],
direction_seeds: &[f32],
batch_size: u32,
) -> Result<TaylorBatchResult<f32>, GpuError> {
let kth =
self.taylor_forward_kth_batch(tape, primal_inputs, direction_seeds, batch_size, 3)?;
let mut coeffs = kth.coefficients.into_iter();
Ok(TaylorBatchResult {
values: coeffs.next().unwrap(),
c1s: coeffs.next().unwrap(),
c2s: coeffs.next().unwrap(),
})
}
#[cfg(feature = "stde")]
fn taylor_forward_kth_batch(
&self,
tape: &Self::TapeBuffers,
primal_inputs: &[f32],
direction_seeds: &[f32],
batch_size: u32,
order: usize,
) -> Result<TaylorKthBatchResult<f32>, GpuError>;
}
pub struct TaylorBatchResult<F> {
pub values: Vec<F>,
pub c1s: Vec<F>,
pub c2s: Vec<F>,
}
#[cfg(feature = "stde")]
pub struct TaylorKthBatchResult<F> {
pub coefficients: Vec<Vec<F>>,
pub order: usize,
}
#[derive(Debug)]
pub enum GpuError {
NoDevice,
ShaderCompilation(String),
OutOfMemory,
CustomOpsNotSupported,
Other(String),
}
impl std::fmt::Display for GpuError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GpuError::NoDevice => write!(f, "no suitable GPU device found"),
GpuError::ShaderCompilation(msg) => write!(f, "shader compilation failed: {msg}"),
GpuError::OutOfMemory => write!(f, "GPU out of memory"),
GpuError::CustomOpsNotSupported => {
write!(f, "tape contains custom ops which cannot run on GPU")
}
GpuError::Other(msg) => write!(f, "GPU error: {msg}"),
}
}
}
impl std::error::Error for GpuError {}
crate::assert_send_sync!(GpuError);
pub struct GpuTapeData {
pub opcodes: Vec<u32>,
pub arg0: Vec<u32>,
pub arg1: Vec<u32>,
pub constants: Vec<f32>,
pub num_ops: u32,
pub num_inputs: u32,
pub num_variables: u32,
pub output_index: u32,
pub output_indices: Vec<u32>,
}
impl GpuTapeData {
fn build_from_tape<F: crate::float::Float>(
tape: &BytecodeTape<F>,
constants: Vec<f32>,
) -> Self {
let opcodes_raw = tape.opcodes_slice();
let args = tape.arg_indices_slice();
let n = opcodes_raw.len();
GpuTapeData {
opcodes: opcodes_raw.iter().map(|op| *op as u32).collect(),
arg0: args.iter().map(|a| a[0]).collect(),
arg1: args.iter().map(|a| a[1]).collect(),
constants,
num_ops: n as u32,
num_inputs: tape.num_inputs() as u32,
num_variables: tape.num_variables_count() as u32,
output_index: tape.output_index() as u32,
output_indices: tape.all_output_indices().to_vec(),
}
}
pub fn from_tape(tape: &BytecodeTape<f32>) -> Result<Self, GpuError> {
if tape.has_custom_ops() {
return Err(GpuError::CustomOpsNotSupported);
}
Ok(Self::build_from_tape(tape, tape.values_slice().to_vec()))
}
pub fn from_tape_f64_lossy(tape: &BytecodeTape<f64>) -> Result<Self, GpuError> {
if tape.has_custom_ops() {
return Err(GpuError::CustomOpsNotSupported);
}
let constants = tape.values_slice().iter().map(|&v| v as f32).collect();
Ok(Self::build_from_tape(tape, constants))
}
}
#[cfg(feature = "gpu-wgpu")]
#[repr(C)]
#[derive(Clone, Copy, Debug, bytemuck::Pod, bytemuck::Zeroable)]
pub struct TapeMeta {
pub num_ops: u32,
pub num_inputs: u32,
pub num_variables: u32,
pub num_outputs: u32,
pub batch_size: u32,
pub _pad: [u32; 3],
}
#[inline]
#[must_use]
pub fn opcode_to_gpu(op: OpCode) -> u32 {
op as u32
}
#[cfg(feature = "stde")]
pub const WGPU_MAX_BUFFER_BYTES: u64 = 128 * 1024 * 1024;
#[cfg(feature = "stde")]
const MAX_WORKGROUPS_PER_DIM: u64 = 65535;
#[cfg(feature = "stde")]
const TAYLOR_WORKGROUP_SIZE: u64 = 256;
#[cfg(feature = "stde")]
#[allow(clippy::too_many_arguments)]
pub fn taylor_forward_2nd_batch_chunked<B: GpuBackend>(
backend: &B,
tape: &B::TapeBuffers,
primal_inputs: &[f32],
direction_seeds: &[f32],
batch_size: u32,
num_inputs: u32,
num_variables: u32,
max_buffer_bytes: u64,
) -> Result<TaylorBatchResult<f32>, GpuError> {
if batch_size == 0 {
return Ok(TaylorBatchResult {
values: vec![],
c1s: vec![],
c2s: vec![],
});
}
let bytes_per_element = (num_variables as u64) * 3 * 4;
if bytes_per_element == 0 {
return Err(GpuError::Other("num_variables is zero".into()));
}
let mut chunk_size = max_buffer_bytes / bytes_per_element;
if chunk_size == 0 {
return Err(GpuError::Other(format!(
"max_buffer_bytes ({max_buffer_bytes}) too small for a single element \
({bytes_per_element} bytes per element)"
)));
}
let dispatch_limit = MAX_WORKGROUPS_PER_DIM * TAYLOR_WORKGROUP_SIZE;
chunk_size = chunk_size.min(dispatch_limit);
let nv_k = (num_variables as u64) * 3;
if let Some(cap) = (u32::MAX as u64).checked_div(nv_k) {
chunk_size = chunk_size.min(cap);
}
let chunk_size = chunk_size as u32;
if batch_size <= chunk_size {
return backend.taylor_forward_2nd_batch(tape, primal_inputs, direction_seeds, batch_size);
}
let ni = num_inputs as usize;
let mut all_values = Vec::new();
let mut all_c1s = Vec::new();
let mut all_c2s = Vec::new();
let mut offset = 0u32;
while offset < batch_size {
let this_chunk = chunk_size.min(batch_size - offset);
let start = (offset as usize) * ni;
let end = start + (this_chunk as usize) * ni;
let chunk_result = backend.taylor_forward_2nd_batch(
tape,
&primal_inputs[start..end],
&direction_seeds[start..end],
this_chunk,
)?;
all_values.extend(chunk_result.values);
all_c1s.extend(chunk_result.c1s);
all_c2s.extend(chunk_result.c2s);
offset += this_chunk;
}
Ok(TaylorBatchResult {
values: all_values,
c1s: all_c1s,
c2s: all_c2s,
})
}
#[cfg(all(test, feature = "stde"))]
mod tests {
use super::*;
fn compute_chunk_size(num_variables: u32, max_buffer_bytes: u64) -> Option<u32> {
let bytes_per_element = (num_variables as u64) * 3 * 4;
if bytes_per_element == 0 {
return None;
}
let mut chunk_size = max_buffer_bytes / bytes_per_element;
if chunk_size == 0 {
return None;
}
let dispatch_limit = MAX_WORKGROUPS_PER_DIM * TAYLOR_WORKGROUP_SIZE;
chunk_size = chunk_size.min(dispatch_limit);
let nv_k = (num_variables as u64) * 3;
if nv_k > 0 {
chunk_size = chunk_size.min(u32::MAX as u64 / nv_k);
}
Some(chunk_size as u32)
}
#[test]
fn chunking_caps_for_large_num_variables() {
let chunk = compute_chunk_size(500_000, u64::MAX).unwrap();
let product = chunk as u64 * 500_000 * 3;
assert!(
product <= u32::MAX as u64,
"chunk_size * nv * K = {} exceeds u32::MAX",
product
);
}
#[test]
fn chunking_caps_for_very_large_num_variables() {
let chunk = compute_chunk_size(1_000_000, u64::MAX).unwrap();
let product = chunk as u64 * 1_000_000 * 3;
assert!(
product <= u32::MAX as u64,
"chunk_size * nv * K = {} exceeds u32::MAX",
product
);
}
#[test]
fn chunking_with_small_buffer() {
let result = compute_chunk_size(1000, 1);
assert!(result.is_none(), "should fail with buffer too small");
}
#[test]
fn chunking_single_variable() {
let chunk = compute_chunk_size(1, WGPU_MAX_BUFFER_BYTES).unwrap();
assert!(chunk > 0, "should handle single variable");
let product = chunk as u64 * 1 * 3;
assert!(product <= u32::MAX as u64);
}
#[test]
fn chunking_zero_variables() {
let result = compute_chunk_size(0, WGPU_MAX_BUFFER_BYTES);
assert!(result.is_none(), "should fail with zero variables");
}
}