use std::{
cmp::Ordering,
collections::HashMap,
fmt,
hash::{Hash, Hasher},
};
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
Data, DeriveInput, Error, Expr, ExprLit, Fields, Ident, Lit, LitInt,
};
pub(crate) fn parse(item: TokenStream) -> syn::Result<TokenStream> {
let AlgId { name, variants } = syn::parse2(item)?;
let error = format_ident!("Invalid{name}");
let error_impl = quote! {
#[derive(
::core::marker::Copy,
::core::clone::Clone,
::core::fmt::Debug,
::core::cmp::Eq,
::core::cmp::PartialEq,
)]
pub struct #error(());
impl ::core::fmt::Display for #error {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
::core::write!(f, "invalid {}", ::core::stringify!(#name))
}
}
impl ::core::error::Error for #error {}
};
let base_impl = {
let to_mappings = variants.iter().map(|v| {
let Variant { ident, discrim } = v;
match discrim {
Discriminant::Id(id) => quote! {
#name::#ident => #id
},
Discriminant::Other => quote! {
#name::Other(__id) => __id.get()
},
}
});
let from_mappings = variants.iter().map(|v| {
let Variant { ident, discrim } = v;
match discrim {
Discriminant::Id(id) => quote! {
#id => ::core::result::Result::Ok(#name::#ident)
},
Discriminant::Other => quote! {
__id => match ::core::num::NonZeroU16::new(__id) {
::core::option::Option::Some(__id) => ::core::result::Result::Ok(#name::Other(__id)),
::core::option::Option::None => ::core::result::Result::Err(#error(())),
}
},
}
});
quote! {
impl #name {
pub const fn to_u16(self) -> u16 {
match self {
#(#to_mappings),*,
}
}
pub const fn to_be_bytes(self) -> [u8; 2] {
self.to_u16().to_be_bytes()
}
pub const fn try_from_u16(__id: u16) -> ::core::result::Result<Self, #error> {
match __id {
#(#from_mappings),*,
}
}
pub const fn try_from_be_bytes(bytes: [u8; 2]) -> ::core::result::Result<Self, #error> {
Self::try_from_u16(u16::from_be_bytes(bytes))
}
}
}
};
let block = quote! {
#[doc(hidden)]
#[allow(missing_docs, unused_extern_crates)]
const _: () = {
#base_impl
#error_impl
};
};
Ok(block)
}
struct AlgId {
name: Ident,
variants: Vec<Variant>,
}
impl Parse for AlgId {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let span = Span::call_site();
let input = DeriveInput::parse(input)?;
let Data::Enum(data) = input.data else {
return Err(Error::new(span, "input must be an enum"));
};
let mut variants = data
.variants
.into_iter()
.map(Variant::new)
.collect::<syn::Result<Vec<_>>>()?;
if variants.is_empty() {
return Err(Error::new(span, "enum must have at least one variant"));
}
variants.sort();
let mut uniq = HashMap::new();
for v in variants.iter() {
if let Some(dup) = uniq.insert(v.discrim.clone(), v) {
return Err(Error::new(
v.ident.span(),
format!(
"duplicate ID {} for {} and {}",
v.discrim, v.ident, dup.ident
),
));
}
}
Ok(Self {
name: input.ident,
variants,
})
}
}
#[derive(Clone)]
struct Variant {
ident: Ident,
discrim: Discriminant,
}
impl Variant {
fn new(v: syn::Variant) -> syn::Result<Self> {
match v.fields {
Fields::Unit | Fields::Unnamed(_) => {
let discrim = Self::parse_discrim(&v)?;
Ok(Self {
ident: v.ident,
discrim,
})
}
_ => Err(Error::new(
v.ident.span(),
"must be a unit-only enum or else `Other`",
)),
}
}
fn parse_discrim(v: &syn::Variant) -> syn::Result<Discriminant> {
let attrs = v
.attrs
.iter()
.filter(|v| v.path().is_ident("alg_id"))
.collect::<Vec<_>>();
if attrs.len() != 1 {
Err(Error::new(
v.ident.span(),
"must contain exactly one `alg_id` attr",
))
} else {
attrs[0].parse_args::<Discriminant>()
}
}
}
impl Eq for Variant {}
impl PartialEq for Variant {
fn eq(&self, other: &Self) -> bool {
self.discrim == other.discrim
}
}
impl Ord for Variant {
fn cmp(&self, other: &Self) -> Ordering {
Ord::cmp(&self.discrim, &other.discrim)
}
}
impl PartialOrd for Variant {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Clone, Hash, Eq, PartialEq)]
enum Discriminant {
Id(U16),
Other,
}
impl Discriminant {
fn ord(&self) -> u32 {
match self {
Self::Id(id) => u32::from(id.repr),
Self::Other => u32::MAX,
}
}
}
impl fmt::Display for Discriminant {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Id(id) => write!(f, "{id}"),
Self::Other => write!(f, "Other"),
}
}
}
impl Ord for Discriminant {
fn cmp(&self, other: &Self) -> Ordering {
Ord::cmp(&self.ord(), &other.ord())
}
}
impl PartialOrd for Discriminant {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Parse for Discriminant {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let expr: Expr = input.parse()?;
match expr {
Expr::Lit(ExprLit {
lit: Lit::Int(lit), ..
}) => Ok(Self::Id(U16::new(lit)?)),
Expr::Path(path) if path.path.is_ident("Other") => Ok(Self::Other),
_ => Err(Error::new(input.span(), "invalid attribute")),
}
}
}
#[derive(Clone)]
struct U16 {
repr: u16,
lit: LitInt,
}
impl U16 {
fn new(lit: LitInt) -> syn::Result<Self> {
let repr = lit.base10_parse::<u16>()?;
Ok(Self { repr, lit })
}
}
impl fmt::Display for U16 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.lit)
}
}
impl Eq for U16 {}
impl PartialEq for U16 {
fn eq(&self, other: &Self) -> bool {
self.repr == other.repr
}
}
impl Hash for U16 {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.repr.hash(state)
}
}
impl ToTokens for U16 {
fn to_tokens(&self, tokens: &mut TokenStream) {
let lit = &self.lit;
tokens.extend(quote!(#lit))
}
}