use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, Data, DeriveInput, Fields, Result};
pub(crate) fn derive(input: DeriveInput) -> Result<TokenStream> {
let name = &input.ident;
let fields = validate_struct(&input)?;
let repr = validate_repr(&input, fields)?;
let channels = generate_channels_constant(fields, &repr)?;
let expanded = quote! {
unsafe impl ::fovea::pixel::PlainChannel for #name {}
unsafe impl ::fovea::pixel::PlainPixel for #name {
const CHANNELS: &'static [usize] = #channels;
}
const _: () = { let _ = <#name as ::fovea::pixel::PlainChannel>::_ASSERT_SIZE; };
const _: () = { let _ = <#name as ::fovea::pixel::PlainPixel>::_ASSERT_CHANNELS; };
};
Ok(expanded)
}
fn generate_channels_constant(fields: &Fields, repr: &Repr) -> Result<TokenStream> {
match repr {
Repr::C => {
let field_types = fields.iter().map(|field| {
let ty = &field.ty;
quote! { <#ty as ::fovea::pixel::PlainChannel>::SIZE }
});
Ok(quote! { &[#(#field_types),*] })
}
Repr::Transparent => {
let field = fields.iter().next().expect("Expected exactly one field");
let ty = &field.ty;
Ok(quote! { &[<#ty as ::fovea::pixel::PlainChannel>::SIZE] })
}
}
}
fn validate_struct(input: &DeriveInput) -> Result<&Fields> {
match &input.data {
Data::Struct(data) => Ok(&data.fields),
Data::Enum(_) => Err(syn::Error::new_spanned(
&input.ident,
"PlainPixel can only be derived for structs, not enums",
)),
Data::Union(_) => Err(syn::Error::new_spanned(
&input.ident,
"PlainPixel can only be derived for structs, not unions",
)),
}
}
enum Repr {
C,
Transparent,
}
fn validate_repr(input: &DeriveInput, fields: &Fields) -> Result<Repr> {
let has_repr_c = has_repr_c(&input.attrs);
let has_repr_transparent = has_repr_transparent(&input.attrs);
match (has_repr_c, has_repr_transparent) {
(true, false) => {
if fields.is_empty() {
Err(syn::Error::new_spanned(
&input.ident,
"PlainPixel requires at least one field",
))
} else {
Ok(Repr::C)
}
}
(false, true) => {
if fields.len() != 1 {
Err(syn::Error::new_spanned(
&input.ident,
"PlainPixel requires exactly one field when #[repr(transparent)] is used",
))
} else {
Ok(Repr::Transparent)
}
}
(true, true) => Err(syn::Error::new_spanned(
&input.ident,
"PlainPixel cannot be derived for structs with both #[repr(C)] and #[repr(transparent)] attributes",
)),
(false, false) => Err(syn::Error::new_spanned(
&input.ident,
"PlainPixel requires #[repr(C)] or #[repr(transparent)] attribute",
)),
}
}
fn has_repr_c(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| {
attr.path().is_ident("repr")
&& matches!(&attr.meta, syn::Meta::List(ml) if ml.tokens.to_string().contains("C"))
})
}
fn has_repr_transparent(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| {
attr.path().is_ident("repr")
&& matches!(&attr.meta, syn::Meta::List(ml) if ml.tokens.to_string().contains("transparent"))
})
}
#[cfg(test)]
mod tests {
use super::*;
use syn::DeriveInput;
#[test]
fn test_validate_struct_with_named_fields() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
struct TestPixel {
r: u8,
g: u8,
b: u8,
}
};
assert!(validate_struct(&input).is_ok());
}
#[test]
fn test_validate_struct_with_tuple_fields() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
struct TestPixel(u8, u8, u8);
};
assert!(validate_struct(&input).is_ok());
}
#[test]
fn test_validate_struct_rejects_enum() {
let input: DeriveInput = syn::parse_quote! {
enum BadPixel {
Red,
}
};
let result = validate_struct(&input);
assert!(result.is_err());
assert!(result.err().unwrap().to_string().contains("not enums"));
}
#[test]
fn test_validate_struct_rejects_union() {
let input: DeriveInput = syn::parse_quote! {
union BadPixel {
a: u8,
b: u16,
}
};
let result = validate_struct(&input);
assert!(result.is_err());
assert!(result.err().unwrap().to_string().contains("not unions"));
}
#[test]
fn test_validate_repr_c() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
struct TestPixel {
r: u8,
}
};
let fields = validate_struct(&input).unwrap();
assert!(matches!(validate_repr(&input, fields), Ok(Repr::C)));
}
#[test]
fn test_validate_repr_transparent() {
let input: DeriveInput = syn::parse_quote! {
#[repr(transparent)]
struct TestPixel {
value: u8,
}
};
let fields = validate_struct(&input).unwrap();
assert!(matches!(
validate_repr(&input, fields),
Ok(Repr::Transparent)
));
}
#[test]
fn test_validate_repr_missing() {
let input: DeriveInput = syn::parse_quote! {
struct TestPixel {
r: u8,
}
};
let fields = validate_struct(&input).unwrap();
let result = validate_repr(&input, fields);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("requires #[repr(C)]")
);
}
#[test]
fn test_validate_repr_c_no_fields() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
struct TestPixel {}
};
let fields = validate_struct(&input).unwrap();
let result = validate_repr(&input, fields);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("at least one field")
);
}
#[test]
fn test_validate_repr_both_c_and_transparent() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
#[repr(transparent)]
struct TestPixel {
value: u8,
}
};
let fields = validate_struct(&input).unwrap();
let result = validate_repr(&input, fields);
assert!(result.is_err());
assert!(result.err().unwrap().to_string().contains("both"));
}
#[test]
fn test_validate_repr_transparent_multiple_fields() {
let input: DeriveInput = syn::parse_quote! {
#[repr(transparent)]
struct TestPixel {
a: u8,
b: u8,
}
};
let fields = validate_struct(&input).unwrap();
let result = validate_repr(&input, fields);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("exactly one field")
);
}
#[test]
fn test_generate_channels_repr_c() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
struct TestPixel {
r: u8,
g: u8,
b: u8,
}
};
let fields = validate_struct(&input).unwrap();
let tokens = generate_channels_constant(fields, &Repr::C).unwrap();
let output = tokens.to_string();
assert!(output.contains("PlainChannel"));
assert!(output.contains("SIZE"));
}
#[test]
fn test_generate_channels_repr_transparent() {
let input: DeriveInput = syn::parse_quote! {
#[repr(transparent)]
struct TestPixel {
value: u8,
}
};
let fields = validate_struct(&input).unwrap();
let tokens = generate_channels_constant(fields, &Repr::Transparent).unwrap();
let output = tokens.to_string();
assert!(output.contains("PlainChannel"));
assert!(output.contains("SIZE"));
}
#[test]
fn test_derive_repr_c_full() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
struct Rgb8 {
r: u8,
g: u8,
b: u8,
}
};
let tokens = derive(input).unwrap();
let output = tokens.to_string();
assert!(output.contains("PlainPixel"));
assert!(output.contains("CHANNELS"));
assert!(
output.contains("_ASSERT_SIZE"),
"must force-evaluate _ASSERT_SIZE"
);
assert!(
output.contains("_ASSERT_CHANNELS"),
"must force-evaluate _ASSERT_CHANNELS"
);
}
#[test]
fn test_derive_repr_transparent_full() {
let input: DeriveInput = syn::parse_quote! {
#[repr(transparent)]
struct Wrapped {
value: u8,
}
};
let tokens = derive(input).unwrap();
let output = tokens.to_string();
assert!(output.contains("PlainPixel"));
assert!(output.contains("CHANNELS"));
assert!(
output.contains("_ASSERT_SIZE"),
"must force-evaluate _ASSERT_SIZE"
);
assert!(
output.contains("_ASSERT_CHANNELS"),
"must force-evaluate _ASSERT_CHANNELS"
);
}
#[test]
fn test_has_repr_c_true() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
struct Foo { x: u8 }
};
assert!(has_repr_c(&input.attrs));
assert!(!has_repr_transparent(&input.attrs));
}
#[test]
fn test_has_repr_transparent_true() {
let input: DeriveInput = syn::parse_quote! {
#[repr(transparent)]
struct Foo { x: u8 }
};
assert!(!has_repr_c(&input.attrs));
assert!(has_repr_transparent(&input.attrs));
}
#[test]
fn test_has_repr_neither() {
let input: DeriveInput = syn::parse_quote! {
struct Foo { x: u8 }
};
assert!(!has_repr_c(&input.attrs));
assert!(!has_repr_transparent(&input.attrs));
}
#[test]
fn test_has_repr_with_non_repr_attr() {
let input: DeriveInput = syn::parse_quote! {
#[derive(Clone)]
struct Foo { x: u8 }
};
assert!(!has_repr_c(&input.attrs));
assert!(!has_repr_transparent(&input.attrs));
}
#[test]
fn test_derive_rejects_enum() {
let input: DeriveInput = syn::parse_quote! {
enum BadPixel { Red }
};
let result = derive(input);
assert!(result.is_err());
assert!(result.err().unwrap().to_string().contains("not enums"));
}
#[test]
fn test_derive_rejects_union() {
let input: DeriveInput = syn::parse_quote! {
union BadPixel { a: u8, b: u16 }
};
let result = derive(input);
assert!(result.is_err());
assert!(result.err().unwrap().to_string().contains("not unions"));
}
#[test]
fn test_derive_rejects_missing_repr() {
let input: DeriveInput = syn::parse_quote! {
struct TestPixel { r: u8 }
};
let result = derive(input);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("requires #[repr(C)]")
);
}
#[test]
fn test_derive_rejects_repr_c_no_fields() {
let input: DeriveInput = syn::parse_quote! {
#[repr(C)]
struct Empty {}
};
let result = derive(input);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("at least one field")
);
}
}