burn_std/tensor/
index_conversion.rs

1//! # Common Index Coercions
2//!
3//! This module contains common index coercions that can be used to implement
4//! various indexing operations.
5
6use super::indexing::IndexWrap;
7use core::fmt::Debug;
8
9/// Types which can be converted to a `usize` Size.
10pub trait AsSize: Debug + Copy + Sized {
11    /// Convert to a `usize` Size.
12    fn as_size(self) -> usize;
13}
14
15impl<T> AsSize for &T
16where
17    T: AsSize,
18{
19    fn as_size(self) -> usize {
20        (*self).as_size()
21    }
22}
23
24macro_rules! gen_as_size {
25    ($ty:ty) => {
26        impl AsSize for $ty {
27            fn as_size(self) -> usize {
28                self.try_into()
29                    .unwrap_or_else(|_| panic!(
30                        "Unable to convert value to usize: {}_{}",
31                        self,
32                        stringify!($ty)))
33            }
34        }
35    };
36    ($($ty:ty),*) => {$(gen_as_size!($ty);)*};
37}
38
39gen_as_size!(usize, isize, i64, u64, i32, u32, i16, u16, i8, u8);
40
41/// Helper trait for implementing indexing with support for negative indices.
42///
43/// # Example
44/// ```rust
45/// use burn_std::AsIndex;
46///
47/// fn example<I: AsIndex, const D: usize>(dim: I, size: usize) -> isize {
48///    let dim: usize = dim.expect_dim_index(D);
49///    unimplemented!()
50/// }
51/// ```
52pub trait AsIndex: Debug + Copy + Sized {
53    /// Converts into an `isize` index.
54    fn as_index(self) -> isize;
55
56    /// Short-form [`IndexWrap::expect_index(idx, size)`].
57    fn expect_elem_index(self, size: usize) -> usize {
58        IndexWrap::expect_elem(self, size)
59    }
60
61    /// Short-form [`IndexWrap::expect_dim(idx, size)`].
62    fn expect_dim_index(self, size: usize) -> usize {
63        IndexWrap::expect_dim(self, size)
64    }
65}
66
67impl<T> AsIndex for &T
68where
69    T: AsIndex,
70{
71    fn as_index(self) -> isize {
72        (*self).as_index()
73    }
74}
75
76macro_rules! gen_as_index {
77    ($ty:ty) => {
78        impl AsIndex for $ty {
79            fn as_index(self) -> isize {
80                self as isize
81            }
82        }
83    };
84    ($($ty:ty),*) => {$(gen_as_index!($ty);)*};
85}
86
87gen_as_index!(usize, isize, i64, u64, i32, u32, i16, u16, i8, u8);
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_as_size() {
95        assert_eq!(1_usize.as_size(), 1_usize);
96        assert_eq!(1_isize.as_size(), 1_usize);
97        assert_eq!(1_i64.as_size(), 1_usize);
98        assert_eq!(1_u64.as_size(), 1_usize);
99        assert_eq!(1_i32.as_size(), 1_usize);
100        assert_eq!(1_u32.as_size(), 1_usize);
101        assert_eq!(1_i16.as_size(), 1_usize);
102        assert_eq!(1_u16.as_size(), 1_usize);
103        assert_eq!(1_i8.as_size(), 1_usize);
104        assert_eq!(1_u8.as_size(), 1_usize);
105
106        assert_eq!((&1_usize).as_size(), 1_usize);
107    }
108
109    #[test]
110    #[should_panic(expected = "Unable to convert value to usize: -1_isize")]
111    fn test_as_size_isize_panic() {
112        (-1_isize).as_size();
113    }
114    #[test]
115    #[should_panic(expected = "Unable to convert value to usize: -1_i64")]
116    fn test_as_size_i64() {
117        (-1_i64).as_size();
118    }
119
120    #[test]
121    #[should_panic(expected = "Unable to convert value to usize: -1_i32")]
122    fn test_as_size_i32() {
123        (-1_i32).as_size();
124    }
125
126    #[test]
127    #[should_panic(expected = "Unable to convert value to usize: -1_i16")]
128    fn test_as_size_i16() {
129        (-1_i16).as_size();
130    }
131
132    #[test]
133    #[should_panic(expected = "Unable to convert value to usize: -1_i8")]
134    fn test_as_size_i8() {
135        (-1_i8).as_size();
136    }
137
138    #[test]
139    fn test_as_index() {
140        assert_eq!(1_usize.as_index(), 1_isize);
141        assert_eq!(1_isize.as_index(), 1_isize);
142        assert_eq!(1_i64.as_index(), 1_isize);
143        assert_eq!(1_u64.as_index(), 1_isize);
144        assert_eq!(1_i32.as_index(), 1_isize);
145        assert_eq!(1_u32.as_index(), 1_isize);
146        assert_eq!(1_i16.as_index(), 1_isize);
147        assert_eq!(1_u16.as_index(), 1_isize);
148        assert_eq!(1_i8.as_index(), 1_isize);
149        assert_eq!(1_u8.as_index(), 1_isize);
150
151        assert_eq!((&1_usize).as_index(), 1_isize);
152
153        assert_eq!(-1_isize.as_index(), -1_isize);
154        assert_eq!(-1_i64.as_index(), -1_isize);
155        assert_eq!(-1_i32.as_index(), -1_isize);
156        assert_eq!(-1_i16.as_index(), -1_isize);
157        assert_eq!(-1_i8.as_index(), -1_isize);
158    }
159}