#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DimensionIndices {
pub batch: usize,
pub channel: usize,
spatial: [usize; 2],
spatial_len: usize,
}
impl DimensionIndices {
#[must_use]
pub fn spatial(&self) -> &[usize] {
&self.spatial[..self.spatial_len]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ConvLayout {
NCHW,
NHWC,
NCL,
NLC,
}
impl ConvLayout {
#[must_use]
pub fn indices(self) -> DimensionIndices {
match self {
Self::NCHW => DimensionIndices {
batch: 0,
channel: 1,
spatial: [2, 3],
spatial_len: 2,
},
Self::NHWC => DimensionIndices {
batch: 0,
channel: 3,
spatial: [1, 2],
spatial_len: 2,
},
Self::NCL => DimensionIndices {
batch: 0,
channel: 1,
spatial: [2, 0],
spatial_len: 1,
},
Self::NLC => DimensionIndices {
batch: 0,
channel: 2,
spatial: [1, 0],
spatial_len: 1,
},
}
}
#[must_use]
pub fn parse_shape(self, shape: &[usize]) -> (usize, usize, Vec<usize>) {
let idx = self.indices();
let batch = shape[idx.batch];
let channels = shape[idx.channel];
let spatial: Vec<usize> = idx.spatial().iter().map(|&i| shape[i]).collect();
(batch, channels, spatial)
}
#[must_use]
pub fn build_shape(self, batch: usize, channels: usize, spatial: &[usize]) -> Vec<usize> {
match self {
Self::NCHW => vec![batch, channels, spatial[0], spatial[1]],
Self::NHWC => vec![batch, spatial[0], spatial[1], channels],
Self::NCL => vec![batch, channels, spatial[0]],
Self::NLC => vec![batch, spatial[0], channels],
}
}
#[must_use]
pub fn is_channels_first(self) -> bool {
matches!(self, Self::NCHW | Self::NCL)
}
#[must_use]
pub fn permutation_to(self, target: Self) -> Vec<usize> {
let src = self.indices();
let tgt = target.indices();
let ndim = match self {
Self::NCHW | Self::NHWC => 4,
Self::NCL | Self::NLC => 3,
};
let mut src_pos = vec![src.batch, src.channel];
src_pos.extend_from_slice(src.spatial());
let mut tgt_pos = vec![tgt.batch, tgt.channel];
tgt_pos.extend_from_slice(tgt.spatial());
let mut perm = vec![0; ndim];
for (semantic_idx, &tp) in tgt_pos.iter().enumerate() {
perm[tp] = src_pos[semantic_idx];
}
perm
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KernelLayout {
OIHW,
HWIO,
OIL,
LIO,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConvDimensionNumbers {
pub input_layout: ConvLayout,
pub kernel_layout: KernelLayout,
pub output_layout: ConvLayout,
}
impl Default for ConvDimensionNumbers {
fn default() -> Self {
Self {
input_layout: ConvLayout::NCHW,
kernel_layout: KernelLayout::OIHW,
output_layout: ConvLayout::NCHW,
}
}
}
impl std::fmt::Display for ConvLayout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NCHW => write!(f, "NCHW"),
Self::NHWC => write!(f, "NHWC"),
Self::NCL => write!(f, "NCL"),
Self::NLC => write!(f, "NLC"),
}
}
}
impl std::fmt::Display for KernelLayout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::OIHW => write!(f, "OIHW"),
Self::HWIO => write!(f, "HWIO"),
Self::OIL => write!(f, "OIL"),
Self::LIO => write!(f, "LIO"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nchw_indices() {
let idx = ConvLayout::NCHW.indices();
assert_eq!(idx.batch, 0);
assert_eq!(idx.channel, 1);
assert_eq!(idx.spatial(), &[2, 3]);
}
#[test]
fn test_nhwc_indices() {
let idx = ConvLayout::NHWC.indices();
assert_eq!(idx.batch, 0);
assert_eq!(idx.channel, 3);
assert_eq!(idx.spatial(), &[1, 2]);
}
#[test]
fn test_ncl_indices() {
let idx = ConvLayout::NCL.indices();
assert_eq!(idx.batch, 0);
assert_eq!(idx.channel, 1);
assert_eq!(idx.spatial(), &[2]);
}
#[test]
fn test_nlc_indices() {
let idx = ConvLayout::NLC.indices();
assert_eq!(idx.batch, 0);
assert_eq!(idx.channel, 2);
assert_eq!(idx.spatial(), &[1]);
}
#[test]
fn test_parse_shape_nchw() {
let (b, c, s) = ConvLayout::NCHW.parse_shape(&[2, 3, 32, 32]);
assert_eq!(b, 2);
assert_eq!(c, 3);
assert_eq!(s, vec![32, 32]);
}
#[test]
fn test_parse_shape_nhwc() {
let (b, c, s) = ConvLayout::NHWC.parse_shape(&[2, 32, 32, 3]);
assert_eq!(b, 2);
assert_eq!(c, 3);
assert_eq!(s, vec![32, 32]);
}
#[test]
fn test_parse_shape_ncl() {
let (b, c, s) = ConvLayout::NCL.parse_shape(&[4, 16, 100]);
assert_eq!(b, 4);
assert_eq!(c, 16);
assert_eq!(s, vec![100]);
}
#[test]
fn test_build_shape_nchw() {
let shape = ConvLayout::NCHW.build_shape(2, 3, &[32, 32]);
assert_eq!(shape, vec![2, 3, 32, 32]);
}
#[test]
fn test_build_shape_nhwc() {
let shape = ConvLayout::NHWC.build_shape(2, 3, &[32, 32]);
assert_eq!(shape, vec![2, 32, 32, 3]);
}
#[test]
fn test_build_shape_ncl() {
let shape = ConvLayout::NCL.build_shape(4, 16, &[100]);
assert_eq!(shape, vec![4, 16, 100]);
}
#[test]
fn test_build_shape_nlc() {
let shape = ConvLayout::NLC.build_shape(4, 16, &[100]);
assert_eq!(shape, vec![4, 100, 16]);
}
#[test]
fn test_is_channels_first() {
assert!(ConvLayout::NCHW.is_channels_first());
assert!(ConvLayout::NCL.is_channels_first());
assert!(!ConvLayout::NHWC.is_channels_first());
assert!(!ConvLayout::NLC.is_channels_first());
}
#[test]
fn test_permutation_nchw_to_nhwc() {
let perm = ConvLayout::NCHW.permutation_to(ConvLayout::NHWC);
assert_eq!(perm, vec![0, 2, 3, 1]);
}
#[test]
fn test_permutation_nhwc_to_nchw() {
let perm = ConvLayout::NHWC.permutation_to(ConvLayout::NCHW);
assert_eq!(perm, vec![0, 3, 1, 2]);
}
#[test]
fn test_permutation_identity() {
let perm = ConvLayout::NCHW.permutation_to(ConvLayout::NCHW);
assert_eq!(perm, vec![0, 1, 2, 3]);
}
#[test]
fn test_permutation_ncl_to_nlc() {
let perm = ConvLayout::NCL.permutation_to(ConvLayout::NLC);
assert_eq!(perm, vec![0, 2, 1]);
}
#[test]
fn test_display_conv_layout() {
assert_eq!(format!("{}", ConvLayout::NCHW), "NCHW");
assert_eq!(format!("{}", ConvLayout::NHWC), "NHWC");
assert_eq!(format!("{}", ConvLayout::NCL), "NCL");
assert_eq!(format!("{}", ConvLayout::NLC), "NLC");
}
#[test]
fn test_display_kernel_layout() {
assert_eq!(format!("{}", KernelLayout::OIHW), "OIHW");
assert_eq!(format!("{}", KernelLayout::HWIO), "HWIO");
assert_eq!(format!("{}", KernelLayout::OIL), "OIL");
assert_eq!(format!("{}", KernelLayout::LIO), "LIO");
}
#[test]
fn test_conv_dimension_numbers_default() {
let cdn = ConvDimensionNumbers::default();
assert_eq!(cdn.input_layout, ConvLayout::NCHW);
assert_eq!(cdn.kernel_layout, KernelLayout::OIHW);
assert_eq!(cdn.output_layout, ConvLayout::NCHW);
}
}