forestrie 0.2.0

Quick string matches
Documentation
use std::convert::identity;

use forestrie_builder::{Builder, IgnoreAsciiCase};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
	Error, Expr, Lit, Token,
	parse::{Parse, ParseStream},
	parse_macro_input,
	punctuated::Punctuated,
	spanned::Spanned,
};

struct Match {
	key: Expr,
	arms: Punctuated<Arm, Token![,]>,
}

impl Parse for Match {
	fn parse(input: ParseStream) -> syn::Result<Self> {
		let key = input.parse()?;
		let _: Token![,] = input.parse()?;
		let arms = input.parse_terminated(Arm::parse, Token![,])?;
		Ok(Self { key, arms })
	}
}

struct Arm {
	key: Key,
	val: Expr,
}

impl Parse for Arm {
	fn parse(input: ParseStream) -> syn::Result<Self> {
		let key = input.parse()?;
		let _: Token![=>] = input.parse()?;
		let val = input.parse()?;
		Ok(Self { key, val })
	}
}

struct Key {
	span: Span,
	kind: KeyKind,
}

enum KeyKind {
	Bytes(Vec<u8>),
	Fallback,
}

impl Parse for Key {
	fn parse(input: ParseStream) -> syn::Result<Self> {
		if let Ok(fallback) = input.parse::<Token![_]>() {
			Ok(Self {
				span: fallback.span(),
				kind: KeyKind::Fallback,
			})
		} else {
			let lit: Lit = input.parse()?;
			let span = lit.span();

			let bytes = match lit {
				Lit::Str(s) => s.value().into_bytes(),
				Lit::ByteStr(s) => s.value(),
				Lit::CStr(s) => s.value().into_bytes_with_nul(),
				_ => {
					return Err(Error::new(
						span,
						"currently only string literals, byte string literals and C-string literals are supported",
					));
				}
			};

			Ok(Self {
				span,
				kind: KeyKind::Bytes(bytes),
			})
		}
	}
}

fn match_impl<K: forestrie_builder::Key>(
	data: Match,
	map_key: fn(u8) -> K,
) -> syn::Result<TokenStream> {
	if data.arms.is_empty() {
		return Err(Error::new(Span::call_site(), "no match arms"));
	}

	let mut builder = Builder::new();
	let mut fallback = None;
	for arm in data.arms {
		if let Some(_fallback) = fallback {
			return Err(Error::new(arm.key.span, "unreachable match arm"));
		};

		match arm.key.kind {
			KeyKind::Bytes(b) => {
				if builder
					.insert(b.into_iter().map(map_key), arm.val)
					.is_some()
				{
					return Err(Error::new(arm.key.span, "duplicate match pattern"));
				}
			}
			KeyKind::Fallback => fallback = Some(arm.val),
		};
	}

	let fallback =
		fallback.ok_or_else(|| Error::new(Span::call_site(), "missing fallback case `_`"))?;

	let key = data.key;
	Ok(builder.build(
		quote! { ::std::convert::AsRef::<[u8]>::as_ref(#key) },
		fallback,
	))
}

#[proc_macro]
pub fn match_exact(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
	let data: Match = parse_macro_input!(input);
	match match_impl(data, identity) {
		Ok(tokens) => tokens.into(),
		Err(e) => e.to_compile_error().into(),
	}
}

#[proc_macro]
pub fn match_ignore_ascii_case(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
	let data: Match = parse_macro_input!(input);
	match match_impl(data, IgnoreAsciiCase::new) {
		Ok(tokens) => tokens.into(),
		Err(e) => e.to_compile_error().into(),
	}
}