dbn-macros 0.59.0

Proc macros for dbn crate
Documentation
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
    parse::{Parse, ParseStream},
    parse_macro_input,
    punctuated::Punctuated,
    spanned::Spanned,
    ExprPath, ItemStruct, Token,
};

use crate::dbn_attr::{find_dbn_attr_args, get_sorted_fields, INDEX_TS_ATTR};

pub fn attribute_macro_impl(
    attr: proc_macro::TokenStream,
    input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let args = parse_macro_input!(attr as Args);
    if args.args.is_empty() {
        return syn::Error::new(
            args.span,
            "Need to specify at least one rtype to match against",
        )
        .into_compile_error()
        .into();
    }
    let input_struct = parse_macro_input!(input as ItemStruct);
    let record_type = &input_struct.ident;
    let raw_index_ts = get_raw_index_ts(&input_struct).unwrap_or_else(|e| e.into_compile_error());
    let rtypes = args.args.iter();
    let crate_name = crate::utils::crate_name();
    let impl_debug = crate::debug::record_debug_impl(&input_struct);
    let impl_py_repr = get_py_repr_impl(&input_struct);
    quote! (
        #input_struct

        impl #crate_name::record::Record for #record_type {
            fn header(&self) -> &#crate_name::record::RecordHeader {
                &self.hd
            }
            #raw_index_ts
        }

        impl #crate_name::record::RecordMut for #record_type {
            fn header_mut(&mut self) -> &mut #crate_name::record::RecordHeader {
                &mut self.hd
            }
        }

        impl #crate_name::record::HasRType for #record_type {
            #[allow(deprecated)]
            fn has_rtype(rtype: u8) -> bool {
                matches!(rtype, #(#rtypes)|*)
            }
        }

        impl AsRef<[u8]> for #record_type {
            fn as_ref(&self) -> &[u8] {
                unsafe { ::std::slice::from_raw_parts(self as *const #record_type as *const u8, ::std::mem::size_of::<#record_type>()) }
            }
        }

        impl std::cmp::PartialOrd for #record_type {
            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
                use #crate_name::record::Record;
                if self.raw_index_ts() == #crate_name::UNDEF_TIMESTAMP || other.raw_index_ts() == #crate_name::UNDEF_TIMESTAMP {
                    None
                } else {
                    Some(self.raw_index_ts().cmp(&other.raw_index_ts()))
                }
            }
        }

        #impl_debug

        #impl_py_repr
    )
    .into()
}

fn get_py_repr_impl(input_struct: &ItemStruct) -> TokenStream {
    let syn::Fields::Named(fields) = &input_struct.fields else {
        return quote!();
    };
    let sorted_fields = match get_sorted_fields(fields.clone()) {
        Ok(fields) => fields,
        Err(e) => return e.into_compile_error(),
    };
    let impl_body = crate::py_repr::py_repr_impl(&input_struct.ident, &sorted_fields);
    quote! {
        #[cfg(feature = "python")]
        #impl_body
    }
}

pub(crate) struct Args {
    args: Vec<ExprPath>,
    span: Span,
}

impl Parse for Args {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let args = Punctuated::<ExprPath, Token![,]>::parse_terminated(input)?;
        Ok(Args {
            args: args.into_iter().collect(),
            span: input.span(),
        })
    }
}

fn get_raw_index_ts(input_struct: &ItemStruct) -> syn::Result<TokenStream> {
    let mut index_ts_fields = Vec::new();
    for field in input_struct.fields.iter() {
        if find_dbn_attr_args(field)?
            .iter()
            .any(|id| id == INDEX_TS_ATTR)
        {
            index_ts_fields.push(field.ident.as_ref().unwrap())
        }
    }
    match index_ts_fields.len() {
        0 => Ok(quote!()),
        1 => {
            let index_ts = index_ts_fields[0];
            Ok(quote!(
                fn raw_index_ts(&self) -> u64 {
                    self.#index_ts
                }
            ))
        }
        _ => Err(syn::Error::new(
            input_struct.span(),
            "Only one field can be marked index_ts",
        )),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_args_single() {
        let input = quote!(rtype::MBO);
        let args = syn::parse2::<Args>(input).unwrap();
        assert_eq!(args.args.len(), 1);
    }

    #[test]
    fn parse_args_multiple() {
        let input = quote!(rtype::MBO, rtype::OHLC);
        let args = syn::parse2::<Args>(input).unwrap();
        assert_eq!(args.args.len(), 2);
    }

    #[test]
    fn parse_args_empty() {
        let input = quote!();
        let args = syn::parse2::<Args>(input).unwrap();
        assert!(args.args.is_empty());
    }
}