tea_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{format_ident, quote};
6use syn::{Data, DeriveInput, FnArg, ItemFn, ReturnType, parse_macro_input, parse_quote};
7
8/// Parses the parameters of a function signature.
9///
10/// This function extracts the patterns from the function's input arguments,
11/// filtering out any `self` parameters.
12///
13/// # Arguments
14///
15/// * `sig` - A reference to the function signature to parse.
16///
17/// # Returns
18///
19/// A vector of boxed patterns representing the function's parameters.
20#[allow(clippy::vec_box)]
21pub(crate) fn parse_params(sig: &syn::Signature) -> Vec<Box<syn::Pat>> {
22    sig.inputs
23        .iter()
24        .filter_map(|arg| {
25            if let FnArg::Typed(pat_type) = arg {
26                Some(pat_type.pat.clone())
27            } else {
28                None
29            }
30        })
31        .collect()
32}
33
34/// Transforms a function to handle optional output.
35///
36/// This function takes a function and modifies it to return an `Option` of its original return type.
37/// It also creates a wrapper function that calls the modified function and unwraps the result.
38///
39/// # Arguments
40///
41/// * `attr` - The attributes applied to the function.
42/// * `func` - The original function to transform.
43///
44/// # Returns
45///
46/// A `TokenStream2` containing the modified function and its wrapper.
47fn no_output_transform(_attr: TokenStream, func: ItemFn) -> TokenStream2 {
48    // let attr: TokenStream2 = attr.into();
49    let mut fn_sig = func.sig;
50    let fn_block = func.block;
51    let fn_attrs = func.attrs;
52    let fn_vis = func.vis;
53
54    let mut new_sig = fn_sig.clone();
55    // change return type of original function
56    let ori_output = new_sig.output.clone();
57    let output_type = match &ori_output {
58        ReturnType::Type(_, ty) => quote! { Option<#ty> },
59        _ => quote! { () },
60    };
61    fn_sig.output = parse_quote! { -> #output_type};
62
63    let ori_func_name = format_ident!("{}_to", fn_sig.ident);
64    fn_sig.ident = ori_func_name.clone();
65    // remove out parameter from new function
66    new_sig.inputs = new_sig
67        .inputs
68        .into_iter()
69        .filter(|arg| {
70            if let FnArg::Typed(pat_type) = arg {
71                if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
72                    return pat_ident.ident != "out";
73                }
74            }
75            true
76        })
77        .collect();
78    let params = parse_params(&new_sig);
79
80    // Filter out #[inline] attribute from fn_attrs
81    let filtered_fn_attrs: Vec<_> = fn_attrs
82        .iter()
83        .filter(|attr| !attr.path().is_ident("inline"))
84        .collect();
85
86    quote! {
87        #(#fn_attrs)*
88        #fn_vis #fn_sig #fn_block
89
90        #[inline]
91        #(#filtered_fn_attrs)*
92        #fn_vis #new_sig {
93            self.#ori_func_name(#(#params,)* None).unwrap()
94        }
95    }
96}
97
98/// Procedural macro to transform a function to handle optional output.
99///
100/// This macro modifies the function to return an `Option` of its original return type
101/// and creates a wrapper function that calls the modified function and unwraps the result.
102#[proc_macro_attribute]
103pub fn no_out(attr: TokenStream, input: TokenStream) -> TokenStream {
104    let input_fn = parse_macro_input!(input as ItemFn);
105    let out = no_output_transform(attr, input_fn);
106    TokenStream::from(out)
107}
108
109/// Procedural macro to derive the `GetDtype` trait for enums.
110///
111/// This macro generates an implementation of the `GetDtype` trait for an enum,
112/// providing a `dtype` method that returns the appropriate `DataType` for each variant.
113#[proc_macro_derive(GetDtype)]
114pub fn derive_get_data_type(input: TokenStream) -> TokenStream {
115    let input = parse_macro_input!(input as DeriveInput);
116    let name = input.ident;
117    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
118
119    let data_type_impls = if let Data::Enum(data_enum) = input.data {
120        data_enum.variants.into_iter().map(|variant| {
121            let ident = &variant.ident;
122            match ident.to_string().as_str() {
123                "DateTimeS" => quote! {Self::#ident(_) => DataType::DateTime(TimeUnit::Second),},
124                "DateTimeMs" => {
125                    quote! {Self::#ident(_) => DataType::DateTime(TimeUnit::Millisecond),}
126                },
127                "DateTimeUs" => {
128                    quote! {Self::#ident(_) => DataType::DateTime(TimeUnit::Microsecond),}
129                },
130                "DateTimeNs" => {
131                    quote! {Self::#ident(_) => DataType::DateTime(TimeUnit::Nanosecond),}
132                },
133                _ => quote! { Self::#ident(_) => DataType::#ident,},
134            }
135        })
136    } else {
137        panic!("GetDtype can only be derived for enums");
138    };
139
140    let expanded = quote! {
141        impl #impl_generics GetDtype for #name #ty_generics #where_clause {
142            fn dtype(&self) -> DataType
143            {
144                match self {
145                    #(#data_type_impls)*
146                }
147            }
148        }
149    };
150
151    TokenStream::from(expanded)
152}