cpp_oop_macros 0.1.9

helper crate for 'oop_cpp' crate
Documentation
use convert_case::{Case, Casing};
use proc_macro2::TokenTree;
use quote::format_ident;
use quote::quote;
use syn::{
    parse_quote, punctuated::Punctuated, token::Comma, Abi, BareFnArg, Field, Fields, FnArg, Ident,
    ImplItem, ImplItemFn, Item, ItemImpl, ItemStruct, Meta, ReturnType, Type, Visibility,
};

use crate::utils::{add_repr_c, get_attr, has_attr, remove_attr};
use crate::TokenStream2;

pub struct MainStructProperties {
    pub default: bool,
    pub base_names: Vec<Ident>,
    pub vis: Visibility,
    pub name: Ident,
    pub inner_name: Ident,
    pub outer_name: Ident,
    pub default_vtable_name: Ident,
    pub field_names: Vec<Ident>,
}

impl MainStructProperties {
    pub fn parse(main_struct: &mut ItemStruct) -> Self {
        let attrs = &mut main_struct.attrs;
        add_repr_c(attrs);

        let name = &main_struct.ident;

        MainStructProperties {
            default: {
                if has_attr(attrs, "default") {
                    remove_attr(attrs, "default");
                    true
                } else {
                    false
                }
            },
            base_names: get_attr(attrs, "base")
                .map(|base| {
                    remove_attr(attrs, "base");
                    let Meta::List(list) = base.meta else {
                        panic!("expected meta list");
                    };
                    list.tokens
                        .into_iter()
                        .step_by(2)
                        .map(|t| match t {
                            TokenTree::Ident(ident) => ident,
                            _ => panic!("expected ident in bases definition"),
                        })
                        .collect()
                })
                .unwrap_or_default(),
            vis: main_struct.vis.clone(),
            name: name.clone(),
            inner_name: format_ident!("{name}MethodsInner"),
            outer_name: format_ident!("{name}Methods"),
            default_vtable_name: format_ident!("{}_VTBL", name.to_string().to_uppercase()),
            field_names: main_struct
                .fields
                .iter()
                .map(|field| field.ident.as_ref().expect("expected named field"))
                .cloned()
                .collect(),
        }
    }
}

pub struct LayoutFnProps {
    pub fun: Field,
    pub virt: bool,
    pub generics: TokenStream2,
    pub abi: Abi,
    pub fn_name: Ident,
    pub fn_name_inner: Ident,
    pub inputs: Punctuated<BareFnArg, Comma>,
    pub inputs_inner: Punctuated<BareFnArg, Comma>,
    pub input_names: Punctuated<Ident, Comma>,
    pub output: ReturnType,
}

pub struct LayoutStructProps {
    pub layout_name: Ident,
    pub layout_fns: Vec<LayoutFnProps>,
    pub vtable: bool,
}

