#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Tile2 {
pub rows: usize,
pub cols: usize,
}
impl Tile2 {
pub const fn new(rows: usize, cols: usize) -> Self {
Self { rows, cols }
}
pub const fn area(self) -> usize {
self.rows * self.cols
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Coord2 {
pub row: usize,
pub col: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Strides2 {
pub row: usize,
pub col: usize,
}
impl Strides2 {
pub const fn row_major(cols: usize) -> Self {
Self { row: cols, col: 1 }
}
pub const fn col_major(rows: usize) -> Self {
Self { row: 1, col: rows }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShapeTuple {
Leaf(usize),
Nested(Vec<ShapeTuple>),
}
impl ShapeTuple {
pub fn leaf(n: usize) -> Self {
Self::Leaf(n)
}
pub fn nested(parts: Vec<ShapeTuple>) -> Self {
Self::Nested(parts)
}
pub fn flat(dims: &[usize]) -> Self {
Self::Nested(dims.iter().map(|&n| Self::Leaf(n)).collect())
}
pub fn is_leaf(&self) -> bool {
matches!(self, Self::Leaf(_))
}
pub fn rank(&self) -> usize {
match self {
Self::Leaf(_) => 1,
Self::Nested(v) => v.len(),
}
}
pub fn product(&self) -> usize {
match self {
Self::Leaf(n) => *n,
Self::Nested(v) => v.iter().map(|p| p.product()).product(),
}
}
pub fn flatten(&self) -> Vec<usize> {
let mut out = Vec::new();
self.flatten_into(&mut out);
out
}
fn flatten_into(&self, out: &mut Vec<usize>) {
match self {
Self::Leaf(n) => out.push(*n),
Self::Nested(v) => v.iter().for_each(|p| p.flatten_into(out)),
}
}
pub fn get(&self, path: &[usize]) -> Option<&ShapeTuple> {
if path.is_empty() {
return Some(self);
}
match self {
Self::Leaf(_) => None, Self::Nested(v) => v.get(path[0]).and_then(|c| c.get(&path[1..])),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Ragged {
pub rows: usize,
pub trailing: usize,
pub total: usize,
}
impl Ragged {
pub const fn new(rows: usize, trailing: usize, total: usize) -> Self {
Self {
rows,
trailing,
total,
}
}
pub const fn data_elements(self) -> usize {
self.total * self.trailing
}
pub const fn offsets_elements(self) -> usize {
self.rows + 1
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Tile3 {
pub batch: usize,
pub rows: usize,
pub cols: usize,
}
impl Tile3 {
pub const fn new(batch: usize, rows: usize, cols: usize) -> Self {
Self { batch, rows, cols }
}
pub const fn area(self) -> usize {
self.batch * self.rows * self.cols
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Strides3 {
pub batch: usize,
pub row: usize,
pub col: usize,
}
impl Strides3 {
pub const fn row_major(rows: usize, cols: usize) -> Self {
Self {
batch: rows * cols,
row: cols,
col: 1,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tile2_area() {
assert_eq!(Tile2::new(3, 4).area(), 12);
}
#[test]
fn strides2_presets() {
assert_eq!(Strides2::row_major(8), Strides2 { row: 8, col: 1 });
assert_eq!(Strides2::col_major(8), Strides2 { row: 1, col: 8 });
}
#[test]
fn strides3_row_major() {
assert_eq!(
Strides3::row_major(3, 4),
Strides3 {
batch: 12,
row: 4,
col: 1
}
);
}
#[test]
fn tuple_leaf_constructors() {
let a = ShapeTuple::leaf(8);
assert_eq!(a.flatten(), vec![8]);
assert_eq!(a.product(), 8);
assert!(a.is_leaf());
}
#[test]
fn tuple_flat_constructor() {
let s = ShapeTuple::flat(&[2, 3, 4]);
assert_eq!(s.flatten(), vec![2, 3, 4]);
assert_eq!(s.product(), 24);
assert_eq!(s.rank(), 3);
}
#[test]
fn tuple_nested_product_and_flatten() {
let bs = ShapeTuple::nested(vec![ShapeTuple::leaf(8), ShapeTuple::leaf(15)]);
let nh = ShapeTuple::nested(vec![ShapeTuple::leaf(12), ShapeTuple::leaf(64)]);
let s = ShapeTuple::nested(vec![bs, nh]);
assert_eq!(s.product(), 8 * 15 * 12 * 64);
assert_eq!(s.flatten(), vec![8, 15, 12, 64]);
assert_eq!(s.rank(), 2); }
#[test]
fn tuple_get_resolves_path() {
let inner = ShapeTuple::nested(vec![ShapeTuple::leaf(12), ShapeTuple::leaf(64)]);
let s = ShapeTuple::nested(vec![ShapeTuple::leaf(8), ShapeTuple::leaf(15), inner]);
assert_eq!(s.get(&[0]), Some(&ShapeTuple::Leaf(8)));
assert_eq!(s.get(&[2, 1]), Some(&ShapeTuple::Leaf(64)));
assert_eq!(s.get(&[2, 99]), None);
}
#[test]
fn ragged_element_counts() {
let r = Ragged::new(4, 8, 30);
assert_eq!(r.data_elements(), 240); assert_eq!(r.offsets_elements(), 5); }
}