Skip to main content

baracuda_kernels/norm/
instance_norm_backward.rs

1//! InstanceNorm backward plan — thin wrapper over GroupNorm BW.
2
3use baracuda_cutlass::Result;
4use baracuda_driver::Stream;
5use baracuda_kernels_types::{
6    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef,
7    Workspace,
8};
9
10use super::group_norm_backward::{
11    GroupNormBackwardArgs, GroupNormBackwardDescriptor, GroupNormBackwardPlan,
12};
13
14/// Descriptor for an InstanceNorm BW op.
15#[derive(Copy, Clone, Debug)]
16pub struct InstanceNormBackwardDescriptor<const N: usize> {
17    /// Input shape `[N, C, ...]`.
18    pub input_shape: [i32; N],
19    /// Channel axis (must equal 1).
20    pub channel_axis: u8,
21    /// Affine.
22    pub has_affine: bool,
23    /// Element type.
24    pub element: ElementKind,
25}
26
27/// Args bundle for InstanceNorm BW.
28pub struct InstanceNormBackwardArgs<'a, T: Element, const N: usize> {
29    /// Upstream gradient.
30    pub dy: TensorRef<'a, T, N>,
31    /// Saved forward input.
32    pub x: TensorRef<'a, T, N>,
33    /// Per-channel gamma.
34    pub gamma: Option<TensorRef<'a, T, 1>>,
35    /// Saved per-`(N, C)` mean.
36    pub saved_mean: TensorRef<'a, T, 1>,
37    /// Saved per-`(N, C)` inv_std.
38    pub saved_rstd: TensorRef<'a, T, 1>,
39    /// Gradient w.r.t. forward input.
40    pub dx: TensorMut<'a, T, N>,
41    /// Gradient w.r.t. gamma.
42    pub dgamma: Option<TensorMut<'a, T, 1>>,
43    /// Gradient w.r.t. beta.
44    pub dbeta: Option<TensorMut<'a, T, 1>>,
45}
46
47/// InstanceNorm BW plan — wraps [`GroupNormBackwardPlan`].
48pub struct InstanceNormBackwardPlan<T: Element, const N: usize> {
49    inner: GroupNormBackwardPlan<T, N>,
50}
51
52impl<T: Element, const N: usize> InstanceNormBackwardPlan<T, N> {
53    /// Pick a kernel.
54    pub fn select(
55        stream: &Stream,
56        desc: &InstanceNormBackwardDescriptor<N>,
57        pref: PlanPreference,
58    ) -> Result<Self> {
59        let c = if N >= 2 { desc.input_shape[desc.channel_axis as usize] } else { 1 };
60        let inner_desc = GroupNormBackwardDescriptor::<N> {
61            input_shape: desc.input_shape,
62            channel_axis: desc.channel_axis,
63            num_groups: c.max(1) as u32,
64            has_affine: desc.has_affine,
65            element: desc.element,
66        };
67        let inner = GroupNormBackwardPlan::<T, N>::select(stream, &inner_desc, pref)?;
68        Ok(Self { inner })
69    }
70
71    /// Workspace bytes.
72    #[inline]
73    pub fn workspace_size(&self) -> usize { self.inner.workspace_size() }
74    /// Kernel SKU identity.
75    #[inline]
76    pub fn sku(&self) -> KernelSku { self.inner.sku() }
77    /// Numerical guarantees.
78    #[inline]
79    pub fn precision_guarantee(&self) -> PrecisionGuarantee { self.inner.precision_guarantee() }
80
81    /// Launch.
82    pub fn run(
83        &self,
84        stream: &Stream,
85        workspace: Workspace<'_>,
86        args: InstanceNormBackwardArgs<'_, T, N>,
87    ) -> Result<()> {
88        let inner_args = GroupNormBackwardArgs::<T, N> {
89            dy: args.dy,
90            x: args.x,
91            gamma: args.gamma,
92            saved_mean: args.saved_mean,
93            saved_rstd: args.saved_rstd,
94            dx: args.dx,
95            dgamma: args.dgamma,
96            dbeta: args.dbeta,
97        };
98        self.inner.run(stream, workspace, inner_args)
99    }
100}