Skip to main content

baracuda_types_derive/
lib.rs

1//! Proc-macros for `baracuda-types`: `#[derive(DeviceRepr)]`.
2//!
3//! `#[derive(KernelArg)]` is deliberately *not* provided: `KernelArg` is
4//! already implemented for `&T where T: DeviceRepr` via a blanket impl, so
5//! deriving `DeviceRepr` is sufficient for a type to be usable as a
6//! kernel argument via `&my_value`.
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::quote;
11use syn::{parse_macro_input, Data, DeriveInput, Fields, Meta};
12
13/// `#[derive(DeviceRepr)]` — implement `baracuda_types::DeviceRepr` for a
14/// `#[repr(C)]` or `#[repr(transparent)]` struct whose fields are all
15/// `DeviceRepr`. Enums and unions are rejected (use a `#[repr(C)]` struct).
16#[proc_macro_derive(DeviceRepr)]
17pub fn derive_device_repr(input: TokenStream) -> TokenStream {
18    let input = parse_macro_input!(input as DeriveInput);
19    match expand_device_repr(input) {
20        Ok(ts) => ts.into(),
21        Err(e) => e.to_compile_error().into(),
22    }
23}
24
25fn expand_device_repr(input: DeriveInput) -> syn::Result<TokenStream2> {
26    let name = &input.ident;
27    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
28
29    ensure_repr_c_or_transparent(&input)?;
30    let field_types = collect_field_types(&input)?;
31
32    // Augment the where-clause so every field type must also be DeviceRepr.
33    // This means if a user forgets to implement the trait on an inner field,
34    // the compile error points here rather than at a remote launch site.
35    let mut where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
36        where_token: Default::default(),
37        predicates: syn::punctuated::Punctuated::new(),
38    });
39    for ty in &field_types {
40        where_clause
41            .predicates
42            .push(syn::parse_quote!(#ty: ::baracuda_types::DeviceRepr));
43    }
44
45    Ok(quote! {
46        // SAFETY: the `#[repr(C)]` / `#[repr(transparent)]` attribute is
47        // enforced by the derive, and every field is required by the
48        // where-clause to implement DeviceRepr.
49        unsafe impl #impl_generics ::baracuda_types::DeviceRepr for #name #ty_generics #where_clause {}
50    })
51}
52
53fn ensure_repr_c_or_transparent(input: &DeriveInput) -> syn::Result<()> {
54    let mut has_required_repr = false;
55    for attr in &input.attrs {
56        if !attr.path().is_ident("repr") {
57            continue;
58        }
59        if let Meta::List(list) = &attr.meta {
60            for tok in list.tokens.clone() {
61                if let proc_macro2::TokenTree::Ident(id) = tok {
62                    let s = id.to_string();
63                    if s == "C" || s == "transparent" {
64                        has_required_repr = true;
65                    }
66                }
67            }
68        }
69    }
70    if has_required_repr {
71        Ok(())
72    } else {
73        Err(syn::Error::new_spanned(
74            &input.ident,
75            "#[derive(DeviceRepr)] requires #[repr(C)] or #[repr(transparent)] on the type",
76        ))
77    }
78}
79
80fn collect_field_types(input: &DeriveInput) -> syn::Result<Vec<syn::Type>> {
81    match &input.data {
82        Data::Struct(data) => match &data.fields {
83            Fields::Named(named) => Ok(named.named.iter().map(|f| f.ty.clone()).collect()),
84            Fields::Unnamed(unnamed) => Ok(unnamed.unnamed.iter().map(|f| f.ty.clone()).collect()),
85            Fields::Unit => Ok(Vec::new()),
86        },
87        Data::Enum(_) => Err(syn::Error::new_spanned(
88            &input.ident,
89            "#[derive(DeviceRepr)] on enums is not supported; use a #[repr(C)] struct instead",
90        )),
91        Data::Union(_) => Err(syn::Error::new_spanned(
92            &input.ident,
93            "#[derive(DeviceRepr)] on unions is not supported",
94        )),
95    }
96}
97
98// ---------------------------------------------------------------------------
99// Unit tests
100// ---------------------------------------------------------------------------
101//
102// Proc-macro crates (`proc-macro = true`) cannot exercise their macro via
103// doc-tests because the crate isn't usable as a normal `extern crate` in
104// the doc-test build environment. The integration test file
105// `tests/derive_device_repr.rs` drives the *positive* path end-to-end via
106// the parent `baracuda-types` crate (dev-dep). Here we cover the
107// *rejection* paths by exercising the helper functions directly against
108// parsed `syn::DeriveInput` trees — no compiler driver round-trip needed.
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use syn::parse_quote;
113
114    #[test]
115    fn accepts_repr_c_struct() {
116        let input: DeriveInput = parse_quote! {
117            #[repr(C)]
118            struct S { a: u32, b: f32 }
119        };
120        ensure_repr_c_or_transparent(&input).expect("repr(C) must be accepted");
121        let fields = collect_field_types(&input).unwrap();
122        assert_eq!(fields.len(), 2);
123    }
124
125    #[test]
126    fn accepts_repr_transparent_newtype() {
127        let input: DeriveInput = parse_quote! {
128            #[repr(transparent)]
129            struct N(u64);
130        };
131        ensure_repr_c_or_transparent(&input).expect("repr(transparent) must be accepted");
132        let fields = collect_field_types(&input).unwrap();
133        assert_eq!(fields.len(), 1);
134    }
135
136    #[test]
137    fn accepts_repr_c_with_align() {
138        let input: DeriveInput = parse_quote! {
139            #[repr(C, align(16))]
140            struct A { x: f32 }
141        };
142        ensure_repr_c_or_transparent(&input).expect("repr(C, align(N)) must still pass");
143    }
144
145    #[test]
146    fn rejects_missing_repr() {
147        let input: DeriveInput = parse_quote! {
148            struct S { a: u32 }
149        };
150        let err = ensure_repr_c_or_transparent(&input).expect_err("missing repr must error");
151        let msg = err.to_string();
152        assert!(
153            msg.contains("repr(C)") || msg.contains("repr(transparent)"),
154            "error should mention required reprs: {msg}"
155        );
156    }
157
158    #[test]
159    fn rejects_repr_rust() {
160        // `#[repr(Rust)]` is not legal syntax to spell explicitly, but
161        // `#[repr(packed)]` alone (without C) is — and must be rejected.
162        let input: DeriveInput = parse_quote! {
163            #[repr(packed)]
164            struct S { a: u32 }
165        };
166        ensure_repr_c_or_transparent(&input)
167            .expect_err("repr(packed) alone (no C) must error");
168    }
169
170    #[test]
171    fn rejects_repr_int_only() {
172        // A bare `#[repr(u32)]` is fine on enums but not what DeviceRepr wants.
173        let input: DeriveInput = parse_quote! {
174            #[repr(u32)]
175            struct S { a: u32 }
176        };
177        ensure_repr_c_or_transparent(&input).expect_err("repr(u32) alone must error");
178    }
179
180    #[test]
181    fn rejects_enum_even_with_repr_c() {
182        // Even with `#[repr(C)]`, an enum is not a valid DeviceRepr shape
183        // (we want it to live in a `#[repr(C)]` struct field). The
184        // `collect_field_types` helper enforces this.
185        let input: DeriveInput = parse_quote! {
186            #[repr(C)]
187            enum E { A, B }
188        };
189        // ensure_repr_c_or_transparent permits the attribute…
190        ensure_repr_c_or_transparent(&input).expect("repr(C) attr alone passes that check");
191        // …but the data-shape check on the enum body rejects.
192        let err = collect_field_types(&input).expect_err("enum body must be rejected");
193        assert!(err.to_string().contains("enums"), "msg: {}", err);
194    }
195
196    #[test]
197    fn rejects_union() {
198        let input: DeriveInput = parse_quote! {
199            #[repr(C)]
200            union U { a: u32, b: f32 }
201        };
202        let err = collect_field_types(&input).expect_err("union body must be rejected");
203        assert!(err.to_string().contains("unions"), "msg: {}", err);
204    }
205
206    #[test]
207    fn unit_struct_collects_zero_fields() {
208        let input: DeriveInput = parse_quote! {
209            #[repr(C)]
210            struct Empty;
211        };
212        let fields = collect_field_types(&input).unwrap();
213        assert!(fields.is_empty());
214    }
215
216    #[test]
217    fn tuple_struct_collects_positional_fields() {
218        let input: DeriveInput = parse_quote! {
219            #[repr(C)]
220            struct T(f32, u32, i16);
221        };
222        let fields = collect_field_types(&input).unwrap();
223        assert_eq!(fields.len(), 3);
224    }
225
226    #[test]
227    fn end_to_end_expand_emits_unsafe_impl() {
228        // Smoke-check the top-level expander: produces a TokenStream
229        // that mentions both `unsafe impl` and `DeviceRepr`, and
230        // includes a where-clause predicate per field type.
231        let input: DeriveInput = parse_quote! {
232            #[repr(C)]
233            struct S { a: u32, b: f32 }
234        };
235        let ts = expand_device_repr(input).expect("valid input must expand cleanly");
236        let s = ts.to_string();
237        assert!(s.contains("unsafe impl"), "missing unsafe impl: {s}");
238        assert!(s.contains("DeviceRepr"), "missing trait name: {s}");
239        // Per-field where-clause predicates land verbatim:
240        assert!(s.contains("u32"), "missing field type in where-clause: {s}");
241        assert!(s.contains("f32"), "missing field type in where-clause: {s}");
242    }
243
244    #[test]
245    fn end_to_end_expand_rejects_enum() {
246        let input: DeriveInput = parse_quote! {
247            enum E { A, B }
248        };
249        // Both checks fire here (no repr AND enum body); either is a
250        // legitimate rejection — we only care that the expander errors.
251        expand_device_repr(input).expect_err("enum without repr must not expand");
252    }
253}