rustler_codegen 0.37.4

Compiler plugin for Rustler
Documentation
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};

use heck::ToSnakeCase;
use std::collections::HashMap;
use syn::{self, spanned::Spanned, Field, Fields, FieldsNamed, FieldsUnnamed, Ident, Variant};

use super::context::Context;

pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
    let ctx = Context::from_ast(ast);

    let variants = ctx
        .variants
        .as_ref()
        .expect("NifTaggedEnum can only be used with enums");

    let atoms = variants
        .iter()
        .flat_map(|variant| {
            let mut ret = if let Fields::Named(fields) = &variant.fields {
                fields
                    .named
                    .iter()
                    .map(|field| {
                        field
                            .ident
                            .as_ref()
                            .expect("Named fields must have an ident.")
                    })
                    .collect::<Vec<_>>()
            } else {
                vec![]
            };

            ret.push(&variant.ident);
            ret
        })
        .map(|atom_ident| {
            let atom_str = atom_ident.to_string().to_snake_case();
            let atom_fn = Context::ident_to_atom_fun(atom_ident);
            (atom_str, atom_fn)
        })
        .collect::<HashMap<_, _>>()
        .into_iter()
        .map(|(atom_str, atom_fn)| {
            quote! {
                #atom_fn = #atom_str,
            }
        })
        .collect::<Vec<_>>();

    let atom_defs = quote! {
        ::rustler::atoms! {
            #(#atoms)*
        }
    };

    let atoms_module_name = ctx.atoms_module_name(Span::call_site());

    let decoder = if ctx.decode() {
        gen_decoder(&ctx, variants, &atoms_module_name)
    } else {
        quote! {}
    };

    let encoder = if ctx.encode() {
        gen_encoder(&ctx, variants, &atoms_module_name)
    } else {
        quote! {}
    };

    let gen = quote! {
        #[allow(non_snake_case)]
        mod #atoms_module_name {
            #atom_defs
        }

        #decoder
        #encoder
    };

    gen
}

fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) -> TokenStream {
    let enum_name = ctx.ident;
    let unit_decoders: Vec<TokenStream> = variants
        .iter()
        .filter_map(|variant| {
            let variant_ident = &variant.ident;
            let atom_fn = Context::ident_to_atom_fun(variant_ident);
            match &variant.fields {
                Fields::Unit => Some(gen_unit_decoder(enum_name, variant_ident, atom_fn)),
                _ => None,
            }
        })
        .collect();
    let named_unnamed_decoders: Vec<TokenStream> = variants
        .iter()
        .filter_map(|variant| {
            let variant_ident = &variant.ident;
            let atom_fn = Context::ident_to_atom_fun(variant_ident);
            match &variant.fields {
                Fields::Unnamed(fields) => Some(gen_unnamed_decoder(
                    enum_name,
                    fields,
                    variant.fields.iter(),
                    variant_ident,
                    atom_fn,
                )),
                Fields::Named(fields) => {
                    Some(gen_named_decoder(enum_name, fields, variant_ident, atom_fn))
                }
                _ => None,
            }
        })
        .collect();

    super::encode_decode_templates::decoder(
        ctx,
        quote! {
            use #atoms_module_name::*;

            fn try_decode_field<'a, T>(
                term: ::rustler::Term<'a>,
                field: ::rustler::Atom,
                ) -> ::rustler::NifResult<T>
                where
                    T: ::rustler::Decoder<'a>,
            {
                use ::rustler::Encoder;
                match ::rustler::Decoder::decode(term.map_get(&field)?) {
                    Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!(
                                    "Could not decode field :{:?} on %{{}}",
                                    field
                    )))),
                    Ok(value) => Ok(value),
                }
            }

            if let Ok(unit) = ::rustler::types::atom::Atom::from_term(term) {
                #(#unit_decoders)*
            } else if let Ok(tuple) = ::rustler::types::tuple::get_tuple(term) {
                let name = tuple
                            .get(0)
                            .and_then(|&first| ::rustler::types::atom::Atom::from_term(first).ok())
                            .ok_or(::rustler::Error::RaiseAtom("invalid_variant"))?;
                #(#named_unnamed_decoders)*
            }

            Err(::rustler::Error::RaiseAtom("invalid_variant"))
        },
    )
}

fn gen_encoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident) -> TokenStream {
    let enum_name = ctx.ident;

    let variant_defs: Vec<TokenStream> = variants
        .iter()
        .map(|variant| {
            let variant_ident = &variant.ident;
            let atom_fn = Context::ident_to_atom_fun(variant_ident);

            match &variant.fields {
                Fields::Unit => gen_unit_encoder(enum_name, variant_ident, atom_fn),
                Fields::Unnamed(fields) => {
                    gen_unnamed_encoder(enum_name, fields, variant_ident, atom_fn)
                }
                Fields::Named(fields) => {
                    gen_named_encoder(enum_name, fields, variant_ident, atom_fn)
                }
            }
        })
        .collect();

    super::encode_decode_templates::encoder(
        ctx,
        quote! {
            use #atoms_module_name::*;

            match self {
                #(#variant_defs)*
            }
        },
    )
}

