baracuda_kernels/norm/
instance_norm.rs1use baracuda_cutlass::Result;
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef,
20 Workspace,
21};
22
23use super::group_norm::{GroupNormArgs, GroupNormDescriptor, GroupNormPlan};
24
25#[derive(Copy, Clone, Debug)]
27pub struct InstanceNormDescriptor<const N: usize> {
28 pub input_shape: [i32; N],
30 pub channel_axis: u8,
32 pub eps: f32,
34 pub has_affine: bool,
36 pub element: ElementKind,
38}
39
40pub struct InstanceNormArgs<'a, T: Element, const N: usize> {
42 pub x: TensorRef<'a, T, N>,
44 pub gamma: Option<TensorRef<'a, T, 1>>,
46 pub beta: Option<TensorRef<'a, T, 1>>,
48 pub y: TensorMut<'a, T, N>,
50 pub saved_mean: TensorMut<'a, T, 1>,
52 pub saved_rstd: TensorMut<'a, T, 1>,
54}
55
56pub struct InstanceNormPlan<T: Element, const N: usize> {
59 inner: GroupNormPlan<T, N>,
60}
61
62impl<T: Element, const N: usize> InstanceNormPlan<T, N> {
63 pub fn select(
65 stream: &Stream,
66 desc: &InstanceNormDescriptor<N>,
67 pref: PlanPreference,
68 ) -> Result<Self> {
69 let c = if N >= 2 { desc.input_shape[desc.channel_axis as usize] } else { 1 };
70 let inner_desc = GroupNormDescriptor::<N> {
71 input_shape: desc.input_shape,
72 channel_axis: desc.channel_axis,
73 num_groups: c.max(1) as u32,
74 eps: desc.eps,
75 has_affine: desc.has_affine,
76 element: desc.element,
77 };
78 let inner = GroupNormPlan::<T, N>::select(stream, &inner_desc, pref)?;
79 Ok(Self { inner })
80 }
81
82 #[inline]
84 pub fn workspace_size(&self) -> usize { self.inner.workspace_size() }
85 #[inline]
87 pub fn sku(&self) -> KernelSku { self.inner.sku() }
88 #[inline]
90 pub fn precision_guarantee(&self) -> PrecisionGuarantee { self.inner.precision_guarantee() }
91
92 pub fn run(
94 &self,
95 stream: &Stream,
96 workspace: Workspace<'_>,
97 args: InstanceNormArgs<'_, T, N>,
98 ) -> Result<()> {
99 let inner_args = GroupNormArgs::<T, N> {
100 x: args.x,
101 gamma: args.gamma,
102 beta: args.beta,
103 y: args.y,
104 saved_mean: args.saved_mean,
105 saved_rstd: args.saved_rstd,
106 };
107 self.inner.run(stream, workspace, inner_args)
108 }
109}