1use proc_macro::TokenStream as TokenStreamRaw;
2use quote::{format_ident, quote, quote_spanned};
3use syn::{
4 parse::Parse, parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, FnArg,
5 Ident, Index, ItemStruct, ItemTrait,
6};
7
8#[proc_macro_attribute]
9pub fn vtable(attr: TokenStreamRaw, item: TokenStreamRaw) -> TokenStreamRaw {
10 if !attr.is_empty() {
11 return (quote! { compile_error!("vtable attribute has no args"); }).into();
12 }
13
14 let item = parse_macro_input!(item as ItemTrait);
15
16 for item in item.items.iter() {
17 let m = match item {
18 syn::TraitItem::Method(m) => m,
19 _ => {
20 return (quote_spanned! { item.span() => compile_error!("only methods allowed in trait"); })
21 .into()
22 }
23 };
24 let r = match m.sig.receiver() {
25 Some(FnArg::Receiver(r)) => { r },
26 Some(v) => {return (quote_spanned! { v.span() => compile_error!("only self receivers allowed"); }).into()}
27 None => {return (quote_spanned! { m.sig.span() => compile_error!("expected receiver"); }).into()},
28 };
29 if r.reference.is_none() {
30 return (quote_spanned! { r.span() => compile_error!("should be reference type"); })
31 .into();
32 }
33
34 for arg in m.sig.inputs.iter().skip(1) {
35 let arg = match arg {
36 FnArg::Typed(arg) => arg,
37 _ => unreachable!(),
38 };
39
40 match arg.pat.as_ref() {
41 syn::Pat::Ident(_) => {},
42 pat => return (quote_spanned! { pat.span() => compile_error!("only ident patterns allowed"); }).into(),
43 }
44 }
45 }
46
47 let methods = item
48 .items
49 .iter()
50 .filter_map(|i| match i {
51 syn::TraitItem::Method(m) => Some(m.sig.clone()),
52 _ => unreachable!(),
53 })
54 .map(|sig| {
55 let r = match sig.receiver() {
56 Some(FnArg::Receiver(r)) => r.clone(),
57 _ => unreachable!(),
58 };
59 (r, sig)
60 })
61 .collect::<Vec<_>>();
62
63 let name = &item.ident;
64 let vtable_name = format_ident!("{}Vtable", name);
65 let vtable_impl_name = format_ident!("unsafe_impl_{}Vtable", name);
66
67 let vtable_members = methods.iter().map(|(r, sig)| {
68 let name = &sig.ident;
69 let this = if r.mutability.is_some() {
70 quote! {core::pin::Pin<&mut ::cppvtbl::VtableRef<Self>>}
71 } else {
72 quote! {&::cppvtbl::VtableRef<Self>}
73 };
74 let inputs = sig.inputs.iter().skip(1);
75 let output = &sig.output;
76 quote! {
77 pub #name: unsafe extern "C" fn(#this, #(#inputs,)*) #output
78 }
79 });
80 let macro_vtable_fields = methods.iter().map(|(r, sig)| {
81 let meth = &sig.ident;
82 let this = if r.mutability.is_some() {
83 quote! {core::pin::Pin<&mut ::cppvtbl::VtableRef<#vtable_name>>}
84 } else {
85 quote! {&::cppvtbl::VtableRef<#vtable_name>}
86 };
87 let get_top = if r.mutability.is_some() {
88 quote! {
89 let top: &mut ::cppvtbl::WithVtables<$this> = core::mem::transmute((core::pin::Pin::get_unchecked_mut(this) as *mut _ as *mut usize).offset($offset))
90 }
91 } else {
92 quote! {
93 let top: &::cppvtbl::WithVtables<$this> = core::mem::transmute((this as *const _ as *const usize).offset($offset))
94 }
95 };
96 let inputs = sig.inputs.iter().skip(1);
97 let output = &sig.output;
98 let args = sig
99 .inputs
100 .iter()
101 .skip(1)
102 .map(|a| match a {
103 FnArg::Receiver(_) => unreachable!(),
104 FnArg::Typed(t) => t,
105 })
106 .map(|a| match a.pat.as_ref() {
107 syn::Pat::Ident(i) => &i.ident,
108 _ => unreachable!(),
109 });
110 quote! {
111 #meth: {
112 unsafe extern "C" fn #meth(this: #this, #(#inputs,)*) #output {
113 #get_top;
114 <$this as #name>::#meth(top, #(#args,)*)
115 }
116
117 #meth
118 }
119 }
120 });
121 let impl_members = methods.iter().map(|(r, sig)| {
122 let name = &sig.ident;
123 let args = sig
124 .inputs
125 .iter()
126 .skip(1)
127 .map(|a| match a {
128 FnArg::Receiver(_) => unreachable!(),
129 FnArg::Typed(t) => t,
130 })
131 .map(|a| match a.pat.as_ref() {
132 syn::Pat::Ident(i) => &i.ident,
133 _ => unreachable!(),
134 });
135 let self_v = if r.mutability.is_some() {
136 quote! {
137 core::pin::Pin::new_unchecked(self)
140 }
141 } else {
142 quote! {
143 self
144 }
145 };
146 quote! {
147 #sig {
148 unsafe { (self.table().#name)(#self_v, #(#args,)*) }
149 }
150 }
151 });
152 let impl_mut_members = methods.iter().map(|(r, sig)| {
153 let name = &sig.ident;
154 let args = sig
155 .inputs
156 .iter()
157 .skip(1)
158 .map(|a| match a {
159 FnArg::Receiver(_) => unreachable!(),
160 FnArg::Typed(t) => t,
161 })
162 .map(|a| match a.pat.as_ref() {
163 syn::Pat::Ident(i) => &i.ident,
164 _ => unreachable!(),
165 });
166 if r.mutability.is_some() {
167 quote! {
168 #sig {
169 let pin = unsafe { core::pin::Pin::get_unchecked_mut(core::pin::Pin::as_mut(self)) };
171 unsafe { (pin.table().#name)(core::pin::Pin::new_unchecked(pin), #(#args,)*) }
172 }
173 }
174 } else {
175 quote! {
176 #sig {
177 let pin = core::pin::Pin::as_ref(self);
178 unsafe { (pin.table().#name)(&pin, #(#args,)*) }
179 }
180 }
181 }
182 });
183
184 (quote! {
185 #item
186
187 #[repr(C)]
188 pub struct #vtable_name {
189 #(#vtable_members,)*
190 }
191 #[allow(non_upper_case_globals, dead_code)]
192 #[macro_export]
193 macro_rules! #vtable_impl_name {
194 ($impl:ident, $this:ty, $offset:expr) => {
195 #[allow(non_upper_case_globals)]
196 const $impl: &'static #vtable_name = &#vtable_name {
197 #(#macro_vtable_fields,)*
198 };
199 }
200 }
201 impl #name for ::cppvtbl::VtableRef<#vtable_name> {
202 #(#impl_members)*
203 }
204 impl #name for core::pin::Pin<&mut ::cppvtbl::VtableRef<#vtable_name>> {
205 #(#impl_mut_members)*
206 }
207
208 })
209 .into()
210}
211
212struct VtablesInput {
213 tables: Punctuated<Ident, Comma>,
214}
215impl Parse for VtablesInput {
216 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
217 Ok(Self {
218 tables: input.parse_terminated(Ident::parse)?,
219 })
220 }
221}
222
223#[proc_macro_attribute]
224pub fn impl_vtables(attr: TokenStreamRaw, item: TokenStreamRaw) -> TokenStreamRaw {
225 let input = parse_macro_input!(attr as VtablesInput);
226 let this = parse_macro_input!(item as ItemStruct);
227
228 let (impl_generics, ty_generics, where_clause) = this.generics.split_for_impl();
229
230 let this_name = &this.ident;
231
232 let impl_macro_calls = input.tables.iter().enumerate().map(|(i, name)| {
233 let macro_name = format_ident!("unsafe_impl_{}Vtable", name);
234 let const_name = format_ident!("{}VtableFor{}", name, this_name);
235 let i = i as isize;
236 quote! {
237 #macro_name!(#const_name, #this_name, -#i)
238 }
239 });
240 let type_tables = input.tables.iter().map(|name| {
241 let vtable_name = format_ident!("{}Vtable", name);
242 quote! {
243 ::cppvtbl::VtableRef<#vtable_name>
244 }
245 });
246 let impl_tables = input.tables.iter().map(|name| {
247 let const_name = format_ident!("{}VtableFor{}", name, this_name);
248 quote! {
249 unsafe { ::cppvtbl::VtableRef::new(#const_name) }
250 }
251 });
252 let has_vtable = input.tables.iter().enumerate().map(|(i, name)| {
253 let vtable_name = format_ident!("{}Vtable", name);
254 let index = Index::from(i);
255 quote! {
256 impl #impl_generics ::cppvtbl::HasVtable<#vtable_name> for #this_name #ty_generics #where_clause {
257 fn get(from: &::cppvtbl::WithVtables<Self>) -> &::cppvtbl::VtableRef<#vtable_name> {
258 &from.vtables().#index
259 }
260 fn get_mut(from: &mut ::cppvtbl::WithVtables<Self>) -> core::pin::Pin<&mut ::cppvtbl::VtableRef<#vtable_name>> {
261 unsafe { core::pin::Pin::new_unchecked(&mut (&mut *from.vtables_mut()).#index) }
262 }
263 }
264 }
265 });
266
267 (quote! {
268 #this
269 #(#impl_macro_calls;)*
270 unsafe impl #impl_generics ::cppvtbl::HasVtables for #this_name #ty_generics #where_clause {
271 type Tables = (#(#type_tables,)*);
272 const TABLES: Self::Tables = (#(#impl_tables,)*);
273 }
274 #(#has_vtable)*
275 })
276 .into()
277}