prgpu-macro 0.1.12

Procedural macros for prgpu — GPU struct layout / ABI codegen
Documentation
use crate::layout::StructLayout;
use crate::parse::GpuStructConfig;
use crate::types::GpuType;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Ident, ItemStruct};

pub fn generate(
    item_struct: &mut ItemStruct,
    config: &GpuStructConfig,
    layout: &StructLayout,
    resolved_fields: &[(Ident, GpuType, proc_macro2::Span)],
) -> TokenStream2 {
    let struct_ident = item_struct.ident.clone();

    let align_val = layout.struct_align;

    let has_unknown_size_fields = resolved_fields
        .iter()
        .any(|(_, gpu_type, _)| matches!(gpu_type, GpuType::GpuStruct { .. }));

    item_struct
        .attrs
        .retain(|attr| !attr.path().is_ident("repr"));

    let align_val_token = syn::LitInt::new(&align_val.to_string(), proc_macro2::Span::call_site());
    let repr_attr: syn::Attribute = syn::parse_quote!(#[repr(C, align(#align_val_token))]);
    item_struct.attrs.push(repr_attr);

    // Pod requires every byte to be defined, so explicit padding is mandatory when bytemuck is on.
    let needs_explicit_padding = config.bytemuck || config.pad;

    // Unknown-size nested structs can introduce implicit padding; skip Pod in that case.
    let can_derive_pod = config.bytemuck && !has_unknown_size_fields;

    let derive_tokens = build_derive_tokens(item_struct, can_derive_pod);
    let derive_attr: syn::Attribute = syn::parse_quote!(#[derive(#(#derive_tokens),*)]);
    item_struct
        .attrs
        .retain(|attr| !attr.path().is_ident("derive"));
    item_struct.attrs.push(derive_attr);

    if let syn::Fields::Named(fields) = &mut item_struct.fields {
        for field in &mut fields.named {
            field
                .attrs
                .retain(|attr| !(attr.path().segments.len() == 1 && attr.path().segments[0].ident == "gpu_nested"));
        }
    }

    if let syn::Fields::Named(fields) = &mut item_struct.fields {
        for field in &mut fields.named {
            if let Some(ident) = &field.ident {
                if let Some((_, gpu_type, _)) = resolved_fields
                    .iter()
                    .find(|(name, _, _)| name == ident)
                {
                    if matches!(gpu_type, GpuType::Bool) {
                        field.ty = syn::parse_quote!(u32);
                    }
                }
            }
        }
    }

    if needs_explicit_padding && !has_unknown_size_fields {
        let padding_gaps: Vec<(usize, usize)> = compute_padding_gaps(layout);
        let tail_padding = layout.tail_padding;

        inject_padding_fields(item_struct, &padding_gaps, tail_padding);
    }

    let size_val = layout.struct_size;
    let align_val_const = layout.struct_align;

    let size_align_tokens = if has_unknown_size_fields {
        quote! {
            pub const SIZE: usize = core::mem::size_of::<Self>();
            pub const ALIGN: usize = core::mem::align_of::<Self>();
        }
    } else {
        quote! {
            pub const SIZE: usize = #size_val;
            pub const ALIGN: usize = #align_val_const;
        }
    };

    let offset_constants = if config.debug_layout || config.emit_offsets {
        let offset_consts: Vec<TokenStream2> = layout
            .fields
            .iter()
            .map(|field_layout| {
                let field_name = &field_layout.name;
                let offset_val = field_layout.offset;
                let const_ident = syn::Ident::new(
                    &format!("{}_OFFSET", field_name.to_string().to_uppercase()),
                    proc_macro2::Span::call_site(),
                );
                quote! {
                    #[doc = concat!("Byte offset of field `", stringify!(#field_name), "`")]
                    pub const #const_ident: usize = #offset_val;
                }
            })
            .collect();
        quote! { #(#offset_consts)* }
    } else {
        quote! {}
    };

    let bool_helpers = generate_bool_helpers(resolved_fields);

    let assertions = generate_assertions(&struct_ident, layout, config, has_unknown_size_fields);

    let struct_tokens = quote! { #item_struct };

    let impl_tokens = quote! {
        impl #struct_ident {
            #size_align_tokens
            #offset_constants
            #bool_helpers
        }
    };

    quote! {
        #struct_tokens
        #impl_tokens
        #assertions
    }
}

/// Returns (field_index, gap_bytes) pairs for inter-field padding.
fn compute_padding_gaps(layout: &StructLayout) -> Vec<(usize, usize)> {
    let mut gaps = Vec::new();
    let mut current_offset = 0usize;

    for (i, field_layout) in layout.fields.iter().enumerate() {
        if field_layout.offset > current_offset {
            gaps.push((i, field_layout.offset - current_offset));
        }
        current_offset = field_layout.offset + field_layout.size;
    }

    gaps
}

/// Insert explicit padding fields between existing fields and at the tail.
fn inject_padding_fields(
    item_struct: &mut ItemStruct,
    padding_gaps: &[(usize, usize)],
    tail_padding: usize,
) {
    if let syn::Fields::Named(fields) = &mut item_struct.fields {
        let original: Vec<syn::Field> = fields.named.iter().cloned().collect();
        fields.named.clear();

        let mut pad_counter = 0usize;
        let mut gap_iter = padding_gaps.iter().peekable();

        for (i, field) in original.into_iter().enumerate() {
            while let Some(&(gap_idx, gap_size)) = gap_iter.peek() {
                if *gap_idx == i {
                    let pad_ident = syn::Ident::new(
                        &format!("_prgpu_pad_{pad_counter}"),
                        proc_macro2::Span::call_site(),
                    );
                    pad_counter += 1;
                    let pad_field: syn::Field = syn::parse_quote!(
                        #[doc(hidden)]
                        #pad_ident: [u8; #gap_size]
                    );
                    fields.named.push(pad_field);
                    gap_iter.next();
                } else {
                    break;
                }
            }

            fields.named.push(field);
        }

        if tail_padding > 0 {
            let pad_ident = syn::Ident::new(
                "_prgpu_pad_tail",
                proc_macro2::Span::call_site(),
            );
            let pad_field: syn::Field = syn::parse_quote!(
                #[doc(hidden)]
                #pad_ident: [u8; #tail_padding]
            );
            fields.named.push(pad_field);
        }
    }
}

fn build_derive_tokens(item_struct: &ItemStruct, can_derive_pod: bool) -> Vec<proc_macro2::TokenStream> {
    let mut derives = Vec::new();

    let always = ["Clone", "Copy", "Debug"];
    for d in &always {
        let ident = syn::Ident::new(d, proc_macro2::Span::call_site());
        if !has_derive(item_struct, d) {
            derives.push(quote! { #ident });
        }
    }

    if can_derive_pod {
        if !has_derive(item_struct, "Pod") {
            derives.push(quote! { bytemuck::Pod });
        }
        if !has_derive(item_struct, "Zeroable") {
            derives.push(quote! { bytemuck::Zeroable });
        }
    }

    // Preserve existing user derives that aren't in our managed set
    for attr in &item_struct.attrs {
        if attr.path().is_ident("derive") {
            if let Ok(nested) = attr.parse_args_with(|input: syn::parse::ParseStream<'_>| {
                let mut idents = Vec::new();
                while !input.is_empty() {
                    let ident: syn::Ident = input.parse()?;
                    idents.push(ident.to_string());
                    if input.peek(syn::Token![,]) {
                        input.parse::<syn::Token![,]>()?;
                    }
                }
                Ok(idents)
            }) {
                for existing in nested {
                    if !always.contains(&existing.as_str())
                        && existing != "Pod"
                        && existing != "Zeroable"
                    {
                        let ident = syn::Ident::new(&existing, proc_macro2::Span::call_site());
                        derives.push(quote! { #ident });
                    }
                }
            }
        }
    }

    derives
}

fn has_derive(item_struct: &ItemStruct, name: &str) -> bool {
    for attr in &item_struct.attrs {
        if attr.path().is_ident("derive") {
            if let Ok(nested) = attr.parse_args_with(|input: syn::parse::ParseStream<'_>| {
                let mut idents = Vec::new();
                while !input.is_empty() {
                    let ident: syn::Ident = input.parse()?;
                    idents.push(ident.to_string());
                    if input.peek(syn::Token![,]) {
                        input.parse::<syn::Token![,]>()?;
                    }
                }
                Ok(idents)
            }) {
                if nested.iter().any(|s| s == name) {
                    return true;
                }
            }
        }
    }
    false
}

fn generate_bool_helpers(
    resolved_fields: &[(Ident, GpuType, proc_macro2::Span)],
) -> TokenStream2 {
    let helpers: Vec<TokenStream2> = resolved_fields
        .iter()
        .filter(|(_, gpu_type, _)| matches!(gpu_type, GpuType::Bool))
        .map(|(name, _, _)| {
            let helper_name = syn::Ident::new(
                &format!("{}_bool", name),
                proc_macro2::Span::call_site(),
            );
            quote! {
                #[doc = concat!("Returns `", stringify!(#name), "` as `bool` (mapped from u32 for GPU ABI).")]
                #[inline]
                pub fn #helper_name(&self) -> bool {
                    self.#name != 0
                }
            }
        })
        .collect();

    quote! { #(#helpers)* }
}

fn generate_assertions(
    struct_ident: &Ident,
    layout: &StructLayout,
    config: &GpuStructConfig,
    has_unknown_size_fields: bool,
) -> TokenStream2 {
    let size_val = layout.struct_size;
    let align_val = layout.struct_align;

    let mut assertions: Vec<TokenStream2> = Vec::new();

    if has_unknown_size_fields {
        assertions.push(quote! {
            assert!(<#struct_ident>::SIZE == core::mem::size_of::<#struct_ident>());
        });
        assertions.push(quote! {
            assert!(<#struct_ident>::ALIGN == core::mem::align_of::<#struct_ident>());
        });
    } else {
        assertions.push(quote! {
            assert!(core::mem::size_of::<#struct_ident>() == #size_val);
        });
        assertions.push(quote! {
            assert!(core::mem::align_of::<#struct_ident>() == #align_val);
        });
    }

    if config.debug_layout || config.emit_offsets {
        for field_layout in &layout.fields {
            let field_name = &field_layout.name;
            let expected_offset = field_layout.offset;
            assertions.push(quote! {
                assert!(core::mem::offset_of!(#struct_ident, #field_name) == #expected_offset);
            });
        }
    }

    quote! {
        const _: () = {
            #(#assertions)*
        };
    }
}