use std::ops::Bound;
use proc_macro2::{Ident, Span, TokenStream};
use quote::ToTokens as _;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{Attribute, Error, Expr, Result, Token};
pub use self::input::*;
use crate::util::error_sink::ErrorSink;
use crate::util::interval::Interval;
mod input;
pub mod helper_attrs;
pub mod args {
use quote::ToTokens;
use syn::LitInt;
use syn::spanned::Spanned;
use super::*;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum RenameSource {
Version(usize),
Enum,
Base,
}
#[derive(Debug, Clone)]
pub enum MacroArg {
Start(usize),
End(usize),
Enum,
Base,
Rename(RenameSource, Ident),
}
enum StartEnd {
Start(kw::start),
End(kw::end),
}
impl StartEnd {
fn variant(self) -> fn(usize) -> MacroArg {
match self {
Self::Start(_) => MacroArg::Start,
Self::End(_) => MacroArg::End,
}
}
}
enum ParsedRenameSource {
Version(LitInt),
Enum(Token![enum]),
Base(kw::base),
}
enum ParsedMacroArg {
StartEnd(StartEnd, Token![=], LitInt),
Enum(Token![enum]),
Base(kw::base),
Rename(kw::rename, ParsedRenameSource, Token![=>], Ident),
Invalid(TokenStream, Error),
}
impl ParsedMacroArg {
fn try_into_arg(self) -> Result<MacroArg> {
Ok(match self {
ParsedMacroArg::StartEnd(kind, _, lit) => kind.variant()(lit.base10_parse()?),
ParsedMacroArg::Enum(_) => MacroArg::Enum,
ParsedMacroArg::Base(_) => MacroArg::Base,
ParsedMacroArg::Rename(_, src, _, ident) => MacroArg::Rename(
match src {
ParsedRenameSource::Version(lit) => {
RenameSource::Version(lit.base10_parse()?)
}
ParsedRenameSource::Enum(_) => RenameSource::Enum,
ParsedRenameSource::Base(_) => RenameSource::Base,
},
ident,
),
ParsedMacroArg::Invalid(t, e) => {
return Err(Error::new_spanned(t, format!("invalid argument; {e}")));
}
})
}
}
mod kw {
syn::custom_keyword!(start);
syn::custom_keyword!(end);
syn::custom_keyword!(base);
syn::custom_keyword!(rename);
}
impl Parse for StartEnd {
fn parse(input: ParseStream) -> Result<Self> {
input
.parse::<kw::start>()
.map(Self::Start)
.or_else(|_| input.parse::<kw::end>().map(Self::End))
}
}
impl ToTokens for StartEnd {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
StartEnd::Start(t) => t.to_tokens(tokens),
StartEnd::End(t) => t.to_tokens(tokens),
}
}
}
impl Parse for ParsedRenameSource {
fn parse(input: ParseStream) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(Token![enum]) {
input.parse().map(Self::Enum)
} else if lookahead.peek(kw::base) {
input.parse().map(Self::Base)
} else if lookahead.peek(LitInt) {
input.parse().map(Self::Version)
} else {
Err(lookahead.error())
}
}
}
impl ToTokens for ParsedRenameSource {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
ParsedRenameSource::Version(lit) => lit.to_tokens(tokens),
ParsedRenameSource::Enum(kw) => kw.to_tokens(tokens),
ParsedRenameSource::Base(kw) => kw.to_tokens(tokens),
}
}
}
impl Parse for ParsedMacroArg {
fn parse(input: ParseStream) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(kw::start) || lookahead.peek(kw::end) {
let kind = input.parse()?;
Ok(Self::StartEnd(kind, input.parse()?, input.parse()?))
} else if lookahead.peek(Token![enum]) {
input.parse().map(Self::Enum)
} else if lookahead.peek(kw::base) {
input.parse().map(Self::Base)
} else if lookahead.peek(kw::rename) {
let kw = input.parse()?;
let inner;
syn::parenthesized!(inner in input);
Ok(Self::Rename(
kw,
inner.parse()?,
inner.parse()?,
inner.parse()?,
))
} else {
let mut ts = TokenStream::new();
input
.step(|cursor| {
let mut rest = *cursor;
ts.extend(
std::iter::from_fn(|| {
let (tt, next) = rest.token_tree()?;
match &tt {
proc_macro2::TokenTree::Punct(p) if p.as_char() == ',' => {
return None;
}
_ => rest = next,
}
Some(tt)
})
.fuse(),
);
Ok(((), rest))
})
.unwrap_or_else(|_| unreachable!());
Ok(Self::Invalid(ts, lookahead.error()))
}
}
}
impl ToTokens for ParsedMacroArg {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
ParsedMacroArg::StartEnd(t1, t2, t3) => {
t1.to_tokens(tokens);
t2.to_tokens(tokens);
t3.to_tokens(tokens);
}
ParsedMacroArg::Enum(t) => t.to_tokens(tokens),
ParsedMacroArg::Base(t) => t.to_tokens(tokens),
ParsedMacroArg::Rename(t1, t2, t3, t4) => {
t1.to_tokens(tokens);
t2.to_tokens(tokens);
t3.to_tokens(tokens);
t4.to_tokens(tokens);
}
ParsedMacroArg::Invalid(ts, _) => ts.to_tokens(tokens),
}
}
}
pub fn parse_macro_args(ts: TokenStream, errs: &mut ErrorSink) -> Vec<(Span, MacroArg)> {
let pt: PunctTerminated<ParsedMacroArg, Token![,]> = {
let Some(x) = errs.eat_err(syn::parse2(ts)) else {
return Vec::new();
};
x
};
pt.0.into_iter()
.filter_map(|arg| {
let span = arg.span();
errs.eat_err(arg.try_into_arg()).map(|arg| (span, arg))
})
.collect()
}
}
fn ensure_no_attrs(attrs: &[Attribute], errs: &mut ErrorSink) {
if !attrs.is_empty() {
errs.push(Error::new_spanned(
attrs
.iter()
.flat_map(|attr| attr.to_token_stream())
.collect::<TokenStream>(),
"attributes are not supported here",
));
}
}
fn lit_usize(expr: &Expr, errs: &mut ErrorSink) -> Result<(Span, usize)> {
match expr {
syn::Expr::Lit(syn::ExprLit {
attrs,
lit: syn::Lit::Int(lit_int),
}) => {
ensure_no_attrs(attrs, errs);
lit_int
.base10_parse::<usize>()
.map(|res| (lit_int.span(), res))
}
_ => Err(Error::new_spanned(expr, "expected integer literal")),
}
}
impl Interval {
fn try_from_expr(
expr: syn::Expr,
validate_bounds: Option<(usize, Option<usize>)>,
) -> Result<Self> {
let mut errs = ErrorSink::new();
let validate_ver = |span, ver, errs: &mut ErrorSink| {
if let Some((min, max)) = validate_bounds {
if ver < min {
errs.push(syn::Error::new(
span,
format!("version {ver} is below minimum version ({min})"),
));
} else if max.is_some_and(|max| ver > max) {
errs.push(syn::Error::new(
span,
format!(
"version {ver} is above maximum version ({})",
max.unwrap_or_else(|| unreachable!())
),
));
}
}
};
match expr {
syn::Expr::Range(er) => {
let syn::ExprRange {
ref attrs,
start,
limits,
end,
} = er;
ensure_no_attrs(attrs, &mut errs);
let start = match start
.as_deref()
.and_then(errs.wrap_err_once_1ary(lit_usize))
{
Some((span, start)) => {
validate_ver(span, start, &mut errs);
start
}
None => validate_bounds.map(|(min, _)| min).unwrap_or_default(),
};
let end = match end.as_deref().and_then(errs.wrap_err_once_1ary(lit_usize)) {
None => Bound::Unbounded,
Some((end_span, end)) => {
validate_ver(end_span, end, &mut errs);
match limits {
syn::RangeLimits::Closed(_) => {
if end < start {
errs.push(Error::new(
end_span,
format!("expected at least {start}"),
));
}
Bound::Included(end)
}
syn::RangeLimits::HalfOpen(_) => {
if end <= start {
errs.push(Error::new(
end_span,
format!("expected a number greater than {start}"),
));
}
Bound::Excluded(end)
}
}
}
};
errs.finish_with(|| Self { start, end })
}
expr => {
let res = lit_usize(&expr, &mut errs)
.map_err(|_| Error::new_spanned(expr, "expected usize literal or range"))
.map(|(span, ver)| {
validate_ver(span, ver, &mut errs);
ver
});
errs.combine_into(res).map(|ver| Self {
start: ver,
end: Bound::Included(ver),
})
}
}
}
}
pub trait StagedParse: Sized {
type Stage1: Parse;
type Cx;
fn stage2(from: Self::Stage1, cx: Self::Cx) -> Result<Self>;
}
pub fn staged_parse2<T: StagedParse>(tokens: TokenStream, cx: T::Cx) -> Result<T> {
let stage1 = syn::parse2::<T::Stage1>(tokens)?;
T::stage2(stage1, cx)
}
impl StagedParse for Interval {
type Stage1 = Expr;
type Cx = Option<(usize, Option<usize>)>;
#[inline]
fn stage2(from: Self::Stage1, cx: Self::Cx) -> Result<Self> {
Self::try_from_expr(from, cx)
}
}
#[cfg_attr(test, derive(Debug))]
pub struct SpannedInterval(pub Span, pub Interval);
impl StagedParse for SpannedInterval {
type Stage1 = <Interval as StagedParse>::Stage1;
type Cx = <Interval as StagedParse>::Cx;
#[inline]
fn stage2(from: Self::Stage1, cx: Self::Cx) -> Result<Self> {
use syn::spanned::Spanned;
Ok(Self(from.span(), Interval::stage2(from, cx)?))
}
}
#[derive(Clone)]
#[cfg_attr(test, derive(Debug))]
pub struct PunctTerminated<T, P>(pub Punctuated<T, P>);
impl<T: Parse, P: Parse> Parse for PunctTerminated<T, P> {
fn parse(input: ParseStream) -> syn::Result<Self> {
Punctuated::parse_terminated(input).map(Self)
}
}
impl<T: StagedParse<Cx: Clone>, P: Parse> StagedParse for PunctTerminated<T, P> {
type Stage1 = PunctTerminated<T::Stage1, P>;
type Cx = T::Cx;
fn stage2(from: Self::Stage1, cx: Self::Cx) -> Result<Self> {
use syn::punctuated::Pair;
Ok(Self(
from.0
.into_pairs()
.map(|p| {
let (t, p) = p.into_tuple();
Ok(Pair::new(T::stage2(t, cx.clone())?, p))
})
.collect::<Result<_>>()?,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lit_usize() {
let lit_usize = ErrorSink::wrap_fn_1ary(super::lit_usize);
let basic = syn::parse_quote!(10);
assert!(lit_usize(&basic).is_ok_and(|(_, u)| u == 10));
let binary = syn::parse_quote!(0b101);
assert!(lit_usize(&binary).is_ok_and(|(_, u)| u == 0b101));
let hex = syn::parse_quote!(0xbeef);
assert!(lit_usize(&hex).is_ok_and(|(_, u)| u == 0xbeef));
let err = syn::parse_quote!(nope);
assert!(lit_usize(&err).is_err());
}
}