rustler_codegen 0.27.0

Compiler plugin for Rustler
Documentation
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};

use heck::ToSnakeCase;
use syn::{self, spanned::Spanned, Fields, Ident, Variant};

use super::context::Context;

pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
    let ctx = Context::from_ast(ast);

    let variants = ctx
        .variants
        .as_ref()
        .expect("NifUnitEnum can only be used with enums");

    for variant in variants {
        if let Fields::Unit = variant.fields {
        } else {
            return quote_spanned! { variant.span() =>
                compile_error!("NifUnitEnum can only be used with enums containing unit variants.");
            };
        }
    }

    let atoms: Vec<TokenStream> = variants
        .iter()
        .map(|variant| {
            let atom_str = variant.ident.to_string().to_snake_case();
            let atom_fn = Ident::new(&format!("atom_{}", atom_str), Span::call_site());
            quote! {
                #atom_fn = #atom_str,
            }
        })
        .collect();

    let atom_defs = quote! {
        rustler::atoms! {
            #(#atoms)*
        }
    };

    let atoms_module_name = ctx.atoms_module_name(Span::call_site());

    let decoder = if ctx.decode() {
        gen_decoder(&ctx, variants, &atoms_module_name)
    } else {
        quote! {}
    };

    let encoder = if ctx.encode() {
        gen_encoder(&ctx, variants, &atoms_module_name)
    } else {
        quote! {}
    };

    let gen = quote! {
        mod #atoms_module_name {
            #atom_defs
        }

        #decoder
        #encoder
    };

    gen
}

fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) -> TokenStream {
    let enum_name = ctx.ident;

    let variant_defs: Vec<TokenStream> = variants
        .iter()
        .map(|variant| {
            let variant_ident = &variant.ident;
            let atom_str = variant_ident.to_string().to_snake_case();
            let atom_fn = Ident::new(&format!("atom_{}", atom_str), Span::call_site());

            quote! {
                if value == #atom_fn() {
                    return Ok ( #enum_name :: #variant_ident );
                }
            }
        })
        .collect();

    super::encode_decode_templates::decoder(
        ctx,
        quote! {
            use #atoms_module_name::*;

            let value = ::rustler::types::atom::Atom::from_term(term)?;

            #(#variant_defs)*

            Err(::rustler::Error::RaiseAtom("invalid_variant"))
        },
    )
}

fn gen_encoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) -> TokenStream {
    let enum_name = ctx.ident;

    let variant_defs: Vec<TokenStream> = variants
        .iter()
        .map(|variant| {
            let variant_ident = &variant.ident;
            let atom_str = variant_ident.to_string().to_snake_case();
            let atom_fn = Ident::new(&format!("atom_{}", atom_str), Span::call_site());

            quote! {
                #enum_name :: #variant_ident => ::rustler::Encoder::encode(&#atom_fn(), env),
            }
        })
        .collect();

    super::encode_decode_templates::encoder(
        ctx,
        quote! {
            use #atoms_module_name::*;

            match *self {
                #(#variant_defs)*
            }
        },
    )
}