use baracuda_cutlass::Result;
use baracuda_driver::Stream;
use baracuda_kernels_types::{
Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef,
Workspace,
};
use super::group_norm_backward::{
GroupNormBackwardArgs, GroupNormBackwardDescriptor, GroupNormBackwardPlan,
};
#[derive(Copy, Clone, Debug)]
pub struct InstanceNormBackwardDescriptor<const N: usize> {
pub input_shape: [i32; N],
pub channel_axis: u8,
pub has_affine: bool,
pub element: ElementKind,
}
pub struct InstanceNormBackwardArgs<'a, T: Element, const N: usize> {
pub dy: TensorRef<'a, T, N>,
pub x: TensorRef<'a, T, N>,
pub gamma: Option<TensorRef<'a, T, 1>>,
pub saved_mean: TensorRef<'a, T, 1>,
pub saved_rstd: TensorRef<'a, T, 1>,
pub dx: TensorMut<'a, T, N>,
pub dgamma: Option<TensorMut<'a, T, 1>>,
pub dbeta: Option<TensorMut<'a, T, 1>>,
}
pub struct InstanceNormBackwardPlan<T: Element, const N: usize> {
inner: GroupNormBackwardPlan<T, N>,
}
impl<T: Element, const N: usize> InstanceNormBackwardPlan<T, N> {
pub fn select(
stream: &Stream,
desc: &InstanceNormBackwardDescriptor<N>,
pref: PlanPreference,
) -> Result<Self> {
let c = if N >= 2 { desc.input_shape[desc.channel_axis as usize] } else { 1 };
let inner_desc = GroupNormBackwardDescriptor::<N> {
input_shape: desc.input_shape,
channel_axis: desc.channel_axis,
num_groups: c.max(1) as u32,
has_affine: desc.has_affine,
element: desc.element,
};
let inner = GroupNormBackwardPlan::<T, N>::select(stream, &inner_desc, pref)?;
Ok(Self { inner })
}
#[inline]
pub fn workspace_size(&self) -> usize { self.inner.workspace_size() }
#[inline]
pub fn sku(&self) -> KernelSku { self.inner.sku() }
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee { self.inner.precision_guarantee() }
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: InstanceNormBackwardArgs<'_, T, N>,
) -> Result<()> {
let inner_args = GroupNormBackwardArgs::<T, N> {
dy: args.dy,
x: args.x,
gamma: args.gamma,
saved_mean: args.saved_mean,
saved_rstd: args.saved_rstd,
dx: args.dx,
dgamma: args.dgamma,
dbeta: args.dbeta,
};
self.inner.run(stream, workspace, inner_args)
}
}