cpp_inherit/
lib.rs

1use std::ops::Deref;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{Fields, ImplItem, Path};
6
7mod method_helpers;
8use method_helpers::{filter_overrides, make_extern_c, remove_override_attr};
9
10mod parsers;
11use parsers::{InheritImplAttr, NamedField};
12
13mod vtable;
14use vtable::generate_vtable_const;
15
16#[proc_macro_attribute]
17pub fn inherit_from(attr: TokenStream, item: TokenStream) -> TokenStream {
18    let mut struct_def = syn::parse_macro_input!(item as syn::ItemStruct);
19    let ty = syn::parse_macro_input!(attr as syn::Type);
20
21    let fields = match struct_def.fields {
22        Fields::Named(ref mut fields) => &mut fields.named,
23        Fields::Unit => {
24            struct_def.fields = Fields::Named(syn::parse_quote!({}));
25            if let Fields::Named(ref mut fields) = struct_def.fields {
26                &mut fields.named
27            } else {
28                unreachable!()
29            }
30        }
31        _ => panic!("Tuple-type structs cannot inherit from classes"),
32    };
33
34    let base_field: NamedField = syn::parse_quote!(
35        _base: #ty
36    );
37
38    fields.insert(0, base_field.0);
39
40    let struct_name = &struct_def.ident;
41
42    struct_def.attrs.push(syn::parse_quote! {
43        #[repr(C)]
44    });
45
46    quote!(
47        #struct_def
48
49        impl ::core::ops::Deref for #struct_name {
50            type Target = #ty;
51
52            fn deref(&self) -> &Self::Target {
53                &self._base
54            }
55        }
56
57        impl ::core::ops::DerefMut for #struct_name {
58            fn deref_mut(&mut self) -> &mut Self::Target {
59                &mut self._base
60            }
61        }
62    )
63    .into()
64}
65
66fn into_path_segment(ident: &&syn::Ident) -> syn::PathSegment {
67    (*ident).clone().into()
68}
69
70#[proc_macro_attribute]
71pub fn inherit_from_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
72    let mut impl_block = syn::parse_macro_input!(item as syn::ItemImpl);
73    let InheritImplAttr { class, header, .. } = syn::parse_macro_input!(attr as InheritImplAttr);
74
75    let header = header.value();
76
77    // List of methods with #[overridden] attrbiute
78    let mut override_items = impl_block
79        .items
80        .iter_mut()
81        .filter_map(|item| {
82            if let ImplItem::Method(ref mut method) = item {
83                filter_overrides(method)
84            } else {
85                None
86            }
87        })
88        .collect::<Vec<_>>();
89
90    // Make all override methods `extern "C"`
91    override_items.iter_mut().for_each(make_extern_c);
92
93    // Remove fake overridden attributes
94    override_items.iter_mut().for_each(remove_override_attr);
95
96    let vtable_info = vtable::get_vtable_info(&header, &class.to_string());
97
98    // List of method override names
99    let override_list = override_items
100        .into_iter()
101        .map(|method| method.sig.ident.clone())
102        .collect::<Vec<_>>();
103
104    let type_ident = match *impl_block.self_ty {
105        syn::Type::Path(ref path) => path.path.get_ident().expect("Class type must be an ident"),
106        _ => panic!("Class type must be an ident"), // Error about how class type must be ident
107    };
108
109    match vtable_info.get(&class.to_string()) {
110        Some(base_type_vtable) => {
111            // Generate a vtable before overrides
112            let base_vtable: Vec<Option<Path>> = vec![None; base_type_vtable.len()];
113
114            let mut vtable = base_vtable;
115
116            // Apply each override to the base vtable
117            for o in override_list {
118                match base_type_vtable.binary_search_by_key(&&o.to_string(), |entry| &entry.name) {
119                    Ok(index) => {
120                        vtable[index] = Some(Path {
121                            leading_colon: None,
122                            //          $class::$method
123                            segments: [&type_ident, &o].iter().map(into_path_segment).collect(),
124                        });
125                    }
126                    Err(..) => panic!("Cannot override a virtual method that doesn't exist in the original vtable"),
127                }
128            }
129
130            let mut bindings_to_gen = vec![];
131
132            let vtable = vtable
133                .into_iter()
134                .enumerate()
135                .map(|(i, x)| {
136                    x.unwrap_or_else(|| {
137                        bindings_to_gen.push(base_type_vtable[i].default.deref());
138
139                        vtable::get_binding_symbol(&base_type_vtable[i].default).into()
140                    })
141                })
142                .collect();
143
144            let self_type = &impl_block.self_ty;
145
146            let vtable_const = generate_vtable_const(vtable, self_type);
147
148            let bindings = bindings_to_gen.into_iter().map(vtable::generate_binding);
149
150            quote!(
151                #impl_block
152
153                #vtable_const
154
155                #(
156                    #bindings
157                )*
158            )
159            .into()
160        }
161        None => panic!("Class does not exist in header"), // add compiler error for class not existing in header
162    }
163}