1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
use std::ops::Deref;

use proc_macro::TokenStream;
use quote::quote;
use syn::{Fields, ImplItem, Path};

mod method_helpers;
use method_helpers::{filter_overrides, make_extern_c, remove_override_attr};

mod parsers;
use parsers::{InheritImplAttr, NamedField};

mod vtable;
use vtable::generate_vtable_const;

#[proc_macro_attribute]
pub fn inherit_from(attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut struct_def = syn::parse_macro_input!(item as syn::ItemStruct);
    let ty = syn::parse_macro_input!(attr as syn::Type);

    let fields = match struct_def.fields {
        Fields::Named(ref mut fields) => &mut fields.named,
        Fields::Unit => {
            struct_def.fields = Fields::Named(syn::parse_quote!({}));
            if let Fields::Named(ref mut fields) = struct_def.fields {
                &mut fields.named
            } else {
                unreachable!()
            }
        }
        _ => panic!("Tuple-type structs cannot inherit from classes"),
    };

    let base_field: NamedField = syn::parse_quote!(
        _base: #ty
    );

    fields.insert(0, base_field.0);

    let struct_name = &struct_def.ident;

    struct_def.attrs.push(syn::parse_quote! {
        #[repr(C)]
    });

    quote!(
        #struct_def

        impl ::core::ops::Deref for #struct_name {
            type Target = #ty;

            fn deref(&self) -> &Self::Target {
                &self._base
            }
        }

        impl ::core::ops::DerefMut for #struct_name {
            fn deref_mut(&mut self) -> &mut Self::Target {
                &mut self._base
            }
        }
    )
    .into()
}

fn into_path_segment(ident: &&syn::Ident) -> syn::PathSegment {
    (*ident).clone().into()
}

#[proc_macro_attribute]
pub fn inherit_from_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut impl_block = syn::parse_macro_input!(item as syn::ItemImpl);
    let InheritImplAttr { class, header, .. } = syn::parse_macro_input!(attr as InheritImplAttr);

    let header = header.value();

    // List of methods with #[overridden] attrbiute
    let mut override_items = impl_block
        .items
        .iter_mut()
        .filter_map(|item| {
            if let ImplItem::Method(ref mut method) = item {
                filter_overrides(method)
            } else {
                None
            }
        })
        .collect::<Vec<_>>();

    // Make all override methods `extern "C"`
    override_items.iter_mut().for_each(make_extern_c);

    // Remove fake overridden attributes
    override_items.iter_mut().for_each(remove_override_attr);

    let vtable_info = vtable::get_vtable_info(&header, &class.to_string());

    // List of method override names
    let override_list = override_items
        .into_iter()
        .map(|method| method.sig.ident.clone())
        .collect::<Vec<_>>();

    let type_ident = match *impl_block.self_ty {
        syn::Type::Path(ref path) => path.path.get_ident().expect("Class type must be an ident"),
        _ => panic!("Class type must be an ident"), // Error about how class type must be ident
    };

    match vtable_info.get(&class.to_string()) {
        Some(base_type_vtable) => {
            // Generate a vtable before overrides
            let base_vtable: Vec<Option<Path>> = vec![None; base_type_vtable.len()];

            let mut vtable = base_vtable;

            // Apply each override to the base vtable
            for o in override_list {
                match base_type_vtable.binary_search_by_key(&&o.to_string(), |entry| &entry.name) {
                    Ok(index) => {
                        vtable[index] = Some(Path {
                            leading_colon: None,
                            //          $class::$method
                            segments: [&type_ident, &o].iter().map(into_path_segment).collect(),
                        });
                    }
                    Err(..) => panic!("Cannot override a virtual method that doesn't exist in the original vtable"),
                }
            }

            let mut bindings_to_gen = vec![];

            let vtable = vtable
                .into_iter()
                .enumerate()
                .map(|(i, x)| {
                    x.unwrap_or_else(|| {
                        bindings_to_gen.push(base_type_vtable[i].default.deref());

                        vtable::get_binding_symbol(&base_type_vtable[i].default).into()
                    })
                })
                .collect();

            let self_type = &impl_block.self_ty;

            let vtable_const = generate_vtable_const(vtable, self_type);

            let bindings = bindings_to_gen.into_iter().map(vtable::generate_binding);

            quote!(
                #impl_block

                #vtable_const

                #(
                    #bindings
                )*
            )
            .into()
        }
        None => panic!("Class does not exist in header"), // add compiler error for class not existing in header
    }
}