use crate::numeric::{BFloat16, Complex32, Complex64, Half};
pub unsafe trait DeviceRepr: Copy + 'static {}
macro_rules! impl_device_repr_primitive {
($($t:ty),* $(,)?) => {
$(
unsafe impl DeviceRepr for $t {}
)*
};
}
impl_device_repr_primitive!(
u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64, bool, char,
);
unsafe impl DeviceRepr for Half {}
unsafe impl DeviceRepr for BFloat16 {}
unsafe impl DeviceRepr for Complex32 {}
unsafe impl DeviceRepr for Complex64 {}
unsafe impl<T: DeviceRepr, const N: usize> DeviceRepr for [T; N] {}
macro_rules! impl_device_repr_tuple {
($($t:ident),+) => {
unsafe impl<$($t: DeviceRepr),+> DeviceRepr for ($($t,)+) {}
};
}
impl_device_repr_tuple!(A);
impl_device_repr_tuple!(A, B);
impl_device_repr_tuple!(A, B, C);
impl_device_repr_tuple!(A, B, C, D);
impl_device_repr_tuple!(A, B, C, D, E);
impl_device_repr_tuple!(A, B, C, D, E, F);
impl_device_repr_tuple!(A, B, C, D, E, F, G);
impl_device_repr_tuple!(A, B, C, D, E, F, G, H);
impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I);
impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J);
impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J, K);
impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
#[cfg(test)]
mod tests {
use super::*;
fn assert_device_repr<T: DeviceRepr>() {}
#[test]
fn primitives_are_device_repr() {
assert_device_repr::<u8>();
assert_device_repr::<u16>();
assert_device_repr::<u32>();
assert_device_repr::<i32>();
assert_device_repr::<f32>();
assert_device_repr::<f64>();
assert_device_repr::<usize>();
}
#[test]
fn numeric_wrappers_are_device_repr() {
assert_device_repr::<Half>();
assert_device_repr::<BFloat16>();
assert_device_repr::<Complex32>();
assert_device_repr::<Complex64>();
}
#[test]
fn arrays_and_tuples_are_device_repr() {
assert_device_repr::<[f32; 4]>();
assert_device_repr::<[u8; 256]>();
assert_device_repr::<(f32, f32)>();
assert_device_repr::<(u32, i32, f64)>();
}
}