derive-com-wrapper 0.1.0

Procedural derive macro for the `com-wrapper` crate, useful for types which are just a safe wrapper around a `wio::com::ComPtr`.
Documentation
use proc_macro2::TokenStream;
use syn::spanned::Spanned;
use syn::{
    Attribute, Data, DeriveInput, Fields, GenericArgument, Ident, Index, Member, Meta, NestedMeta,
    PathArguments, Type,
};

pub fn expand_com_wrapper(input: &DeriveInput) -> Result<TokenStream, String> {
    let fields = match &input.data {
        Data::Struct(data) => &data.fields,
        _ => return Err("ComWrapper can only wrap a `struct`".into()),
    };

    let (member, itype, ctype) = get_comptr_member(fields)?;
    let attrinfo = parse_attr(&input.attrs)?;

    let wrapper_impl = wrapper_impl(&input.ident, &member, itype, ctype);
    let meta_impl = meta_impl(&input.ident, &attrinfo);

    Ok(quote! {
        #wrapper_impl
        #meta_impl
    })
}

fn wrapper_impl(wrap: &Ident, member: &Member, itype: &Type, ctype: &Type) -> TokenStream {
    let ptr_wrap = create_wrapping(wrap, member);
    quote! {
        impl ::com_wrapper::ComWrapper for #wrap {
            type Interface = #itype;
            #[inline]
            unsafe fn get_raw(&self) -> *mut #itype {
                ComPtr::as_raw(&self.#member)
            }
            #[inline]
            unsafe fn from_raw(ptr: *mut #itype) -> Self {
                <Self as ::com_wrapper::ComWrapper>::from_ptr(ComPtr::from_raw(ptr))
            }
            #[inline]
            unsafe fn into_raw(self) -> *mut #itype {
                ComPtr::into_raw(<Self as ::com_wrapper::ComWrapper>::into_ptr(self))
            }
            #[inline]
            unsafe fn from_ptr(ptr: #ctype) -> Self {
                #ptr_wrap
            }
            #[inline]
            unsafe fn into_ptr(self) -> #ctype {
                self.#member
            }
        }
    }
}

fn dbg_impl(wrap: &Ident) -> TokenStream {
    let wrap_str = wrap.to_string();
    quote! {
        impl ::std::fmt::Debug for #wrap {
            fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
                fmt.debug_tuple(#wrap_str)
                    .field(&unsafe { <Self as ::com_wrapper::ComWrapper>::get_raw(self) })
                    .finish()
            }
        }
    }
}

fn meta_impl(wrap: &Ident, meta: &AttrInfo) -> TokenStream {
    let send_impl = if meta.send {
        quote! {
            unsafe impl Send for #wrap {}
        }
    } else {
        quote!{}
    };
    let sync_impl = if meta.sync {
        quote! {
            unsafe impl Sync for #wrap {}
        }
    } else {
        quote!{}
    };
    let dbg_impl = if meta.debug {
        dbg_impl(wrap)
    } else {
        quote!{}
    };

    quote! {
        #send_impl
        #sync_impl
        #dbg_impl
    }
}

fn create_wrapping(wrap: &Ident, member: &Member) -> TokenStream {
    match member {
        Member::Named(member) => quote! {
            #wrap { #member: ptr }
        },
        Member::Unnamed(_) => quote! {
            #wrap(ptr)
        },
    }
}

fn get_comptr_member(fields: &Fields) -> Result<(Member, &Type, &Type), String> {
    let (member, field) = match fields {
        Fields::Named(fields) => {
            if fields.named.len() != 1 {
                return Err("A ComWrapper struct must have exactly 1 member, a ComPtr.".into());
            }

            let field = &fields.named[0];
            let mem = Member::Named(field.ident.clone().unwrap());
            (mem, field)
        }
        Fields::Unnamed(fields) => {
            if fields.unnamed.len() != 1 {
                return Err("A ComWrapper struct must have exactly 1 member, a ComPtr.".into());
            }

            let field = &fields.unnamed[0];
            let mem = Member::Unnamed(Index {
                index: 0,
                span: field.span(),
            });
            (mem, field)
        }
        Fields::Unit => return Err("A ComWrapper struct must have a ComPtr member.".into()),
    };

    let itype = extract_comptr_ty(&field.ty)?;

    Ok((member, itype, &field.ty))
}

fn extract_comptr_ty(ty: &Type) -> Result<&Type, String> {
    let segments = match ty {
        Type::Path(typath) => &typath.path.segments,
        _ => return Err("A ComWrapper struct must have a ComPtr member.".into()),
    };

    let final_seg = match segments.last() {
        Some(seg) => *seg.value(),
        None => return Err("A ComWrapper struct must have a ComPtr member.".into()),
    };

    if final_seg.ident != "ComPtr" {
        return Err("A ComWrapper struct must have a ComPtr member.".into());
    }

    let args = match &final_seg.arguments {
        PathArguments::AngleBracketed(args) => &args.args,
        _ => return Err("Invalid generic arguments to ComPtr.".into()),
    };

    if args.len() != 1 {
        return Err("Invalid generic arguments to ComPtr.".into());
    }

    let itype = match &args[0] {
        GenericArgument::Type(ty) => ty,
        _ => return Err("Invalid generic arguments to ComPtr.".into()),
    };

    Ok(itype)
}

#[derive(Default)]
struct AttrInfo {
    send: bool,
    sync: bool,
    debug: bool,
}

fn parse_attr(attrs: &[Attribute]) -> Result<AttrInfo, String> {
    let com_attr = match attrs.iter().filter(is_com_attr).nth(0) {
        Some(attr) => attr,
        None => return Ok(Default::default()),
    };

    let meta = com_attr.parse_meta().map_err(|e| e.to_string())?;

    let attrs = match &meta {
        Meta::List(list) => &list.nested,
        _ => return Err("Invalid parameters to the `com` attribute".into()),
    };

    let mut send = false;
    let mut sync = false;
    let mut debug = false;
    for attr in attrs.iter() {
        let ident = match attr {
            NestedMeta::Meta(Meta::Word(ident)) => ident,
            _ => return Err("Invalid parameters to the `com` attribute".into()),
        };

        if ident == "send" {
            if send {
                return Err("Duplicate parameters to the `com` attribute".into());
            }
            send = true;
        } else if ident == "sync" {
            if sync {
                return Err("Duplicate parameters to the `com` attribute".into());
            }
            sync = true;
        } else if ident == "debug" {
            if debug {
                return Err("Duplicate parameters to the `com` attribute".into());
            }
            debug = true;
        } else {
            return Err("Invalid parameters to the `com` attribute".into());
        }
    }

    Ok(AttrInfo { send, sync, debug })
}

fn is_com_attr(attr: &&Attribute) -> bool {
    attr.path.segments.len() == 1 && attr.path.segments[0].ident == "com"
}