use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{Ident, LitInt, LitStr, Token, parse_macro_input};
#[derive(Debug, Clone)]
pub enum ArgCount {
Exact(usize),
Any,
Multiple(Vec<usize>),
}
impl Parse for ArgCount {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(syn::Ident) {
let ident: Ident = input.parse()?;
if ident == "Any" {
Ok(ArgCount::Any)
} else {
Err(syn::Error::new(ident.span(), "Expected 'Any' or a number"))
}
} else if lookahead.peek(LitInt) {
let first: LitInt = input.parse()?;
let first_val = first.base10_parse::<usize>()?;
let mut values = vec![first_val];
while input.peek(Token![|]) {
input.parse::<Token![|]>()?;
let next: LitInt = input.parse()?;
values.push(next.base10_parse()?);
}
if values.len() == 1 {
Ok(ArgCount::Exact(values[0]))
} else {
Ok(ArgCount::Multiple(values))
}
} else {
Err(lookahead.error())
}
}
}
struct SqlFunctionInput {
struct_name: Ident,
sql_name: String,
arg_count: ArgCount,
}
impl Parse for SqlFunctionInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let struct_name: Ident = input.parse()?;
input.parse::<Token![;]>()?;
let sql_name_lit: LitStr = input.parse()?;
let sql_name = sql_name_lit.value();
input.parse::<Token![;]>()?;
let arg_count: ArgCount = input.parse()?;
if !input.is_empty() {
input.parse::<Token![;]>()?;
}
Ok(SqlFunctionInput {
struct_name,
sql_name,
arg_count,
})
}
}
pub fn custom_sql_function_impl(input: TokenStream) -> TokenStream {
let SqlFunctionInput {
struct_name,
sql_name,
arg_count,
} = parse_macro_input!(input as SqlFunctionInput);
let function_name = struct_name.to_string().to_lowercase();
let macro_name = Ident::new(&function_name, struct_name.span());
let sql_name_lit = LitStr::new(&sql_name, struct_name.span());
let arg_error_msg = match &arg_count {
ArgCount::Exact(count) => format!("{} expects exactly {} argument(s)", sql_name, count),
ArgCount::Multiple(values) => format!(
"{} expects {} argument(s)",
sql_name,
values
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(" or ")
),
ArgCount::Any => String::new(),
};
let build_arm_for_count = |count: usize| {
if count == 0 {
quote! {
() => ( #sql_name_lit );
}
} else {
let arg_idents: Vec<Ident> = (0..count)
.map(|idx| Ident::new(&format!("arg{}", idx), struct_name.span()))
.collect();
let arg_patterns = arg_idents.iter().map(|ident| quote! { $ #ident : expr });
quote! {
(#(#arg_patterns),* $(,)?) => ( #sql_name_lit );
}
}
};
let macro_arms = match &arg_count {
ArgCount::Exact(count) => {
let arm = build_arm_for_count(*count);
quote! {
#arm
($($args:expr),* $(,)?) => ( compile_error!(#arg_error_msg) );
}
}
ArgCount::Multiple(values) => {
let arms = values.iter().map(|count| build_arm_for_count(*count));
quote! {
#(#arms)*
($($args:expr),* $(,)?) => ( compile_error!(#arg_error_msg) );
}
}
ArgCount::Any => {
quote! {
($($args:expr),* $(,)?) => ( #sql_name_lit );
}
}
};
let output = quote! {
#[macro_export]
macro_rules! #macro_name {
#macro_arms
}
};
output.into()
}