#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Shape {
D1(u32),
D2([u32; 2]),
D3([u32; 3]),
}
impl Shape {
pub fn elements(&self) -> u64 {
match self {
Shape::D1(n) => *n as u64,
Shape::D2([a, b]) => *a as u64 * *b as u64,
Shape::D3([a, b, c]) => *a as u64 * *b as u64 * *c as u64,
}
}
pub fn complex_half_elements(&self) -> u64 {
match self {
Shape::D1(n) => (*n as u64 / 2) + 1,
Shape::D2([a, b]) => *a as u64 * ((*b as u64 / 2) + 1),
Shape::D3([a, b, c]) => *a as u64 * *b as u64 * ((*c as u64 / 2) + 1),
}
}
pub fn rank(&self) -> u8 {
match self {
Shape::D1(_) => 1,
Shape::D2(_) => 2,
Shape::D3(_) => 3,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Direction {
Forward,
Inverse,
}
impl Direction {
pub fn as_int(self) -> i32 {
match self {
Direction::Forward => 0,
Direction::Inverse => 1,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct PlanDesc {
pub shape: Shape,
pub batch: u32,
pub normalize: bool,
}
impl Default for PlanDesc {
fn default() -> Self {
Self {
shape: Shape::D1(1),
batch: 1,
normalize: false,
}
}
}