easy_deref 0.1.0

Derive macros for the `Deref` and `DerefMut` traits
Documentation
#![doc = include_str!("../README.md")]

use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, Error, Fields, parse_macro_input};

#[proc_macro_derive(Deref, attributes(deref))]
pub fn deref_derive(input: TokenStream) -> TokenStream {
    match impl_deref(parse_macro_input!(input as syn::DeriveInput), false) {
        Ok(code) => code.into(),
        Err(err) => err.into_compile_error().into(),
    }
}

#[proc_macro_derive(DerefMut, attributes(deref))]
pub fn deref_mut_derive(input: TokenStream) -> TokenStream {
    match impl_deref(parse_macro_input!(input as syn::DeriveInput), true) {
        Ok(code) => code.into(),
        Err(err) => err.into_compile_error().into(),
    }
}

fn impl_deref(input: syn::DeriveInput, impl_mut: bool) -> syn::Result<proc_macro2::TokenStream> {
    let trait_name = if impl_mut { "DerefMut" } else { "Deref" };

    let name = &input.ident;
    let generics = &input.generics;
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    let stct = match &input.data {
        Data::Struct(stct) => stct,
        _ => {
            return Err(Error::new_spanned(
                &input,
                format!("`deref_derive::{trait_name}` does not support enums or unions"),
            ));
        }
    };

    let (target_ty, idx) = match &stct.fields {
        Fields::Named(fields) if !fields.named.is_empty() => find_deref_field(&fields.named, true),
        Fields::Unnamed(fields) if !fields.unnamed.is_empty() => {
            find_deref_field(&fields.unnamed, false)
        }
        _ => {
            return Err(Error::new_spanned(
                input,
                format!("`deref_derive::{trait_name}` does not support unit or empty structs"),
            ));
        }
    }?;

    Ok(if impl_mut {
        quote! {
            impl #impl_generics ::core::ops::DerefMut
            for #name #ty_generics #where_clause
            {
                fn deref_mut(&mut self) -> &mut Self::Target {
                    &mut self.#idx
                }
            }
        }
    } else {
        quote! {
            impl #impl_generics ::core::ops::Deref
            for #name #ty_generics #where_clause
            {
                type Target = #target_ty;

                fn deref(&self) -> &Self::Target {
                    &self.#idx
                }
            }
        }
    })
}

enum StructIndex {
    Named(syn::Ident),
    Unnamed(syn::Index),
}

impl quote::ToTokens for StructIndex {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        match self {
            Self::Named(ident) => ident.to_tokens(tokens),
            Self::Unnamed(index) => index.to_tokens(tokens),
        }
    }
}

fn find_deref_field(
    fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
    named: bool,
) -> syn::Result<(syn::Type, StructIndex)> {
    if fields.len() == 1 {
        let field = fields.first().unwrap();

        let idx = if named {
            StructIndex::Named(field.ident.clone().unwrap())
        } else {
            StructIndex::Unnamed(syn::Index::from(0))
        };

        return Ok((field.ty.clone(), idx));
    }

    let mut deref_field = None;
    let mut idx = None;

    for (i, field) in fields.iter().enumerate() {
        for attr in &field.attrs {
            let meta = &attr.meta;

            match meta {
                syn::Meta::Path(path) => {
                    if path.is_ident("deref") {
                        if deref_field.is_some() {
                            return Err(Error::new_spanned(
                                field,
                                "more than one field has the `#[deref]` attribute",
                            ));
                        }

                        deref_field = Some(field);

                        idx = if named {
                            Some(StructIndex::Named(field.ident.clone().unwrap()))
                        } else {
                            Some(StructIndex::Unnamed(syn::Index::from(i)))
                        };
                    }
                }
                _ => continue,
            }
        }
    }

    Ok((
        deref_field
            .ok_or(Error::new_spanned(
                fields,
                "no field has the `#[deref]` attribute",
            ))?
            .ty
            .clone(),
        idx.unwrap(),
    ))
}