lolibaso-macros 0.1.0

lolibaso 的过程宏 crate
Documentation
use convert_case::{Case, Casing};
use quote::quote;
use quote::quote_spanned;
use syn::spanned::Spanned;
use syn::{Ident, parse::Parse};

pub struct ErrEnum {
    vis: syn::Visibility,
    ident: Ident,
    base_biz_code: u32,
    default_http_status: Option<u16>,
    err_path: Option<syn::TypePath>,
    variants: Vec<ErrVariant>,
}

struct ErrVariant {
    ident: Ident,
    desc: Option<String>,
    http_status: Option<u16>,
}

impl Parse for ErrEnum {
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
        let input = syn::DeriveInput::parse(input)?;
        let data = match input.data {
            syn::Data::Enum(data_enum) => data_enum,
            _ => {
                return Err(syn::Error::new_spanned(input, "expected enum"));
            }
        };
        let mut variants = vec![];
        for variant in data.variants {
            if variant.discriminant.is_some() {
                return Err(syn::Error::new_spanned(
                    variant,
                    "enum variant must not have a discriminant",
                ));
            }
            let ident = variant.ident;
            let mut desc = None;
            let mut http_status = None;
            for attr in &variant.attrs {
                let Some(path) = attr.path().get_ident() else {
                    continue;
                };
                match &*path.to_string() {
                    "doc" => {
                        if desc.is_some() {
                            continue;
                        }
                        desc = Some(parse_variant_doc(attr)?);
                    }
                    "http_status" => {
                        http_status = Some(parse_http_status(attr)?);
                    }
                    _ => {
                        return Err(syn::Error::new_spanned(
                            attr,
                            format!("Unknown attribute: {}. Expected: `http_status`", path),
                        ));
                    }
                }
            }

            variants.push(ErrVariant {
                ident,
                desc,
                http_status,
            });
        }

        let mut err_path = None;
        let mut default_http_status = None;
        let mut base_biz_code = None;
        for attr in &input.attrs {
            let Some(path) = attr.path().get_ident() else {
                continue;
            };
            match &*path.to_string() {
                "base_biz_code" => {
                    base_biz_code = Some(parse_base_biz_code(attr)?);
                }
                "err_path" => {
                    err_path = Some(parse_err_path(attr)?);
                }
                "default_http_status" => {
                    default_http_status = Some(parse_http_status(attr)?);
                }
                path => {
                    return Err(syn::Error::new_spanned(
                        attr,
                        format!(
                            "unknown attribute: {}. Expected: `base_biz_code`, `err_path`, `default_http_status`",
                            path
                        ),
                    ));
                }
            }
        }

        let base_biz_code = match base_biz_code {
            Some(v) => v,
            None => {
                return Err(syn::Error::new_spanned(
                    input.ident,
                    "`base_biz_code` must be set. Use #[base_biz_code = ...] attribute",
                ));
            }
        };

        Ok(ErrEnum {
            ident: input.ident,
            base_biz_code,
            default_http_status,
            err_path,
            variants,
            vis: input.vis,
        })
    }
}

fn require_lit_int(expr: &syn::Expr) -> syn::Result<u32> {
    let span = expr.span();
    let lit = require_literal(expr)?;
    match lit {
        syn::Lit::Int(lit_int) => Ok(lit_int.base10_parse()?),
        _ => Err(syn::Error::new(span, "expected literal integer")),
    }
}

fn require_lit_str(expr: &syn::Expr) -> syn::Result<&syn::LitStr> {
    let span = expr.span();
    let lit = require_literal(expr)?;
    match lit {
        syn::Lit::Str(lit_str) => Ok(lit_str),
        _ => Err(syn::Error::new(span, "expected literal string")),
    }
}

fn require_literal(expr: &syn::Expr) -> syn::Result<&syn::Lit> {
    match expr {
        syn::Expr::Lit(expr_lit) => Ok(&expr_lit.lit),
        _ => {
            return Err(syn::Error::new_spanned(expr, "expected literal"));
        }
    }
}

fn parse_http_status(attr: &syn::Attribute) -> syn::Result<u16> {
    let named = attr.meta.require_name_value()?;
    let lit = require_lit_int(&named.value)?;

    Ok(lit as u16)
}

fn parse_err_path(attr: &syn::Attribute) -> syn::Result<syn::TypePath> {
    let Ok(list) = attr.meta.require_list() else {
        return Err(syn::Error::new_spanned(
            attr,
            "Use `#[err_path(...)]` attribute",
        ));
    };

    let Ok(path) = list.parse_args() else {
        return Err(syn::Error::new_spanned(
            attr,
            "Expected TypePath. Use `#[err_path(TypePath)]` attribute",
        ));
    };

    Ok(path)
}
fn parse_base_biz_code(attr: &syn::Attribute) -> syn::Result<u32> {
    let named = attr.meta.require_name_value()?;
    let lit = require_lit_int(&named.value)?;
    Ok(lit)
}

fn parse_variant_doc(attr: &syn::Attribute) -> syn::Result<String> {
    let named = attr.meta.require_name_value()?;
    let lit = require_lit_str(&named.value)?;
    Ok(lit.value().trim().to_string())
}

impl ErrEnum {
    pub fn expand(self) -> syn::Result<proc_macro2::TokenStream> {
        let ident = &self.ident;
        let kinds = self.gen_err_kinds();
        let err_path = match self.err_path {
            Some(p) => quote!(#p),
            None => quote!(lolibaso::http::error::BizError),
        };
        let all_variants = self.variants.iter().map(|v| {
            let variant = &v.ident;
            quote!(&#ident::#variant)
        });
        let vis = &self.vis;
        let stream = quote_spanned! { self.ident.span() =>

            #vis enum #ident {}

            const _: () = {
                use #err_path as BizError;

                #[allow(non_upper_case_globals)]
                impl #ident {
                    #(#kinds)*

                    pub fn all() -> &'static [&'static BizError] {
                       static ALL: &[&BizError] = &[#(#all_variants),*];
                       ALL
                    }
                }
            };

        };

        Ok(stream)
    }

    fn gen_err_kinds(&self) -> Vec<proc_macro2::TokenStream> {
        let mut kinds = vec![];

        for (index, variant) in self.variants.iter().enumerate() {
            let v_ident = &variant.ident;
            let biz_code = self.base_biz_code + (index as u32) + 1;
            let desc = variant.desc.clone().unwrap_or_else(|| {
                let mut desc = v_ident.to_string().to_case(Case::Lower);
                let first_char = desc.chars().next().unwrap().to_uppercase();
                desc.replace_range(0..1, &first_char.to_string());
                desc
            });
            let http_status = variant
                .http_status
                .or(self.default_http_status)
                .unwrap_or(400);
            let kind = quote_spanned! { variant.ident.span() =>
                pub const #v_ident: BizError = BizError::new(#http_status, #biz_code, #desc);
            };
            kinds.push(kind);
        }

        kinds
    }
}