use std::convert::identity;
use forestrie_builder::{Builder, IgnoreAsciiCase};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
Error, Expr, Lit, Token,
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned,
};
struct Match {
key: Expr,
arms: Punctuated<Arm, Token![,]>,
}
impl Parse for Match {
fn parse(input: ParseStream) -> syn::Result<Self> {
let key = input.parse()?;
let _: Token![,] = input.parse()?;
let arms = input.parse_terminated(Arm::parse, Token![,])?;
Ok(Self { key, arms })
}
}
struct Arm {
key: Key,
val: Expr,
}
impl Parse for Arm {
fn parse(input: ParseStream) -> syn::Result<Self> {
let key = input.parse()?;
let _: Token![=>] = input.parse()?;
let val = input.parse()?;
Ok(Self { key, val })
}
}
struct Key {
span: Span,
kind: KeyKind,
}
enum KeyKind {
Bytes(Vec<u8>),
Fallback,
}
impl Parse for Key {
fn parse(input: ParseStream) -> syn::Result<Self> {
if let Ok(fallback) = input.parse::<Token![_]>() {
Ok(Self {
span: fallback.span(),
kind: KeyKind::Fallback,
})
} else {
let lit: Lit = input.parse()?;
let span = lit.span();
let bytes = match lit {
Lit::Str(s) => s.value().into_bytes(),
Lit::ByteStr(s) => s.value(),
Lit::CStr(s) => s.value().into_bytes_with_nul(),
_ => {
return Err(Error::new(
span,
"currently only string literals, byte string literals and C-string literals are supported",
));
}
};
Ok(Self {
span,
kind: KeyKind::Bytes(bytes),
})
}
}
}
fn match_impl<K: forestrie_builder::Key>(
data: Match,
map_key: fn(u8) -> K,
) -> syn::Result<TokenStream> {
if data.arms.is_empty() {
return Err(Error::new(Span::call_site(), "no match arms"));
}
let mut builder = Builder::new();
let mut fallback = None;
for arm in data.arms {
if let Some(_fallback) = fallback {
return Err(Error::new(arm.key.span, "unreachable match arm"));
};
match arm.key.kind {
KeyKind::Bytes(b) => {
if builder
.insert(b.into_iter().map(map_key), arm.val)
.is_some()
{
return Err(Error::new(arm.key.span, "duplicate match pattern"));
}
}
KeyKind::Fallback => fallback = Some(arm.val),
};
}
let fallback =
fallback.ok_or_else(|| Error::new(Span::call_site(), "missing fallback case `_`"))?;
let key = data.key;
Ok(builder.build(
quote! { ::std::convert::AsRef::<[u8]>::as_ref(#key) },
fallback,
))
}
#[proc_macro]
pub fn match_exact(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let data: Match = parse_macro_input!(input);
match match_impl(data, identity) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro]
pub fn match_ignore_ascii_case(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let data: Match = parse_macro_input!(input);
match match_impl(data, IgnoreAsciiCase::new) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}