Skip to main content

baracuda_kernels/norm/
instance_norm.rs

1//! InstanceNorm forward plan.
2//!
3//! Per-`(sample, channel)` normalization across spatial dims. Equivalent
4//! to GroupNorm with `num_groups == num_channels` — this plan is sugar
5//! that builds a [`super::group_norm::GroupNormPlan`] internally with
6//! that setting. **Same kernel symbols** dispatch behind the scenes
7//! (no separate `.cu` file).
8//!
9//! ## Why a separate plan?
10//!
11//! PyTorch ships `InstanceNorm1d/2d/3d` as their own modules — the API
12//! split matches their layer-shape semantics and lets callers be
13//! explicit about intent. Internally the kernel is identical to
14//! `GroupNorm(num_groups=C)`.
15
16use baracuda_cutlass::Result;
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef,
20    Workspace,
21};
22
23use super::group_norm::{GroupNormArgs, GroupNormDescriptor, GroupNormPlan};
24
25/// Descriptor for an InstanceNorm forward op.
26#[derive(Copy, Clone, Debug)]
27pub struct InstanceNormDescriptor<const N: usize> {
28    /// Input tensor shape `[N, C, *spatial]`.
29    pub input_shape: [i32; N],
30    /// Channel axis (must equal 1).
31    pub channel_axis: u8,
32    /// Epsilon.
33    pub eps: f32,
34    /// Whether gamma + beta participate.
35    pub has_affine: bool,
36    /// Element type.
37    pub element: ElementKind,
38}
39
40/// Args bundle for InstanceNorm FW.
41pub struct InstanceNormArgs<'a, T: Element, const N: usize> {
42    /// Input.
43    pub x: TensorRef<'a, T, N>,
44    /// Per-channel gamma.
45    pub gamma: Option<TensorRef<'a, T, 1>>,
46    /// Per-channel beta.
47    pub beta: Option<TensorRef<'a, T, 1>>,
48    /// Output.
49    pub y: TensorMut<'a, T, N>,
50    /// Saved per-`(N, C)` mean — length == `N * C`.
51    pub saved_mean: TensorMut<'a, T, 1>,
52    /// Saved per-`(N, C)` inv_std — length == `N * C`.
53    pub saved_rstd: TensorMut<'a, T, 1>,
54}
55
56/// InstanceNorm forward plan. Thin wrapper over [`GroupNormPlan`] with
57/// `num_groups == num_channels`.
58pub struct InstanceNormPlan<T: Element, const N: usize> {
59    inner: GroupNormPlan<T, N>,
60}
61
62impl<T: Element, const N: usize> InstanceNormPlan<T, N> {
63    /// Pick a kernel.
64    pub fn select(
65        stream: &Stream,
66        desc: &InstanceNormDescriptor<N>,
67        pref: PlanPreference,
68    ) -> Result<Self> {
69        let c = if N >= 2 { desc.input_shape[desc.channel_axis as usize] } else { 1 };
70        let inner_desc = GroupNormDescriptor::<N> {
71            input_shape: desc.input_shape,
72            channel_axis: desc.channel_axis,
73            num_groups: c.max(1) as u32,
74            eps: desc.eps,
75            has_affine: desc.has_affine,
76            element: desc.element,
77        };
78        let inner = GroupNormPlan::<T, N>::select(stream, &inner_desc, pref)?;
79        Ok(Self { inner })
80    }
81
82    /// Workspace bytes.
83    #[inline]
84    pub fn workspace_size(&self) -> usize { self.inner.workspace_size() }
85    /// Kernel SKU identity.
86    #[inline]
87    pub fn sku(&self) -> KernelSku { self.inner.sku() }
88    /// Numerical guarantees.
89    #[inline]
90    pub fn precision_guarantee(&self) -> PrecisionGuarantee { self.inner.precision_guarantee() }
91
92    /// Launch.
93    pub fn run(
94        &self,
95        stream: &Stream,
96        workspace: Workspace<'_>,
97        args: InstanceNormArgs<'_, T, N>,
98    ) -> Result<()> {
99        let inner_args = GroupNormArgs::<T, N> {
100            x: args.x,
101            gamma: args.gamma,
102            beta: args.beta,
103            y: args.y,
104            saved_mean: args.saved_mean,
105            saved_rstd: args.saved_rstd,
106        };
107        self.inner.run(stream, workspace, inner_args)
108    }
109}