use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ConvKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
#[derive(Copy, Clone, Debug)]
pub struct Im2ColDescriptor {
pub batch: i32,
pub channels: i32,
pub h_in: i32,
pub w_in: i32,
pub kh: i32,
pub kw: i32,
pub stride_h: i32,
pub stride_w: i32,
pub pad_h: i32,
pub pad_w: i32,
pub dilation_h: i32,
pub dilation_w: i32,
pub element: ElementKind,
}
pub struct Im2ColArgs<'a, T: Element> {
pub input: TensorRef<'a, T, 4>,
pub output: TensorMut<'a, T, 3>,
}
pub struct Im2ColPlan<T: Element> {
desc: Im2ColDescriptor,
h_out: i32,
w_out: i32,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> Im2ColPlan<T> {
pub fn select(
_stream: &Stream,
desc: &Im2ColDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::Im2ColPlan: descriptor.element != T::KIND",
));
}
if !matches!(
T::KIND,
ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
) {
return Err(Error::Unsupported(
"baracuda-kernels::Im2ColPlan: dtype must be f32 / f64 / f16 / bf16",
));
}
if desc.batch <= 0 || desc.channels <= 0 || desc.h_in <= 0 || desc.w_in <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::Im2ColPlan: input shape extents must be > 0",
));
}
if desc.kh <= 0 || desc.kw <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::Im2ColPlan: kernel extents must be > 0",
));
}
if desc.stride_h <= 0 || desc.stride_w <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::Im2ColPlan: stride must be > 0",
));
}
if desc.dilation_h <= 0 || desc.dilation_w <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::Im2ColPlan: dilation must be > 0",
));
}
if desc.pad_h < 0 || desc.pad_w < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::Im2ColPlan: padding must be >= 0",
));
}
let (h_out, w_out) = compute_im2col_2d_dims(desc);
if h_out <= 0 || w_out <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::Im2ColPlan: computed output dims <= 0 — \
padding / stride / dilation combination produces an empty output",
));
}
let sku = build_im2col_sku::<T>(ConvKind::Im2Col2d);
Ok(Self {
desc: *desc,
h_out,
w_out,
sku,
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn output_dims(&self) -> (i32, i32) {
(self.h_out, self.w_out)
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: Im2ColArgs<'_, T>,
) -> Result<()> {
self.check_args(&args)?;
let input_ptr = args.input.data.as_raw().0 as *const c_void;
let output_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let d = &self.desc;
let status = match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_im2col_2d_f32_run(
d.batch, d.channels, d.h_in, d.w_in, self.h_out, self.w_out,
d.kh, d.kw, d.stride_h, d.stride_w, d.pad_h, d.pad_w,
d.dilation_h, d.dilation_w,
input_ptr, output_ptr, stream_ptr,
)
},
ElementKind::F64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_im2col_2d_f64_run(
d.batch, d.channels, d.h_in, d.w_in, self.h_out, self.w_out,
d.kh, d.kw, d.stride_h, d.stride_w, d.pad_h, d.pad_w,
d.dilation_h, d.dilation_w,
input_ptr, output_ptr, stream_ptr,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_im2col_2d_f16_run(
d.batch, d.channels, d.h_in, d.w_in, self.h_out, self.w_out,
d.kh, d.kw, d.stride_h, d.stride_w, d.pad_h, d.pad_w,
d.dilation_h, d.dilation_w,
input_ptr, output_ptr, stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_im2col_2d_bf16_run(
d.batch, d.channels, d.h_in, d.w_in, self.h_out, self.w_out,
d.kh, d.kw, d.stride_h, d.stride_w, d.pad_h, d.pad_w,
d.dilation_h, d.dilation_w,
input_ptr, output_ptr, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::Im2ColPlan: unexpected dtype after select()",
));
}
};
map_im2col_status(status)
}
fn check_args(&self, args: &Im2ColArgs<'_, T>) -> Result<()> {
let in_shape = [self.desc.batch, self.desc.channels, self.desc.h_in, self.desc.w_in];
let col_rows = self.desc.channels * self.desc.kh * self.desc.kw;
let spatial = self.h_out * self.w_out;
let out_shape = [self.desc.batch, col_rows, spatial];
if args.input.shape != in_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::Im2ColPlan: input shape != [N, C, H_in, W_in]",
));
}
if args.output.shape != out_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::Im2ColPlan: output shape != [N, C·kh·kw, h_out·w_out]",
));
}
Ok(())
}
}
#[inline]
pub(super) fn compute_im2col_2d_dims(d: &Im2ColDescriptor) -> (i32, i32) {
let h_eff = d.dilation_h * (d.kh - 1) + 1;
let w_eff = d.dilation_w * (d.kw - 1) + 1;
let h_out = (d.h_in + 2 * d.pad_h - h_eff) / d.stride_h + 1;
let w_out = (d.w_in + 2 * d.pad_w - w_eff) / d.stride_w + 1;
(h_out, w_out)
}
pub(super) fn build_im2col_sku<T: Element>(op: ConvKind) -> KernelSku {
let math_precision = match T::KIND {
ElementKind::F64 => MathPrecision::F64,
ElementKind::F16 => MathPrecision::F16,
ElementKind::Bf16 => MathPrecision::Bf16,
_ => MathPrecision::F32,
};
let accumulator = match T::KIND {
ElementKind::F64 => ElementKind::F64,
_ => ElementKind::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator,
bit_stable_on_same_hardware: !matches!(op, ConvKind::Col2Im1d),
deterministic: true,
};
KernelSku {
category: OpCategory::Convolution,
op: op as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
}
}
pub(crate) fn map_im2col_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys::im2col reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys::im2col reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}