baracuda_kernels/norm/
instance_norm_backward.rs1use baracuda_cutlass::Result;
4use baracuda_driver::Stream;
5use baracuda_kernels_types::{
6 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef,
7 Workspace,
8};
9
10use super::group_norm_backward::{
11 GroupNormBackwardArgs, GroupNormBackwardDescriptor, GroupNormBackwardPlan,
12};
13
14#[derive(Copy, Clone, Debug)]
16pub struct InstanceNormBackwardDescriptor<const N: usize> {
17 pub input_shape: [i32; N],
19 pub channel_axis: u8,
21 pub has_affine: bool,
23 pub element: ElementKind,
25}
26
27pub struct InstanceNormBackwardArgs<'a, T: Element, const N: usize> {
29 pub dy: TensorRef<'a, T, N>,
31 pub x: TensorRef<'a, T, N>,
33 pub gamma: Option<TensorRef<'a, T, 1>>,
35 pub saved_mean: TensorRef<'a, T, 1>,
37 pub saved_rstd: TensorRef<'a, T, 1>,
39 pub dx: TensorMut<'a, T, N>,
41 pub dgamma: Option<TensorMut<'a, T, 1>>,
43 pub dbeta: Option<TensorMut<'a, T, 1>>,
45}
46
47pub struct InstanceNormBackwardPlan<T: Element, const N: usize> {
49 inner: GroupNormBackwardPlan<T, N>,
50}
51
52impl<T: Element, const N: usize> InstanceNormBackwardPlan<T, N> {
53 pub fn select(
55 stream: &Stream,
56 desc: &InstanceNormBackwardDescriptor<N>,
57 pref: PlanPreference,
58 ) -> Result<Self> {
59 let c = if N >= 2 { desc.input_shape[desc.channel_axis as usize] } else { 1 };
60 let inner_desc = GroupNormBackwardDescriptor::<N> {
61 input_shape: desc.input_shape,
62 channel_axis: desc.channel_axis,
63 num_groups: c.max(1) as u32,
64 has_affine: desc.has_affine,
65 element: desc.element,
66 };
67 let inner = GroupNormBackwardPlan::<T, N>::select(stream, &inner_desc, pref)?;
68 Ok(Self { inner })
69 }
70
71 #[inline]
73 pub fn workspace_size(&self) -> usize { self.inner.workspace_size() }
74 #[inline]
76 pub fn sku(&self) -> KernelSku { self.inner.sku() }
77 #[inline]
79 pub fn precision_guarantee(&self) -> PrecisionGuarantee { self.inner.precision_guarantee() }
80
81 pub fn run(
83 &self,
84 stream: &Stream,
85 workspace: Workspace<'_>,
86 args: InstanceNormBackwardArgs<'_, T, N>,
87 ) -> Result<()> {
88 let inner_args = GroupNormBackwardArgs::<T, N> {
89 dy: args.dy,
90 x: args.x,
91 gamma: args.gamma,
92 saved_mean: args.saved_mean,
93 saved_rstd: args.saved_rstd,
94 dx: args.dx,
95 dgamma: args.dgamma,
96 dbeta: args.dbeta,
97 };
98 self.inner.run(stream, workspace, inner_args)
99 }
100}