fn gen_unit_decoder(enum_name: &Ident, variant_ident: &Ident, atom_fn: Ident) -> TokenStream {
    quote! {
        if unit == #atom_fn() {
            return Ok ( #enum_name :: #variant_ident );
        }
    }
}

fn gen_unnamed_decoder<'a>(
    enum_name: &Ident,
    fields: &FieldsUnnamed,
    fields_iter: impl Iterator<Item = &'a Field>,
    variant_ident: &Ident,
    atom_fn: Ident,
) -> TokenStream {
    let decoded_field = &fields_iter
        .enumerate()
        .map(|(i, f)| {
            let i = i + 1;
            let ty = &f.ty;
            quote! {
                <#ty as ::rustler::Decoder>::decode(tuple[#i]).map_err(|_| ::rustler::Error::RaiseTerm(
                    Box::new(format!("Could not decode field on position {}", #i))
                ))?
            }
        })
        .collect::<Vec<_>>();
    let len = fields.unnamed.len();
    quote! {
        if name == #atom_fn() {
            if tuple.len() - 1 != #len {
                return Err(::rustler::Error::RaiseTerm(Box::new(format!(
                    "The tuple must have {} elements, but it has {}",
                    #len + 1, tuple.len()
                ))));
            }
            return Ok( #enum_name :: #variant_ident ( #(#decoded_field),* ) )
        }
    }
}

fn gen_named_decoder(
    enum_name: &Ident,
    fields: &FieldsNamed,
    variant_ident: &Ident,
    atom_fn: Ident,
) -> TokenStream {
    let (assignments, field_defs): (Vec<TokenStream>, Vec<TokenStream>) = fields
        .named
        .iter()
        .enumerate()
        .map(|(index, field)| {
            let ident = field
                .ident
                .as_ref()
                .expect("Named fields must have an ident.");
            let atom_fun = Context::field_to_atom_fun(field);
            let variable = Context::escape_ident_with_index(&ident.to_string(), index, "map");

            let ident_string = ident.to_string();
            let enum_name_string = enum_name.to_string();

            let assignment = quote_spanned! { field.span() =>
                let #variable = try_decode_field(tuple[1], #atom_fun()).map_err(|_|{
                    ::rustler::Error::RaiseTerm(Box::new(format!(
                        "Could not decode field '{}' on Enum '{}'",
                        #ident_string, #enum_name_string
                    )))
                })?;
            };

            let field_def = quote! {
                #ident: #variable
            };
            (assignment, field_def)
        })
        .unzip();

    quote! {
        if tuple.len() == 2 && name == #atom_fn() {
            let len = tuple[1].map_size().map_err(|_| ::rustler::Error::RaiseTerm(Box::new(
                "The second element of the tuple must be a map"
            )))?;
            #(#assignments)*
            return Ok( #enum_name :: #variant_ident { #(#field_defs),* } )
        }
    }
}

fn gen_unit_encoder(enum_name: &Ident, variant_ident: &Ident, atom_fn: Ident) -> TokenStream {
    quote! {
        #enum_name :: #variant_ident => ::rustler::Encoder::encode(&#atom_fn(), env),
    }
}

fn gen_unnamed_encoder(
    enum_name: &Ident,
    fields: &FieldsUnnamed,
    variant_ident: &Ident,
    atom_fn: Ident,
) -> TokenStream {
    let len = fields.unnamed.len();
    let inners = (0..len)
        .map(|i| Ident::new(&format!("inner{i}"), Span::call_site()))
        .collect::<Vec<_>>();
    quote! {
        #enum_name :: #variant_ident ( #(ref #inners),* ) => ::rustler::Encoder::encode(&(#atom_fn(), #(#inners),*), env),
    }
}

fn gen_named_encoder(
    enum_name: &Ident,
    fields: &FieldsNamed,
    variant_ident: &Ident,
    atom_fn: Ident,
) -> TokenStream {
    let field_decls = fields
        .named
        .iter()
        .map(|field| {
            let field_ident = &field.ident;
            quote! {
                #field_ident,
            }
        })
        .collect::<Vec<_>>();
    let (keys, values): (Vec<_>, Vec<_>) = fields
        .named
        .iter()
        .map(|field| {
            let field_ident = field
                .ident
                .as_ref()
                .expect("Named fields must have an ident.");
            let atom_fun = Context::field_to_atom_fun(field);
            (
                quote! { ::rustler::Encoder::encode(&#atom_fun(), env) },
                quote! { ::rustler::Encoder::encode(&#field_ident, env) },
            )
        })
        .unzip();
    quote! {
        #enum_name :: #variant_ident{
            #(#field_decls)*
        } => {
            let map = ::rustler::Term::map_from_term_arrays(env, &[#(#keys),*], &[#(#values),*])
                .expect("Failed to create map");
            ::rustler::types::tuple::make_tuple(env, &[::rustler::Encoder::encode(&#atom_fn(), env), map])
        }
    }
}