use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, Meta};
#[proc_macro_derive(DeviceRepr)]
pub fn derive_device_repr(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match expand_device_repr(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn expand_device_repr(input: DeriveInput) -> syn::Result<TokenStream2> {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
ensure_repr_c_or_transparent(&input)?;
let field_types = collect_field_types(&input)?;
let mut where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
where_token: Default::default(),
predicates: syn::punctuated::Punctuated::new(),
});
for ty in &field_types {
where_clause
.predicates
.push(syn::parse_quote!(#ty: ::baracuda_types::DeviceRepr));
}
Ok(quote! {
unsafe impl #impl_generics ::baracuda_types::DeviceRepr for #name #ty_generics #where_clause {}
})
}
fn ensure_repr_c_or_transparent(input: &DeriveInput) -> syn::Result<()> {
let mut has_required_repr = false;
for attr in &input.attrs {
if !attr.path().is_ident("repr") {
continue;
}
if let Meta::List(list) = &attr.meta {
for tok in list.tokens.clone() {
if let proc_macro2::TokenTree::Ident(id) = tok {
let s = id.to_string();
if s == "C" || s == "transparent" {
has_required_repr = true;
}
}
}
}
}
if has_required_repr {
Ok(())
} else {
Err(syn::Error::new_spanned(
&input.ident,
"#[derive(DeviceRepr)] requires #[repr(C)] or #[repr(transparent)] on the type",
))
}
}
fn collect_field_types(input: &DeriveInput) -> syn::Result<Vec<syn::Type>> {
match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(named) => Ok(named.named.iter().map(|f| f.ty.clone()).collect()),
Fields::Unnamed(unnamed) => Ok(unnamed.unnamed.iter().map(|f| f.ty.clone()).collect()),
Fields::Unit => Ok(Vec::new()),
},
Data::Enum(_) => Err(syn::Error::new_spanned(
&input.ident,
"#[derive(DeviceRepr)] on enums is not supported; use a #[repr(C)] struct instead",
)),
Data::Union(_) => Err(syn::Error::new_spanned(
&input.ident,
"#[derive(DeviceRepr)] on unions is not supported",
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn accepts_repr_c_struct() {
let input: DeriveInput = parse_quote! {
#[repr(C)]
struct S { a: u32, b: f32 }
};
ensure_repr_c_or_transparent(&input).expect("repr(C) must be accepted");
let fields = collect_field_types(&input).unwrap();
assert_eq!(fields.len(), 2);
}
#[test]
fn accepts_repr_transparent_newtype() {
let input: DeriveInput = parse_quote! {
#[repr(transparent)]
struct N(u64);
};
ensure_repr_c_or_transparent(&input).expect("repr(transparent) must be accepted");
let fields = collect_field_types(&input).unwrap();
assert_eq!(fields.len(), 1);
}
#[test]
fn accepts_repr_c_with_align() {
let input: DeriveInput = parse_quote! {
#[repr(C, align(16))]
struct A { x: f32 }
};
ensure_repr_c_or_transparent(&input).expect("repr(C, align(N)) must still pass");
}
#[test]
fn rejects_missing_repr() {
let input: DeriveInput = parse_quote! {
struct S { a: u32 }
};
let err = ensure_repr_c_or_transparent(&input).expect_err("missing repr must error");
let msg = err.to_string();
assert!(
msg.contains("repr(C)") || msg.contains("repr(transparent)"),
"error should mention required reprs: {msg}"
);
}
#[test]
fn rejects_repr_rust() {
let input: DeriveInput = parse_quote! {
#[repr(packed)]
struct S { a: u32 }
};
ensure_repr_c_or_transparent(&input)
.expect_err("repr(packed) alone (no C) must error");
}
#[test]
fn rejects_repr_int_only() {
let input: DeriveInput = parse_quote! {
#[repr(u32)]
struct S { a: u32 }
};
ensure_repr_c_or_transparent(&input).expect_err("repr(u32) alone must error");
}
#[test]
fn rejects_enum_even_with_repr_c() {
let input: DeriveInput = parse_quote! {
#[repr(C)]
enum E { A, B }
};
ensure_repr_c_or_transparent(&input).expect("repr(C) attr alone passes that check");
let err = collect_field_types(&input).expect_err("enum body must be rejected");
assert!(err.to_string().contains("enums"), "msg: {}", err);
}
#[test]
fn rejects_union() {
let input: DeriveInput = parse_quote! {
#[repr(C)]
union U { a: u32, b: f32 }
};
let err = collect_field_types(&input).expect_err("union body must be rejected");
assert!(err.to_string().contains("unions"), "msg: {}", err);
}
#[test]
fn unit_struct_collects_zero_fields() {
let input: DeriveInput = parse_quote! {
#[repr(C)]
struct Empty;
};
let fields = collect_field_types(&input).unwrap();
assert!(fields.is_empty());
}
#[test]
fn tuple_struct_collects_positional_fields() {
let input: DeriveInput = parse_quote! {
#[repr(C)]
struct T(f32, u32, i16);
};
let fields = collect_field_types(&input).unwrap();
assert_eq!(fields.len(), 3);
}
#[test]
fn end_to_end_expand_emits_unsafe_impl() {
let input: DeriveInput = parse_quote! {
#[repr(C)]
struct S { a: u32, b: f32 }
};
let ts = expand_device_repr(input).expect("valid input must expand cleanly");
let s = ts.to_string();
assert!(s.contains("unsafe impl"), "missing unsafe impl: {s}");
assert!(s.contains("DeviceRepr"), "missing trait name: {s}");
assert!(s.contains("u32"), "missing field type in where-clause: {s}");
assert!(s.contains("f32"), "missing field type in where-clause: {s}");
}
#[test]
fn end_to_end_expand_rejects_enum() {
let input: DeriveInput = parse_quote! {
enum E { A, B }
};
expand_device_repr(input).expect_err("enum without repr must not expand");
}
}