use std::sync::Arc;
use oxicuda_blas::GpuFloat;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_ptx::prelude::*;
use crate::error::{DnnError, DnnResult};
use crate::handle::DnnHandle;
use crate::ptx_helpers::*;
use crate::tensor_util::{nchw_dims, nchw_dims_mut};
use crate::types::{TensorDesc, TensorDescMut, pool_output_size};
const AVG_POOL_BLOCK: u32 = 256;
pub fn avg_pool2d<T: GpuFloat>(
handle: &DnnHandle,
input: &TensorDesc<T>,
output: &mut TensorDescMut<T>,
kernel_size: (u32, u32),
stride: (u32, u32),
padding: (u32, u32),
count_include_pad: bool,
) -> DnnResult<()> {
let (in_n, in_c, in_h, in_w) = nchw_dims(input)?;
let (out_n, out_c, out_h, out_w) = nchw_dims_mut(output)?;
let expected_oh =
pool_output_size(in_h, kernel_size.0, stride.0, padding.0).ok_or_else(|| {
DnnError::InvalidDimension(format!(
"avg_pool2d: invalid output height for h={in_h}, kh={}, sh={}, ph={}",
kernel_size.0, stride.0, padding.0
))
})?;
let expected_ow =
pool_output_size(in_w, kernel_size.1, stride.1, padding.1).ok_or_else(|| {
DnnError::InvalidDimension(format!(
"avg_pool2d: invalid output width for w={in_w}, kw={}, sw={}, pw={}",
kernel_size.1, stride.1, padding.1
))
})?;
if out_n != in_n || out_c != in_c || out_h != expected_oh || out_w != expected_ow {
return Err(DnnError::InvalidDimension(format!(
"avg_pool2d: output ({out_n},{out_c},{out_h},{out_w}) != expected ({in_n},{in_c},{expected_oh},{expected_ow})"
)));
}
let total_output = output.numel() as u32;
if total_output == 0 {
return Ok(());
}
let ptx = generate_avg_pool2d_ptx::<T>(handle.sm_version(), count_include_pad)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let name = avg_pool2d_kernel_name::<T>(count_include_pad);
let kernel = Kernel::from_module(module, &name)?;
let grid = grid_size_for(total_output, AVG_POOL_BLOCK);
let params = LaunchParams::new(grid, AVG_POOL_BLOCK);
let args = (
input.ptr,
output.ptr,
in_n,
in_c,
in_h,
in_w,
out_h,
out_w,
kernel_size.0,
kernel_size.1,
stride.0,
stride.1,
padding.0,
padding.1,
total_output,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| DnnError::LaunchFailed(format!("avg_pool2d: {e}")))?;
Ok(())
}
fn avg_pool2d_kernel_name<T: GpuFloat>(count_include_pad: bool) -> String {
let suffix = if count_include_pad { "cip" } else { "nocip" };
format!("dnn_avg_pool2d_{suffix}_{}", T::NAME)
}
fn generate_avg_pool2d_ptx<T: GpuFloat>(
sm: SmVersion,
count_include_pad: bool,
) -> DnnResult<String> {
let name = avg_pool2d_kernel_name::<T>(count_include_pad);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(AVG_POOL_BLOCK)
.param("in_ptr", PtxType::U64)
.param("out_ptr", PtxType::U64)
.param("batch", PtxType::U32)
.param("channels", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("out_h", PtxType::U32)
.param("out_w", PtxType::U32)
.param("kh", PtxType::U32)
.param("kw", PtxType::U32)
.param("sh", PtxType::U32)
.param("sw", PtxType::U32)
.param("ph", PtxType::U32)
.param("pw", PtxType::U32)
.param("total", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let total = b.load_param_u32("total");
b.if_lt_u32(gid.clone(), total, move |b| {
let out_w = b.load_param_u32("out_w");
let out_h = b.load_param_u32("out_h");
let channels = b.load_param_u32("channels");
let in_h = b.load_param_u32("in_h");
let in_w = b.load_param_u32("in_w");
let ow_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {ow_idx}, {gid}, {out_w};"));
let tmp1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {tmp1}, {gid}, {out_w};"));
let oh_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {oh_idx}, {tmp1}, {out_h};"));
let tmp2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {tmp2}, {tmp1}, {out_h};"));
let c_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {c_idx}, {tmp2}, {channels};"));
let n_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {n_idx}, {tmp2}, {channels};"));
let sh = b.load_param_u32("sh");
let sw = b.load_param_u32("sw");
let ph = b.load_param_u32("ph");
let pw = b.load_param_u32("pw");
let kh = b.load_param_u32("kh");
let kw = b.load_param_u32("kw");
let h_start_raw = b.mul_lo_u32(oh_idx, sh);
let w_start_raw = b.mul_lo_u32(ow_idx, sw);
let h_start = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("sub.s32 {h_start}, {h_start_raw}, {ph};"));
let w_start = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("sub.s32 {w_start}, {w_start_raw}, {pw};"));
let in_hw_val = b.mul_lo_u32(in_h.clone(), in_w.clone());
let c_hw = b.mul_lo_u32(c_idx, in_hw_val.clone());
let chw = b.mul_lo_u32(channels, in_hw_val);
let n_offset = b.mul_lo_u32(n_idx, chw);
let base_offset = b.add_u32(n_offset, c_hw);
let in_ptr = b.load_param_u64("in_ptr");
let sum = load_float_imm::<T>(b, 0.0);
let count = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {count}, 0;"));
let h_end = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("add.s32 {h_end}, {h_start}, {kh};"));
let w_end = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("add.s32 {w_end}, {w_start}, {kw};"));
let in_h_s32 = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("mov.s32 {in_h_s32}, {in_h};"));
let in_w_s32 = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("mov.s32 {in_w_s32}, {in_w};"));
let ih = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("mov.s32 {ih}, {h_start};"));
let loop_h = b.fresh_label("avg_h");
let end_h = b.fresh_label("avg_h_end");
b.label(&loop_h);
let ph_cmp = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.s32 {ph_cmp}, {ih}, {h_end};"));
b.branch_if(ph_cmp, &end_h);
let iw = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("mov.s32 {iw}, {w_start};"));
let loop_w = b.fresh_label("avg_w");
let end_w = b.fresh_label("avg_w_end");
b.label(&loop_w);
let pw_cmp = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.s32 {pw_cmp}, {iw}, {w_end};"));
b.branch_if(pw_cmp, &end_w);
if count_include_pad {
b.raw_ptx(&format!("add.u32 {count}, {count}, 1;"));
}
let h_ok = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.and.s32 {h_ok}, {ih}, 0, {{true}};"));
let h_ok2 = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!(
"setp.lt.and.s32 {h_ok2}, {ih}, {in_h_s32}, {h_ok};"
));
let w_ok = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.and.s32 {w_ok}, {iw}, 0, {{true}};"));
let w_ok2 = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!(
"setp.lt.and.s32 {w_ok2}, {iw}, {in_w_s32}, {w_ok};"
));
let hw_ok = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("and.pred {hw_ok}, {h_ok2}, {w_ok2};"));
let skip = b.fresh_label("avg_skip");
let inv = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("not.pred {inv}, {hw_ok};"));
b.branch_if(inv, &skip);
let ih_u32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.b32 {ih_u32}, {ih};"));
let iw_u32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.b32 {iw_u32}, {iw};"));
let row_off = b.mul_lo_u32(ih_u32, in_w.clone());
let hw_off = b.add_u32(row_off, iw_u32);
let elem_idx = b.add_u32(base_offset.clone(), hw_off);
let addr = b.byte_offset_addr(in_ptr.clone(), elem_idx, T::size_u32());
let val = load_global_float::<T>(b, addr);
let new_sum = add_float::<T>(b, sum.clone(), val);
b.raw_ptx(&format!(
"mov.{ptx} {sum}, {new_sum};",
ptx = crate::ptx_helpers::ptx_type_name::<T>()
));
if !count_include_pad {
b.raw_ptx(&format!("add.u32 {count}, {count}, 1;"));
}
b.label(&skip);
b.raw_ptx(&format!("add.s32 {iw}, {iw}, 1;"));
b.branch(&loop_w);
b.label(&end_w);
b.raw_ptx(&format!("add.s32 {ih}, {ih}, 1;"));
b.branch(&loop_h);
b.label(&end_h);
let count_f = cvt_u32_to_float::<T>(b, count);
let result = div_float::<T>(b, sum, count_f);
let out_ptr = b.load_param_u64("out_ptr");
let out_addr = b.byte_offset_addr(out_ptr, gid, T::size_u32());
store_global_float::<T>(b, out_addr, result);
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(format!("avg_pool2d: {e}")))?;
Ok(ptx)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn avg_pool2d_ptx_f32_cip() {
let ptx = generate_avg_pool2d_ptx::<f32>(SmVersion::Sm80, true);
assert!(ptx.is_ok());
let s = ptx.expect("should gen");
assert!(s.contains("dnn_avg_pool2d_cip_f32"));
}
#[test]
fn avg_pool2d_ptx_f32_nocip() {
let ptx = generate_avg_pool2d_ptx::<f32>(SmVersion::Sm80, false);
assert!(ptx.is_ok());
}
#[test]
fn avg_pool2d_ptx_f64() {
let ptx = generate_avg_pool2d_ptx::<f64>(SmVersion::Sm80, true);
assert!(ptx.is_ok());
}
}