Skip to main content

plusplus_macros/
lib.rs

1use proc_macro_crate::{FoundCrate, crate_name};
2use proc_macro2::{Span, TokenStream};
3use quote::{format_ident, quote};
4use syn::{AngleBracketedGenericArguments, CapturedParam, GenericArgument, GenericParam, Lifetime, LifetimeParam, Path, PathArguments, ReturnType, Signature, Token, TypeParamBound};
5use syn::parse::{Parse, ParseStream};
6use syn::token::Brace;
7use syn::{
8    Attribute, Field, FnArg, Ident, ImplItemFn, Pat, Type, Visibility, braced,
9    parse_macro_input, parse_quote,
10};
11
12mod kw {
13    syn::custom_keyword!(class);
14}
15
16#[derive(Debug, Clone)]
17struct OverrideItem {
18    _override_token: Token![override],
19    override_class: Type,
20    _brace_token: Brace,
21    items: Vec<ImplItemFn>,
22}
23
24fn correct_priv_vis(vis: Visibility) -> Visibility {
25    match vis {
26        Visibility::Public(p) => Visibility::Public(p),
27        Visibility::Restricted(mut restricted) => {
28            if restricted.path.segments.get(0) != Some(&parse_quote!(crate)) {
29                restricted.path.segments.insert(0, parse_quote!(super));
30            }
31            Visibility::Restricted(restricted)
32        }
33        Visibility::Inherited => {
34            parse_quote!(pub(super))
35        }
36    }
37}
38
39impl Parse for OverrideItem {
40    fn parse(input: ParseStream) -> syn::Result<Self> {
41        let content;
42        let override_token = input.parse()?;
43        let override_class = input.parse()?;
44        let brace_token = braced!(content in input);
45        let mut items = Vec::new();
46        while !content.is_empty() {
47            let mut item: ImplItemFn = content.parse()?;
48            item.vis = correct_priv_vis(item.vis);
49            items.push(item);
50        }
51
52        Ok(OverrideItem {
53            _override_token: override_token,
54            override_class,
55            _brace_token: brace_token,
56            items,
57        })
58    }
59}
60
61#[derive(Debug, Clone)]
62enum ClassItem {
63    Field {
64        field: Field,
65        _semi_token: Token![;],
66    },
67    ImplItemFn(ImplItemFn),
68    OverrideItem(OverrideItem),
69}
70
71impl Parse for ClassItem {
72    fn parse(input: ParseStream) -> syn::Result<Self> {
73        let lookahead = input.lookahead1();
74        if lookahead.peek(Token![override]) {
75            return Ok(ClassItem::OverrideItem(OverrideItem::parse(input)?));
76        }
77
78        let begin = input.fork();
79        let _attrs = input.call(Attribute::parse_outer)?;
80        let _vis: Visibility = begin.parse()?;
81        let lookahead = begin.lookahead1();
82        if lookahead.peek(Token![fn]) || lookahead.peek(Token![unsafe]) || lookahead.peek(Token![async]) {
83            return Ok(ClassItem::ImplItemFn(ImplItemFn::parse(input)?));
84        }
85
86        let mut field = Field::parse_named(input)?;
87        field.vis = correct_priv_vis(field.vis);
88        let semi_token = input.parse()?;
89        Ok(ClassItem::Field {
90            field,
91            _semi_token: semi_token,
92        })
93    }
94}
95
96#[derive(Debug, Clone)]
97struct SuperclassInput {
98    _colon_token: Token![:],
99    ty: Type,
100}
101
102impl Parse for SuperclassInput {
103    fn parse(input: ParseStream) -> syn::Result<Self> {
104        Ok(SuperclassInput {
105            _colon_token: input.parse()?,
106            ty: input.parse()?,
107        })
108    }
109}
110
111#[derive(Debug, Clone)]
112struct ClassInput {
113    vis: Visibility,
114    _class_token: kw::class,
115    ident: Ident,
116    superclass: Option<SuperclassInput>,
117    _brace_token: Brace,
118    items: Vec<ClassItem>,
119}
120
121impl Parse for ClassInput {
122    fn parse(input: ParseStream) -> syn::Result<Self> {
123        let content;
124        let vis = input.parse()?;
125        let class_token = input.parse()?;
126        let ident = input.parse()?;
127
128        let lookahead = input.lookahead1();
129        let superclass = if lookahead.peek(Token![:]) {
130            Some(input.parse()?)
131        } else {
132            None
133        };
134
135        let brace_token = braced!(content in input);
136        let mut items = Vec::new();
137        while !content.is_empty() {
138            items.push(content.parse()?);
139        }
140        Ok(ClassInput {
141            vis,
142            _class_token: class_token,
143            ident,
144            superclass,
145            _brace_token: brace_token,
146            items,
147        })
148    }
149}
150
151#[derive(Debug, Clone)]
152struct CrateAlias {
153    _crate_token: Token![crate],
154    _as_token: Token![as],
155    ident: Ident,
156    _semi_token: Token![;],
157}
158
159impl Parse for CrateAlias {
160    fn parse(input: ParseStream) -> syn::Result<Self> {
161        Ok(CrateAlias {
162            _crate_token: input.parse()?,
163            _as_token: input.parse()?,
164            ident: input.parse()?,
165            _semi_token: input.parse()?,
166        })
167    }
168}
169
170struct ClassInputs {
171    crate_alias: Option<CrateAlias>,
172    inputs: Vec<ClassInput>,
173}
174
175impl Parse for ClassInputs {
176    fn parse(input: ParseStream) -> syn::Result<Self> {
177        let lookahead = input.lookahead1();
178        let crate_alias = if lookahead.peek(Token![crate]) {
179            Some(input.parse()?)
180        } else {
181            None
182        };
183
184        let mut inputs = Vec::new();
185        while !input.is_empty() {
186            inputs.push(input.parse()?);
187        }
188
189        Ok(ClassInputs {
190            crate_alias,
191            inputs,
192        })
193    }
194}
195
196fn plusplus() -> proc_macro2::TokenStream {
197    let found_crate = crate_name("plusplus").expect("plusplus is present in `Cargo.toml`");
198    match found_crate {
199        FoundCrate::Itself => quote!(crate),
200        FoundCrate::Name(name) => {
201            let ident = Ident::new(&name, Span::call_site());
202            quote!( #ident )
203        }
204    }
205}
206
207fn cast_class_ptr(
208    plusplus: &proc_macro2::TokenStream,
209    from: &Type,
210    to: &Type,
211    expr: impl Into<proc_macro2::TokenStream>,
212) -> proc_macro2::TokenStream {
213    let expr = expr.into();
214    quote! {{
215        let t: &#from = #expr;
216        let self_size = std::mem::size_of_val(t);
217        let target_size = std::mem::size_of::<#to<#plusplus::InConstruction>>();
218        assert!(self_size >= target_size);
219        let array_size = self_size - target_size;
220        let target_ptr = std::ptr::slice_from_raw_parts(t as *const #from as *const u8, array_size);
221        let target_ref = &*(target_ptr as *const #to);
222        assert_eq!(self_size, std::mem::size_of_val(target_ref));
223        target_ref
224    }}
225}
226
227fn cast_class_ptr_mut(
228    plusplus: &TokenStream,
229    from: &Type,
230    to: &Type,
231    expr: impl Into<TokenStream>,
232) -> TokenStream {
233    let expr = expr.into();
234    quote! {{
235        let t: &mut #from = #expr;
236        let self_size = std::mem::size_of_val(t);
237        let target_size = std::mem::size_of::<#to<#plusplus::InConstruction>>();
238        assert!(self_size >= target_size);
239        let array_size = self_size - target_size;
240        let target_ptr = std::ptr::slice_from_raw_parts_mut(t as *mut #from as *mut u8, array_size);
241        let target_ref = &mut *(target_ptr as *mut #to);
242        assert_eq!(self_size, std::mem::size_of_val(target_ref));
243        target_ref
244    }}
245}
246
247fn set_arg_blank_lifetime(arg: &mut FnArg, lifetime: &Lifetime, lifetime_set: &mut bool) {
248    match arg {
249        FnArg::Receiver(rx) => {
250            if let Some((_, lt)) = &mut rx.reference {
251                maybe_set_lifetime_opt(lt, lifetime, lifetime_set);
252            }
253        }
254        FnArg::Typed(ty) => {
255            set_blank_type_lifetimes(&mut ty.ty, lifetime, lifetime_set);
256        }
257    }
258}
259
260fn set_blank_type_lifetimes(ty: &mut Type, lifetime: &Lifetime, lifetime_set: &mut bool) {
261    match ty {
262        Type::Array(arr) => set_blank_type_lifetimes(&mut arr.elem, lifetime, lifetime_set),
263        Type::BareFn(_) => {}
264        Type::Group(group) => set_blank_type_lifetimes(&mut group.elem, lifetime, lifetime_set),
265        Type::ImplTrait(_) => {} // technically we should implement this but it doesn't come up in our use
266        Type::Infer(_) => {}
267        Type::Macro(_) => {}
268        Type::Never(_) => {}
269        Type::Paren(paren) => set_blank_type_lifetimes(&mut paren.elem, lifetime, lifetime_set),
270        Type::Path(path) => {
271            if let Some(qself) = &mut path.qself {
272                set_blank_type_lifetimes(&mut qself.ty, lifetime, lifetime_set);
273            }
274            set_blank_path_lifetimes(&mut path.path, lifetime, lifetime_set);
275        }
276        Type::Ptr(ptr) => set_blank_type_lifetimes(&mut ptr.elem, lifetime, lifetime_set),
277        Type::Reference(refer) => {
278            maybe_set_lifetime_opt(&mut refer.lifetime, lifetime, lifetime_set);
279            set_blank_type_lifetimes(&mut refer.elem, lifetime, lifetime_set)
280        }
281        Type::Slice(slice) => set_blank_type_lifetimes(&mut slice.elem, lifetime, lifetime_set),
282        Type::TraitObject(obj) => set_blank_type_param_bounds(&mut obj.bounds, lifetime, lifetime_set),
283        Type::Tuple(tup) => {
284            for ty in &mut tup.elems {
285                set_blank_type_lifetimes(ty, lifetime, lifetime_set);
286            }
287        }
288        Type::Verbatim(_) => {}
289        _ => ()
290    }
291}
292
293fn set_blank_path_lifetimes(path: &mut Path, lifetime: &Lifetime, lifetime_set: &mut bool) {
294    for path_args in path.segments.iter_mut().map(|seg| &mut seg.arguments) {
295        match path_args {
296            PathArguments::None => {}
297            PathArguments::AngleBracketed(angle_args) => set_blank_angle_bracket_lifetimes(angle_args, lifetime, lifetime_set),
298            PathArguments::Parenthesized(paren_args) => {
299                for ty in &mut paren_args.inputs {
300                    set_blank_type_lifetimes(ty, lifetime, lifetime_set);
301                }
302                if let ReturnType::Type(_, ty) = &mut paren_args.output {
303                    set_blank_type_lifetimes(ty, lifetime, lifetime_set);
304                }
305            }
306        }
307    }
308}
309
310fn set_blank_angle_bracket_lifetimes(angle: &mut AngleBracketedGenericArguments, lifetime: &Lifetime, lifetime_set: &mut bool) {
311    for arg in &mut angle.args {
312        set_blank_generic_lifetimes(arg, lifetime, lifetime_set);
313    }
314}
315
316fn maybe_set_lifetime(set: &mut Lifetime, to: &Lifetime, lifetime_set: &mut bool) {
317    if set.ident == format_ident!("_") {
318        *set = to.clone();
319        *lifetime_set = true;
320    }
321}
322fn maybe_set_lifetime_opt(set: &mut Option<Lifetime>, to: &Lifetime, lifetime_set: &mut bool) {
323    match set {
324        Some(lt) => maybe_set_lifetime(lt, to, lifetime_set),
325        None => {
326            *set = Some(to.clone());
327            *lifetime_set = true;
328        }
329    }
330}
331
332fn set_blank_type_param_bounds<'a>(bounds: impl IntoIterator<Item=&'a mut TypeParamBound>, lifetime: &Lifetime, lifetime_set: &mut bool) {
333    for bound in bounds {
334        match bound {
335            TypeParamBound::Trait(trait_bound) => {set_blank_path_lifetimes(&mut trait_bound.path, lifetime, lifetime_set)}
336            TypeParamBound::Lifetime(lt) => maybe_set_lifetime(lt, lifetime, lifetime_set),
337            TypeParamBound::PreciseCapture(cap) => {
338                for p in &mut cap.params {
339                    match p {
340                        CapturedParam::Lifetime(lt) => maybe_set_lifetime(lt, lifetime, lifetime_set),
341                        CapturedParam::Ident(_) => {}
342                        _ => ()
343                    }
344                }
345            }
346            TypeParamBound::Verbatim(_) => {}
347            _ => (),
348        }
349    }
350}
351fn set_blank_generic_lifetimes(arg: &mut GenericArgument, lifetime: &Lifetime, lifetime_set: &mut bool) {
352    match arg {
353        GenericArgument::Lifetime(lt) => maybe_set_lifetime(lt, lifetime, lifetime_set),
354        GenericArgument::Type(ty) => set_blank_type_lifetimes(ty, lifetime, lifetime_set),
355        GenericArgument::Const(_) => {}
356        GenericArgument::AssocType(ty) => {
357            if let Some(args) = &mut ty.generics {
358                set_blank_angle_bracket_lifetimes(args, lifetime, lifetime_set);
359            }
360            set_blank_type_lifetimes(&mut ty.ty, lifetime, lifetime_set);
361        }
362        GenericArgument::AssocConst(_) => {}
363        GenericArgument::Constraint(constraint) => {
364            if let Some(args) = &mut constraint.generics {
365                set_blank_angle_bracket_lifetimes(args, lifetime, lifetime_set);
366            }
367
368            set_blank_type_param_bounds(&mut constraint.bounds, lifetime, lifetime_set);
369        },
370        _ => (),
371    }
372}
373
374#[derive(Debug, Clone)]
375struct FuncInfo {
376    func: ImplItemFn,
377    name: Ident,
378    vtbl_name: Ident,
379    lifetime_bounds: Option<TokenStream>,
380    vtbl_sig: TokenStream,
381    args: Vec<Box<Pat>>,
382    mut_self: bool,
383}
384
385fn get_func_sig(class_name: &Type, f: &ImplItemFn) -> FuncInfo {
386    let func_name = &f.sig.ident;
387    let is_async = f.sig.asyncness.is_some();
388
389    let mut inputs = f.sig.inputs.clone();
390    let mut mut_self = false;
391    if let Some(FnArg::Receiver(rx)) = &f.sig.inputs.get(0) {
392        let receiver = &mut inputs[0];
393        if let Some((_, rx_lifetime)) = &rx.reference {
394
395            if rx.mutability.is_some() {
396                *receiver = FnArg::Typed(parse_quote!(this: & #rx_lifetime mut #class_name));
397            } else {
398                *receiver = FnArg::Typed(parse_quote!(this: & #rx_lifetime #class_name));
399            }
400        }
401        mut_self = rx.mutability.is_some();
402    };
403
404    let mut lifetimes: Vec<LifetimeParam> = Vec::new();
405    let future_fallback_lt: Lifetime = parse_quote!('rpp_future);
406    let mut using_future_lifetime = false;
407    let mut vtbl_inputs = inputs.clone();
408    // eprintln!("inputs = {:#?}", vtbl_inputs);
409    if is_async {
410        for arg in &mut vtbl_inputs {
411            set_arg_blank_lifetime(arg, &future_fallback_lt, &mut using_future_lifetime);
412        }
413        if using_future_lifetime {
414            lifetimes.push(parse_quote!(#future_fallback_lt));
415        }
416    }
417    for generic in &f.sig.generics.params {
418        match generic {
419            GenericParam::Lifetime(lt) => {
420                let mut lt = lt.clone();
421                if using_future_lifetime {
422                    lt.bounds.push(future_fallback_lt.clone());
423                }
424                lifetimes.push(lt)
425                // if using_future_lifetime {
426                //     lifetimes.push(parse_quote!(#lt: #future_fallback_lt));
427                // } else {
428                //     lifetimes.push(parse_quote!(#lt));
429                // }
430            }
431            GenericParam::Type(_) => {}
432            GenericParam::Const(_) => {}
433        }
434    }
435
436    let lifetime_bounds = if lifetimes.len() > 0 {
437        Some(quote!(<#(#lifetimes),*>))
438    } else {
439        None
440    };
441
442    let mut output = f.sig.output.clone();
443    if is_async {
444        let future_output = match output {
445            ReturnType::Default => quote!(()),
446            ReturnType::Type(_, ty) => quote!(#ty),
447        };
448        let future_lifetimes = if using_future_lifetime {
449            future_fallback_lt
450        } else {
451            assert!(lifetimes.len() > 0);
452            lifetimes[0].lifetime.clone()
453        };
454        output = parse_quote!(-> std::pin::Pin<Box<dyn #future_lifetimes + Future<Output=#future_output>>>);
455    }
456    let vtbl_name = format_ident!("fn_{func_name}");
457    let vtbl_sig = quote! {
458        (#vtbl_inputs) #output
459    };
460
461    let func_args = inputs
462        .into_iter()
463        .skip(1)
464        .map(|arg| match arg {
465            FnArg::Receiver(_) => unreachable!(),
466            FnArg::Typed(arg) => arg.pat,
467        })
468        .collect::<Vec<_>>();
469
470    FuncInfo {
471        func: f.clone(),
472        name: func_name.clone(),
473        vtbl_name,
474        lifetime_bounds,
475        vtbl_sig,
476        args: func_args,
477        mut_self,
478    }
479}
480
481struct ClassData {
482    plusplus: TokenStream,
483    class_name: Ident,
484    class_type: Type,
485    class_mod_name: Ident,
486    vtbl_ident: Ident,
487    class_vis: Visibility,
488    fields: Vec<Field>,
489    constructors: Vec<FuncInfo>,
490    member_funcs: Vec<FuncInfo>,
491    overrides: Vec<OverrideItem>,
492    override_funcs: Vec<FuncInfo>,
493    superclass_type: Option<Type>,
494    mod_superclass_type: Option<Type>,
495}
496
497impl ClassData {
498    fn from_input(input: ClassInput, crate_alias: Option<&CrateAlias>) -> ClassData {
499        let ClassInput {
500            vis: class_vis,
501            _class_token: _,
502            ident: class_name,
503            superclass,
504            _brace_token: _,
505            items: class_items,
506        } = input;
507
508        let mut fields = Vec::new();
509        let mut constructors = Vec::new();
510        let mut member_funcs = Vec::new();
511        let mut overrides = Vec::new();
512        let mut override_funcs = Vec::new();
513
514        let class_type: Type = parse_quote!(#class_name);
515
516        for item in class_items {
517            match item {
518                ClassItem::Field { field, .. } => fields.push(field),
519                ClassItem::ImplItemFn(func) => {
520                    if let Some(FnArg::Receiver(_)) = func.sig.inputs.get(0) {
521                        member_funcs.push(get_func_sig(&class_type, &func))
522                    } else {
523                        constructors.push(get_func_sig(&class_type, &func));
524                    }
525                }
526                ClassItem::OverrideItem(override_item) => {
527                    override_funcs.extend(override_item.items.iter().map(|f| get_func_sig(&class_type, &f)));
528                    overrides.push(override_item);
529                }
530            }
531        }
532
533        let class_mod_name = format_ident!("plusplus__class_{}", class_name.to_string().to_lowercase());
534
535        let plusplus = if let Some(alias) = crate_alias {
536            let alias = &alias.ident;
537            quote!(#alias)
538        } else {
539            plusplus()
540        };
541
542        let superclass_type = superclass.map(|sc| sc.ty);
543        let mod_superclass_type = match superclass_type.clone() {
544            Some(Type::Path(mut type_path)) => {
545                if type_path.path.segments.get(0) != Some(&parse_quote!(crate)) {
546                    type_path.path.segments.insert(0, parse_quote!(super));
547                }
548                Some(Type::Path(type_path))
549            }
550            ty => ty
551        };
552
553        ClassData {
554            plusplus,
555            vtbl_ident: format_ident!("{}Vtbl", class_name),
556            class_vis,
557            class_name,
558            class_type,
559            class_mod_name,
560            fields,
561            constructors,
562            member_funcs,
563            overrides,
564            override_funcs,
565            superclass_type,
566            mod_superclass_type,
567        }
568    }
569
570    fn has_superclass(&self) -> bool {
571        self.superclass_type.is_some()
572    }
573
574    fn gen_mod_vtbl_struct(&self) -> TokenStream {
575        // create vtbl
576        let vtbl_ident = &self.vtbl_ident;
577        let class_name = &self.class_name;
578
579        let mut vtbl_func_names = Vec::new();
580        let mut vtbl_sigs = Vec::new();
581        let mut my_func_names = Vec::new();
582        let mut async_func_impls = Vec::new();
583        let mut func_setters = Vec::new();
584        let mut vtbl_fors = Vec::new();
585        let mut vtbl_unsafes = Vec::new();
586
587        for f in self.member_funcs.iter() {
588            let FuncInfo {
589                func: ImplItemFn{ sig: Signature {asyncness, unsafety, .. }, .. },
590                name: func_name,
591                vtbl_name: vtbl_func_name,
592                lifetime_bounds,
593                vtbl_sig,
594                args,
595                mut_self: _,
596            } = f;
597            let my_func_name = format_ident!("my_{func_name}");
598            let vtbl_for = lifetime_bounds.as_ref().map(|bounds| quote!(for #bounds));
599            vtbl_fors.push(vtbl_for.clone());
600            vtbl_unsafes.push(unsafety);
601
602            if asyncness.is_some() {
603                async_func_impls.push(quote!{
604                    // fn #my_func_name #func_sig {
605                    let #my_func_name: #vtbl_for #unsafety fn #vtbl_sig = |this, #(#args)*| #unsafety {
606                        Box::pin(#class_name::#my_func_name(this, #(#args)*))
607                    };
608                });
609                func_setters.push(quote!{
610                    #vtbl_func_name: #my_func_name,
611                })
612            } else {
613                func_setters.push(quote!{
614                     #vtbl_func_name: #class_name::#my_func_name,
615                })
616            }
617
618            my_func_names.push(my_func_name);
619            vtbl_func_names.push(vtbl_func_name);
620            vtbl_sigs.push(vtbl_sig);
621        }
622
623        let vtbl_drop_field: Option<_>;
624        let vtbl_drop_func: Option<_>;
625        let vtbl_drop_set: Option<_>;
626        if !self.has_superclass() {
627            vtbl_drop_field = Some(quote! {
628                pub manually_drop: unsafe fn(*mut #class_name),
629            });
630            vtbl_drop_func = Some(quote! {
631                unsafe fn manually_drop(this: *mut #class_name) {
632                    unsafe{ std::ptr::drop_in_place(this) }
633                }
634            });
635            vtbl_drop_set = Some(quote! {
636                manually_drop,
637            });
638        } else {
639            vtbl_drop_field = None;
640            vtbl_drop_func = None;
641            vtbl_drop_set = None;
642        };
643
644        quote! {
645            #[doc(hidden)]
646            pub struct #vtbl_ident {
647                #vtbl_drop_field
648                #(pub #vtbl_func_names: #vtbl_fors #vtbl_unsafes fn #vtbl_sigs,)*
649            }
650
651            impl #vtbl_ident {
652                const BASE: Self = {
653                    #vtbl_drop_func
654                    #(#async_func_impls)*
655
656                    Self {
657                        #vtbl_drop_set
658                        #(#func_setters)*
659                    }
660                };
661            }
662        }
663    }
664
665    fn gen_fn_set_vtbls(&self) -> TokenStream {
666        let plusplus = &self.plusplus;
667        let class_name = &self.class_name;
668
669        let set_vtbls = self.overrides.iter().map(|ovr| {
670            let ovr_class = &ovr.override_class;
671            let ovr_class = match ovr_class.clone() {
672                Type::Path(mut type_path) => {
673                    type_path.path.segments.insert(0, parse_quote!(super));
674                    Type::Path(type_path)
675                }
676                ty => ty
677            };
678
679            let mut ol_func_names = Vec::new();
680            let mut ol_func_sigs = Vec::new();
681            let mut ol_func_self_call_impls = Vec::new();
682            let mut ol_lifetime_bounds = Vec::new();
683            for f in &ovr.items {
684                let FuncInfo {
685                    func: ImplItemFn {sig: Signature{ asyncness, unsafety, .. }, ..},
686                    name: func_name,
687                    vtbl_name,
688                    lifetime_bounds,
689                    vtbl_sig: func_sig,
690                    args: func_args,
691                    mut_self,
692                } = get_func_sig(&ovr_class, f);
693                ol_func_names.push(vtbl_name);
694                ol_func_sigs.push(func_sig);
695                ol_lifetime_bounds.push(lifetime_bounds);
696
697                let func_name = format_ident!("my_{}", func_name);
698                let make_this = if mut_self {
699                    let cast_mut = cast_class_ptr_mut(
700                        &plusplus,
701                        &ovr_class,
702                        &self.class_type,
703                        quote! {this},
704                    );
705                    quote! {
706                        let this: &mut #class_name = unsafe{ #cast_mut };
707                    }
708                } else {
709                    let cast = cast_class_ptr(&plusplus, &ovr_class, &self.class_type, quote! {this});
710                    quote! {
711                        let this: &#class_name = unsafe{ #cast };
712                    }
713                };
714                let self_call = if asyncness.is_some() {
715                    quote! {
716                        #make_this
717                        #unsafety { Box::pin(this.#func_name(#(#func_args,)*)) }
718                    }
719                } else {
720                    quote! {
721                        #make_this
722                        #unsafety { this.#func_name(#(#func_args,)*) }
723                    }
724                };
725                ol_func_self_call_impls.push(self_call);
726            }
727
728            quote! {{
729                let this: &mut #ovr_class = &mut *(unsafe{ self.to_constructed() });
730                #(
731                    fn #ol_func_names #ol_lifetime_bounds #ol_func_sigs {
732                        #ol_func_self_call_impls
733                    }
734                    unsafe{ this.plusplus__vtbl_mut().#ol_func_names = #ol_func_names };
735                )*
736            }}
737        });
738
739        let root_type: Type = parse_quote!(<#class_name as #plusplus::Class>::RootClass);
740        let cast_root_to_self =
741            cast_class_ptr_mut(&plusplus, &root_type, &self.class_type, quote!(ref_mut));
742        let root_type = &root_type;
743
744        let set_subclass = self.has_superclass().then(|| quote!{
745            unsafe{ self.superclass.plusplus__set_subclass(<#class_name as #plusplus::Class>::TYPE_ID) };
746
747            {
748                unsafe fn manually_drop(this: *mut #root_type) {
749                    let ref_mut = unsafe{ &mut *this };
750                    let this = unsafe{ #cast_root_to_self };
751                    unsafe{ std::ptr::drop_in_place(this) };
752                }
753                let root_vtbl = unsafe{ <#class_name as #plusplus::Class>::root_class_mut(self.to_constructed()).plusplus__vtbl_mut() };
754                root_vtbl.manually_drop = manually_drop;
755            }
756        });
757
758        quote! {
759            fn plusplus__set_vtbls(&mut self) {
760                #set_subclass
761                #(#set_vtbls)*
762            }
763        }
764    }
765
766    fn gen_mod_class_struct(&self) -> TokenStream {
767        let plusplus = &self.plusplus;
768
769        let superclass_field = self.mod_superclass_type.as_ref().map(|sc_type| {
770            quote! {
771                superclass: #sc_type<#plusplus::InConstruction>,
772            }
773        });
774        let superclass_field = superclass_field.as_ref();
775        let superclass_bound = self.mod_superclass_type.as_ref().map(|sc_ident| quote!{
776            where #sc_ident: #plusplus::Class
777        });
778
779        let class_struct_vis = correct_priv_vis(self.class_vis.clone());
780        let vtbl_ident = &self.vtbl_ident;
781        let class_name = &self.class_name;
782        let fields = &self.fields;
783        let init_superclass_field = superclass_field.map(|f| quote!(pub #f));
784        let init_fields = self.fields.iter().cloned().map(|f| Field {
785            vis: Visibility::Public(parse_quote!(pub)),
786            ..f
787        });
788
789        quote! {
790            #[repr(C)]
791            #class_struct_vis struct #class_name<C: ?Sized + #plusplus::ClassMemory = #plusplus::Constructed>
792                #superclass_bound
793            {
794                #superclass_field
795                vtbl: #vtbl_ident,
796                // we store a reference to a type ID instead of a type ID as a size optimization;
797                // an Option<TypeId> is 24 bytes, while an Option<&TypeId> is 8 bytes
798                subclass_id: Option<&'static std::any::TypeId>,
799                #(#fields,)*
800                memory: C,
801            }
802
803            pub struct PlusPlus__InitClass {
804                #init_superclass_field
805                #(#init_fields,)*
806            }
807        }
808    }
809
810    fn gen_superclass_casters(&self) -> Option<TokenStream> {
811        let Some(sc_type) = self.superclass_type.as_ref() else {
812            return None;
813        };
814        let class_name = &self.class_name;
815
816        let deref_upcast = cast_class_ptr(&self.plusplus, &self.class_type, sc_type, quote! {self});
817        let deref_upcast_mut = cast_class_ptr_mut(
818            &self.plusplus,
819            &self.class_type,
820            sc_type,
821            quote! {self},
822        );
823        let ref_downcast = cast_class_ptr(&self.plusplus, sc_type, &self.class_type, quote! {self});
824        let ref_downcast_mut = cast_class_ptr_mut(
825            &self.plusplus,
826            sc_type,
827            &self.class_type,
828            quote! {self},
829        );
830
831        let plusplus = &self.plusplus;
832        Some(quote! {
833            impl std::ops::Deref for #class_name {
834                type Target = #sc_type;
835
836                fn deref(&self) -> &Self::Target {
837                    unsafe { #deref_upcast }
838                }
839            }
840
841            impl std::ops::DerefMut for #class_name {
842                fn deref_mut(&mut self) -> &mut Self::Target {
843                    unsafe { #deref_upcast_mut }
844                }
845            }
846
847            impl<'a> #plusplus::Downcast<#class_name> for &'a #sc_type {
848                type Wrapped = &'a #class_name;
849                fn downcast(self) -> Result<&'a #class_name, Self> {
850                    use #plusplus::Class;
851                    let subclass_type_id = <#class_name as #plusplus::Class>::TYPE_ID;
852                    if self.subclass_id() == Some(subclass_type_id) {
853                        Ok(unsafe{ #ref_downcast })
854                    } else {
855                        Err(self)
856                    }
857                }
858            }
859
860            impl<'a> #plusplus::Downcast<#class_name> for &'a mut #sc_type {
861                type Wrapped = &'a mut #class_name;
862                fn downcast(self) -> Result<&'a mut #class_name, Self> {
863                    use #plusplus::Class;
864                    let subclass_type_id = <#class_name as #plusplus::Class>::TYPE_ID;
865                    if self.subclass_id() == Some(subclass_type_id) {
866                        Ok(unsafe{ #ref_downcast_mut })
867                    } else {
868                        Err(self)
869                    }
870                }
871            }
872        })
873    }
874
875    fn gen_mod_impl_class_trait(&self) -> TokenStream {
876        let plusplus = &self.plusplus;
877        let class_name = &self.class_name;
878        let root_class = self.mod_superclass_type
879            .as_ref()
880            .map(|sc_ident| quote! { <#sc_ident as #plusplus::Class>::RootClass })
881            .unwrap_or_else(|| quote! { #class_name });
882
883        quote! {
884            unsafe impl #plusplus::Class for #class_name {
885                const TYPE_ID: &'static std::any::TypeId = &std::any::TypeId::of::<#class_name>();
886
887                type RootClass = #root_class;
888
889                fn subclass_id(&self) -> Option<&'static std::any::TypeId> {
890                    self.subclass_id
891                }
892
893                fn root_class(&self) -> &Self::RootClass {
894                    self
895                }
896
897                fn root_class_mut(&mut self) -> &mut Self::RootClass {
898                    self
899                }
900
901                unsafe fn manually_drop(slot: &mut std::mem::ManuallyDrop<Self>) {
902                    let as_root_class = slot.root_class_mut();
903                    let manual_drop_fn = unsafe{ as_root_class.plusplus__vtbl_mut().manually_drop };
904                    unsafe{ manual_drop_fn(as_root_class); }
905                }
906            }
907        }
908    }
909
910    fn gen_mod_class_impl(&self) -> TokenStream {
911        let class_name = &self.class_name;
912        let vtbl_ident = &self.vtbl_ident;
913
914        let mut call_vtbl_impls = Vec::new();
915        for f in self.member_funcs.iter() {
916            let FuncInfo {
917                func: ImplItemFn{ vis, sig, ..},
918                name: _,
919                vtbl_name,
920                lifetime_bounds: _,
921                vtbl_sig: _,
922                args: func_args,
923                mut_self: _,
924            } = f;
925
926            let vis = correct_priv_vis(vis.clone());
927            let do_await = sig.asyncness.is_some().then(|| quote!(.await));
928            let call_vtbl = quote! {
929                #vis #sig {
930                    (self.vtbl.#vtbl_name)(self, #(#func_args,)*) #do_await
931                }
932            };
933            call_vtbl_impls.push(call_vtbl);
934        }
935
936        let mut my_func_impls = Vec::new();
937        let mut super_func_impls = Vec::new();
938
939        for (f, is_override) in self.member_funcs
940            .iter().cloned()
941            .map(|f| (f, false))
942            .chain(self.override_funcs.iter().cloned().map(|f| (f, true)))
943        {
944            let FuncInfo {
945                func: ImplItemFn {
946                    attrs: _,
947                    vis,
948                    defaultness: _,
949                    sig,
950                    block,
951                },
952                name: func_name,
953                vtbl_name: _,
954                lifetime_bounds: _,
955                vtbl_sig: _,
956                args: func_args,
957                mut_self,
958            } = f;
959
960            let my_impl_name = format_ident!("my_{}", func_name);
961            let mut my_impl_sig = sig.clone();
962            my_impl_sig.ident = my_impl_name.clone();
963            my_func_impls.push(quote! {
964                #vis #my_impl_sig {
965                    #block
966                }
967            });
968
969            if is_override {
970                let super_impl_name = format_ident!("super_{}", func_name);
971                let mut super_impl_sig = sig.clone();
972                super_impl_sig.ident = super_impl_name.clone();
973                let get_super = if mut_self {
974                    quote! {
975                        self.plusplus__super_mut()
976                    }
977                } else {
978                    quote! {
979                        self.plusplus__super_ref()
980                    }
981                };
982                let super_impl_block = if sig.asyncness.is_some() {
983                    quote!{ #get_super.#my_impl_name(#(#func_args,)*).await }
984                } else {
985                    quote!{ #get_super.#my_impl_name(#(#func_args,)*) }
986                };
987
988                super_func_impls.push(quote! {
989                    #vis #super_impl_sig {
990                        #super_impl_block
991                    }
992                });
993            }
994        }
995
996        let superclass_getters = self.mod_superclass_type.as_ref().map(|sc_ident| {
997            quote! {
998                fn plusplus__super_ref(&self) -> &#sc_ident {
999                    self
1000                }
1001
1002                fn plusplus__super_mut(&mut self) -> &mut #sc_ident {
1003                    self
1004                }
1005            }
1006        });
1007
1008        quote!{
1009            impl #class_name {
1010                #(#call_vtbl_impls)*
1011                #(#super_func_impls)*
1012
1013                #superclass_getters
1014
1015                #[doc(hidden)]
1016                pub unsafe fn plusplus__vtbl_mut(&mut self) -> &mut #vtbl_ident {
1017                    &mut self.vtbl
1018                }
1019            }
1020        }
1021    }
1022
1023    fn gen_mod_in_construction_class_impl(&self) -> TokenStream {
1024        let plusplus = &self.plusplus;
1025        let class_name = &self.class_name;
1026        let vtbl_ident = &self.vtbl_ident;
1027
1028        let set_vtbl_func = self.gen_fn_set_vtbls();
1029        let class_vis = correct_priv_vis(self.class_vis.clone());
1030
1031        let superclass_field = self.has_superclass().then(|| quote!{superclass: init.superclass,});
1032        let fields = self.fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
1033
1034        quote!{
1035            impl #class_name<#plusplus::InConstruction> {
1036                #set_vtbl_func
1037
1038                pub(super) fn plusplus__new_from_init(init: PlusPlus__InitClass) -> Self {
1039                    let mut this = Self {
1040                        vtbl: #vtbl_ident::BASE,
1041                        memory: #plusplus::InConstruction::default(),
1042                        subclass_id: None,
1043                        #superclass_field
1044                        #(#fields: init.#fields,)*
1045                    };
1046
1047                    this.plusplus__set_vtbls();
1048
1049                    this
1050                }
1051
1052                #[doc(hidden)]
1053                pub unsafe fn plusplus__set_subclass(&mut self, subclass_id: &'static std::any::TypeId) {
1054                    self.subclass_id = Some(subclass_id);
1055                }
1056
1057                /// Unsafe because caller must guarantee that vtbl doesn't contain any
1058                /// subclass methods
1059                pub unsafe fn to_constructed(&mut self) -> &mut #class_name {
1060                    unsafe{ &mut *(std::ptr::slice_from_raw_parts_mut::<u8>(self as *mut _ as *mut u8, 0) as *mut #class_name) }
1061                }
1062
1063                /// Finish constructing this by moving it to the heap placing it in a `ClassBox`.
1064                ///
1065                /// Downcasting, upcasting, and deref coersions will work properly after calling this!
1066                #class_vis fn finish(self: #class_name<#plusplus::InConstruction>) -> #plusplus::ClassBox<#class_name> {
1067                    let boxed = Box::new(self);
1068                    let leaked = Box::leak(boxed);
1069                    let constructed = unsafe{ leaked.to_constructed() };
1070                    unsafe{ #plusplus::ClassBox::from_raw(constructed) }
1071                }
1072            }
1073        }
1074    }
1075
1076    fn gen_init_class_macro(&self) -> TokenStream {
1077        let plusplus = &self.plusplus;
1078        let class_name = &self.class_name;
1079        let class_mod_name = &self.class_mod_name;
1080
1081        quote! {
1082            macro_rules! init_class {
1083                ($($tt:tt)*) => {{
1084                    #class_name::<#plusplus::InConstruction>::plusplus__new_from_init(#class_mod_name::PlusPlus__InitClass {
1085                        $($tt)*
1086                    })
1087                }}
1088            }
1089        }
1090    }
1091
1092    fn gen_class_impl(&self) -> TokenStream {
1093        let class_name = &self.class_name;
1094        let mut my_func_impls = Vec::new();
1095        for f in self.member_funcs
1096            .iter().cloned()
1097            .chain(self.override_funcs.iter().cloned())
1098        {
1099            let FuncInfo {
1100                func: ImplItemFn {
1101                    attrs: _,
1102                    vis,
1103                    defaultness: _,
1104                    sig,
1105                    block,
1106                },
1107                name: func_name,
1108                vtbl_name: _,
1109                lifetime_bounds: _,
1110                vtbl_sig: _,
1111                args: _,
1112                mut_self: _,
1113            } = f;
1114
1115            let my_impl_name = format_ident!("my_{}", func_name);
1116            let mut my_impl_sig = sig.clone();
1117            my_impl_sig.ident = my_impl_name.clone();
1118            my_func_impls.push(quote! {
1119                #vis #my_impl_sig {
1120                    #block
1121                }
1122            });
1123        }
1124
1125        let init_class_macro = self.gen_init_class_macro();
1126        let mut constructor_impls = Vec::new();
1127        for c in self.constructors.iter() {
1128            let ImplItemFn {
1129                attrs,
1130                vis,
1131                defaultness,
1132                sig,
1133                block,
1134            } = &c.func;
1135
1136            let vis = correct_priv_vis(vis.clone());
1137            constructor_impls.push(quote! {
1138                #(#attrs)* #vis #defaultness #sig {
1139                    #init_class_macro
1140                    #block
1141                }
1142            });
1143        }
1144
1145        quote!{
1146            impl #class_name {
1147                #(#constructor_impls)*
1148                #(#my_func_impls)*
1149            }
1150        }
1151    }
1152
1153    fn gen_class(&self) -> TokenStream {
1154        let class_vis = &self.class_vis;
1155        let class_name = &self.class_name;
1156        let class_mod_name = &self.class_mod_name;
1157        let mod_vtbl_struct = self.gen_mod_vtbl_struct();
1158        let mod_class_struct = self.gen_mod_class_struct();
1159        let mod_impl_class_trait = self.gen_mod_impl_class_trait();
1160        let superclass_cast = self.gen_superclass_casters();
1161        let mod_class_impl = self.gen_mod_class_impl();
1162        let mod_in_construction_class_impl = self.gen_mod_in_construction_class_impl();
1163        let class_impl = self.gen_class_impl();
1164
1165        quote! {
1166            #class_vis use #class_mod_name::#class_name;
1167            mod #class_mod_name {
1168                use super::*;
1169                #mod_vtbl_struct
1170                #mod_class_struct
1171
1172                #mod_class_impl
1173
1174                #mod_impl_class_trait
1175
1176                #mod_in_construction_class_impl
1177            }
1178
1179            #class_impl
1180
1181            #superclass_cast
1182        }
1183    }
1184}
1185
1186/// The whole point.
1187#[proc_macro]
1188pub fn class(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
1189    let inputs = parse_macro_input!(tokens as ClassInputs);
1190    let class_data = inputs.inputs.into_iter().map(|input| ClassData::from_input(input, inputs.crate_alias.as_ref()).gen_class());
1191
1192    let output = quote!{
1193        #(#class_data)*
1194    };
1195    output.into()
1196}