use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ConvKind, Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, TensorMut,
TensorRef, Workspace,
};
use super::im2col::{build_im2col_sku, map_im2col_status};
use super::im2col1d::{compute_im2col_1d_l_out, validate_im2col_1d, Im2Col1dDescriptor};
#[derive(Copy, Clone, Debug)]
pub struct Col2Im1dDescriptor {
pub batch: i32,
pub channels: i32,
pub l_in: i32,
pub kl: i32,
pub stride_l: i32,
pub pad_l: i32,
pub dilation_l: i32,
pub element: ElementKind,
}
pub struct Col2Im1dArgs<'a, T: Element> {
pub input: TensorRef<'a, T, 3>,
pub output: TensorMut<'a, T, 3>,
}
pub struct Col2Im1dPlan<T: Element> {
desc: Col2Im1dDescriptor,
l_out: i32,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> Col2Im1dPlan<T> {
pub fn select(
_stream: &Stream,
desc: &Col2Im1dDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
let im2col_desc = Im2Col1dDescriptor {
batch: desc.batch,
channels: desc.channels,
l_in: desc.l_in,
kl: desc.kl,
stride_l: desc.stride_l,
pad_l: desc.pad_l,
dilation_l: desc.dilation_l,
element: desc.element,
};
validate_im2col_1d::<T>(&im2col_desc).map_err(|e| match e {
Error::Unsupported(_) => Error::Unsupported(
"baracuda-kernels::Col2Im1dPlan: dtype/descriptor unsupported",
),
Error::InvalidProblem(_) => Error::InvalidProblem(
"baracuda-kernels::Col2Im1dPlan: invalid problem dimensions",
),
other => other,
})?;
let l_out = compute_im2col_1d_l_out(&im2col_desc).map_err(|_| {
Error::InvalidProblem("baracuda-kernels::Col2Im1dPlan: computed l_out <= 0")
})?;
let sku = build_im2col_sku::<T>(ConvKind::Col2Im1d);
Ok(Self {
desc: *desc,
l_out,
sku,
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn input_l_out(&self) -> i32 {
self.l_out
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: Col2Im1dArgs<'_, T>,
) -> Result<()> {
self.check_args(&args)?;
let input_ptr = args.input.data.as_raw().0 as *const c_void;
let output_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let d = &self.desc;
let status = match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_col2im_1d_f32_run(
d.batch, d.channels, d.l_in, self.l_out,
d.kl, d.stride_l, d.pad_l, d.dilation_l,
input_ptr, output_ptr, stream_ptr,
)
},
ElementKind::F64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_col2im_1d_f64_run(
d.batch, d.channels, d.l_in, self.l_out,
d.kl, d.stride_l, d.pad_l, d.dilation_l,
input_ptr, output_ptr, stream_ptr,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_col2im_1d_f16_run(
d.batch, d.channels, d.l_in, self.l_out,
d.kl, d.stride_l, d.pad_l, d.dilation_l,
input_ptr, output_ptr, stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_col2im_1d_bf16_run(
d.batch, d.channels, d.l_in, self.l_out,
d.kl, d.stride_l, d.pad_l, d.dilation_l,
input_ptr, output_ptr, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::Col2Im1dPlan: unexpected dtype after select()",
));
}
};
map_im2col_status(status)
}
fn check_args(&self, args: &Col2Im1dArgs<'_, T>) -> Result<()> {
let in_shape = [self.desc.batch, self.desc.channels * self.desc.kl, self.l_out];
let out_shape = [self.desc.batch, self.desc.channels, self.desc.l_in];
if args.input.shape != in_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::Col2Im1dPlan: input shape != [N, C·kl, l_out]",
));
}
if args.output.shape != out_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::Col2Im1dPlan: output shape != [N, C, L_in]",
));
}
Ok(())
}
}