use core::{fmt, fmt::Display, mem, str::FromStr};
use std::{
borrow::Cow,
collections::{btree_map::Entry, BTreeMap},
rc::Rc,
sync::Arc,
};
use proc_macro2::{Span, TokenStream};
use syn::{
ext::IdentExt as _,
parse::{Parse, ParseStream, Parser},
Attribute, Ident, LitBool, LitStr, Token,
};
#[derive(Default)]
pub struct Attrs<'a> {
map: BTreeMap<Ident, Attr<'a>>,
#[expect(clippy::type_complexity)]
fallback: Option<Box<dyn 'a + FnMut(&Ident, ParseStream<'_>) -> syn::Result<()>>>,
}
impl fmt::Debug for Attrs<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Attrs")
.field("map", &self.map)
.field(
"fallback",
&match self.fallback {
Some(_) => "Some(..)",
None => "None",
},
)
.finish()
}
}
impl<'a> Attrs<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn contains<Q>(&self, key: &Q) -> bool
where
Q: ?Sized,
Ident: PartialEq<Q>,
{
self.map.keys().any(|it| it == key)
}
#[track_caller]
pub fn once<K, F>(&mut self, key: K, f: F) -> &mut Self
where
K: UnwrapIdent,
F: 'a + FnOnce(ParseStream<'_>) -> syn::Result<()>,
{
self.insert(key, Attr::Once(Once::Some(Box::new(f))))
}
#[track_caller]
pub fn many<K, F>(&mut self, key: K, f: F) -> &mut Self
where
K: UnwrapIdent,
F: 'a + FnMut(ParseStream<'_>) -> syn::Result<()>,
{
self.insert(key, Attr::Many(Box::new(f)))
}
pub fn fallback<F>(&mut self, f: F) -> &mut Self
where
F: 'a + FnMut(&Ident, ParseStream<'_>) -> syn::Result<()>,
{
self.fallback = Some(Box::new(f));
self
}
#[track_caller]
pub fn alias<A, K>(&mut self, alias: A, key: K) -> &mut Self
where
A: UnwrapIdent,
K: UnwrapIdent,
{
let key = key.unwrap_ident();
assert!(
self.contains(&key),
"`{key}` is not registered (aliases may only be registered after their destination)"
);
self.insert(alias, Attr::AliasFor(key.unwrap_ident()))
}
pub fn parse_attrs<Q>(&mut self, path: &Q, attrs: &[Attribute]) -> syn::Result<()>
where
Q: ?Sized,
Ident: PartialEq<Q>,
{
for attr in attrs {
if attr.path().is_ident(path) {
attr.parse_args_with(&mut *self)?
}
}
Ok(())
}
pub fn extract_from<Q>(&mut self, path: &Q, attrs: &mut Vec<Attribute>) -> syn::Result<()>
where
Q: ?Sized,
Ident: PartialEq<Q>,
{
let mut e = None;
attrs.retain(|attr| match attr.path().is_ident(path) {
true => {
match (e.as_mut(), attr.parse_args_with(&mut *self)) {
(_, Ok(())) => {}
(None, Err(e2)) => e = Some(e2),
(Some(e1), Err(e2)) => e1.combine(e2),
}
false }
false => true, });
e.map(Err).unwrap_or(Ok(()))
}
fn _parse(&mut self, input: ParseStream<'_>) -> syn::Result<()> {
let msg = Phrase {
many: "Expected one of",
one: "Expected",
none: match &self.fallback {
Some(_) => "No explicit arguments specified",
None => "No arguments accepted",
},
conjunction: "or",
iter: self
.map
.iter()
.filter_map(|(k, v)| match v {
Attr::AliasFor(_) => None,
Attr::Once(_) | Attr::Many(_) => Some(k.clone()),
})
.collect::<Vec<_>>(),
};
loop {
if input.is_empty() {
break;
}
match input.call(Ident::parse_any) {
Ok(it) => {
let mut key = it.unraw();
loop {
break match (self.map.get_mut(&key), &mut self.fallback) {
(Some(attr), _) => match attr {
Attr::AliasFor(redirect) => {
key = redirect.clone();
continue;
}
Attr::Once(once) => {
match mem::replace(once, Once::Already(it.span())) {
Once::Some(f) => f(input)?,
Once::Already(already) => {
let mut e =
syn::Error::new(it.span(), "Duplicate argument");
e.combine(syn::Error::new(
already,
"Already used here",
));
return Err(e);
}
}
}
Attr::Many(f) => f(input)?,
},
(None, Some(fallback)) => match fallback(&key, input) {
Ok(()) => {}
Err(mut e) => {
e.combine(syn::Error::new(e.span(), msg));
return Err(e);
}
},
(None, None) => return Err(syn::Error::new(it.span(), msg)),
};
}
}
Err(mut e) => {
e.combine(syn::Error::new(e.span(), msg));
return Err(e);
}
}
if input.is_empty() {
break;
}
input.parse::<Token![,]>()?;
}
Ok(())
}
#[track_caller]
fn insert(&mut self, key: impl UnwrapIdent, val: Attr<'a>) -> &mut Self {
match self.map.entry(key.unwrap_ident()) {
Entry::Vacant(it) => it.insert(val),
Entry::Occupied(it) => panic!("duplicate entry for key `{}`", it.key()),
};
self
}
fn into_parser(mut self) -> impl FnMut(ParseStream<'_>) -> syn::Result<()> + use<'a> {
move |input| self._parse(input)
}
fn as_parser(&mut self) -> impl FnMut(ParseStream<'_>) -> syn::Result<()> + use<'_, 'a> {
|input| self._parse(input)
}
}
enum Attr<'a> {
AliasFor(Ident),
Once(Once<'a>),
Many(Box<dyn 'a + FnMut(ParseStream<'_>) -> syn::Result<()>>),
}
impl fmt::Debug for Attr<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AliasFor(it) => f.debug_tuple("AliasFor").field(it).finish(),
Self::Once(it) => f.debug_tuple("Once").field(it).finish(),
Self::Many(_) => f.debug_tuple("Many").finish_non_exhaustive(),
}
}
}
enum Once<'a> {
Some(Box<dyn 'a + FnOnce(ParseStream<'_>) -> syn::Result<()>>),
Already(Span),
}
impl fmt::Debug for Once<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Some(_) => f.debug_tuple("Some").finish_non_exhaustive(),
Self::Already(span) => f.debug_tuple("Already").field(span).finish(),
}
}
}
impl Parser for &mut Attrs<'_> {
type Output = ();
fn parse2(self, tokens: TokenStream) -> syn::Result<Self::Output> {
self.as_parser().parse2(tokens)
}
}
impl Parser for Attrs<'_> {
type Output = ();
fn parse2(self, tokens: TokenStream) -> syn::Result<Self::Output> {
self.into_parser().parse2(tokens)
}
}
#[test]
fn test() {
use quote::quote;
use syn::{punctuated::Punctuated, *};
strum_lite::strum! {
#[derive(PartialEq, Debug)]
enum Casing {
Pascal = "PascalCase",
Snake = "snake_case",
}
}
let mut casing = Casing::Snake;
let mut vis = Visibility::Inherited;
let mut opt_pred = None::<WherePredicate>;
let mut use_unsafe = false;
let mut aliases = vec![];
Attrs::new()
.once("rename_all", with::eq(on::from_str(&mut casing)))
.once("vis", with::peq(on::parse(&mut vis)))
.once("use_unsafe", set::flag(&mut use_unsafe))
.once("where", with::paren(set::parse(&mut opt_pred)))
.many(
"alias",
with::paren(|input| {
aliases.extend(Punctuated::<LitStr, Token![,]>::parse_separated_nonempty(
input,
)?);
Ok(())
}),
)
.parse2(quote! {
rename_all = "PascalCase",
vis = pub,
use_unsafe,
where(T: Ord),
alias("hello", "world"),
alias("goodbye")
})
.unwrap();
assert_eq!(casing, Casing::Pascal);
assert!(matches!(vis, Visibility::Public(_)));
assert!(opt_pred.is_some());
assert!(use_unsafe);
assert_eq!(aliases.len(), 3);
}
pub trait UnwrapIdent {
#[track_caller]
fn unwrap_ident(&self) -> Ident;
}
impl UnwrapIdent for str {
#[track_caller]
fn unwrap_ident(&self) -> Ident {
Ident::new(self, Span::call_site())
}
}
impl UnwrapIdent for String {
#[track_caller]
fn unwrap_ident(&self) -> Ident {
<str>::unwrap_ident(self)
}
}
impl UnwrapIdent for Cow<'_, str> {
#[track_caller]
fn unwrap_ident(&self) -> Ident {
<str>::unwrap_ident(self)
}
}
impl UnwrapIdent for Ident {
#[track_caller]
fn unwrap_ident(&self) -> Ident {
self.clone()
}
}
impl<T: UnwrapIdent + ?Sized> UnwrapIdent for &T {
#[track_caller]
fn unwrap_ident(&self) -> Ident {
T::unwrap_ident(self)
}
}
impl<T: UnwrapIdent + ?Sized> UnwrapIdent for Box<T> {
#[track_caller]
fn unwrap_ident(&self) -> Ident {
T::unwrap_ident(self)
}
}
impl<T: UnwrapIdent + ?Sized> UnwrapIdent for Rc<T> {
#[track_caller]
fn unwrap_ident(&self) -> Ident {
T::unwrap_ident(self)
}
}
impl<T: UnwrapIdent + ?Sized> UnwrapIdent for Arc<T> {
#[track_caller]
fn unwrap_ident(&self) -> Ident {
T::unwrap_ident(self)
}
}
pub mod with {
use syn::{
braced, bracketed, parenthesized,
parse::{discouraged::AnyDelimiter, ParseStream},
token, Token,
};
pub fn eq<'a, F>(mut f: F) -> impl 'a + FnMut(ParseStream<'_>) -> syn::Result<()>
where
F: 'a + FnMut(ParseStream<'_>) -> syn::Result<()>,
{
move |input| {
input.parse::<Token![=]>()?;
f(input)
}
}
pub fn paren<'a, F>(mut f: F) -> impl 'a + FnMut(ParseStream<'_>) -> syn::Result<()>
where
F: 'a + FnMut(ParseStream<'_>) -> syn::Result<()>,
{
move |input| {
let content;
parenthesized!(content in input);
f(&content)
}
}
pub fn bracket<'a, F>(mut f: F) -> impl 'a + FnMut(ParseStream<'_>) -> syn::Result<()>
where
F: 'a + FnMut(ParseStream<'_>) -> syn::Result<()>,
{
move |input| {
let content;
bracketed!(content in input);
f(&content)
}
}
pub fn brace<'a, F>(mut f: F) -> impl 'a + FnMut(ParseStream<'_>) -> syn::Result<()>
where
F: 'a + FnMut(ParseStream<'_>) -> syn::Result<()>,
{
move |input| {
let content;
braced!(content in input);
f(&content)
}
}
pub fn delim<'a, F>(mut f: F) -> impl 'a + FnMut(ParseStream<'_>) -> syn::Result<()>
where
F: 'a + FnMut(ParseStream<'_>) -> syn::Result<()>,
{
move |input| {
let (_, _, content) = input.parse_any_delimiter()?;
f(&content)
}
}
pub fn peq<'a, F>(mut f: F) -> impl 'a + FnMut(ParseStream<'_>) -> syn::Result<()>
where
F: 'a + FnMut(ParseStream<'_>) -> syn::Result<()>,
{
move |input| {
if input.peek(Token![=]) {
input.parse::<Token![=]>()?;
f(input)
} else if input.peek(token::Paren) {
let content;
parenthesized!(content in input);
f(&content)
} else {
Err(input.error("Expected a `=` or `(..)`"))
}
}
}
}
pub mod set {
use super::*;
#[deprecated = "use `flag::free` instead"]
pub use flag::free as flag;
#[deprecated = "Use `set::lit` instead"]
pub fn bool(dst: &mut Option<bool>) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::set::lit(dst, input)
}
pub fn parse<T: Parse>(
dst: &mut Option<T>,
) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::set::parse(dst, input)
}
pub fn from_str<T: FromStr>(
dst: &mut Option<T>,
) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()>
where
T::Err: Display,
{
|input| parse::set::from_str(dst, input)
}
pub fn parse_str<T: Parse>(
dst: &mut Option<T>,
) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::set::parse_str(dst, input)
}
pub fn maybe_str<T: Parse>(
dst: &mut Option<T>,
) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::set::maybe_str(dst, input)
}
pub fn lit<T: Lit>(dst: &mut Option<T>) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::set::lit(dst, input)
}
}
pub mod on {
use super::*;
#[deprecated = "Use `on::lit` instead"]
pub fn bool(dst: &mut bool) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::lit(dst, input)
}
pub fn parse<T: Parse>(dst: &mut T) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::parse(dst, input)
}
pub fn from_str<T: FromStr>(dst: &mut T) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()>
where
T::Err: Display,
{
|input| parse::from_str(dst, input)
}
pub fn parse_str<T: Parse>(dst: &mut T) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::parse_str(dst, input)
}
pub fn maybe_str<T: Parse>(dst: &mut T) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::maybe_str(dst, input)
}
pub fn lit<T: Lit>(dst: &mut T) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| parse::lit(dst, input)
}
}
pub mod flag {
use syn::token;
use super::*;
pub fn free(dst: &mut bool) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|_| {
*dst = true;
Ok(())
}
}
pub fn or_eq(dst: &mut bool) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| match input.peek(Token![=]) {
true => with::eq(on::lit(dst))(input),
false => free(dst)(input),
}
}
pub fn or_paren(dst: &mut bool) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| match input.peek(token::Paren) {
true => with::paren(on::lit(dst))(input),
false => free(dst)(input),
}
}
pub fn or_peq(dst: &mut bool) -> impl '_ + FnMut(ParseStream<'_>) -> syn::Result<()> {
|input| match input.peek(Token![=]) || input.peek(token::Paren) {
true => with::peq(on::lit(dst))(input),
false => free(dst)(input),
}
}
}
pub mod parse {
use super::*;
#[deprecated = "Use `parse::lit` instead"]
pub fn bool(dst: &mut bool, input: ParseStream<'_>) -> syn::Result<()> {
*dst = input.parse::<LitBool>()?.value;
Ok(())
}
pub fn parse<T: Parse>(dst: &mut T, input: ParseStream<'_>) -> syn::Result<()> {
*dst = input.parse()?;
Ok(())
}
pub fn from_str<T: FromStr>(dst: &mut T, input: ParseStream<'_>) -> syn::Result<()>
where
T::Err: Display,
{
let lit_str = input.parse::<LitStr>()?;
match lit_str.value().parse() {
Ok(it) => {
*dst = it;
Ok(())
}
Err(e) => Err(syn::Error::new(lit_str.span(), e)),
}
}
pub fn parse_str<T: Parse>(dst: &mut T, input: ParseStream<'_>) -> syn::Result<()> {
let lit_str = input.parse::<LitStr>()?;
*dst = T::parse.parse_str(&lit_str.value())?;
Ok(())
}
pub fn maybe_str<T: Parse>(dst: &mut T, input: ParseStream<'_>) -> syn::Result<()> {
*dst = match input.peek(LitStr) {
true => input.parse::<LitStr>()?.parse()?,
false => input.parse()?,
};
Ok(())
}
pub fn lit<T: Lit>(dst: &mut T, input: ParseStream<'_>) -> syn::Result<()> {
*dst = Lit::parse(input)?;
Ok(())
}
pub fn until_comma(input: ParseStream<'_>) -> syn::Result<TokenStream> {
input.step(|cursor| {
let mut tokens = TokenStream::new();
let mut rest = *cursor;
while let Some((tt, cursor)) = rest.token_tree() {
rest = cursor;
match tt {
proc_macro2::TokenTree::Punct(it) if it.as_char() == ',' => break,
tt => tokens.extend([tt]),
};
}
Ok((tokens, rest))
})
}
pub mod set {
use super::*;
#[deprecated = "Use `parse::set::lit` instead"]
pub fn bool(dst: &mut Option<bool>, input: ParseStream<'_>) -> syn::Result<()> {
*dst = Some(input.parse::<LitBool>()?.value);
Ok(())
}
pub fn parse<T: Parse>(dst: &mut Option<T>, input: ParseStream<'_>) -> syn::Result<()> {
*dst = Some(input.parse()?);
Ok(())
}
pub fn from_str<T: FromStr>(dst: &mut Option<T>, input: ParseStream<'_>) -> syn::Result<()>
where
T::Err: Display,
{
let lit_str = input.parse::<LitStr>()?;
match lit_str.value().parse() {
Ok(it) => {
*dst = Some(it);
Ok(())
}
Err(e) => Err(syn::Error::new(lit_str.span(), e)),
}
}
pub fn parse_str<T: Parse>(dst: &mut Option<T>, input: ParseStream<'_>) -> syn::Result<()> {
*dst = Some(input.parse::<LitStr>()?.parse()?);
Ok(())
}
pub fn maybe_str<T: Parse>(dst: &mut Option<T>, input: ParseStream<'_>) -> syn::Result<()> {
*dst = Some(match input.peek(LitStr) {
true => input.parse::<LitStr>()?.parse()?,
false => input.parse()?,
});
Ok(())
}
pub fn lit<T: Lit>(dst: &mut Option<T>, input: ParseStream<'_>) -> syn::Result<()> {
*dst = Some(Lit::parse(input)?);
Ok(())
}
}
}
pub trait Lit: Sized + sealed::Sealed {
fn parse(input: ParseStream<'_>) -> syn::Result<Self>;
}
mod sealed {
pub trait Sealed {}
macro_rules! sealed {
($($ty:ty),* $(,)?) => {
$(impl Sealed for $ty {})*
};
}
sealed! {
u8, u16, u32, u64, u128, usize,
i8, i16, i32, i64, i128, isize,
f32, f64,
bool,
char,
String,
Vec<u8>,
std::ffi::CString,
}
}
macro_rules! num {
($($via:ty {
$($ty:ty),* $(,)?
} )*) => {
$(
$(
impl Lit for $ty {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let lit = input.parse::<$via>()?;
match lit.suffix() {
"" | stringify!($ty) => lit.base10_parse(),
_ => Err(syn::Error::new(
lit.span(),
concat!("Expected suffix `", stringify!($ty), "`"),
)),
}
}
}
)*
)*
};
}
num! {
syn::LitInt {
u16, u32, u64, u128, usize,
i8, i16, i32, i64, i128, isize,
}
syn::LitFloat {
f32, f64
}
}
impl Lit for u8 {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
match input.parse::<syn::Lit>()? {
syn::Lit::Byte(it) => Ok(it.value()),
syn::Lit::Int(it) => match it.suffix() {
"" | "u8" => it.base10_parse(),
_ => Err(syn::Error::new(it.span(), "Expected suffix `u8`")),
},
other => Err(syn::Error::new(
other.span(),
"Expected a u8 or byte literal",
)),
}
}
}
impl Lit for bool {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Ok(input.parse::<syn::LitBool>()?.value())
}
}
impl Lit for String {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Ok(input.parse::<syn::LitStr>()?.value())
}
}
impl Lit for Vec<u8> {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Ok(input.parse::<syn::LitByteStr>()?.value())
}
}
impl Lit for char {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Ok(input.parse::<syn::LitChar>()?.value())
}
}
impl Lit for std::ffi::CString {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Ok(input.parse::<syn::LitCStr>()?.value())
}
}
#[derive(Clone, Copy)]
struct Phrase<'a, I> {
pub many: &'a str,
pub one: &'a str,
pub none: &'a str,
pub conjunction: &'a str,
pub iter: I,
}
impl<I: Clone + IntoIterator> fmt::Display for Phrase<'_, I>
where
I::Item: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self {
many,
one,
none,
conjunction,
iter,
} = self.clone();
let mut iter = iter.into_iter().peekable();
match iter.next() {
Some(first) => match iter.peek() {
Some(_) => {
f.write_fmt(format_args!("{many} `{first}`"))?;
while let Some(it) = iter.next() {
match iter.peek() {
Some(_) => f.write_fmt(format_args!(", `{it}`"))?,
None => f.write_fmt(format_args!(" {conjunction} `{it}`"))?,
}
}
Ok(())
}
None => f.write_fmt(format_args!("{one} `{first}`")),
},
None => f.write_str(none),
}
}
}