use std::sync::{Arc, Mutex};
use cudarc::driver::{
CudaContext as CudarContext, CudaFunction, CudaSlice, CudaStream, LaunchConfig, PushKernelArg,
};
use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions};
use super::{GpuBackend, GpuError, GpuTapeData};
const KERNEL_SRC: &str = include_str!("kernels/tape_eval.cu");
const BLOCK_SIZE: u32 = 256;
macro_rules! cuda_forward_batch_body {
($self:expr, $tape:expr, $inputs:expr, $batch_size:expr, $F:ty, $constants:ident, $kernel:ident) => {{
let s = &$self.stream;
let ni = $tape.num_inputs;
let nv = $tape.num_variables;
let no = $tape.num_outputs;
assert_eq!($inputs.len(), ($batch_size as usize) * (ni as usize));
let d_inputs = s.clone_htod($inputs).map_err(cuda_err)?;
let mut d_values = s
.alloc_zeros::<$F>(($batch_size as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_outputs = s
.alloc_zeros::<$F>(($batch_size as usize) * (no as usize))
.map_err(cuda_err)?;
let cfg = LaunchConfig {
grid_dim: CudaContext::grid_dim($batch_size),
block_dim: CudaContext::block_dim(),
shared_mem_bytes: 0,
};
let mut builder = s.launch_builder(&$self.$kernel);
builder.arg(&$tape.opcodes);
builder.arg(&$tape.arg0);
builder.arg(&$tape.arg1);
builder.arg(&$tape.$constants);
builder.arg(&d_inputs);
builder.arg(&mut d_values);
builder.arg(&mut d_outputs);
builder.arg(&$tape.output_indices);
builder.arg(&$tape.num_ops);
builder.arg(&ni);
builder.arg(&nv);
builder.arg(&no);
builder.arg(&$batch_size);
unsafe { builder.launch(cfg) }.map_err(cuda_err)?;
s.synchronize().map_err(cuda_err)?;
let results = s.clone_dtoh(&d_outputs).map_err(cuda_err)?;
Ok(results)
}};
}
macro_rules! cuda_gradient_batch_body {
($self:expr, $tape:expr, $inputs:expr, $batch_size:expr, $F:ty, $constants:ident, $fwd_kernel:ident, $rev_kernel:ident) => {{
let s = &$self.stream;
let ni = $tape.num_inputs;
let nv = $tape.num_variables;
let no = $tape.num_outputs;
assert_eq!($inputs.len(), ($batch_size as usize) * (ni as usize));
let d_inputs = s.clone_htod($inputs).map_err(cuda_err)?;
let mut d_values = s
.alloc_zeros::<$F>(($batch_size as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_outputs = s
.alloc_zeros::<$F>(($batch_size as usize) * (no as usize))
.map_err(cuda_err)?;
let cfg = LaunchConfig {
grid_dim: CudaContext::grid_dim($batch_size),
block_dim: CudaContext::block_dim(),
shared_mem_bytes: 0,
};
let mut builder = s.launch_builder(&$self.$fwd_kernel);
builder.arg(&$tape.opcodes);
builder.arg(&$tape.arg0);
builder.arg(&$tape.arg1);
builder.arg(&$tape.$constants);
builder.arg(&d_inputs);
builder.arg(&mut d_values);
builder.arg(&mut d_outputs);
builder.arg(&$tape.output_indices);
builder.arg(&$tape.num_ops);
builder.arg(&ni);
builder.arg(&nv);
builder.arg(&no);
builder.arg(&$batch_size);
unsafe { builder.launch(cfg) }.map_err(cuda_err)?;
let mut d_adjoints = s
.alloc_zeros::<$F>(($batch_size as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_grads = s
.alloc_zeros::<$F>(($batch_size as usize) * (ni as usize))
.map_err(cuda_err)?;
let mut builder = s.launch_builder(&$self.$rev_kernel);
builder.arg(&$tape.opcodes);
builder.arg(&$tape.arg0);
builder.arg(&$tape.arg1);
builder.arg(&d_values);
builder.arg(&mut d_adjoints);
builder.arg(&mut d_grads);
builder.arg(&$tape.output_indices);
builder.arg(&$tape.num_ops);
builder.arg(&ni);
builder.arg(&nv);
builder.arg(&$batch_size);
unsafe { builder.launch(cfg) }.map_err(cuda_err)?;
s.synchronize().map_err(cuda_err)?;
let output_vals = s.clone_dtoh(&d_outputs).map_err(cuda_err)?;
let grads = s.clone_dtoh(&d_grads).map_err(cuda_err)?;
Ok((output_vals, grads))
}};
}
macro_rules! cuda_hvp_batch_body {
($self:expr, $tape:expr, $x:expr, $tangent_dirs:expr, $batch_size:expr, $F:ty, $constants:ident, $kernel:ident) => {{
let s = &$self.stream;
let ni = $tape.num_inputs;
let nv = $tape.num_variables;
assert_eq!($x.len(), ni as usize);
assert_eq!($tangent_dirs.len(), ($batch_size as usize) * (ni as usize));
let mut primal_inputs = Vec::with_capacity(($batch_size as usize) * (ni as usize));
for _ in 0..$batch_size {
primal_inputs.extend_from_slice($x);
}
let d_primal_in = s.clone_htod(&primal_inputs).map_err(cuda_err)?;
let d_seeds = s.clone_htod($tangent_dirs).map_err(cuda_err)?;
let mut d_primals = s
.alloc_zeros::<$F>(($batch_size as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_tans = s
.alloc_zeros::<$F>(($batch_size as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_adj_re = s
.alloc_zeros::<$F>(($batch_size as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_adj_eps = s
.alloc_zeros::<$F>(($batch_size as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_grads = s
.alloc_zeros::<$F>(($batch_size as usize) * (ni as usize))
.map_err(cuda_err)?;
let mut d_hvps = s
.alloc_zeros::<$F>(($batch_size as usize) * (ni as usize))
.map_err(cuda_err)?;
let cfg = LaunchConfig {
grid_dim: CudaContext::grid_dim($batch_size),
block_dim: CudaContext::block_dim(),
shared_mem_bytes: 0,
};
let mut builder = s.launch_builder(&$self.$kernel);
builder.arg(&$tape.opcodes);
builder.arg(&$tape.arg0);
builder.arg(&$tape.arg1);
builder.arg(&$tape.$constants);
builder.arg(&d_primal_in);
builder.arg(&d_seeds);
builder.arg(&mut d_primals);
builder.arg(&mut d_tans);
builder.arg(&mut d_adj_re);
builder.arg(&mut d_adj_eps);
builder.arg(&mut d_grads);
builder.arg(&mut d_hvps);
builder.arg(&$tape.output_indices);
builder.arg(&$tape.num_ops);
builder.arg(&ni);
builder.arg(&nv);
builder.arg(&$batch_size);
unsafe { builder.launch(cfg) }.map_err(cuda_err)?;
s.synchronize().map_err(cuda_err)?;
let grads = s.clone_dtoh(&d_grads).map_err(cuda_err)?;
let hvps = s.clone_dtoh(&d_hvps).map_err(cuda_err)?;
Ok((grads, hvps))
}};
}
macro_rules! cuda_sparse_jacobian_body {
($self:expr, $tape:expr, $tape_cpu:expr, $x:expr, $F:ty, $constants:ident, $tangent_fwd_kernel:ident) => {{
let ni = $tape.num_inputs as usize;
let no = $tape.num_outputs as usize;
let pattern = $tape_cpu.detect_jacobian_sparsity();
let (colors, num_colors) = crate::sparse::column_coloring(&pattern);
if num_colors == 0 {
$tape_cpu.forward($x);
let vals = $tape_cpu.output_values();
return Ok((vals, pattern, vec![]));
}
let batch = num_colors as u32;
let mut seeds = Vec::with_capacity(batch as usize * ni);
for c in 0..num_colors {
for i in 0..ni {
seeds.push(if colors[i] == c as u32 {
(1.0 as $F)
} else {
(0.0 as $F)
});
}
}
let s = &$self.stream;
let nv = $tape.num_variables;
let d_primals_in = {
let mut replicated = Vec::with_capacity(batch as usize * ni);
for _ in 0..batch {
replicated.extend_from_slice($x);
}
s.clone_htod(&replicated).map_err(cuda_err)?
};
let d_seeds = s.clone_htod(&seeds).map_err(cuda_err)?;
let mut d_primals = s
.alloc_zeros::<$F>((batch as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_tangents = s
.alloc_zeros::<$F>((batch as usize) * (nv as usize))
.map_err(cuda_err)?;
let mut d_tangent_out = s
.alloc_zeros::<$F>((batch as usize) * ($tape.num_outputs as usize))
.map_err(cuda_err)?;
let cfg = LaunchConfig {
grid_dim: CudaContext::grid_dim(batch),
block_dim: CudaContext::block_dim(),
shared_mem_bytes: 0,
};
let mut builder = s.launch_builder(&$self.$tangent_fwd_kernel);
builder.arg(&$tape.opcodes);
builder.arg(&$tape.arg0);
builder.arg(&$tape.arg1);
builder.arg(&$tape.$constants);
builder.arg(&d_primals_in);
builder.arg(&d_seeds);
builder.arg(&mut d_primals);
builder.arg(&mut d_tangents);
builder.arg(&mut d_tangent_out);
builder.arg(&$tape.output_indices);
builder.arg(&$tape.num_ops);
builder.arg(&$tape.num_inputs);
builder.arg(&nv);
builder.arg(&$tape.num_outputs);
builder.arg(&batch);
unsafe { builder.launch(cfg) }.map_err(cuda_err)?;
s.synchronize().map_err(cuda_err)?;
let tangent_outs = s.clone_dtoh(&d_tangent_out).map_err(cuda_err)?;
$tape_cpu.forward($x);
let output_values = $tape_cpu.output_values();
let nnz = pattern.nnz();
let mut jac_values = vec![(0.0 as $F); nnz];
for (k, (&row, &col)) in pattern.rows.iter().zip(pattern.cols.iter()).enumerate() {
let c = colors[col as usize] as usize;
jac_values[k] = tangent_outs[c * no + row as usize];
}
Ok((output_values, pattern, jac_values))
}};
}
macro_rules! cuda_sparse_hessian_body {
($self:expr, $tape:expr, $tape_cpu:expr, $x:expr, $F:ty, $hvp_method:ident) => {{
let ni = $tape.num_inputs as usize;
let pattern = $tape_cpu.detect_sparsity();
let (colors, num_colors) = crate::sparse::greedy_coloring(&pattern);
if num_colors == 0 {
$tape_cpu.forward($x);
let val = $tape_cpu.output_value();
let grad = $tape_cpu.gradient($x);
return Ok((val, grad, pattern, vec![]));
}
let batch = num_colors as u32;
let mut tangent_dirs = Vec::with_capacity(batch as usize * ni);
for c in 0..num_colors {
for i in 0..ni {
tangent_dirs.push(if colors[i] == c as u32 {
(1.0 as $F)
} else {
(0.0 as $F)
});
}
}
let (grads, hvps) = $self.$hvp_method($tape, $x, &tangent_dirs, batch)?;
let gradient: Vec<$F> = grads[..ni].to_vec();
let nnz = pattern.nnz();
let mut hess_values = vec![(0.0 as $F); nnz];
for (k, (&row, &col)) in pattern.rows.iter().zip(pattern.cols.iter()).enumerate() {
let c = colors[col as usize] as usize;
hess_values[k] = hvps[c * ni + row as usize];
}
$tape_cpu.forward($x);
let value = $tape_cpu.output_value();
Ok((value, gradient, pattern, hess_values))
}};
}
fn cuda_err(e: impl std::fmt::Display) -> GpuError {
GpuError::Other(format!("{e}"))
}
pub struct CudaTapeBuffers {
pub(crate) opcodes: CudaSlice<u32>,
pub(crate) arg0: CudaSlice<u32>,
pub(crate) arg1: CudaSlice<u32>,
pub(crate) constants_f32: CudaSlice<f32>,
pub(crate) output_indices: CudaSlice<u32>,
pub(crate) num_ops: u32,
pub(crate) num_inputs: u32,
pub(crate) num_variables: u32,
pub(crate) num_outputs: u32,
}
pub struct CudaTapeBuffersF64 {
pub(crate) opcodes: CudaSlice<u32>,
pub(crate) arg0: CudaSlice<u32>,
pub(crate) arg1: CudaSlice<u32>,
pub(crate) constants_f64: CudaSlice<f64>,
pub(crate) output_indices: CudaSlice<u32>,
pub(crate) num_ops: u32,
pub(crate) num_inputs: u32,
pub(crate) num_variables: u32,
pub(crate) num_outputs: u32,
}
pub struct CudaContext {
ctx: Arc<CudarContext>,
stream: Arc<CudaStream>,
forward_f32: CudaFunction,
reverse_f32: CudaFunction,
tangent_fwd_f32: CudaFunction,
tangent_rev_f32: CudaFunction,
forward_f64: CudaFunction,
reverse_f64: CudaFunction,
tangent_fwd_f64: CudaFunction,
tangent_rev_f64: CudaFunction,
#[cfg(feature = "stde")]
taylor_fwd_kth_f32: Mutex<[Option<CudaFunction>; 5]>,
#[cfg(feature = "stde")]
taylor_fwd_kth_f64: Mutex<[Option<CudaFunction>; 5]>,
}
impl CudaContext {
pub fn new() -> Option<Self> {
let ctx = CudarContext::new(0).ok()?;
let stream = ctx.default_stream();
let nvrtc_opts = || CompileOptions {
fmad: Some(false),
..Default::default()
};
let src_f32 = format!("#define FLOAT_TYPE float\n{}", KERNEL_SRC);
let ptx_f32 = compile_ptx_with_opts(&src_f32, nvrtc_opts()).ok()?;
let module_f32 = ctx.load_module(ptx_f32).ok()?;
let forward_f32 = module_f32.load_function("forward_eval").ok()?;
let reverse_f32 = module_f32.load_function("reverse_sweep").ok()?;
let tangent_fwd_f32 = module_f32.load_function("tangent_forward").ok()?;
let tangent_rev_f32 = module_f32.load_function("tangent_reverse").ok()?;
let src_f64 = format!("#define FLOAT_TYPE double\n{}", KERNEL_SRC);
let ptx_f64 = compile_ptx_with_opts(&src_f64, nvrtc_opts()).ok()?;
let module_f64 = ctx.load_module(ptx_f64).ok()?;
let forward_f64 = module_f64.load_function("forward_eval").ok()?;
let reverse_f64 = module_f64.load_function("reverse_sweep").ok()?;
let tangent_fwd_f64 = module_f64.load_function("tangent_forward").ok()?;
let tangent_rev_f64 = module_f64.load_function("tangent_reverse").ok()?;
Some(CudaContext {
ctx,
stream,
forward_f32,
reverse_f32,
tangent_fwd_f32,
tangent_rev_f32,
forward_f64,
reverse_f64,
tangent_fwd_f64,
tangent_rev_f64,
#[cfg(feature = "stde")]
taylor_fwd_kth_f32: Mutex::new([None, None, None, None, None]),
#[cfg(feature = "stde")]
taylor_fwd_kth_f64: Mutex::new([None, None, None, None, None]),
})
}
fn grid_dim(batch_size: u32) -> (u32, u32, u32) {
((batch_size + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1)
}
fn block_dim() -> (u32, u32, u32) {
(BLOCK_SIZE, 1, 1)
}
pub fn upload_tape_f64(
&self,
tape: &crate::BytecodeTape<f64>,
) -> Result<CudaTapeBuffersF64, GpuError> {
if tape.has_custom_ops() {
return Err(GpuError::CustomOpsNotSupported);
}
let s = &self.stream;
let opcodes: Vec<u32> = tape.opcodes_slice().iter().map(|op| *op as u32).collect();
let args = tape.arg_indices_slice();
let arg0: Vec<u32> = args.iter().map(|a| a[0]).collect();
let arg1: Vec<u32> = args.iter().map(|a| a[1]).collect();
let constants: Vec<f64> = tape.values_slice().to_vec();
let output_indices_raw = tape.all_output_indices().to_vec();
let (output_indices, num_outputs) = if output_indices_raw.is_empty() {
(vec![tape.output_index() as u32], 1u32)
} else {
let n = output_indices_raw.len() as u32;
(output_indices_raw, n)
};
Ok(CudaTapeBuffersF64 {
opcodes: s.clone_htod(&opcodes).map_err(cuda_err)?,
arg0: s.clone_htod(&arg0).map_err(cuda_err)?,
arg1: s.clone_htod(&arg1).map_err(cuda_err)?,
constants_f64: s.clone_htod(&constants).map_err(cuda_err)?,
output_indices: s.clone_htod(&output_indices).map_err(cuda_err)?,
num_ops: tape.opcodes_slice().len() as u32,
num_inputs: tape.num_inputs() as u32,
num_variables: tape.num_variables_count() as u32,
num_outputs,
})
}
}
impl GpuBackend for CudaContext {
type TapeBuffers = CudaTapeBuffers;
fn num_outputs(&self, tape: &CudaTapeBuffers) -> u32 {
tape.num_outputs
}
fn upload_tape(&self, data: &GpuTapeData) -> CudaTapeBuffers {
let s = &self.stream;
let (output_indices_src, num_outputs) = if data.output_indices.is_empty() {
(vec![data.output_index], 1u32)
} else {
(
data.output_indices.clone(),
data.output_indices.len() as u32,
)
};
CudaTapeBuffers {
opcodes: s.clone_htod(&data.opcodes).unwrap(),
arg0: s.clone_htod(&data.arg0).unwrap(),
arg1: s.clone_htod(&data.arg1).unwrap(),
constants_f32: s.clone_htod(&data.constants).unwrap(),
output_indices: s.clone_htod(&output_indices_src).unwrap(),
num_ops: data.num_ops,
num_inputs: data.num_inputs,
num_variables: data.num_variables,
num_outputs,
}
}
fn forward_batch(
&self,
tape: &CudaTapeBuffers,
inputs: &[f32],
batch_size: u32,
) -> Result<Vec<f32>, GpuError> {
cuda_forward_batch_body!(
self,
tape,
inputs,
batch_size,
f32,
constants_f32,
forward_f32
)
}
fn gradient_batch(
&self,
tape: &CudaTapeBuffers,
inputs: &[f32],
batch_size: u32,
) -> Result<(Vec<f32>, Vec<f32>), GpuError> {
cuda_gradient_batch_body!(
self,
tape,
inputs,
batch_size,
f32,
constants_f32,
forward_f32,
reverse_f32
)
}
fn sparse_jacobian(
&self,
tape: &CudaTapeBuffers,
tape_cpu: &mut crate::BytecodeTape<f32>,
x: &[f32],
) -> Result<(Vec<f32>, crate::sparse::JacobianSparsityPattern, Vec<f32>), GpuError> {
cuda_sparse_jacobian_body!(self, tape, tape_cpu, x, f32, constants_f32, tangent_fwd_f32)
}
fn hvp_batch(
&self,
tape: &CudaTapeBuffers,
x: &[f32],
tangent_dirs: &[f32],
batch_size: u32,
) -> Result<(Vec<f32>, Vec<f32>), GpuError> {
cuda_hvp_batch_body!(
self,
tape,
x,
tangent_dirs,
batch_size,
f32,
constants_f32,
tangent_rev_f32
)
}
fn sparse_hessian(
&self,
tape: &CudaTapeBuffers,
tape_cpu: &mut crate::BytecodeTape<f32>,
x: &[f32],
) -> Result<(f32, Vec<f32>, crate::sparse::SparsityPattern, Vec<f32>), GpuError> {
cuda_sparse_hessian_body!(self, tape, tape_cpu, x, f32, hvp_batch)
}
#[cfg(feature = "stde")]
fn taylor_forward_kth_batch(
&self,
tape: &CudaTapeBuffers,
primal_inputs: &[f32],
direction_seeds: &[f32],
batch_size: u32,
order: usize,
) -> Result<super::TaylorKthBatchResult<f32>, GpuError> {
self.taylor_forward_kth_batch(tape, primal_inputs, direction_seeds, batch_size, order)
}
}
impl CudaContext {
#[cfg(feature = "stde")]
#[deprecated(
since = "0.5.0",
note = "import GpuBackend trait and call taylor_forward_2nd_batch() directly"
)]
pub fn taylor_forward_2nd_batch(
&self,
tape: &CudaTapeBuffers,
primal_inputs: &[f32],
direction_seeds: &[f32],
batch_size: u32,
) -> Result<super::TaylorBatchResult<f32>, GpuError> {
<Self as GpuBackend>::taylor_forward_2nd_batch(
self,
tape,
primal_inputs,
direction_seeds,
batch_size,
)
}
#[cfg(feature = "stde")]
pub fn taylor_forward_2nd_batch_f64(
&self,
tape: &CudaTapeBuffersF64,
primal_inputs: &[f64],
direction_seeds: &[f64],
batch_size: u32,
) -> Result<super::TaylorBatchResult<f64>, GpuError> {
let kth =
self.taylor_forward_kth_batch_f64(tape, primal_inputs, direction_seeds, batch_size, 3)?;
let mut coeffs = kth.coefficients.into_iter();
Ok(super::TaylorBatchResult {
values: coeffs.next().unwrap(),
c1s: coeffs.next().unwrap(),
c2s: coeffs.next().unwrap(),
})
}
#[cfg(feature = "stde")]
fn get_taylor_kth_kernel(
&self,
order: usize,
f32_mode: bool,
) -> Result<CudaFunction, GpuError> {
let mutex = if f32_mode {
&self.taylor_fwd_kth_f32
} else {
&self.taylor_fwd_kth_f64
};
{
let kernels = mutex.lock().map_err(|e| GpuError::Other(e.to_string()))?;
if let Some(ref func) = kernels[order - 1] {
return Ok(func.clone());
}
}
let float_type = if f32_mode { "float" } else { "double" };
let src = format!(
"#define FLOAT_TYPE {float_type}\n{}",
super::taylor_codegen::generate_taylor_cuda(order)
);
let opts = CompileOptions {
fmad: Some(false),
..Default::default()
};
let ptx = compile_ptx_with_opts(&src, opts).map_err(cuda_err)?;
let module = self.ctx.load_module(ptx).map_err(cuda_err)?;
let func = module
.load_function("taylor_forward_kth")
.map_err(cuda_err)?;
let mut kernels = mutex.lock().map_err(|e| GpuError::Other(e.to_string()))?;
if kernels[order - 1].is_none() {
kernels[order - 1] = Some(func.clone());
}
Ok(kernels[order - 1].as_ref().unwrap().clone())
}
#[cfg(feature = "stde")]
pub fn taylor_forward_kth_batch(
&self,
tape: &CudaTapeBuffers,
primal_inputs: &[f32],
direction_seeds: &[f32],
batch_size: u32,
order: usize,
) -> Result<super::TaylorKthBatchResult<f32>, GpuError> {
if !(1..=5).contains(&order) {
return Err(GpuError::Other(format!(
"unsupported Taylor order {order}, must be 1..=5"
)));
}
let k = order as u32;
let s = &self.stream;
let ni = tape.num_inputs;
let nv = tape.num_variables;
let no = tape.num_outputs;
let total_in = (batch_size as usize) * (ni as usize);
assert_eq!(
primal_inputs.len(),
total_in,
"primal_inputs length mismatch"
);
assert_eq!(
direction_seeds.len(),
total_in,
"direction_seeds length mismatch"
);
let kernel = self.get_taylor_kth_kernel(order, true)?;
let d_primals = s.clone_htod(primal_inputs).map_err(cuda_err)?;
let d_seeds = s.clone_htod(direction_seeds).map_err(cuda_err)?;
let mut d_jets = s
.alloc_zeros::<f32>((batch_size as usize) * (nv as usize) * (k as usize))
.map_err(cuda_err)?;
let mut d_jet_out = s
.alloc_zeros::<f32>((batch_size as usize) * (no as usize) * (k as usize))
.map_err(cuda_err)?;
let cfg = LaunchConfig {
grid_dim: Self::grid_dim(batch_size),
block_dim: Self::block_dim(),
shared_mem_bytes: 0,
};
let mut builder = s.launch_builder(&kernel);
builder.arg(&tape.opcodes);
builder.arg(&tape.arg0);
builder.arg(&tape.arg1);
builder.arg(&tape.constants_f32);
builder.arg(&d_primals);
builder.arg(&d_seeds);
builder.arg(&mut d_jets);
builder.arg(&mut d_jet_out);
builder.arg(&tape.output_indices);
builder.arg(&tape.num_ops);
builder.arg(&ni);
builder.arg(&nv);
builder.arg(&no);
builder.arg(&batch_size);
unsafe { builder.launch(cfg) }.map_err(cuda_err)?;
s.synchronize().map_err(cuda_err)?;
let raw = s.clone_dtoh(&d_jet_out).map_err(cuda_err)?;
let total_out = (batch_size as usize) * (no as usize);
let mut coefficients: Vec<Vec<f32>> =
(0..order).map(|_| Vec::with_capacity(total_out)).collect();
for i in 0..total_out {
for c in 0..order {
coefficients[c].push(raw[i * order + c]);
}
}
Ok(super::TaylorKthBatchResult {
coefficients,
order,
})
}
#[cfg(feature = "stde")]
pub fn taylor_forward_kth_batch_f64(
&self,
tape: &CudaTapeBuffersF64,
primal_inputs: &[f64],
direction_seeds: &[f64],
batch_size: u32,
order: usize,
) -> Result<super::TaylorKthBatchResult<f64>, GpuError> {
if !(1..=5).contains(&order) {
return Err(GpuError::Other(format!(
"unsupported Taylor order {order}, must be 1..=5"
)));
}
let k = order as u32;
let s = &self.stream;
let ni = tape.num_inputs;
let nv = tape.num_variables;
let no = tape.num_outputs;
let total_in = (batch_size as usize) * (ni as usize);
assert_eq!(
primal_inputs.len(),
total_in,
"primal_inputs length mismatch"
);
assert_eq!(
direction_seeds.len(),
total_in,
"direction_seeds length mismatch"
);
let kernel = self.get_taylor_kth_kernel(order, false)?;
let d_primals = s.clone_htod(primal_inputs).map_err(cuda_err)?;
let d_seeds = s.clone_htod(direction_seeds).map_err(cuda_err)?;
let mut d_jets = s
.alloc_zeros::<f64>((batch_size as usize) * (nv as usize) * (k as usize))
.map_err(cuda_err)?;
let mut d_jet_out = s
.alloc_zeros::<f64>((batch_size as usize) * (no as usize) * (k as usize))
.map_err(cuda_err)?;
let cfg = LaunchConfig {
grid_dim: Self::grid_dim(batch_size),
block_dim: Self::block_dim(),
shared_mem_bytes: 0,
};
let mut builder = s.launch_builder(&kernel);
builder.arg(&tape.opcodes);
builder.arg(&tape.arg0);
builder.arg(&tape.arg1);
builder.arg(&tape.constants_f64);
builder.arg(&d_primals);
builder.arg(&d_seeds);
builder.arg(&mut d_jets);
builder.arg(&mut d_jet_out);
builder.arg(&tape.output_indices);
builder.arg(&tape.num_ops);
builder.arg(&ni);
builder.arg(&nv);
builder.arg(&no);
builder.arg(&batch_size);
unsafe { builder.launch(cfg) }.map_err(cuda_err)?;
s.synchronize().map_err(cuda_err)?;
let raw = s.clone_dtoh(&d_jet_out).map_err(cuda_err)?;
let total_out = (batch_size as usize) * (no as usize);
let mut coefficients: Vec<Vec<f64>> =
(0..order).map(|_| Vec::with_capacity(total_out)).collect();
for i in 0..total_out {
for c in 0..order {
coefficients[c].push(raw[i * order + c]);
}
}
Ok(super::TaylorKthBatchResult {
coefficients,
order,
})
}
pub fn forward_batch_f64(
&self,
tape: &CudaTapeBuffersF64,
inputs: &[f64],
batch_size: u32,
) -> Result<Vec<f64>, GpuError> {
cuda_forward_batch_body!(
self,
tape,
inputs,
batch_size,
f64,
constants_f64,
forward_f64
)
}
pub fn gradient_batch_f64(
&self,
tape: &CudaTapeBuffersF64,
inputs: &[f64],
batch_size: u32,
) -> Result<(Vec<f64>, Vec<f64>), GpuError> {
cuda_gradient_batch_body!(
self,
tape,
inputs,
batch_size,
f64,
constants_f64,
forward_f64,
reverse_f64
)
}
pub fn sparse_jacobian_f64(
&self,
tape: &CudaTapeBuffersF64,
tape_cpu: &mut crate::BytecodeTape<f64>,
x: &[f64],
) -> Result<(Vec<f64>, crate::sparse::JacobianSparsityPattern, Vec<f64>), GpuError> {
cuda_sparse_jacobian_body!(self, tape, tape_cpu, x, f64, constants_f64, tangent_fwd_f64)
}
pub fn sparse_hessian_f64(
&self,
tape: &CudaTapeBuffersF64,
tape_cpu: &mut crate::BytecodeTape<f64>,
x: &[f64],
) -> Result<(f64, Vec<f64>, crate::sparse::SparsityPattern, Vec<f64>), GpuError> {
cuda_sparse_hessian_body!(self, tape, tape_cpu, x, f64, hvp_batch_f64)
}
pub fn hvp_batch_f64(
&self,
tape: &CudaTapeBuffersF64,
x: &[f64],
tangent_dirs: &[f64],
batch_size: u32,
) -> Result<(Vec<f64>, Vec<f64>), GpuError> {
cuda_hvp_batch_body!(
self,
tape,
x,
tangent_dirs,
batch_size,
f64,
constants_f64,
tangent_rev_f64
)
}
}