Skip to main content

tml_utils/
shape.rs

1use std::marker::PhantomData;
2
3mod sealed {
4    pub trait Sealed {}
5}
6
7#[derive(Debug, Clone, Copy)]
8pub struct Nil;
9
10#[derive(Debug, Clone, Copy)]
11pub struct Dim<const N: usize, Rest>(pub(crate) PhantomData<Rest>);
12
13impl sealed::Sealed for Nil {}
14impl<const N: usize, Rest> sealed::Sealed for Dim<N, Rest> {}
15
16pub trait TensorShape: sealed::Sealed {
17    const SIZE: usize;
18    const RANK: usize;
19
20    fn offset(index: &[usize]) -> usize;
21}
22
23pub trait NonScalarShape: TensorShape {
24    type Subshape: TensorShape;
25    const AXIS_LEN: usize;
26}
27
28impl TensorShape for Nil {
29    const SIZE: usize = 1;
30    const RANK: usize = 0;
31
32    fn offset(index: &[usize]) -> usize {
33        assert!(index.is_empty(), "expected scalar index");
34        0
35    }
36}
37
38impl<const N: usize, Rest> TensorShape for Dim<N, Rest>
39where
40    Rest: TensorShape,
41{
42    const SIZE: usize = N * Rest::SIZE;
43    const RANK: usize = 1 + Rest::RANK;
44
45    fn offset(index: &[usize]) -> usize {
46        assert_eq!(index.len(), Self::RANK, "index rank mismatch");
47        let head = index[0];
48        assert!(head < N, "index out of bounds");
49        head * Rest::SIZE + Rest::offset(&index[1..])
50    }
51}
52
53impl<const N: usize, Rest> NonScalarShape for Dim<N, Rest>
54where
55    Rest: TensorShape,
56{
57    type Subshape = Rest;
58    const AXIS_LEN: usize = N;
59}
60
61#[macro_export]
62macro_rules! shape {
63    () => {
64        $crate::shape::Nil
65    };
66
67    ($dim:expr $(,)?) => {
68        $crate::shape::Dim<{ $dim }, $crate::shape::Nil>
69    };
70
71    ($first:expr, $($rest:expr),+ $(,)?) => {
72        $crate::shape::Dim<{ $first }, $crate::shape!($($rest),+)>
73    };
74}