use proc_macro2::TokenStream;
use quote::quote;
use syn::DeriveInput;
use crate::gpu_struct_parse::{GpuStructAttrs, GpuStructDef};
use crate::layout::{FieldInput, LayoutField, compute_std140_layout};
pub fn generate(attrs: GpuStructAttrs, input: DeriveInput) -> syn::Result<TokenStream> {
let def = GpuStructDef::from_input(attrs, input)?;
let layout = build_layout(&def)?;
let struct_def = gen_struct(&def, &layout);
let default_impl = gen_default(&def, &layout);
let wgsl_type_impl = gen_wgsl_type(&def, &layout);
let wgsl_struct_impl = gen_wgsl_struct(&def, &layout);
let gpu_data_impl = gen_gpu_data(&def);
Ok(quote! {
#struct_def
#default_impl
#wgsl_type_impl
#wgsl_struct_impl
#gpu_data_impl
})
}
fn build_layout(def: &GpuStructDef) -> syn::Result<Vec<LayoutField>> {
let inputs: Vec<FieldInput> = def
.fields
.iter()
.map(|f| FieldInput {
name: f.name.clone(),
ty: f.ty.clone(),
default_expr: f.default_expr.clone(),
})
.collect();
compute_std140_layout(&inputs, def.dynamic_offset)
}
fn wgsl_visible(field: &LayoutField) -> bool {
!field.is_padding && !field.name.to_string().starts_with("__")
}
fn gen_struct(def: &GpuStructDef, layout: &[LayoutField]) -> TokenStream {
let vis = &def.vis;
let name = &def.name;
let struct_docs = &def.struct_docs;
let fields = layout.iter().map(|f| {
let fname = &f.name;
let fty = &f.ty;
if f.is_padding {
quote! {
#[doc(hidden)]
pub #fname: #fty,
}
} else {
let orig = def.fields.iter().find(|of| of.name == *fname);
if let Some(orig) = orig {
let fvis = &orig.vis;
let docs = &orig.docs;
quote! {
#(#docs)*
#fvis #fname: #fty,
}
} else {
quote! { pub #fname: #fty, }
}
}
});
quote! {
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, bytemuck::Pod, bytemuck::Zeroable)]
#(#struct_docs)*
#vis struct #name {
#(#fields)*
}
}
}
fn gen_default(def: &GpuStructDef, layout: &[LayoutField]) -> TokenStream {
let name = &def.name;
let field_defaults = layout.iter().map(|f| {
let fname = &f.name;
let fty = &f.ty;
if f.is_padding {
quote! { #fname: <#fty as Default>::default(), }
} else if let Some(expr) = &f.default_expr {
quote! { #fname: #expr, }
} else {
quote! { #fname: <#fty as Default>::default(), }
}
});
quote! {
impl Default for #name {
fn default() -> Self {
Self {
#(#field_defaults)*
}
}
}
}
}
fn gen_wgsl_type(def: &GpuStructDef, layout: &[LayoutField]) -> TokenStream {
let cr = &def.crate_path;
let name = &def.name;
let name_str = name.to_string();
let collect_fields = layout.iter().filter(|f| wgsl_visible(f)).map(|f| {
let ty = &f.ty;
quote! {
<#ty as #cr::uniforms::WgslType>::collect_wgsl_defs(defs, inserted);
}
});
let body_fields = layout.iter().filter(|f| wgsl_visible(f)).map(|f| {
let field_name_str = f.name.to_string();
let ty = &f.ty;
quote! {
let _ = std::fmt::Write::write_fmt(
&mut code,
format_args!(
" {}: {},\n",
#field_name_str,
<#ty as #cr::uniforms::WgslType>::wgsl_type_name(),
),
);
}
});
quote! {
impl #cr::uniforms::WgslType for #name {
fn wgsl_type_name() -> std::borrow::Cow<'static, str> {
#name_str.into()
}
fn collect_wgsl_defs(
defs: &mut Vec<String>,
inserted: &mut std::collections::HashSet<String>,
) {
#(#collect_fields)*
let my_name = #name_str;
if !inserted.contains(my_name) {
let mut code = format!("struct {} {{\n", my_name);
#(#body_fields)*
code.push_str("};\n");
defs.push(code);
inserted.insert(my_name.to_string());
}
}
}
}
}
fn gen_wgsl_struct(def: &GpuStructDef, layout: &[LayoutField]) -> TokenStream {
let cr = &def.crate_path;
let name = &def.name;
let collect_deps = layout.iter().filter(|f| wgsl_visible(f)).map(|f| {
let ty = &f.ty;
quote! {
<#ty as #cr::uniforms::WgslType>::collect_wgsl_defs(&mut defs, &mut inserted);
}
});
let body_fields = layout.iter().filter(|f| wgsl_visible(f)).map(|f| {
let field_name_str = f.name.to_string();
let ty = &f.ty;
quote! {
let _ = std::fmt::Write::write_fmt(
&mut code,
format_args!(
" {}: {},\n",
#field_name_str,
<#ty as #cr::uniforms::WgslType>::wgsl_type_name(),
),
);
}
});
quote! {
impl #cr::uniforms::WgslStruct for #name {
fn wgsl_struct_def(struct_name: &str) -> String {
let mut defs = Vec::new();
let mut inserted = std::collections::HashSet::new();
#(#collect_deps)*
let mut code = format!("struct {} {{\n", struct_name);
#(#body_fields)*
code.push_str("};\n");
defs.push(code);
defs.join("\n")
}
}
}
}
fn gen_gpu_data(def: &GpuStructDef) -> TokenStream {
let cr = &def.crate_path;
let name = &def.name;
quote! {
impl #cr::buffer::GpuData for #name {
fn as_bytes(&self) -> &[u8] {
bytemuck::bytes_of(self)
}
fn byte_size(&self) -> usize {
std::mem::size_of::<Self>()
}
}
}
}