use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input, spanned::Spanned};
#[proc_macro_derive(FromRow, attributes(from_row))]
pub fn derive_from_row(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let strict = input.attrs.iter().any(|attr| {
if !attr.path().is_ident("from_row") {
return false;
}
match &attr.meta {
Meta::List(list) => list.tokens.to_string().contains("strict"),
_ => false,
}
});
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => panic!("FromRow only supports structs with named fields"),
},
_ => panic!("FromRow only supports structs"),
};
let field_names: Vec<_> = fields
.iter()
.map(|f| f.ident.as_ref().expect("named fields always have idents"))
.collect();
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let field_name_strs: Vec<_> = field_names.iter().map(|n| n.to_string()).collect();
let uninit_decls = field_names
.iter()
.zip(field_types.iter())
.map(|(name, ty)| {
quote! {
let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
}
});
let set_flag_names: Vec<_> = field_names
.iter()
.map(|n| syn::Ident::new(&format!("{}_set", n), n.span()))
.collect();
let set_flag_decls = set_flag_names.iter().map(|flag| {
quote! { let mut #flag = false; }
});
let match_arms_text = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
quote! {
#name_str => {
let __val: #ty = match __value {
None => ::zero_postgres::conversion::FromWireValue::from_null()?,
Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_text(__field.type_oid(), __bytes)?,
};
#name.write(__val);
#flag = true;
}
}
});
let match_arms_binary = field_names.iter().zip(field_types.iter()).zip(set_flag_names.iter()).zip(field_name_strs.iter()).map(|(((name, ty), flag), name_str)| {
quote! {
#name_str => {
let __val: #ty = match __value {
None => ::zero_postgres::conversion::FromWireValue::from_null()?,
Some(__bytes) => ::zero_postgres::conversion::FromWireValue::from_binary(__field.type_oid(), __bytes)?,
};
#name.write(__val);
#flag = true;
}
}
});
let fallback_arm = if strict {
quote! {
__unknown => {
return Err(::zero_postgres::Error::Decode(format!("unknown column: {}", __unknown)));
}
}
} else {
quote! {
_ => {
}
}
};
let init_checks = field_names
.iter()
.zip(set_flag_names.iter())
.zip(field_name_strs.iter())
.map(|((_name, flag), name_str)| {
quote! {
if !#flag {
return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
}
}
});
let field_inits = field_names.iter().map(|name| {
quote! {
#name: unsafe { #name.assume_init() }
}
});
let uninit_decls_text = uninit_decls.clone();
let set_flag_decls_text = set_flag_decls.clone();
let init_checks_text = init_checks.clone();
let field_inits_text = field_inits.clone();
let uninit_decls_binary = field_names
.iter()
.zip(field_types.iter())
.map(|(name, ty)| {
quote! {
let mut #name: ::core::mem::MaybeUninit<#ty> = ::core::mem::MaybeUninit::uninit();
}
});
let set_flag_decls_binary = set_flag_names.iter().map(|flag| {
quote! { let mut #flag = false; }
});
let init_checks_binary = field_names
.iter()
.zip(set_flag_names.iter())
.zip(field_name_strs.iter())
.map(|((_name, flag), name_str)| {
quote! {
if !#flag {
return Err(::zero_postgres::Error::Decode(format!("missing column: {}", #name_str)));
}
}
});
let field_inits_binary = field_names.iter().map(|name| {
quote! {
#name: unsafe { #name.assume_init() }
}
});
let expanded = quote! {
impl #impl_generics ::zero_postgres::conversion::FromRow<'_> for #name #ty_generics #where_clause {
fn from_row_text(
__cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
__row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
) -> ::zero_postgres::Result<Self> {
#(#uninit_decls_text)*
#(#set_flag_decls_text)*
let mut __values = __row.iter();
for __field in __cols.iter() {
let __value = __values.next().flatten();
let __col_name = __field.name;
match __col_name {
#(#match_arms_text)*
#fallback_arm
}
}
#(#init_checks_text)*
Ok(Self {
#(#field_inits_text),*
})
}
fn from_row_binary(
__cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
__row: ::zero_postgres::protocol::backend::query::DataRow<'_>,
) -> ::zero_postgres::Result<Self> {
#(#uninit_decls_binary)*
#(#set_flag_decls_binary)*
let mut __values = __row.iter();
for __field in __cols.iter() {
let __value = __values.next().flatten();
let __col_name = __field.name;
match __col_name {
#(#match_arms_binary)*
#fallback_arm
}
}
#(#init_checks_binary)*
Ok(Self {
#(#field_inits_binary),*
})
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(RefFromRow)]
pub fn derive_ref_from_row(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let has_repr_c_packed = input.attrs.iter().any(|attr| {
if !attr.path().is_ident("repr") {
return false;
}
let tokens = match &attr.meta {
Meta::List(list) => list.tokens.to_string(),
_ => return false,
};
tokens.contains("C") && tokens.contains("packed")
});
if !has_repr_c_packed {
return syn::Error::new(
input.ident.span(),
"RefFromRow requires #[repr(C, packed)] on the struct",
)
.to_compile_error()
.into();
}
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return syn::Error::new(
input.ident.span(),
"RefFromRow only supports structs with named fields",
)
.to_compile_error()
.into();
}
},
_ => {
return syn::Error::new(input.ident.span(), "RefFromRow only supports structs")
.to_compile_error()
.into();
}
};
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let wire_size_checks = field_types.iter().map(|ty| {
quote! {
const _: () = {
fn __assert_fixed_wire_size<T: ::zero_postgres::conversion::ref_row::FixedWireSize>() {}
fn __check() { __assert_fixed_wire_size::<#ty>(); }
};
}
});
let wire_size_sum = field_types.iter().map(|ty| {
quote! { <#ty as ::zero_postgres::conversion::ref_row::FixedWireSize>::WIRE_SIZE }
});
let expanded = quote! {
#(#wire_size_checks)*
unsafe impl ::zerocopy::KnownLayout for #name {}
unsafe impl ::zerocopy::Immutable for #name {}
unsafe impl ::zerocopy::FromBytes for #name {}
impl<'a> ::zero_postgres::conversion::ref_row::RefFromRow<'a> for #name {
fn ref_from_row_binary(
_cols: &[::zero_postgres::protocol::backend::query::FieldDescription],
row: ::zero_postgres::protocol::backend::query::DataRow<'a>,
) -> ::zero_postgres::Result<&'a Self> {
const EXPECTED_SIZE: usize = 0 #(+ #wire_size_sum)*;
let data = row.raw_data();
if data.len() < EXPECTED_SIZE {
return Err(::zero_postgres::Error::Decode(
format!(
"Row data too small: expected {} bytes, got {}",
EXPECTED_SIZE,
data.len()
)
));
}
::zerocopy::FromBytes::ref_from_bytes(&data[..EXPECTED_SIZE])
.map_err(|e| ::zero_postgres::Error::Decode(
format!("RefFromRow zerocopy error: {:?}", e)
))
}
}
};
TokenStream::from(expanded)
}