Skip to main content

baracuda_types/
device_repr.rs

1//! The `DeviceRepr` marker trait: types with a stable, ABI-compatible layout
2//! suitable for being laid out in GPU-visible memory.
3//!
4//! `DeviceRepr` is implemented for every primitive numeric type, arrays of
5//! `DeviceRepr`, tuples of `DeviceRepr` up to arity 12, and the numeric
6//! helpers in [`crate::numeric`]. User-defined `#[repr(C)]` structs can
7//! derive it via `#[derive(baracuda_types::DeviceRepr)]` (requires the
8//! `derive` feature).
9
10use crate::numeric::{BFloat16, Complex32, Complex64, Half};
11
12/// A type whose Rust memory layout is valid to expose to a CUDA kernel or
13/// to store in a GPU-visible buffer.
14///
15/// # Safety
16///
17/// Implementors must uphold:
18///
19/// 1. The type has a fixed, compile-time-known size and alignment.
20/// 2. The type is [`Copy`] and contains no references, pointers to
21///    host-only data, or non-trivial destructors.
22/// 3. Any bit pattern that Rust reads back from the type after a
23///    device memcpy must produce a valid value (or the code working with
24///    the type must tolerate transient "garbage" and write before read).
25/// 4. The layout is `#[repr(C)]`, `#[repr(transparent)]`, or a primitive.
26pub unsafe trait DeviceRepr: Copy + 'static {}
27
28macro_rules! impl_device_repr_primitive {
29    ($($t:ty),* $(,)?) => {
30        $(
31            // SAFETY: primitives are Copy + Sized and trivially ABI-stable.
32            unsafe impl DeviceRepr for $t {}
33        )*
34    };
35}
36
37impl_device_repr_primitive!(
38    u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64, bool, char,
39);
40
41// SAFETY: wrappers in `crate::numeric` are `#[repr(transparent)]` / `#[repr(C)]`.
42unsafe impl DeviceRepr for Half {}
43unsafe impl DeviceRepr for BFloat16 {}
44unsafe impl DeviceRepr for Complex32 {}
45unsafe impl DeviceRepr for Complex64 {}
46
47// SAFETY: a fixed-size array of `DeviceRepr` elements has a well-defined
48// C-compatible layout (contiguous, same alignment as T).
49unsafe impl<T: DeviceRepr, const N: usize> DeviceRepr for [T; N] {}
50
51macro_rules! impl_device_repr_tuple {
52    ($($t:ident),+) => {
53        // SAFETY: Rust does not actually guarantee C layout for tuples; we
54        // restrict tuple support to homogeneous sizes and mark as unsafe.
55        // Consumers should prefer `#[repr(C)] struct` or `[T; N]` for
56        // heterogeneous aggregates; this impl exists so `(T,)` and other
57        // simple tuples are ergonomic in tests.
58        unsafe impl<$($t: DeviceRepr),+> DeviceRepr for ($($t,)+) {}
59    };
60}
61
62impl_device_repr_tuple!(A);
63impl_device_repr_tuple!(A, B);
64impl_device_repr_tuple!(A, B, C);
65impl_device_repr_tuple!(A, B, C, D);
66impl_device_repr_tuple!(A, B, C, D, E);
67impl_device_repr_tuple!(A, B, C, D, E, F);
68impl_device_repr_tuple!(A, B, C, D, E, F, G);
69impl_device_repr_tuple!(A, B, C, D, E, F, G, H);
70impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I);
71impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J);
72impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J, K);
73impl_device_repr_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    fn assert_device_repr<T: DeviceRepr>() {}
80
81    #[test]
82    fn primitives_are_device_repr() {
83        assert_device_repr::<u8>();
84        assert_device_repr::<u16>();
85        assert_device_repr::<u32>();
86        assert_device_repr::<i32>();
87        assert_device_repr::<f32>();
88        assert_device_repr::<f64>();
89        assert_device_repr::<usize>();
90    }
91
92    #[test]
93    fn numeric_wrappers_are_device_repr() {
94        assert_device_repr::<Half>();
95        assert_device_repr::<BFloat16>();
96        assert_device_repr::<Complex32>();
97        assert_device_repr::<Complex64>();
98    }
99
100    #[test]
101    fn arrays_and_tuples_are_device_repr() {
102        assert_device_repr::<[f32; 4]>();
103        assert_device_repr::<[u8; 256]>();
104        assert_device_repr::<(f32, f32)>();
105        assert_device_repr::<(u32, i32, f64)>();
106    }
107}