use baracuda_cutlass::Result;
use baracuda_driver::Stream;
use baracuda_kernels_types::{
Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef,
Workspace,
};
use super::group_norm::{GroupNormArgs, GroupNormDescriptor, GroupNormPlan};
#[derive(Copy, Clone, Debug)]
pub struct InstanceNormDescriptor<const N: usize> {
pub input_shape: [i32; N],
pub channel_axis: u8,
pub eps: f32,
pub has_affine: bool,
pub element: ElementKind,
}
pub struct InstanceNormArgs<'a, T: Element, const N: usize> {
pub x: TensorRef<'a, T, N>,
pub gamma: Option<TensorRef<'a, T, 1>>,
pub beta: Option<TensorRef<'a, T, 1>>,
pub y: TensorMut<'a, T, N>,
pub saved_mean: TensorMut<'a, T, 1>,
pub saved_rstd: TensorMut<'a, T, 1>,
}
pub struct InstanceNormPlan<T: Element, const N: usize> {
inner: GroupNormPlan<T, N>,
}
impl<T: Element, const N: usize> InstanceNormPlan<T, N> {
pub fn select(
stream: &Stream,
desc: &InstanceNormDescriptor<N>,
pref: PlanPreference,
) -> Result<Self> {
let c = if N >= 2 { desc.input_shape[desc.channel_axis as usize] } else { 1 };
let inner_desc = GroupNormDescriptor::<N> {
input_shape: desc.input_shape,
channel_axis: desc.channel_axis,
num_groups: c.max(1) as u32,
eps: desc.eps,
has_affine: desc.has_affine,
element: desc.element,
};
let inner = GroupNormPlan::<T, N>::select(stream, &inner_desc, pref)?;
Ok(Self { inner })
}
#[inline]
pub fn workspace_size(&self) -> usize { self.inner.workspace_size() }
#[inline]
pub fn sku(&self) -> KernelSku { self.inner.sku() }
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee { self.inner.precision_guarantee() }
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: InstanceNormArgs<'_, T, N>,
) -> Result<()> {
let inner_args = GroupNormArgs::<T, N> {
x: args.x,
gamma: args.gamma,
beta: args.beta,
y: args.y,
saved_mean: args.saved_mean,
saved_rstd: args.saved_rstd,
};
self.inner.run(stream, workspace, inner_args)
}
}