use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Field};
use crate::{
dbn_attr::{
find_dbn_serialize_attr, get_sorted_fields, is_hidden, C_CHAR_ATTR, FIXED_PRICE_ATTR,
UNIX_NANOS_ATTR,
},
utils::crate_name,
};
pub fn derive_csv_macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input as DeriveInput);
if let Data::Struct(data_struct) = data {
if let syn::Fields::Named(fields) = data_struct.fields {
let crate_name = crate_name();
let fields = match get_sorted_fields(fields) {
Ok(fields) => fields,
Err(ts) => {
return ts.into_compile_error().into();
}
};
let serialize_header_iter = fields.iter().map(write_csv_header_token_stream);
let serialize_fields = fields
.iter()
.map(write_csv_field_token_stream)
.collect::<syn::Result<Vec<_>>>()
.unwrap_or_else(|e| vec![syn::Error::to_compile_error(&e)]);
return quote! {
impl #crate_name::encode::csv::serialize::CsvSerialize for #ident {
fn serialize_header<W: ::std::io::Write>(writer: &mut ::csv::Writer<W>) -> ::csv::Result<()> {
use #crate_name::encode::csv::serialize::WriteField;
#(#serialize_header_iter)*
Ok(())
}
fn serialize_to<W: ::std::io::Write, const PRETTY_PX: bool, const PRETTY_TS: bool>(
&self,
writer: &mut ::csv::Writer<W>
) -> ::csv::Result<()> {
use #crate_name::encode::csv::serialize::WriteField;
#(#serialize_fields)*
Ok(())
}
}
}
.into();
}
}
syn::Error::new(ident.span(), "Can only derive CsvSerialize for structs")
.into_compile_error()
.into()
}
pub fn derive_json_macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input as DeriveInput);
if let Data::Struct(data_struct) = data {
if let syn::Fields::Named(fields) = data_struct.fields {
let crate_name = crate_name();
let fields = match get_sorted_fields(fields) {
Ok(fields) => fields,
Err(ts) => {
return ts.into_compile_error().into();
}
};
let serialize_fields = fields
.iter()
.map(write_json_field_token_stream)
.collect::<syn::Result<Vec<_>>>()
.unwrap_or_else(|e| vec![syn::Error::to_compile_error(&e)]);
return quote! {
impl crate::encode::json::serialize::JsonSerialize for #ident {
fn to_json<J: #crate_name::json_writer::JsonWriter, const PRETTY_PX: bool, const PRETTY_TS: bool>(
&self,
writer: &mut #crate_name::json_writer::JsonObjectWriter<J>,
) {
use #crate_name::encode::json::serialize::WriteField;
#(#serialize_fields)*
}
}
}
.into();
}
}
syn::Error::new(ident.span(), "Can only derive JsonSerialize for structs")
.into_compile_error()
.into()
}
fn write_csv_header_token_stream(field: &Field) -> TokenStream {
let ident = field.ident.as_ref().unwrap();
let field_type = &field.ty;
if is_hidden(field) {
return TokenStream::new();
}
quote! {
<#field_type>::write_header(writer, stringify!(#ident))?;
}
}
fn write_csv_field_token_stream(field: &Field) -> syn::Result<TokenStream> {
let ident = field.ident.as_ref().unwrap();
if is_hidden(field) {
return Ok(quote! {});
}
if let Some(dbn_attr_id) = find_dbn_serialize_attr(field)? {
if dbn_attr_id == UNIX_NANOS_ATTR {
Ok(quote! {
crate::encode::csv::serialize::write_ts_field::<_, PRETTY_TS>(writer, self.#ident)?;
})
} else if dbn_attr_id == FIXED_PRICE_ATTR {
Ok(quote! {
crate::encode::csv::serialize::write_px_field::<_, PRETTY_PX>(writer, self.#ident)?;
})
} else if dbn_attr_id == C_CHAR_ATTR {
Ok(quote! {
crate::encode::csv::serialize::write_c_char_field(writer, self.#ident)?;
})
} else {
Err(syn::Error::new(
dbn_attr_id.span(),
format!("Invalid attr `{dbn_attr_id}` passed to `#[dbn]`"),
))
}
} else {
Ok(quote! {
self.#ident.write_field::<_, PRETTY_PX, PRETTY_TS>(writer)?;
})
}
}
fn write_json_field_token_stream(field: &Field) -> syn::Result<TokenStream> {
let ident = field.ident.as_ref().unwrap();
if is_hidden(field) {
return Ok(quote! {});
}
if let Some(dbn_attr_id) = find_dbn_serialize_attr(field)? {
if dbn_attr_id == UNIX_NANOS_ATTR {
Ok(quote! {
crate::encode::json::serialize::write_ts_field::<_, PRETTY_TS>(writer, stringify!(#ident), self.#ident);
})
} else if dbn_attr_id == FIXED_PRICE_ATTR {
Ok(quote! {
crate::encode::json::serialize::write_px_field::<_, PRETTY_PX>(writer, stringify!(#ident), self.#ident);
})
} else if dbn_attr_id == C_CHAR_ATTR {
Ok(quote! {
crate::encode::json::serialize::write_c_char_field(writer, stringify!(#ident), self.#ident);
})
} else {
Err(syn::Error::new(
dbn_attr_id.span(),
format!("Invalid attr `{dbn_attr_id}` passed to `#[dbn]`"),
))
}
} else {
Ok(quote! {
self.#ident.write_field::<_, PRETTY_PX, PRETTY_TS>(writer, stringify!(#ident));
})
}
}
#[cfg(test)]
mod tests {
use syn::FieldsNamed;
use super::*;
#[test]
fn skip_field() {
let input = quote!({
#[dbn(skip)]
pub b: bool,
});
let fields = syn::parse2::<FieldsNamed>(input).unwrap();
assert_eq!(fields.named.len(), 1);
let csv_generated = write_csv_field_token_stream(fields.named.first().unwrap()).unwrap();
let json_generated = write_json_field_token_stream(fields.named.first().unwrap()).unwrap();
assert!(csv_generated.is_empty());
assert!(json_generated.is_empty());
}
#[test]
fn skip_underscore_field() {
let input = quote!({
pub _a: bool,
});
let fields = syn::parse2::<FieldsNamed>(input).unwrap();
assert_eq!(fields.named.len(), 1);
let csv_generated = write_csv_field_token_stream(fields.named.first().unwrap()).unwrap();
let json_generated = write_json_field_token_stream(fields.named.first().unwrap()).unwrap();
assert!(csv_generated.is_empty());
assert!(json_generated.is_empty());
}
}