impl LayoutStructProps {
    pub fn parse(layout_struct: &mut ItemStruct, name: &Ident, self_ident: &Ident) -> Self {
        add_repr_c(&mut layout_struct.attrs);
        let layout_struct_name = layout_struct.ident.clone();
        assert!(
            layout_struct_name == name.to_string() + "Layout",
            "Second struct's name should be first struct's n@ keyame with 'Layout' at the end"
        );

        let layout_fns = match &mut layout_struct.fields {
            Fields::Named(named) => named
                .named
                .iter_mut()
                .map(|fun| {
                    let generics = get_attr(&fun.attrs, "generics")
                        .map(|generics| {
                            remove_attr(&mut fun.attrs, "generics");
                            let Meta::List(list) = generics.meta.clone() else {
                                panic!("expected meta list");
                            };
                            let tokens = list.tokens;
                            quote! { #tokens }
                        })
                        .unwrap_or_default();
                    let virt = if has_attr(&fun.attrs, "virt") {
                        remove_attr(&mut fun.attrs, "virt");
                        true
                    } else {
                        false
                    };
                    (fun.clone(), virt, generics)
                })
                .collect::<Vec<_>>(),
            _ => panic!("expected named field"),
        };

        let virt_fns = layout_fns
            .iter()
            .filter_map(|(fun, virt, _)| if *virt { Some(fun) } else { None })
            .collect::<Punctuated<_, Comma>>();

        layout_struct.fields = Fields::Named(parse_quote!({ #virt_fns }));

        let vtable = layout_fns.iter().any(|(_, virt, _)| *virt);

        LayoutStructProps {
            layout_name: format_ident!("{name}Layout"),
            layout_fns: layout_fns
                .iter()
                .map(|(fun, virt, generics)| {
                    let mut fun = fun.clone();
                    let Type::BareFn(bare_fn_ty) = &mut fun.ty else {
                        panic!("expected bare fn type");
                    };
                    let abi: Abi = bare_fn_ty
                        .abi
                        .get_or_insert_with(|| parse_quote!(extern "C"))
                        .clone();
                    let fn_name = fun.ident.as_ref().expect("expected named fun").clone();
                    let fn_name_inner = if *virt {
                        format_ident!("{fn_name}_inner_virt")
                    } else {
                        format_ident!("{fn_name}_inner")
                    };
                    let inputs_inner = bare_fn_ty.inputs.clone();
                    let inputs = inputs_inner
                        .iter()
                        .map(|arg| {
                            let mut new_arg = (*arg).clone();
                            if let Some((name, _)) = &mut new_arg.name {
                                if name == "this" {
                                    // correct `this: &Self` to `&self`
                                    if let Type::Reference(t_ref) = &mut new_arg.ty {
                                        if let Type::Path(p) = t_ref.elem.as_mut() {
                                            new_arg.name = None;
                                            p.path.segments[0].ident = self_ident.clone();
                                        } else {
                                            panic!("Syntax error");
                                        }
                                    } else {
                                        panic!("Syntax error");
                                    }
                                }
                            }
                            new_arg
                        })
                        .collect::<Punctuated<_, _>>();
                    let input_names = inputs
                        .iter()
                        .map(|arg| arg.name.clone().map_or(self_ident.clone(), |name| name.0))
                        .collect();
                    let output = bare_fn_ty.output.clone();
                    LayoutFnProps {
                        fun,
                        virt: *virt,
                        generics: generics.clone(),
                        abi,
                        fn_name,
                        fn_name_inner,
                        inputs,
                        inputs_inner,
                        input_names,
                        output,
                    }
                })
                .collect(),
            vtable,
        }
    }
}

pub struct BaseImplProps {
    pub impl_base_inner: ItemImpl,
    pub fns: Vec<BaseFnProps>,
    pub base_name: Ident,
    pub base_field_name: Ident,
    pub base_vtable_name: Ident,
}

pub struct BaseFnProps {
    pub fun: ImplItemFn,
    pub fn_name_inner: Ident,
    pub fn_name: Ident,
    pub inputs_inner: Punctuated<FnArg, Comma>,
    pub inputs: Punctuated<FnArg, Comma>,
    pub output: ReturnType,
    pub unused: bool,
    pub virt: bool,
}

pub fn parse_impl_bases_inner(
    items: &[Item],
    name: &Ident,
    base_names: &[Ident],
    mut vtable: bool,
) -> (Vec<BaseImplProps>, bool) {
    (
        base_names
            .iter()
            .enumerate()
            .map(|(i, base_name)| {
                let Item::Impl(impl_base_inner) = items
                    .get(i + 3)
                    .unwrap_or_else(|| panic!("pls impl {base_name}'s inner trait"))
                    .clone()
                else {
                    panic!("expected base impl")
                };
                let Type::Path(path) = &*impl_base_inner.self_ty else {
                    panic!("expected type to be a path")
                };
                assert!(
                    path.path.segments.first().unwrap().ident == *name,
                    "implemented for the wrong struct"
                );
                assert!(
                    impl_base_inner
                        .trait_
                        .as_ref()
                        .unwrap()
                        .1
                        .segments
                        .first()
                        .unwrap()
                        .ident
                        == format_ident!("{base_name}MethodsInner"),
                    "implemented the wrong trait"
                );
                let mut impl_base_inner = impl_base_inner.clone();
                let fns = impl_base_inner
                    .items
                    .iter_mut()
                    .map(|i| {
                        let ImplItem::Fn(fun) = i else {
                            panic!("expected fn")
                        };
                        let sig = &fun.sig;
                        let s = sig.ident.to_string();
                        let fn_name_inner = sig.ident.clone();
                        let fn_name = format_ident!(
                            "{}",
                            &s[..s.len() - if s.ends_with("virt") { 11 } else { 6 }]
                        );
                        let inputs_inner = sig.inputs.clone();
                        let inputs: Punctuated<FnArg, Comma> = inputs_inner
                            .iter()
                            .enumerate()
                            .map(|(i, arg)| {
                                if i == 0
                                    && let FnArg::Typed(arg) = arg
                                    && let Type::Reference(ty) = &*arg.ty
                                {
                                    FnArg::Receiver(if ty.mutability.is_some() {
                                        parse_quote! { &mut self }
                                    } else {
                                        parse_quote! { &self }
                                    })
                                } else {
                                    (*arg).clone()
                                }
                            })
                            .collect();
                        let output = sig.output.clone();
                        let unused = if get_attr(&fun.attrs, "unused").is_some() {
                            remove_attr(&mut fun.attrs, "unused");
                            fun.block = parse_quote! {
                                {
                                    unreachable!("fn is unused: pls report bug for cpp_oop")
                                }
                            };
                            true
                        } else {
                            false
                        };
                        let virt = sig.ident.to_string().ends_with("_virt");
                        if virt {
                            vtable = true;
                        }
                        BaseFnProps {
                            fun: fun.clone(),
                            fn_name_inner,
                            fn_name,
                            inputs_inner,
                            inputs,
                            output,
                            unused,
                            virt,
                        }
                    })
                    .collect::<Vec<_>>();
                let base_field_name =
                    format_ident!("base_{}", base_name.to_string().to_case(Case::Snake));
                let base_vtable_name =
                    format_ident!("{}_VTBL", base_name.to_string().to_uppercase());
                BaseImplProps {
                    impl_base_inner,
                    fns,
                    base_name: base_name.clone(),
                    base_field_name,
                    base_vtable_name,
                }
            })
            .collect::<Vec<_>>(),
        vtable,
    )
}