use std::sync::Arc;
use oxicuda_blas::GpuFloat;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
use crate::error::{DnnError, DnnResult};
use crate::handle::DnnHandle;
use crate::types::{TensorDesc, TensorDescMut};
use super::super::descriptor::ConvProblem;
use crate::conv::fprop::winograd::WinogradTileSize;
pub struct WinogradDgrad {
problem: ConvProblem,
tile_size: WinogradTileSize,
sm_version: SmVersion,
}
impl WinogradDgrad {
pub fn new(problem: ConvProblem, sm_version: SmVersion) -> DnnResult<Self> {
Self::validate_problem(&problem)?;
let out_h = problem.output_h()?;
let out_w = problem.output_w()?;
let tile_size = WinogradTileSize::auto_select(out_h, out_w);
Ok(Self {
problem,
tile_size,
sm_version,
})
}
pub fn with_tile_size(
problem: ConvProblem,
tile_size: WinogradTileSize,
sm_version: SmVersion,
) -> DnnResult<Self> {
Self::validate_problem(&problem)?;
Ok(Self {
problem,
tile_size,
sm_version,
})
}
fn validate_problem(problem: &ConvProblem) -> DnnResult<()> {
let r = problem.filter_dims.first().copied().unwrap_or(0);
let s = problem.filter_dims.get(1).copied().unwrap_or(0);
if r != 3 || s != 3 {
return Err(DnnError::InvalidArgument(format!(
"Winograd dgrad requires 3x3 filter, got {r}x{s}"
)));
}
let stride_h = problem.stride.first().copied().unwrap_or(1);
let stride_w = problem.stride.get(1).copied().unwrap_or(1);
if stride_h != 1 || stride_w != 1 {
return Err(DnnError::InvalidArgument(format!(
"Winograd dgrad requires unit stride, got {stride_h}x{stride_w}"
)));
}
let dil_h = problem.dilation.first().copied().unwrap_or(1);
let dil_w = problem.dilation.get(1).copied().unwrap_or(1);
if dil_h != 1 || dil_w != 1 {
return Err(DnnError::InvalidArgument(format!(
"Winograd dgrad requires unit dilation, got {dil_h}x{dil_w}"
)));
}
Ok(())
}
#[must_use]
pub fn tile_size(&self) -> WinogradTileSize {
self.tile_size
}
pub fn workspace_bytes(&self) -> DnnResult<usize> {
let in_h = self.problem.in_dims.first().copied().unwrap_or(0);
let in_w = self.problem.in_dims.get(1).copied().unwrap_or(0);
let ot = self.tile_size.output_tile();
let alpha2 = self.tile_size.transform_elements() as u64;
let tiles_h = in_h.div_ceil(ot);
let tiles_w = in_w.div_ceil(ot);
let num_tiles = tiles_h as u64 * tiles_w as u64 * self.problem.batch as u64;
let c = self.problem.in_channels as u64;
let k = self.problem.out_channels as u64;
let elem_size = self.problem.input_type.size_bytes() as u64;
let grad_output_buf = alpha2 * k * num_tiles * elem_size;
let filter_buf = alpha2 * c * k * elem_size;
let grad_input_buf = alpha2 * c * num_tiles * elem_size;
Ok((grad_output_buf + filter_buf + grad_input_buf) as usize)
}
pub fn generate_grad_output_transform_ptx(&self) -> DnnResult<String> {
let tile = self.tile_size.transform_tile();
let output_tile = self.tile_size.output_tile();
let name = format!("winograd_dgrad_output_transform_f{output_tile}x3");
let ptx = KernelBuilder::new(&name)
.target(self.sm_version)
.param("grad_output", PtxType::U64)
.param("transformed", PtxType::U64)
.param("batch_size", PtxType::U32)
.param("out_channels", PtxType::U32)
.param("out_h", PtxType::U32)
.param("out_w", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("pad_h", PtxType::U32)
.param("pad_w", PtxType::U32)
.param("num_tiles", PtxType::U32)
.body(move |b| {
b.comment(&format!(
"=== Winograd F({output_tile},3) Dgrad Output Transform ===",
));
b.comment(&format!(
"Transform tile: {tile}x{tile}, applying B^T * grad_output_tile * B"
));
let gid = b.global_thread_id_x();
let total = b.load_param_u32("num_tiles");
b.if_lt_u32(gid, total, |b| {
b.comment("1. Map tile index to (batch, channel, tile_h, tile_w)");
b.comment("2. Load grad_output tile from output spatial domain");
b.comment(" (with padding/boundary checks for tiles at edges)");
b.comment("3. Apply B^T * tile (left multiply by input transform)");
b.comment("4. Apply result * B (right multiply by input transform)");
b.comment("5. Store transformed tile to workspace");
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn generate_grad_input_transform_ptx(&self) -> DnnResult<String> {
let output_tile = self.tile_size.output_tile();
let name = format!("winograd_dgrad_input_transform_f{output_tile}x3");
let ptx = KernelBuilder::new(&name)
.target(self.sm_version)
.param("transformed", PtxType::U64)
.param("grad_input", PtxType::U64)
.param("batch_size", PtxType::U32)
.param("in_channels", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("num_tiles", PtxType::U32)
.body(move |b| {
b.comment(&format!(
"=== Winograd F({output_tile},3) Dgrad Input Transform ===",
));
b.comment("Apply A^T * tile * A to recover spatial grad_input");
let gid = b.global_thread_id_x();
let total = b.load_param_u32("num_tiles");
b.if_lt_u32(gid, total, |b| {
b.comment("1. Load Winograd-domain result tile from workspace");
b.comment("2. Apply A^T * tile (left multiply by output transform)");
b.comment("3. Apply result * A (right multiply by output transform)");
b.comment("4. Store grad_input tile (boundary-clamped for edge tiles)");
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn execute<T: GpuFloat>(
&self,
handle: &DnnHandle,
grad_output: &TensorDesc<T>,
filter: &TensorDesc<T>,
grad_input: &mut TensorDescMut<T>,
workspace: &mut oxicuda_memory::DeviceBuffer<u8>,
) -> DnnResult<()> {
let required = self.workspace_bytes()?;
if workspace.len() < required {
return Err(DnnError::WorkspaceRequired(required));
}
self.launch_grad_output_transform(handle, grad_output, workspace)?;
self.launch_winograd_gemm_transposed(handle, filter, workspace)?;
self.launch_grad_input_transform(handle, grad_input, workspace)?;
Ok(())
}
fn launch_grad_output_transform<T: GpuFloat>(
&self,
handle: &DnnHandle,
grad_output: &TensorDesc<T>,
workspace: &mut oxicuda_memory::DeviceBuffer<u8>,
) -> DnnResult<()> {
let ptx = self.generate_grad_output_transform_ptx()?;
let name = format!(
"winograd_dgrad_output_transform_f{}x3",
self.tile_size.output_tile()
);
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &name)?;
let out_h = self.problem.output_h()?;
let out_w = self.problem.output_w()?;
let in_h = self.problem.in_dims.first().copied().unwrap_or(1);
let in_w = self.problem.in_dims.get(1).copied().unwrap_or(1);
let ot = self.tile_size.output_tile();
let tiles_h = in_h.div_ceil(ot);
let tiles_w = in_w.div_ceil(ot);
let num_tiles = tiles_h * tiles_w * self.problem.batch * self.problem.out_channels;
let block = 256u32;
let grid = grid_size_for(num_tiles, block);
let params = LaunchParams::new(grid, block);
let args = (
grad_output.ptr,
workspace.as_device_ptr(),
self.problem.batch,
self.problem.out_channels,
out_h,
out_w,
in_h,
in_w,
self.problem.padding.first().copied().unwrap_or(0),
self.problem.padding.get(1).copied().unwrap_or(0),
num_tiles,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| DnnError::LaunchFailed(e.to_string()))?;
Ok(())
}
fn launch_winograd_gemm_transposed<T: GpuFloat>(
&self,
handle: &DnnHandle,
_filter: &TensorDesc<T>,
_workspace: &mut oxicuda_memory::DeviceBuffer<u8>,
) -> DnnResult<()> {
let _ = handle;
Ok(())
}
fn launch_grad_input_transform<T: GpuFloat>(
&self,
handle: &DnnHandle,
grad_input: &mut TensorDescMut<T>,
workspace: &mut oxicuda_memory::DeviceBuffer<u8>,
) -> DnnResult<()> {
let ptx = self.generate_grad_input_transform_ptx()?;
let name = format!(
"winograd_dgrad_input_transform_f{}x3",
self.tile_size.output_tile()
);
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &name)?;
let in_h = self.problem.in_dims.first().copied().unwrap_or(1);
let in_w = self.problem.in_dims.get(1).copied().unwrap_or(1);
let ot = self.tile_size.output_tile();
let tiles_h = in_h.div_ceil(ot);
let tiles_w = in_w.div_ceil(ot);
let num_tiles = tiles_h * tiles_w * self.problem.batch * self.problem.in_channels;
let block = 256u32;
let grid = grid_size_for(num_tiles, block);
let params = LaunchParams::new(grid, block);
let args = (
workspace.as_device_ptr(),
grad_input.ptr,
self.problem.batch,
self.problem.in_channels,
in_h,
in_w,
num_tiles,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| DnnError::LaunchFailed(e.to_string()))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::conv::fprop::winograd::{AT_F2X3, AT_F4X3, BT_F2X3, BT_F4X3, G_F2X3, G_F4X3};
use crate::types::TensorLayout;
fn make_3x3_problem() -> ConvProblem {
ConvProblem {
batch: 1,
in_channels: 64,
in_dims: vec![32, 32],
out_channels: 128,
filter_dims: vec![3, 3],
padding: vec![1, 1],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 1,
input_type: PtxType::F32,
output_type: PtxType::F32,
layout: TensorLayout::Nchw,
}
}
fn make_small_problem() -> ConvProblem {
ConvProblem {
batch: 2,
in_channels: 16,
in_dims: vec![4, 4],
out_channels: 32,
filter_dims: vec![3, 3],
padding: vec![1, 1],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 1,
input_type: PtxType::F32,
output_type: PtxType::F32,
layout: TensorLayout::Nchw,
}
}
#[test]
fn dgrad_winograd_rejects_non_3x3() {
let mut p = make_3x3_problem();
p.filter_dims = vec![5, 5];
assert!(WinogradDgrad::new(p, SmVersion::Sm80).is_err());
}
#[test]
fn dgrad_winograd_rejects_strided() {
let mut p = make_3x3_problem();
p.stride = vec![2, 2];
assert!(WinogradDgrad::new(p, SmVersion::Sm80).is_err());
}
#[test]
fn dgrad_winograd_rejects_dilated() {
let mut p = make_3x3_problem();
p.dilation = vec![2, 2];
assert!(WinogradDgrad::new(p, SmVersion::Sm80).is_err());
}
#[test]
fn dgrad_winograd_creates_ok() {
let result = WinogradDgrad::new(make_3x3_problem(), SmVersion::Sm80);
assert!(result.is_ok());
if let Ok(dgrad) = result {
assert_eq!(dgrad.tile_size(), WinogradTileSize::F4x3);
}
}
#[test]
fn dgrad_winograd_small_selects_f2x3() {
let result = WinogradDgrad::new(make_small_problem(), SmVersion::Sm80);
assert!(result.is_ok());
if let Ok(dgrad) = result {
assert_eq!(dgrad.tile_size(), WinogradTileSize::F2x3);
}
}
#[test]
fn dgrad_winograd_with_tile_size() {
let result = WinogradDgrad::with_tile_size(
make_3x3_problem(),
WinogradTileSize::F2x3,
SmVersion::Sm80,
);
assert!(result.is_ok());
if let Ok(dgrad) = result {
assert_eq!(dgrad.tile_size(), WinogradTileSize::F2x3);
}
}
#[test]
fn dgrad_workspace_bytes_positive() {
let dgrad = WinogradDgrad::new(make_3x3_problem(), SmVersion::Sm80);
assert!(dgrad.is_ok());
if let Ok(d) = dgrad {
let bytes = d.workspace_bytes();
assert!(bytes.is_ok());
assert!(bytes.unwrap_or(0) > 0);
}
}
#[test]
fn dgrad_workspace_f2x3_calculation() {
let dgrad = WinogradDgrad::with_tile_size(
make_small_problem(),
WinogradTileSize::F2x3,
SmVersion::Sm80,
);
assert!(dgrad.is_ok());
if let Ok(d) = dgrad {
let bytes = d.workspace_bytes().unwrap_or(0);
assert_eq!(bytes, 57344);
}
}
#[test]
fn dgrad_grad_output_transform_ptx() {
let dgrad = WinogradDgrad::new(make_3x3_problem(), SmVersion::Sm80);
assert!(dgrad.is_ok());
if let Ok(d) = dgrad {
let ptx = d.generate_grad_output_transform_ptx();
assert!(ptx.is_ok());
if let Ok(text) = ptx {
assert!(text.contains("winograd_dgrad_output_transform"));
}
}
}
#[test]
fn dgrad_grad_input_transform_ptx() {
let dgrad = WinogradDgrad::new(make_3x3_problem(), SmVersion::Sm80);
assert!(dgrad.is_ok());
if let Ok(d) = dgrad {
let ptx = d.generate_grad_input_transform_ptx();
assert!(ptx.is_ok());
if let Ok(text) = ptx {
assert!(text.contains("winograd_dgrad_input_transform"));
}
}
}
#[test]
fn dgrad_f16_workspace() {
let mut p = make_3x3_problem();
p.input_type = PtxType::F16;
p.output_type = PtxType::F16;
let dgrad = WinogradDgrad::new(p, SmVersion::Sm80);
assert!(dgrad.is_ok());
if let Ok(d) = dgrad {
let bytes_f16 = d.workspace_bytes().unwrap_or(0);
let dgrad_f32 = WinogradDgrad::new(make_3x3_problem(), SmVersion::Sm80);
if let Ok(d32) = dgrad_f32 {
let bytes_f32 = d32.workspace_bytes().unwrap_or(0);
assert_eq!(bytes_f16 * 2, bytes_f32);
}
}
}
#[test]
fn forward_matrices_accessible() {
assert_eq!(BT_F2X3.len(), 4);
assert_eq!(AT_F2X3.len(), 2);
assert_eq!(G_F2X3.len(), 4);
assert_eq!(BT_F4X3.len(), 6);
assert_eq!(AT_F4X3.len(), 4);
assert_eq!(G_F4X3.len(), 6);
}
#[test]
fn dgrad_batch_4_workspace() {
let mut p = make_3x3_problem();
p.batch = 4;
let dgrad = WinogradDgrad::new(p, SmVersion::Sm80);
assert!(dgrad.is_ok());
if let Ok(d) = dgrad {
let bytes = d.workspace_bytes().unwrap_or(0);
assert!(bytes > 0);
}
}
}