1use std::marker::PhantomData;
2
3mod sealed {
4 pub trait Sealed {}
5}
6
7#[derive(Debug, Clone, Copy)]
8pub struct Nil;
9
10#[derive(Debug, Clone, Copy)]
11pub struct Dim<const N: usize, Rest>(pub(crate) PhantomData<Rest>);
12
13impl sealed::Sealed for Nil {}
14impl<const N: usize, Rest> sealed::Sealed for Dim<N, Rest> {}
15
16pub trait TensorShape: sealed::Sealed {
17 const SIZE: usize;
18 const RANK: usize;
19
20 fn offset(index: &[usize]) -> usize;
21}
22
23pub trait NonScalarShape: TensorShape {
24 type Subshape: TensorShape;
25 const AXIS_LEN: usize;
26}
27
28impl TensorShape for Nil {
29 const SIZE: usize = 1;
30 const RANK: usize = 0;
31
32 fn offset(index: &[usize]) -> usize {
33 assert!(index.is_empty(), "expected scalar index");
34 0
35 }
36}
37
38impl<const N: usize, Rest> TensorShape for Dim<N, Rest>
39where
40 Rest: TensorShape,
41{
42 const SIZE: usize = N * Rest::SIZE;
43 const RANK: usize = 1 + Rest::RANK;
44
45 fn offset(index: &[usize]) -> usize {
46 assert_eq!(index.len(), Self::RANK, "index rank mismatch");
47 let head = index[0];
48 assert!(head < N, "index out of bounds");
49 head * Rest::SIZE + Rest::offset(&index[1..])
50 }
51}
52
53impl<const N: usize, Rest> NonScalarShape for Dim<N, Rest>
54where
55 Rest: TensorShape,
56{
57 type Subshape = Rest;
58 const AXIS_LEN: usize = N;
59}
60
61#[macro_export]
62macro_rules! shape {
63 () => {
64 $crate::shape::Nil
65 };
66
67 ($dim:expr $(,)?) => {
68 $crate::shape::Dim<{ $dim }, $crate::shape::Nil>
69 };
70
71 ($first:expr, $($rest:expr),+ $(,)?) => {
72 $crate::shape::Dim<{ $first }, $crate::shape!($($rest),+)>
73 };
74}