use std::convert::identity;
use forestrie_builder::{Builder, IgnoreAsciiCase};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
Error, Expr, Lit, Token, braced,
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned,
};
struct Match {
key: Expr,
arms: Punctuated<Arm, Token![,]>,
}
impl Parse for Match {
fn parse(input: ParseStream) -> syn::Result<Self> {
let _: Token![match] = input.parse()?;
let key = Expr::parse_without_eager_brace(input)?;
let content;
braced!(content in input);
let arms = Punctuated::parse_terminated(&content)?;
Ok(Self { key, arms })
}
}
struct Arm {
pattern: Pattern,
val: Expr,
}
impl Parse for Arm {
fn parse(input: ParseStream) -> syn::Result<Self> {
let pattern = input.parse()?;
let _: Token![=>] = input.parse()?;
let val = input.parse()?;
Ok(Self { pattern, val })
}
}
struct Pattern {
keys: Punctuated<Key, Token![|]>,
}
impl Parse for Pattern {
fn parse(input: ParseStream) -> syn::Result<Self> {
let keys = Punctuated::parse_separated_nonempty(input)?;
Ok(Self { keys })
}
}
struct Key {
span: Span,
kind: KeyKind,
}
enum KeyKind {
Bytes(Vec<u8>),
Fallback,
}
impl Parse for Key {
fn parse(input: ParseStream) -> syn::Result<Self> {
if let Ok(fallback) = input.parse::<Token![_]>() {
Ok(Self {
span: fallback.span(),
kind: KeyKind::Fallback,
})
} else {
let lit: Lit = input.parse()?;
let span = lit.span();
let bytes = match lit {
Lit::Str(s) => s.value().into_bytes(),
Lit::ByteStr(s) => s.value(),
Lit::CStr(s) => s.value().into_bytes_with_nul(),
_ => {
return Err(Error::new(
span,
"currently only string, byte string and C-string literals are supported",
));
}
};
Ok(Self {
span,
kind: KeyKind::Bytes(bytes),
})
}
}
}
fn match_impl<K: forestrie_builder::Key>(
data: Match,
map_key: fn(u8) -> K,
) -> syn::Result<TokenStream> {
if data.arms.is_empty() {
return Err(Error::new(Span::call_site(), "no match arms"));
}
let mut builder = Builder::new();
let mut fallback = None;
for arm in &data.arms {
for key in &arm.pattern.keys {
if let Some(_fallback) = fallback {
return Err(Error::new(key.span, "unreachable match arm"));
};
match &key.kind {
KeyKind::Bytes(b) => {
if builder
.insert(b.iter().copied().map(map_key), &arm.val)
.is_some()
{
return Err(Error::new(key.span, "duplicate match pattern"));
}
}
KeyKind::Fallback => fallback = Some(&arm.val),
};
}
}
let fallback =
fallback.ok_or_else(|| Error::new(Span::call_site(), "missing fallback case `_`"))?;
let key = data.key;
Ok(builder.build(
quote! { ::core::convert::AsRef::<[u8]>::as_ref(#key) },
fallback,
))
}
#[proc_macro]
pub fn exact(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let data: Match = parse_macro_input!(input);
match match_impl(data, identity) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro]
pub fn ignore_ascii_case(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let data: Match = parse_macro_input!(input);
match match_impl(data, IgnoreAsciiCase::new) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}