Skip to main content

plusplus_macros/
lib.rs

1use proc_macro_crate::{FoundCrate, crate_name};
2use proc_macro2::{Span, TokenStream};
3use quote::{format_ident, quote};
4use syn::{AngleBracketedGenericArguments, CapturedParam, GenericArgument, GenericParam, Lifetime, LifetimeParam, Path, PathArguments, ReturnType, Signature, Token, TypeParamBound};
5use syn::parse::{Parse, ParseStream};
6use syn::token::Brace;
7use syn::{
8    Attribute, Field, FnArg, Ident, ImplItemFn, Pat, Type, Visibility, braced,
9    parse_macro_input, parse_quote,
10};
11
12mod kw {
13    syn::custom_keyword!(class);
14}
15
16#[derive(Debug, Clone)]
17struct OverrideItem {
18    _override_token: Token![override],
19    override_class: Type,
20    _brace_token: Brace,
21    items: Vec<ImplItemFn>,
22}
23
24fn correct_priv_vis(vis: Visibility) -> Visibility {
25    match vis {
26        Visibility::Public(p) => Visibility::Public(p),
27        Visibility::Restricted(mut restricted) => {
28            if restricted.path.segments.get(0) != Some(&parse_quote!(crate)) {
29                restricted.path.segments.insert(0, parse_quote!(super));
30            }
31            Visibility::Restricted(restricted)
32        }
33        Visibility::Inherited => {
34            parse_quote!(pub(super))
35        }
36    }
37}
38
39impl Parse for OverrideItem {
40    fn parse(input: ParseStream) -> syn::Result<Self> {
41        let content;
42        let override_token = input.parse()?;
43        let override_class = input.parse()?;
44        let brace_token = braced!(content in input);
45        let mut items = Vec::new();
46        while !content.is_empty() {
47            let mut item: ImplItemFn = content.parse()?;
48            item.vis = correct_priv_vis(item.vis);
49            items.push(item);
50        }
51
52        Ok(OverrideItem {
53            _override_token: override_token,
54            override_class,
55            _brace_token: brace_token,
56            items,
57        })
58    }
59}
60
61#[derive(Debug, Clone)]
62enum ClassItem {
63    Field {
64        field: Field,
65        _semi_token: Token![;],
66    },
67    ImplItemFn(ImplItemFn),
68    OverrideItem(OverrideItem),
69}
70
71impl Parse for ClassItem {
72    fn parse(input: ParseStream) -> syn::Result<Self> {
73        let lookahead = input.lookahead1();
74        if lookahead.peek(Token![override]) {
75            return Ok(ClassItem::OverrideItem(OverrideItem::parse(input)?));
76        }
77
78        let begin = input.fork();
79        let _attrs = input.call(Attribute::parse_outer)?;
80        let _vis: Visibility = begin.parse()?;
81        let lookahead = begin.lookahead1();
82        if lookahead.peek(Token![fn]) || lookahead.peek(Token![unsafe]) || lookahead.peek(Token![async]) {
83            return Ok(ClassItem::ImplItemFn(ImplItemFn::parse(input)?));
84        }
85
86        let mut field = Field::parse_named(input)?;
87        field.vis = correct_priv_vis(field.vis);
88        let semi_token = input.parse()?;
89        Ok(ClassItem::Field {
90            field,
91            _semi_token: semi_token,
92        })
93    }
94}
95
96#[derive(Debug, Clone)]
97struct SuperclassInput {
98    _colon_token: Token![:],
99    ty: Type,
100}
101
102impl Parse for SuperclassInput {
103    fn parse(input: ParseStream) -> syn::Result<Self> {
104        Ok(SuperclassInput {
105            _colon_token: input.parse()?,
106            ty: input.parse()?,
107        })
108    }
109}
110
111#[derive(Debug, Clone)]
112struct ClassInput {
113    vis: Visibility,
114    _class_token: kw::class,
115    ident: Ident,
116    superclass: Option<SuperclassInput>,
117    _brace_token: Brace,
118    items: Vec<ClassItem>,
119}
120
121impl Parse for ClassInput {
122    fn parse(input: ParseStream) -> syn::Result<Self> {
123        let content;
124        let vis = input.parse()?;
125        let class_token = input.parse()?;
126        let ident = input.parse()?;
127
128        let lookahead = input.lookahead1();
129        let superclass = if lookahead.peek(Token![:]) {
130            Some(input.parse()?)
131        } else {
132            None
133        };
134
135        let brace_token = braced!(content in input);
136        let mut items = Vec::new();
137        while !content.is_empty() {
138            items.push(content.parse()?);
139        }
140        Ok(ClassInput {
141            vis,
142            _class_token: class_token,
143            ident,
144            superclass,
145            _brace_token: brace_token,
146            items,
147        })
148    }
149}
150
151#[derive(Debug, Clone)]
152struct CrateAlias {
153    _crate_token: Token![crate],
154    _as_token: Token![as],
155    ident: Ident,
156    _semi_token: Token![;],
157}
158
159impl Parse for CrateAlias {
160    fn parse(input: ParseStream) -> syn::Result<Self> {
161        Ok(CrateAlias {
162            _crate_token: input.parse()?,
163            _as_token: input.parse()?,
164            ident: input.parse()?,
165            _semi_token: input.parse()?,
166        })
167    }
168}
169
170struct ClassInputs {
171    crate_alias: Option<CrateAlias>,
172    inputs: Vec<ClassInput>,
173}
174
175impl Parse for ClassInputs {
176    fn parse(input: ParseStream) -> syn::Result<Self> {
177        let lookahead = input.lookahead1();
178        let crate_alias = if lookahead.peek(Token![crate]) {
179            Some(input.parse()?)
180        } else {
181            None
182        };
183
184        let mut inputs = Vec::new();
185        while !input.is_empty() {
186            inputs.push(input.parse()?);
187        }
188
189        Ok(ClassInputs {
190            crate_alias,
191            inputs,
192        })
193    }
194}
195
196fn plusplus() -> proc_macro2::TokenStream {
197    let found_crate = crate_name("plusplus").expect("plusplus is present in `Cargo.toml`");
198    match found_crate {
199        FoundCrate::Itself => quote!(crate),
200        FoundCrate::Name(name) => {
201            let ident = Ident::new(&name, Span::call_site());
202            quote!( #ident )
203        }
204    }
205}
206
207fn cast_class_ptr(
208    plusplus: &proc_macro2::TokenStream,
209    from: &Type,
210    to: &Type,
211    expr: impl Into<proc_macro2::TokenStream>,
212) -> proc_macro2::TokenStream {
213    let expr = expr.into();
214    quote! {{
215        let t: &#from = #expr;
216        let self_size = std::mem::size_of_val(t);
217        let target_size = std::mem::size_of::<#to<#plusplus::InConstruction>>();
218        assert!(self_size >= target_size);
219        let array_size = self_size - target_size;
220        let target_ptr = std::ptr::slice_from_raw_parts(t as *const #from as *const u8, array_size);
221        let target_ref = &*(target_ptr as *const #to);
222        assert_eq!(self_size, std::mem::size_of_val(target_ref));
223        target_ref
224    }}
225}
226
227fn cast_class_ptr_mut(
228    plusplus: &TokenStream,
229    from: &Type,
230    to: &Type,
231    expr: impl Into<TokenStream>,
232) -> TokenStream {
233    let expr = expr.into();
234    quote! {{
235        let t: &mut #from = #expr;
236        let self_size = std::mem::size_of_val(t);
237        let target_size = std::mem::size_of::<#to<#plusplus::InConstruction>>();
238        assert!(self_size >= target_size);
239        let array_size = self_size - target_size;
240        let target_ptr = std::ptr::slice_from_raw_parts_mut(t as *mut #from as *mut u8, array_size);
241        let target_ref = &mut *(target_ptr as *mut #to);
242        assert_eq!(self_size, std::mem::size_of_val(target_ref));
243        target_ref
244    }}
245}
246
247fn set_arg_blank_lifetime(arg: &mut FnArg, lifetime: &Lifetime, lifetime_set: &mut bool) {
248    match arg {
249        FnArg::Receiver(rx) => {
250            if let Some((_, lt)) = &mut rx.reference {
251                maybe_set_lifetime_opt(lt, lifetime, lifetime_set);
252            }
253        }
254        FnArg::Typed(ty) => {
255            set_blank_type_lifetimes(&mut ty.ty, lifetime, lifetime_set);
256        }
257    }
258}
259
260fn set_blank_type_lifetimes(ty: &mut Type, lifetime: &Lifetime, lifetime_set: &mut bool) {
261    match ty {
262        Type::Array(arr) => set_blank_type_lifetimes(&mut arr.elem, lifetime, lifetime_set),
263        Type::BareFn(_) => {}
264        Type::Group(group) => set_blank_type_lifetimes(&mut group.elem, lifetime, lifetime_set),
265        Type::ImplTrait(_) => {} // technically we should implement this but it doesn't come up in our use
266        Type::Infer(_) => {}
267        Type::Macro(_) => {}
268        Type::Never(_) => {}
269        Type::Paren(paren) => set_blank_type_lifetimes(&mut paren.elem, lifetime, lifetime_set),
270        Type::Path(path) => {
271            if let Some(qself) = &mut path.qself {
272                set_blank_type_lifetimes(&mut qself.ty, lifetime, lifetime_set);
273            }
274            set_blank_path_lifetimes(&mut path.path, lifetime, lifetime_set);
275        }
276        Type::Ptr(ptr) => set_blank_type_lifetimes(&mut ptr.elem, lifetime, lifetime_set),
277        Type::Reference(refer) => {
278            maybe_set_lifetime_opt(&mut refer.lifetime, lifetime, lifetime_set);
279            set_blank_type_lifetimes(&mut refer.elem, lifetime, lifetime_set)
280        }
281        Type::Slice(slice) => set_blank_type_lifetimes(&mut slice.elem, lifetime, lifetime_set),
282        Type::TraitObject(obj) => set_blank_type_param_bounds(&mut obj.bounds, lifetime, lifetime_set),
283        Type::Tuple(tup) => {
284            for ty in &mut tup.elems {
285                set_blank_type_lifetimes(ty, lifetime, lifetime_set);
286            }
287        }
288        Type::Verbatim(_) => {}
289        _ => ()
290    }
291}
292
293fn set_blank_path_lifetimes(path: &mut Path, lifetime: &Lifetime, lifetime_set: &mut bool) {
294    for path_args in path.segments.iter_mut().map(|seg| &mut seg.arguments) {
295        match path_args {
296            PathArguments::None => {}
297            PathArguments::AngleBracketed(angle_args) => set_blank_angle_bracket_lifetimes(angle_args, lifetime, lifetime_set),
298            PathArguments::Parenthesized(paren_args) => {
299                for ty in &mut paren_args.inputs {
300                    set_blank_type_lifetimes(ty, lifetime, lifetime_set);
301                }
302                if let ReturnType::Type(_, ty) = &mut paren_args.output {
303                    set_blank_type_lifetimes(ty, lifetime, lifetime_set);
304                }
305            }
306        }
307    }
308}
309
310fn set_blank_angle_bracket_lifetimes(angle: &mut AngleBracketedGenericArguments, lifetime: &Lifetime, lifetime_set: &mut bool) {
311    for arg in &mut angle.args {
312        set_blank_generic_lifetimes(arg, lifetime, lifetime_set);
313    }
314}
315
316fn maybe_set_lifetime(set: &mut Lifetime, to: &Lifetime, lifetime_set: &mut bool) {
317    if set.ident == format_ident!("_") {
318        *set = to.clone();
319        *lifetime_set = true;
320    }
321}
322fn maybe_set_lifetime_opt(set: &mut Option<Lifetime>, to: &Lifetime, lifetime_set: &mut bool) {
323    match set {
324        Some(lt) => maybe_set_lifetime(lt, to, lifetime_set),
325        None => {
326            *set = Some(to.clone());
327            *lifetime_set = true;
328        }
329    }
330}
331
332fn set_blank_type_param_bounds<'a>(bounds: impl IntoIterator<Item=&'a mut TypeParamBound>, lifetime: &Lifetime, lifetime_set: &mut bool) {
333    for bound in bounds {
334        match bound {
335            TypeParamBound::Trait(trait_bound) => {set_blank_path_lifetimes(&mut trait_bound.path, lifetime, lifetime_set)}
336            TypeParamBound::Lifetime(lt) => maybe_set_lifetime(lt, lifetime, lifetime_set),
337            TypeParamBound::PreciseCapture(cap) => {
338                for p in &mut cap.params {
339                    match p {
340                        CapturedParam::Lifetime(lt) => maybe_set_lifetime(lt, lifetime, lifetime_set),
341                        CapturedParam::Ident(_) => {}
342                        _ => ()
343                    }
344                }
345            }
346            TypeParamBound::Verbatim(_) => {}
347            _ => (),
348        }
349    }
350}
351fn set_blank_generic_lifetimes(arg: &mut GenericArgument, lifetime: &Lifetime, lifetime_set: &mut bool) {
352    match arg {
353        GenericArgument::Lifetime(lt) => maybe_set_lifetime(lt, lifetime, lifetime_set),
354        GenericArgument::Type(ty) => set_blank_type_lifetimes(ty, lifetime, lifetime_set),
355        GenericArgument::Const(_) => {}
356        GenericArgument::AssocType(ty) => {
357            if let Some(args) = &mut ty.generics {
358                set_blank_angle_bracket_lifetimes(args, lifetime, lifetime_set);
359            }
360            set_blank_type_lifetimes(&mut ty.ty, lifetime, lifetime_set);
361        }
362        GenericArgument::AssocConst(_) => {}
363        GenericArgument::Constraint(constraint) => {
364            if let Some(args) = &mut constraint.generics {
365                set_blank_angle_bracket_lifetimes(args, lifetime, lifetime_set);
366            }
367
368            set_blank_type_param_bounds(&mut constraint.bounds, lifetime, lifetime_set);
369        },
370        _ => (),
371    }
372}
373
374#[derive(Debug, Clone)]
375struct FuncInfo {
376    func: ImplItemFn,
377    name: Ident,
378    vtbl_name: Ident,
379    lifetime_bounds: Option<TokenStream>,
380    vtbl_sig: TokenStream,
381    args: Vec<Box<Pat>>,
382    mut_self: bool,
383}
384
385fn get_func_sig(class_name: &Type, f: &ImplItemFn) -> FuncInfo {
386    let func_name = &f.sig.ident;
387    let is_async = f.sig.asyncness.is_some();
388
389    let mut inputs = f.sig.inputs.clone();
390    let mut mut_self = false;
391    if let Some(FnArg::Receiver(rx)) = &f.sig.inputs.get(0) {
392        let receiver = &mut inputs[0];
393        if let Some((_, rx_lifetime)) = &rx.reference {
394
395            if rx.mutability.is_some() {
396                *receiver = FnArg::Typed(parse_quote!(this: & #rx_lifetime mut #class_name));
397            } else {
398                *receiver = FnArg::Typed(parse_quote!(this: & #rx_lifetime #class_name));
399            }
400        }
401        mut_self = rx.mutability.is_some();
402    };
403
404    let mut lifetimes: Vec<LifetimeParam> = Vec::new();
405    let future_fallback_lt: Lifetime = parse_quote!('rpp_future);
406    let mut using_future_lifetime = false;
407    let mut vtbl_inputs = inputs.clone();
408    // eprintln!("inputs = {:#?}", vtbl_inputs);
409    if is_async {
410        for arg in &mut vtbl_inputs {
411            set_arg_blank_lifetime(arg, &future_fallback_lt, &mut using_future_lifetime);
412        }
413        if using_future_lifetime {
414            lifetimes.push(parse_quote!(#future_fallback_lt));
415        }
416    }
417    for generic in &f.sig.generics.params {
418        match generic {
419            GenericParam::Lifetime(lt) => {
420                let mut lt = lt.clone();
421                if using_future_lifetime {
422                    lt.bounds.push(future_fallback_lt.clone());
423                }
424                lifetimes.push(lt)
425                // if using_future_lifetime {
426                //     lifetimes.push(parse_quote!(#lt: #future_fallback_lt));
427                // } else {
428                //     lifetimes.push(parse_quote!(#lt));
429                // }
430            }
431            GenericParam::Type(_) => {}
432            GenericParam::Const(_) => {}
433        }
434    }
435
436    let lifetime_bounds = if lifetimes.len() > 0 {
437        Some(quote!(<#(#lifetimes),*>))
438    } else {
439        None
440    };
441
442    let mut output = f.sig.output.clone();
443    if is_async {
444        let future_output = match output {
445            ReturnType::Default => quote!(()),
446            ReturnType::Type(_, ty) => quote!(#ty),
447        };
448        let future_lifetimes = if using_future_lifetime {
449            future_fallback_lt
450        } else {
451            assert!(lifetimes.len() > 0);
452            lifetimes[0].lifetime.clone()
453        };
454        output = parse_quote!(-> std::pin::Pin<Box<dyn #future_lifetimes + Future<Output=#future_output>>>);
455    }
456    let vtbl_name = format_ident!("fn_{func_name}");
457    let vtbl_sig = quote! {
458        (#vtbl_inputs) #output
459    };
460
461    let func_args = inputs
462        .into_iter()
463        .skip(1)
464        .map(|arg| match arg {
465            FnArg::Receiver(_) => unreachable!(),
466            FnArg::Typed(arg) => arg.pat,
467        })
468        .collect::<Vec<_>>();
469
470    FuncInfo {
471        func: f.clone(),
472        name: func_name.clone(),
473        vtbl_name,
474        lifetime_bounds,
475        vtbl_sig,
476        args: func_args,
477        mut_self,
478    }
479}
480
481struct ClassData {
482    plusplus: TokenStream,
483    class_name: Ident,
484    class_type: Type,
485    class_mod_name: Ident,
486    vtbl_ident: Ident,
487    class_vis: Visibility,
488    fields: Vec<Field>,
489    constructors: Vec<FuncInfo>,
490    member_funcs: Vec<FuncInfo>,
491    overrides: Vec<OverrideItem>,
492    override_funcs: Vec<FuncInfo>,
493    superclass_type: Option<Type>,
494    mod_superclass_type: Option<Type>,
495}
496
497impl ClassData {
498    fn from_input(input: ClassInput, crate_alias: Option<&CrateAlias>) -> ClassData {
499        let ClassInput {
500            vis: class_vis,
501            _class_token: _,
502            ident: class_name,
503            superclass,
504            _brace_token: _,
505            items: class_items,
506        } = input;
507
508        let mut fields = Vec::new();
509        let mut constructors = Vec::new();
510        let mut member_funcs = Vec::new();
511        let mut overrides = Vec::new();
512        let mut override_funcs = Vec::new();
513
514        let class_type: Type = parse_quote!(#class_name);
515
516        for item in class_items {
517            match item {
518                ClassItem::Field { field, .. } => fields.push(field),
519                ClassItem::ImplItemFn(func) => {
520                    if let Some(FnArg::Receiver(_)) = func.sig.inputs.get(0) {
521                        member_funcs.push(get_func_sig(&class_type, &func))
522                    } else {
523                        constructors.push(get_func_sig(&class_type, &func));
524                    }
525                }
526                ClassItem::OverrideItem(override_item) => {
527                    override_funcs.extend(override_item.items.iter().map(|f| get_func_sig(&class_type, &f)));
528                    overrides.push(override_item);
529                }
530            }
531        }
532
533        let class_mod_name = format_ident!("plusplus__class_{}", class_name.to_string().to_lowercase());
534
535        let plusplus = if let Some(alias) = crate_alias {
536            let alias = &alias.ident;
537            quote!(#alias)
538        } else {
539            plusplus()
540        };
541
542        let superclass_type = superclass.map(|sc| sc.ty);
543        let mod_superclass_type = match superclass_type.clone() {
544            Some(Type::Path(mut type_path)) => {
545                if type_path.path.segments.get(0) != Some(&parse_quote!(crate)) {
546                    type_path.path.segments.insert(0, parse_quote!(super));
547                }
548                Some(Type::Path(type_path))
549            }
550            ty => ty
551        };
552
553        ClassData {
554            plusplus,
555            vtbl_ident: format_ident!("{}Vtbl", class_name),
556            class_vis,
557            class_name,
558            class_type,
559            class_mod_name,
560            fields,
561            constructors,
562            member_funcs,
563            overrides,
564            override_funcs,
565            superclass_type,
566            mod_superclass_type,
567        }
568    }
569
570    fn has_superclass(&self) -> bool {
571        self.superclass_type.is_some()
572    }
573
574    fn gen_mod_vtbl_struct(&self) -> TokenStream {
575        // create vtbl
576        let vtbl_ident = &self.vtbl_ident;
577        let class_name = &self.class_name;
578
579        let mut vtbl_func_names = Vec::new();
580        let mut vtbl_sigs = Vec::new();
581        let mut my_func_names = Vec::new();
582        let mut async_func_impls = Vec::new();
583        let mut func_setters = Vec::new();
584        let mut vtbl_fors = Vec::new();
585        let mut vtbl_unsafes = Vec::new();
586
587        for f in self.member_funcs.iter() {
588            let FuncInfo {
589                func: ImplItemFn{ sig: Signature {asyncness, unsafety, .. }, .. },
590                name: func_name,
591                vtbl_name: vtbl_func_name,
592                lifetime_bounds,
593                vtbl_sig,
594                args,
595                mut_self: _,
596            } = f;
597            let my_func_name = format_ident!("my_{func_name}");
598            let vtbl_for = lifetime_bounds.as_ref().map(|bounds| quote!(for #bounds));
599            vtbl_fors.push(vtbl_for.clone());
600            vtbl_unsafes.push(unsafety);
601
602            if asyncness.is_some() {
603                async_func_impls.push(quote!{
604                    // fn #my_func_name #func_sig {
605                    let #my_func_name: #vtbl_for #unsafety fn #vtbl_sig = |this, #(#args)*| #unsafety {
606                        Box::pin(#class_name::#my_func_name(this, #(#args)*))
607                    };
608                });
609                func_setters.push(quote!{
610                    #vtbl_func_name: #my_func_name,
611                })
612            } else {
613                func_setters.push(quote!{
614                     #vtbl_func_name: #class_name::#my_func_name,
615                })
616            }
617
618            my_func_names.push(my_func_name);
619            vtbl_func_names.push(vtbl_func_name);
620            vtbl_sigs.push(vtbl_sig);
621        }
622
623        let vtbl_drop_field: Option<_>;
624        let vtbl_drop_func: Option<_>;
625        let vtbl_drop_set: Option<_>;
626        if !self.has_superclass() {
627            vtbl_drop_field = Some(quote! {
628                pub manually_drop: unsafe fn(*mut #class_name),
629            });
630            vtbl_drop_func = Some(quote! {
631                unsafe fn manually_drop(this: *mut #class_name) {
632                    unsafe{ std::ptr::drop_in_place(this) }
633                }
634            });
635            vtbl_drop_set = Some(quote! {
636                manually_drop,
637            });
638        } else {
639            vtbl_drop_field = None;
640            vtbl_drop_func = None;
641            vtbl_drop_set = None;
642        };
643
644        quote! {
645            #[doc(hidden)]
646            pub struct #vtbl_ident {
647                #vtbl_drop_field
648                #(pub #vtbl_func_names: #vtbl_fors #vtbl_unsafes fn #vtbl_sigs,)*
649            }
650
651            impl #vtbl_ident {
652                const BASE: Self = {
653                    #vtbl_drop_func
654                    #(#async_func_impls)*
655
656                    Self {
657                        #vtbl_drop_set
658                        #(#func_setters)*
659                    }
660                };
661            }
662        }
663    }
664
665    fn gen_fn_set_vtbls(&self) -> TokenStream {
666        let plusplus = &self.plusplus;
667        let class_name = &self.class_name;
668
669        let set_vtbls = self.overrides.iter().map(|ovr| {
670            let ovr_class = &ovr.override_class;
671            let ovr_class = match ovr_class.clone() {
672                Type::Path(mut type_path) => {
673                    type_path.path.segments.insert(0, parse_quote!(super));
674                    Type::Path(type_path)
675                }
676                ty => ty
677            };
678
679            let mut ol_func_names = Vec::new();
680            let mut ol_func_sigs = Vec::new();
681            let mut ol_func_self_call_impls = Vec::new();
682            let mut ol_lifetime_bounds = Vec::new();
683            for f in &ovr.items {
684                let FuncInfo {
685                    func: ImplItemFn {sig: Signature{ asyncness, unsafety, .. }, ..},
686                    name: func_name,
687                    vtbl_name,
688                    lifetime_bounds,
689                    vtbl_sig: func_sig,
690                    args: func_args,
691                    mut_self,
692                } = get_func_sig(&ovr_class, f);
693                ol_func_names.push(vtbl_name);
694                ol_func_sigs.push(func_sig);
695                ol_lifetime_bounds.push(lifetime_bounds);
696
697                let func_name = format_ident!("my_{}", func_name);
698                let make_this = if mut_self {
699                    let cast_mut = cast_class_ptr_mut(
700                        &plusplus,
701                        &ovr_class,
702                        &self.class_type,
703                        quote! {this},
704                    );
705                    quote! {
706                        let this: &mut #class_name = unsafe{ #cast_mut };
707                    }
708                } else {
709                    let cast = cast_class_ptr(&plusplus, &ovr_class, &self.class_type, quote! {this});
710                    quote! {
711                        let this: &#class_name = unsafe{ #cast };
712                    }
713                };
714                let self_call = if asyncness.is_some() {
715                    quote! {
716                        #make_this
717                        #unsafety { Box::pin(this.#func_name(#(#func_args,)*)) }
718                    }
719                } else {
720                    quote! {
721                        #make_this
722                        #unsafety { this.#func_name(#(#func_args,)*) }
723                    }
724                };
725                ol_func_self_call_impls.push(self_call);
726            }
727
728            quote! {{
729                let this: &mut #ovr_class = &mut *(unsafe{ self.to_constructed() });
730                #(
731                    fn #ol_func_names #ol_lifetime_bounds #ol_func_sigs {
732                        #ol_func_self_call_impls
733                    }
734                    unsafe{ this.plusplus__vtbl_mut().#ol_func_names = #ol_func_names };
735                )*
736            }}
737        });
738
739        let root_type: Type = parse_quote!(<#class_name as #plusplus::Class>::RootClass);
740        let cast_root_to_self =
741            cast_class_ptr_mut(&plusplus, &root_type, &self.class_type, quote!(ref_mut));
742        let root_type = &root_type;
743
744        let set_subclass = self.has_superclass().then(|| quote!{
745            unsafe{ self.superclass.plusplus__set_subclass(<#class_name as #plusplus::Class>::TYPE_ID) };
746
747            {
748                unsafe fn manually_drop(this: *mut #root_type) {
749                    let ref_mut = unsafe{ &mut *this };
750                    let this = unsafe{ #cast_root_to_self };
751                    unsafe{ std::ptr::drop_in_place(this) };
752                }
753                let root_vtbl = unsafe{ <#class_name as #plusplus::Class>::root_class_mut(self.to_constructed()).plusplus__vtbl_mut() };
754                root_vtbl.manually_drop = manually_drop;
755            }
756        });
757
758        quote! {
759            fn plusplus__set_vtbls(&mut self) {
760                #set_subclass
761                #(#set_vtbls)*
762            }
763        }
764    }
765
766    fn gen_mod_class_struct(&self) -> TokenStream {
767        let plusplus = &self.plusplus;
768
769        let superclass_field = self.mod_superclass_type.as_ref().map(|sc_type| {
770            quote! {
771                superclass: #sc_type<#plusplus::InConstruction>,
772            }
773        });
774        let superclass_field = superclass_field.as_ref();
775        let superclass_bound = self.mod_superclass_type.as_ref().map(|sc_ident| quote!{
776            where #sc_ident: #plusplus::Class
777        });
778
779        let class_struct_vis = correct_priv_vis(self.class_vis.clone());
780        let vtbl_ident = &self.vtbl_ident;
781        let class_name = &self.class_name;
782        let fields = &self.fields;
783        let init_superclass_field = superclass_field.map(|f| quote!(pub #f));
784        let init_fields = self.fields.iter().cloned().map(|f| Field {
785            vis: Visibility::Public(parse_quote!(pub)),
786            ..f
787        });
788
789        quote! {
790            #[repr(C)]
791            #class_struct_vis struct #class_name<C: ?Sized + #plusplus::ClassMemory = #plusplus::Constructed>
792                #superclass_bound
793            {
794                #superclass_field
795                vtbl: #vtbl_ident,
796                subclass_id: Option<std::any::TypeId>,
797                #(#fields,)*
798                memory: C,
799            }
800
801            pub struct PlusPlus__InitClass {
802                #init_superclass_field
803                #(#init_fields,)*
804            }
805        }
806    }
807
808    fn gen_superclass_casters(&self) -> Option<TokenStream> {
809        let Some(sc_type) = self.superclass_type.as_ref() else {
810            return None;
811        };
812        let class_name = &self.class_name;
813
814        let deref_upcast = cast_class_ptr(&self.plusplus, &self.class_type, sc_type, quote! {self});
815        let deref_upcast_mut = cast_class_ptr_mut(
816            &self.plusplus,
817            &self.class_type,
818            sc_type,
819            quote! {self},
820        );
821        let ref_downcast = cast_class_ptr(&self.plusplus, sc_type, &self.class_type, quote! {self});
822        let ref_downcast_mut = cast_class_ptr_mut(
823            &self.plusplus,
824            sc_type,
825            &self.class_type,
826            quote! {self},
827        );
828
829        let plusplus = &self.plusplus;
830        Some(quote! {
831            impl std::ops::Deref for #class_name {
832                type Target = #sc_type;
833
834                fn deref(&self) -> &Self::Target {
835                    unsafe { #deref_upcast }
836                }
837            }
838
839            impl std::ops::DerefMut for #class_name {
840                fn deref_mut(&mut self) -> &mut Self::Target {
841                    unsafe { #deref_upcast_mut }
842                }
843            }
844
845            impl<'a> #plusplus::Downcast<#class_name> for &'a #sc_type {
846                type Wrapped = &'a #class_name;
847                fn downcast(self) -> Result<&'a #class_name, Self> {
848                    use #plusplus::Class;
849                    let subclass_type_id = <#class_name as #plusplus::Class>::TYPE_ID;
850                    if self.subclass_id() == Some(subclass_type_id) {
851                        Ok(unsafe{ #ref_downcast })
852                    } else {
853                        Err(self)
854                    }
855                }
856            }
857
858            impl<'a> #plusplus::Downcast<#class_name> for &'a mut #sc_type {
859                type Wrapped = &'a mut #class_name;
860                fn downcast(self) -> Result<&'a mut #class_name, Self> {
861                    use #plusplus::Class;
862                    let subclass_type_id = <#class_name as #plusplus::Class>::TYPE_ID;
863                    if self.subclass_id() == Some(subclass_type_id) {
864                        Ok(unsafe{ #ref_downcast_mut })
865                    } else {
866                        Err(self)
867                    }
868                }
869            }
870        })
871    }
872
873    fn gen_mod_impl_class_trait(&self) -> TokenStream {
874        let plusplus = &self.plusplus;
875        let class_name = &self.class_name;
876        let root_class = self.mod_superclass_type
877            .as_ref()
878            .map(|sc_ident| quote! { <#sc_ident as #plusplus::Class>::RootClass })
879            .unwrap_or_else(|| quote! { #class_name });
880
881        quote! {
882            unsafe impl #plusplus::Class for #class_name {
883                const TYPE_ID: std::any::TypeId = std::any::TypeId::of::<#class_name>();
884
885                type RootClass = #root_class;
886
887                fn subclass_id(&self) -> Option<std::any::TypeId> {
888                    self.subclass_id
889                }
890
891                fn root_class(&self) -> &Self::RootClass {
892                    self
893                }
894
895                fn root_class_mut(&mut self) -> &mut Self::RootClass {
896                    self
897                }
898
899                unsafe fn manually_drop(slot: &mut std::mem::ManuallyDrop<Self>) {
900                    let as_root_class = slot.root_class_mut();
901                    let manual_drop_fn = unsafe{ as_root_class.plusplus__vtbl_mut().manually_drop };
902                    unsafe{ manual_drop_fn(as_root_class); }
903                }
904            }
905        }
906    }
907
908    fn gen_mod_class_impl(&self) -> TokenStream {
909        let class_name = &self.class_name;
910        let vtbl_ident = &self.vtbl_ident;
911
912        let mut call_vtbl_impls = Vec::new();
913        for f in self.member_funcs.iter() {
914            let FuncInfo {
915                func: ImplItemFn{ vis, sig, ..},
916                name: _,
917                vtbl_name,
918                lifetime_bounds: _,
919                vtbl_sig: _,
920                args: func_args,
921                mut_self: _,
922            } = f;
923
924            let vis = correct_priv_vis(vis.clone());
925            let do_await = sig.asyncness.is_some().then(|| quote!(.await));
926            let call_vtbl = quote! {
927                #vis #sig {
928                    (self.vtbl.#vtbl_name)(self, #(#func_args,)*) #do_await
929                }
930            };
931            call_vtbl_impls.push(call_vtbl);
932        }
933
934        let mut my_func_impls = Vec::new();
935        let mut super_func_impls = Vec::new();
936
937        for (f, is_override) in self.member_funcs
938            .iter().cloned()
939            .map(|f| (f, false))
940            .chain(self.override_funcs.iter().cloned().map(|f| (f, true)))
941        {
942            let FuncInfo {
943                func: ImplItemFn {
944                    attrs: _,
945                    vis,
946                    defaultness: _,
947                    sig,
948                    block,
949                },
950                name: func_name,
951                vtbl_name: _,
952                lifetime_bounds: _,
953                vtbl_sig: _,
954                args: func_args,
955                mut_self,
956            } = f;
957
958            let my_impl_name = format_ident!("my_{}", func_name);
959            let mut my_impl_sig = sig.clone();
960            my_impl_sig.ident = my_impl_name.clone();
961            my_func_impls.push(quote! {
962                #vis #my_impl_sig {
963                    #block
964                }
965            });
966
967            if is_override {
968                let super_impl_name = format_ident!("super_{}", func_name);
969                let mut super_impl_sig = sig.clone();
970                super_impl_sig.ident = super_impl_name.clone();
971                let get_super = if mut_self {
972                    quote! {
973                        self.plusplus__super_mut()
974                    }
975                } else {
976                    quote! {
977                        self.plusplus__super_ref()
978                    }
979                };
980                let super_impl_block = if sig.asyncness.is_some() {
981                    quote!{ #get_super.#my_impl_name(#(#func_args)*).await }
982                } else {
983                    quote!{ #get_super.#my_impl_name(#(#func_args)*) }
984                };
985
986                super_func_impls.push(quote! {
987                    #vis #super_impl_sig {
988                        #super_impl_block
989                    }
990                });
991            }
992        }
993
994        let superclass_getters = self.mod_superclass_type.as_ref().map(|sc_ident| {
995            quote! {
996                fn plusplus__super_ref(&self) -> &#sc_ident {
997                    self
998                }
999
1000                fn plusplus__super_mut(&mut self) -> &mut #sc_ident {
1001                    self
1002                }
1003            }
1004        });
1005
1006        quote!{
1007            impl #class_name {
1008                #(#call_vtbl_impls)*
1009                #(#super_func_impls)*
1010
1011                #superclass_getters
1012
1013                #[doc(hidden)]
1014                pub unsafe fn plusplus__vtbl_mut(&mut self) -> &mut #vtbl_ident {
1015                    &mut self.vtbl
1016                }
1017            }
1018        }
1019    }
1020
1021    fn gen_mod_in_construction_class_impl(&self) -> TokenStream {
1022        let plusplus = &self.plusplus;
1023        let class_name = &self.class_name;
1024        let vtbl_ident = &self.vtbl_ident;
1025
1026        let set_vtbl_func = self.gen_fn_set_vtbls();
1027        let class_vis = correct_priv_vis(self.class_vis.clone());
1028
1029        let superclass_field = self.has_superclass().then(|| quote!{superclass: init.superclass,});
1030        let fields = self.fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
1031
1032        quote!{
1033            impl #class_name<#plusplus::InConstruction> {
1034                #set_vtbl_func
1035
1036                pub(super) fn plusplus__new_from_init(init: PlusPlus__InitClass) -> Self {
1037                    let mut this = Self {
1038                        vtbl: #vtbl_ident::BASE,
1039                        memory: #plusplus::InConstruction::default(),
1040                        subclass_id: None,
1041                        #superclass_field
1042                        #(#fields: init.#fields,)*
1043                    };
1044
1045                    this.plusplus__set_vtbls();
1046
1047                    this
1048                }
1049
1050                #[doc(hidden)]
1051                pub unsafe fn plusplus__set_subclass(&mut self, subclass_id: std::any::TypeId) {
1052                    self.subclass_id = Some(subclass_id);
1053                }
1054
1055                /// Unsafe because caller must guarantee that vtbl doesn't contain any
1056                /// subclass methods
1057                pub unsafe fn to_constructed(&mut self) -> &mut #class_name {
1058                    unsafe{ &mut *(std::ptr::slice_from_raw_parts_mut::<u8>(self as *mut _ as *mut u8, 0) as *mut #class_name) }
1059                }
1060
1061                /// Finish constructing this by moving it to the heap placing it in a `ClassBox`.
1062                ///
1063                /// Downcasting, upcasting, and deref coersions will work properly after calling this!
1064                #class_vis fn finish(self: #class_name<#plusplus::InConstruction>) -> #plusplus::ClassBox<#class_name> {
1065                    let boxed = Box::new(self);
1066                    let leaked = Box::leak(boxed);
1067                    let constructed = unsafe{ leaked.to_constructed() };
1068                    unsafe{ #plusplus::ClassBox::from_raw(constructed) }
1069                }
1070            }
1071        }
1072    }
1073
1074    fn gen_init_class_macro(&self) -> TokenStream {
1075        let plusplus = &self.plusplus;
1076        let class_name = &self.class_name;
1077        let class_mod_name = &self.class_mod_name;
1078
1079        quote! {
1080            macro_rules! init_class {
1081                ($($tt:tt)*) => {{
1082                    #class_name::<#plusplus::InConstruction>::plusplus__new_from_init(#class_mod_name::PlusPlus__InitClass {
1083                        $($tt)*
1084                    })
1085                }}
1086            }
1087        }
1088    }
1089
1090    fn gen_class_impl(&self) -> TokenStream {
1091        let class_name = &self.class_name;
1092        let mut my_func_impls = Vec::new();
1093        for f in self.member_funcs
1094            .iter().cloned()
1095            .chain(self.override_funcs.iter().cloned())
1096        {
1097            let FuncInfo {
1098                func: ImplItemFn {
1099                    attrs: _,
1100                    vis,
1101                    defaultness: _,
1102                    sig,
1103                    block,
1104                },
1105                name: func_name,
1106                vtbl_name: _,
1107                lifetime_bounds: _,
1108                vtbl_sig: _,
1109                args: _,
1110                mut_self: _,
1111            } = f;
1112
1113            let my_impl_name = format_ident!("my_{}", func_name);
1114            let mut my_impl_sig = sig.clone();
1115            my_impl_sig.ident = my_impl_name.clone();
1116            my_func_impls.push(quote! {
1117                #vis #my_impl_sig {
1118                    #block
1119                }
1120            });
1121        }
1122
1123        let init_class_macro = self.gen_init_class_macro();
1124        let mut constructor_impls = Vec::new();
1125        for c in self.constructors.iter() {
1126            let ImplItemFn {
1127                attrs,
1128                vis,
1129                defaultness,
1130                sig,
1131                block,
1132            } = &c.func;
1133
1134            let vis = correct_priv_vis(vis.clone());
1135            constructor_impls.push(quote! {
1136                #(#attrs)* #vis #defaultness #sig {
1137                    #init_class_macro
1138                    #block
1139                }
1140            });
1141        }
1142
1143        quote!{
1144            impl #class_name {
1145                #(#constructor_impls)*
1146                #(#my_func_impls)*
1147            }
1148        }
1149    }
1150
1151    fn gen_class(&self) -> TokenStream {
1152        let class_vis = &self.class_vis;
1153        let class_name = &self.class_name;
1154        let class_mod_name = &self.class_mod_name;
1155        let mod_vtbl_struct = self.gen_mod_vtbl_struct();
1156        let mod_class_struct = self.gen_mod_class_struct();
1157        let mod_impl_class_trait = self.gen_mod_impl_class_trait();
1158        let superclass_cast = self.gen_superclass_casters();
1159        let mod_class_impl = self.gen_mod_class_impl();
1160        let mod_in_construction_class_impl = self.gen_mod_in_construction_class_impl();
1161        let class_impl = self.gen_class_impl();
1162
1163        quote! {
1164            #class_vis use #class_mod_name::#class_name;
1165            mod #class_mod_name {
1166                use super::*;
1167                #mod_vtbl_struct
1168                #mod_class_struct
1169
1170                #mod_class_impl
1171
1172                #mod_impl_class_trait
1173
1174                #mod_in_construction_class_impl
1175            }
1176
1177            #class_impl
1178
1179            #superclass_cast
1180        }
1181    }
1182}
1183
1184/// The whole point.
1185#[proc_macro]
1186pub fn class(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
1187    let inputs = parse_macro_input!(tokens as ClassInputs);
1188    let class_data = inputs.inputs.into_iter().map(|input| ClassData::from_input(input, inputs.crate_alias.as_ref()).gen_class());
1189
1190    let output = quote!{
1191        #(#class_data)*
1192    };
1193    output.into()
1194}