cppvtbl_macros/
lib.rs

1use proc_macro::TokenStream as TokenStreamRaw;
2use quote::{format_ident, quote, quote_spanned};
3use syn::{
4	parse::Parse, parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, FnArg,
5	Ident, Index, ItemStruct, ItemTrait,
6};
7
8#[proc_macro_attribute]
9pub fn vtable(attr: TokenStreamRaw, item: TokenStreamRaw) -> TokenStreamRaw {
10	if !attr.is_empty() {
11		return (quote! { compile_error!("vtable attribute has no args"); }).into();
12	}
13
14	let item = parse_macro_input!(item as ItemTrait);
15
16	for item in item.items.iter() {
17		let m = match item {
18			syn::TraitItem::Method(m) => m,
19			_ => {
20				return (quote_spanned! { item.span() => compile_error!("only methods allowed in trait"); })
21					.into()
22			}
23		};
24		let r = match m.sig.receiver() {
25			Some(FnArg::Receiver(r)) => { r },
26			Some(v) => {return (quote_spanned! { v.span() => compile_error!("only self receivers allowed"); }).into()}
27			None => {return (quote_spanned! { m.sig.span() => compile_error!("expected receiver"); }).into()},
28		};
29		if r.reference.is_none() {
30			return (quote_spanned! { r.span() => compile_error!("should be reference type"); })
31				.into();
32		}
33
34		for arg in m.sig.inputs.iter().skip(1) {
35			let arg = match arg {
36				FnArg::Typed(arg) => arg,
37				_ => unreachable!(),
38			};
39
40			match arg.pat.as_ref() {
41				syn::Pat::Ident(_) => {},
42				pat => return (quote_spanned! { pat.span() => compile_error!("only ident patterns allowed"); }).into(),
43			}
44		}
45	}
46
47	let methods = item
48		.items
49		.iter()
50		.filter_map(|i| match i {
51			syn::TraitItem::Method(m) => Some(m.sig.clone()),
52			_ => unreachable!(),
53		})
54		.map(|sig| {
55			let r = match sig.receiver() {
56				Some(FnArg::Receiver(r)) => r.clone(),
57				_ => unreachable!(),
58			};
59			(r, sig)
60		})
61		.collect::<Vec<_>>();
62
63	let name = &item.ident;
64	let vtable_name = format_ident!("{}Vtable", name);
65	let vtable_impl_name = format_ident!("unsafe_impl_{}Vtable", name);
66
67	let vtable_members = methods.iter().map(|(r, sig)| {
68		let name = &sig.ident;
69		let this = if r.mutability.is_some() {
70			quote! {core::pin::Pin<&mut ::cppvtbl::VtableRef<Self>>}
71		} else {
72			quote! {&::cppvtbl::VtableRef<Self>}
73		};
74		let inputs = sig.inputs.iter().skip(1);
75		let output = &sig.output;
76		quote! {
77			pub #name: unsafe extern "C" fn(#this, #(#inputs,)*) #output
78		}
79	});
80	let macro_vtable_fields = methods.iter().map(|(r, sig)| {
81		let meth = &sig.ident;
82		let this = if r.mutability.is_some() {
83			quote! {core::pin::Pin<&mut ::cppvtbl::VtableRef<#vtable_name>>}
84		} else {
85			quote! {&::cppvtbl::VtableRef<#vtable_name>}
86		};
87		let get_top = if r.mutability.is_some() {
88			quote! {
89				 let top: &mut ::cppvtbl::WithVtables<$this> = core::mem::transmute((core::pin::Pin::get_unchecked_mut(this) as *mut _ as *mut usize).offset($offset))
90			}
91		} else {
92			quote! {
93				let top: &::cppvtbl::WithVtables<$this> = core::mem::transmute((this as *const _ as *const usize).offset($offset))
94			}
95		};
96		let inputs = sig.inputs.iter().skip(1);
97		let output = &sig.output;
98		let args = sig
99			.inputs
100			.iter()
101			.skip(1)
102			.map(|a| match a {
103				FnArg::Receiver(_) => unreachable!(),
104				FnArg::Typed(t) => t,
105			})
106			.map(|a| match a.pat.as_ref() {
107				syn::Pat::Ident(i) => &i.ident,
108				_ => unreachable!(),
109			});
110		quote! {
111			#meth: {
112				unsafe extern "C" fn #meth(this: #this, #(#inputs,)*) #output {
113					#get_top;
114					<$this as #name>::#meth(top, #(#args,)*)
115				}
116
117				#meth
118			}
119		}
120	});
121	let impl_members = methods.iter().map(|(r, sig)| {
122		let name = &sig.ident;
123		let args = sig
124			.inputs
125			.iter()
126			.skip(1)
127			.map(|a| match a {
128				FnArg::Receiver(_) => unreachable!(),
129				FnArg::Typed(t) => t,
130			})
131			.map(|a| match a.pat.as_ref() {
132				syn::Pat::Ident(i) => &i.ident,
133				_ => unreachable!(),
134			});
135		let self_v = if r.mutability.is_some() {
136			quote! {
137				// Safety: creating mut reference to VtableRef is already unsafe
138				// so we assuming pin is valid
139				core::pin::Pin::new_unchecked(self)
140			}
141		} else {
142			quote! {
143				self
144			}
145		};
146		quote! {
147			#sig {
148				unsafe { (self.table().#name)(#self_v, #(#args,)*) }
149			}
150		}
151	});
152	let impl_mut_members = methods.iter().map(|(r, sig)| {
153		let name = &sig.ident;
154		let args = sig
155			.inputs
156			.iter()
157			.skip(1)
158			.map(|a| match a {
159				FnArg::Receiver(_) => unreachable!(),
160				FnArg::Typed(t) => t,
161			})
162			.map(|a| match a.pat.as_ref() {
163				syn::Pat::Ident(i) => &i.ident,
164				_ => unreachable!(),
165			});
166		if r.mutability.is_some() {
167			quote! {
168				#sig {
169					// Safety: we're not moving pinned value, nor giving inner code access to it
170					let pin = unsafe { core::pin::Pin::get_unchecked_mut(core::pin::Pin::as_mut(self)) };
171					unsafe { (pin.table().#name)(core::pin::Pin::new_unchecked(pin), #(#args,)*) }
172				}
173			}
174		} else {
175			quote! {
176				#sig {
177					let pin = core::pin::Pin::as_ref(self);
178					unsafe { (pin.table().#name)(&pin, #(#args,)*) }
179				}
180			}
181		}
182	});
183
184	(quote! {
185		#item
186
187		#[repr(C)]
188		pub struct #vtable_name {
189			#(#vtable_members,)*
190		}
191		#[allow(non_upper_case_globals, dead_code)]
192		#[macro_export]
193		macro_rules! #vtable_impl_name {
194			($impl:ident, $this:ty, $offset:expr) => {
195				#[allow(non_upper_case_globals)]
196				const $impl: &'static #vtable_name = &#vtable_name {
197					#(#macro_vtable_fields,)*
198				};
199			}
200		}
201		impl #name for ::cppvtbl::VtableRef<#vtable_name> {
202			#(#impl_members)*
203		}
204		impl #name for core::pin::Pin<&mut ::cppvtbl::VtableRef<#vtable_name>> {
205			#(#impl_mut_members)*
206		}
207
208	})
209	.into()
210}
211
212struct VtablesInput {
213	tables: Punctuated<Ident, Comma>,
214}
215impl Parse for VtablesInput {
216	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
217		Ok(Self {
218			tables: input.parse_terminated(Ident::parse)?,
219		})
220	}
221}
222
223#[proc_macro_attribute]
224pub fn impl_vtables(attr: TokenStreamRaw, item: TokenStreamRaw) -> TokenStreamRaw {
225	let input = parse_macro_input!(attr as VtablesInput);
226	let this = parse_macro_input!(item as ItemStruct);
227
228	let (impl_generics, ty_generics, where_clause) = this.generics.split_for_impl();
229
230	let this_name = &this.ident;
231
232	let impl_macro_calls = input.tables.iter().enumerate().map(|(i, name)| {
233		let macro_name = format_ident!("unsafe_impl_{}Vtable", name);
234		let const_name = format_ident!("{}VtableFor{}", name, this_name);
235		let i = i as isize;
236		quote! {
237			#macro_name!(#const_name, #this_name, -#i)
238		}
239	});
240	let type_tables = input.tables.iter().map(|name| {
241		let vtable_name = format_ident!("{}Vtable", name);
242		quote! {
243			::cppvtbl::VtableRef<#vtable_name>
244		}
245	});
246	let impl_tables = input.tables.iter().map(|name| {
247		let const_name = format_ident!("{}VtableFor{}", name, this_name);
248		quote! {
249			unsafe { ::cppvtbl::VtableRef::new(#const_name) }
250		}
251	});
252	let has_vtable = input.tables.iter().enumerate().map(|(i, name)| {
253		let vtable_name = format_ident!("{}Vtable", name);
254		let index = Index::from(i);
255		quote! {
256			impl #impl_generics ::cppvtbl::HasVtable<#vtable_name> for #this_name #ty_generics #where_clause {
257				fn get(from: &::cppvtbl::WithVtables<Self>) -> &::cppvtbl::VtableRef<#vtable_name> {
258					&from.vtables().#index
259				}
260				fn get_mut(from: &mut ::cppvtbl::WithVtables<Self>) -> core::pin::Pin<&mut ::cppvtbl::VtableRef<#vtable_name>> {
261					unsafe { core::pin::Pin::new_unchecked(&mut (&mut *from.vtables_mut()).#index) }
262				}
263			}
264		}
265	});
266
267	(quote! {
268		#this
269		#(#impl_macro_calls;)*
270		unsafe impl #impl_generics ::cppvtbl::HasVtables for #this_name #ty_generics #where_clause {
271			type Tables = (#(#type_tables,)*);
272			const TABLES: Self::Tables = (#(#impl_tables,)*);
273		}
274		#(#has_vtable)*
275	})
276	.into()
277}