1use crate::numeric::{BFloat16, Complex32, Complex64, Half};
11
12pub unsafe trait DeviceRepr: Copy + 'static {}
27
28macro_rules! impl_device_repr_primitive {
29 ($($t:ty),* $(,)?) => {
30 $(
31 unsafe impl DeviceRepr for $t {}
33 )*
34 };
35}
36
37impl_device_repr_primitive!(
38 u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64, bool, char,
39);
40
41unsafe impl DeviceRepr for Half {}
43unsafe impl DeviceRepr for BFloat16 {}
44unsafe impl DeviceRepr for Complex32 {}
45unsafe impl DeviceRepr for Complex64 {}
46
47unsafe impl<T: DeviceRepr, const N: usize> DeviceRepr for [T; N] {}
50
51macro_rules! impl_device_repr_tuple {
52 ($($t:ident),+) => {
53 unsafe impl<$($t: DeviceRepr),+> DeviceRepr for ($($t,)+) {}
59 };
60}
61
62impl_device_repr_tuple!(A);
63impl_device_repr_tuple!(A, B);
64impl_device_repr_tuple!(A, B, C);
65impl_device_repr_tuple!(A, B, C, D);
66impl_device_repr_tuple!(A, B, C, D, E);
67impl_device_repr_tuple!(A, B, C, D, E, F);
68impl_device_repr_tuple!(A, B, C, D, E, F, G);
69impl_device_repr_tuple!(A, B, C, D, E, F, G, H);
70impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I);
71impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J);
72impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J, K);
73impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 fn assert_device_repr<T: DeviceRepr>() {}
80
81 #[test]
82 fn primitives_are_device_repr() {
83 assert_device_repr::<u8>();
84 assert_device_repr::<u16>();
85 assert_device_repr::<u32>();
86 assert_device_repr::<i32>();
87 assert_device_repr::<f32>();
88 assert_device_repr::<f64>();
89 assert_device_repr::<usize>();
90 }
91
92 #[test]
93 fn numeric_wrappers_are_device_repr() {
94 assert_device_repr::<Half>();
95 assert_device_repr::<BFloat16>();
96 assert_device_repr::<Complex32>();
97 assert_device_repr::<Complex64>();
98 }
99
100 #[test]
101 fn arrays_and_tuples_are_device_repr() {
102 assert_device_repr::<[f32; 4]>();
103 assert_device_repr::<[u8; 256]>();
104 assert_device_repr::<(f32, f32)>();
105 assert_device_repr::<(u32, i32, f64)>();
106 }
107}