forestrie 0.3.1

Quick string matches
Documentation
use std::convert::identity;

use forestrie_builder::{Builder, IgnoreAsciiCase};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
	Error, Expr, Lit, Token, braced,
	parse::{Parse, ParseStream},
	parse_macro_input,
	punctuated::Punctuated,
	spanned::Spanned,
};

struct Match {
	key: Expr,
	arms: Punctuated<Arm, Token![,]>,
}

impl Parse for Match {
	fn parse(input: ParseStream) -> syn::Result<Self> {
		let _: Token![match] = input.parse()?;
		let key = Expr::parse_without_eager_brace(input)?;
		let content;
		braced!(content in input);
		let arms = Punctuated::parse_terminated(&content)?;
		Ok(Self { key, arms })
	}
}

struct Arm {
	pattern: Pattern,
	val: Expr,
}

impl Parse for Arm {
	fn parse(input: ParseStream) -> syn::Result<Self> {
		let pattern = input.parse()?;
		let _: Token![=>] = input.parse()?;
		let val = input.parse()?;
		Ok(Self { pattern, val })
	}
}

struct Pattern {
	keys: Punctuated<Key, Token![|]>,
}

impl Parse for Pattern {
	fn parse(input: ParseStream) -> syn::Result<Self> {
		let keys = Punctuated::parse_separated_nonempty(input)?;
		Ok(Self { keys })
	}
}

struct Key {
	span: Span,
	kind: KeyKind,
}

enum KeyKind {
	Bytes(Vec<u8>),
	Fallback,
}

impl Parse for Key {
	fn parse(input: ParseStream) -> syn::Result<Self> {
		if let Ok(fallback) = input.parse::<Token![_]>() {
			Ok(Self {
				span: fallback.span(),
				kind: KeyKind::Fallback,
			})
		} else {
			let lit: Lit = input.parse()?;
			let span = lit.span();

			let bytes = match lit {
				Lit::Str(s) => s.value().into_bytes(),
				Lit::ByteStr(s) => s.value(),
				Lit::CStr(s) => s.value().into_bytes_with_nul(),
				_ => {
					return Err(Error::new(
						span,
						"currently only string, byte string and C-string literals are supported",
					));
				}
			};

			Ok(Self {
				span,
				kind: KeyKind::Bytes(bytes),
			})
		}
	}
}

fn match_impl<K: forestrie_builder::Key>(
	data: Match,
	map_key: fn(u8) -> K,
) -> syn::Result<TokenStream> {
	if data.arms.is_empty() {
		return Err(Error::new(Span::call_site(), "no match arms"));
	}

	let mut builder = Builder::new();
	let mut fallback = None;
	for arm in &data.arms {
		for key in &arm.pattern.keys {
			if let Some(_fallback) = fallback {
				return Err(Error::new(key.span, "unreachable match arm"));
			};

			match &key.kind {
				KeyKind::Bytes(b) => {
					if builder
						.insert(b.iter().copied().map(map_key), &arm.val)
						.is_some()
					{
						return Err(Error::new(key.span, "duplicate match pattern"));
					}
				}
				KeyKind::Fallback => fallback = Some(&arm.val),
			};
		}
	}

	let fallback =
		fallback.ok_or_else(|| Error::new(Span::call_site(), "missing fallback case `_`"))?;

	let key = data.key;
	Ok(builder.build(
		quote! { ::core::convert::AsRef::<[u8]>::as_ref(#key) },
		fallback,
	))
}

/// Case-sensitive match
///
/// ```
/// #[derive(Debug, PartialEq)]
/// enum Command {
/// 	Up,
/// 	Down,
/// 	Forward,
/// 	Rotate,
/// }
///
/// impl Command {
/// 	fn parse(s: &str) -> Option<Command> {
/// 		forestrie::exact! {
/// 			match s {
/// 				"up" => Some(Self::Up),
/// 				"down" => Some(Self::Down),
/// 				"forward" | "go" => Some(Self::Forward),
/// 				"rotate" => Some(Self::Rotate),
/// 				_ => None,
/// 			}
/// 		}
/// 	}
/// }
///
/// assert_eq!(Command::parse("up"), Some(Command::Up));
/// assert_eq!(Command::parse("go"), Some(Command::Forward));
/// assert_eq!(Command::parse("Rotate"), None); // check out `ignore_ascii_case` for this
/// ```
#[proc_macro]
pub fn exact(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
	let data: Match = parse_macro_input!(input);
	match match_impl(data, identity) {
		Ok(tokens) => tokens.into(),
		Err(e) => e.to_compile_error().into(),
	}
}

/// ASCII case-insensitive match, similar to [`str::eq_ignore_ascii_case`]
///
/// ```
/// #[derive(Debug, PartialEq)]
/// enum Modifier {
/// 	Escape,
/// 	Meta,
/// 	Alt,
/// 	Control,
/// 	Shift,
/// }
///
/// impl Modifier {
/// 	fn parse(s: &str) -> Option<Modifier> {
/// 		forestrie::ignore_ascii_case! {
/// 			match s {
/// 				"esc" => Some(Self::Escape),
/// 				"meta" | "super" | "logo" => Some(Self::Meta),
/// 				"alt" => Some(Self::Alt),
/// 				"ctrl" => Some(Self::Control),
/// 				"shift" => Some(Self::Shift),
/// 				_ => None,
/// 			}
/// 		}
/// 	}
/// }
///
/// assert_eq!(Modifier::parse("esc"), Some(Modifier::Escape));
/// assert_eq!(Modifier::parse("SUPER"), Some(Modifier::Meta));
/// assert_eq!(Modifier::parse("Ctrl"), Some(Modifier::Control));
/// assert_eq!(Modifier::parse("tab"), None);
/// ```
#[proc_macro]
pub fn ignore_ascii_case(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
	let data: Match = parse_macro_input!(input);
	match match_impl(data, IgnoreAsciiCase::new) {
		Ok(tokens) => tokens.into(),
		Err(e) => e.to_compile_error().into(),
	}
}