nuidl_lib/codegen/
rust.rs

1use crate::codegen::{File, FnParam, Function, Interface, InterfaceDefn, Toplevel};
2use crate::parser::Ctype;
3use proc_macro2::{Ident, Span, TokenStream};
4use quote::quote;
5use std::io;
6use std::io::Write;
7use std::path::Path;
8use syn::token;
9
10pub fn write_rust_module<W: Write>(mut w: W, f: &File, source: &Path) -> io::Result<()> {
11    let tokens = rust_module(f);
12    let st = syn::parse2(tokens).unwrap();
13    let text = prettyplease::unparse(&st);
14    writeln!(w, "// Auto-generated by nuidl from {}.", source.display())?;
15    writeln!(w, "{}", text)?;
16
17    Ok(())
18}
19
20pub fn rust_module(f: &File) -> TokenStream {
21    fn make_imports(
22        pub_t: TokenStream,
23        prefix: TokenStream,
24        f: &File,
25    ) -> impl Iterator<Item = TokenStream> + '_ {
26        f.imports
27            .iter()
28            .filter_map(|i| i.rust_name.as_ref())
29            .map(move |name| quote! { #pub_t use #prefix::#name::scope::*; })
30    }
31
32    let imports = make_imports(quote! {}, quote! { super }, f);
33    let pub_imports = make_imports(quote! { pub }, quote! { super::super }, f);
34
35    let elems = f.elems.iter().map(|elem| match elem {
36        Toplevel::CppQuote(_) => quote! {},
37        Toplevel::Interface(itf) => make_interface(itf),
38    });
39
40    quote! {
41        use ::nucomcore::idl::prelude::*;
42
43        #[doc(hidden)]
44        pub mod scope {
45            pub use ::nucomcore::idl::prelude::*;
46            pub use super::*;
47
48            #(#pub_imports)*
49        }
50
51        #(#imports)*
52        #(#elems)*
53    }
54}
55
56struct InterfaceInfo<'a> {
57    itf: &'a Interface,
58    defn: &'a InterfaceDefn,
59    all_fns: Vec<(&'a Function, &'a Interface)>,
60}
61
62fn make_interface(itf: &Interface) -> TokenStream {
63    let Some(defn) = &itf.defn else {
64        println!(
65            "forward-declaration for interface {} not supported when generating Rust",
66            itf.name
67        );
68        return quote! {};
69    };
70
71    let mut all_fns = Vec::new();
72
73    fn collect_fns<'a>(fns: &mut Vec<(&'a Function, &'a Interface)>, itf: &'a Interface) {
74        let defn = itf.defn.as_ref().unwrap();
75
76        if let Some(base) = &defn.base {
77            collect_fns(fns, &base);
78        }
79
80        fns.extend(defn.fns.iter().map(|v| (v, itf)));
81    }
82
83    collect_fns(&mut all_fns, itf);
84
85    let info = InterfaceInfo { itf, defn, all_fns };
86
87    let name = make_ident(&itf.name);
88    let vtbl_name = make_ident(&format!("{}Vtbl", name));
89    let impl_name = make_ident(&format!("{}Impl", name));
90    let safe_ext_name = make_ident(&format!("{}Ext", name));
91    let unsafe_ext_name = make_ident(&format!("{}UnsafeExt", name));
92    let base_impl = defn
93        .base
94        .as_deref()
95        .map(|base| Ident::new(&format!("{}Impl", base.name), Span::call_site()))
96        .map(|ident| quote! { : #ident })
97        .unwrap_or_default();
98    let ffi_functions = make_ffi_functions(&info);
99    let impl_functions = make_impl_functions(&info);
100    let safe_ext_functions = make_ext_functions(&info, ExtVariant::Safe);
101    let unsafe_ext_functions = make_ext_functions(&info, ExtVariant::Unsafe);
102    let dispatch = make_dispatch(&info);
103    let (d1, d2, d3, &d4) = defn.uuid.unwrap().as_fields();
104
105    let hierarchy = {
106        let mut hierarchy = Vec::new();
107        let mut cur = Some(itf);
108
109        while let Some(c) = cur {
110            let base_name = make_ident(&c.name);
111            let imp = quote! { unsafe impl ::nucomcore::interface::ComHierarchy<#base_name> for #name {} };
112            hierarchy.push(imp);
113            cur = c.defn.as_ref().and_then(|v| v.base.as_ref()).map(|v| &**v);
114        }
115
116        hierarchy
117    };
118
119    let itf_doc = {
120        let mut extends = String::new();
121
122        if let Some(base) = &defn.base {
123            extends = format!(" Extends [`{}`].", base.name);
124        }
125
126        format!(
127            r#"
128            Interface `{name}`.{extends}
129
130            To implement this interface, see [`{impl_name}`].
131            To call methods of this interface, see [`{safe_ext_name}`] and [`{unsafe_ext_name}`].
132            "#
133        )
134    };
135
136    let vtable_doc = format!(
137        r#"
138        Vtable for interface [`{name}`].
139
140        This is the table of function pointers corresponding to an object's
141        implementation of this interface. To generate this for a type
142        implementing trait [`{impl_name}`], see [`{vtbl_name}::dispatch`].
143        "#
144    );
145
146    let impl_doc = format!(
147        r#"
148        Trait implementing methods of interface [`{name}`].
149
150        To make a struct implementing this trait into a COM object, add a
151        reference to the vtable as the first field of the struct. This reference
152        can be generated using [`{vtbl_name}::dispatch`].
153        "#
154    );
155
156    let safe_ext_doc = format!(
157        r#"
158        Extension trait allowing calling the methods of the interface [`{name}`]
159        on COM pointers such as [`ComRef<T>`](nucomcore::interface::ComRef) or
160        [`ComPtr<T>`](nucomcore::interface::ComPtr).
161
162        To call these methods, the COM pointer needs to be
163        [`Safe`](nucomcore::interface::Safe), e.g. `ComPtr<{name}, _, Safe>`.
164        "#
165    );
166
167    let unsafe_ext_doc = format!(
168        r#"
169        Extension trait allowing calling the methods of the interface [`{name}`]
170        on COM pointers such as [`ComRef<T>`](nucomcore::interface::ComRef) or
171        [`ComPtr<T>`](nucomcore::interface::ComPtr).
172
173        To call these methods, the COM pointer needs to be
174        [`Unsafe`](nucomcore::interface::Unsafe), e.g.
175        `ComPtr<{name}, _, Unsafe>`.
176        "#
177    );
178
179    quote! {
180        #[doc = #itf_doc]
181        #[repr(C)]
182        pub struct #name {
183            pub vtbl: *const #vtbl_name,
184            _data: ()
185        }
186
187        #[doc = #vtable_doc]
188        #[repr(C)]
189        #[allow(non_snake_case)]
190        pub struct #vtbl_name {
191            #ffi_functions
192        }
193
194        impl ::nucomcore::Identify for #name {
195            const GUID: ::nucomcore::GUID = ::nucomcore::GUID(#d1, #d2, #d3, [#(#d4),*]);
196        }
197
198        impl ::nucomcore::interface::ComInterface for #name {
199            type Vtbl = #vtbl_name;
200            type Mt = ::nucomcore::interface::Apartment;
201        }
202
203        #(#hierarchy)*
204
205        #[allow(non_snake_case)]
206        impl #vtbl_name {
207            #dispatch
208        }
209
210        #[doc = #impl_doc]
211        #[allow(non_snake_case)]
212        pub trait #impl_name #base_impl {
213            #impl_functions
214        }
215
216        #[doc = #safe_ext_doc]
217        #[allow(non_snake_case)]
218        pub trait #safe_ext_name: ::nucomcore::interface::ComDeref
219            where
220                <Self as ::nucomcore::interface::ComDeref>::Target: ::nucomcore::interface::ComHierarchy<#name>,
221        {
222            #safe_ext_functions
223        }
224
225        impl <T> #safe_ext_name for T
226            where
227                T: ::nucomcore::interface::ComDeref<Safety = ::nucomcore::interface::Safe>,
228                <T as ::nucomcore::interface::ComDeref>::Target: ::nucomcore::interface::ComHierarchy<#name>,
229        {}
230
231        #[doc = #unsafe_ext_doc]
232        #[allow(non_snake_case)]
233        pub trait #unsafe_ext_name: ::nucomcore::interface::ComDeref
234            where
235                <Self as ::nucomcore::interface::ComDeref>::Target: ::nucomcore::interface::ComHierarchy<#name>,
236        {
237            #unsafe_ext_functions
238        }
239
240        impl <T> #unsafe_ext_name for T
241            where
242                T: ::nucomcore::interface::ComDeref<Safety = ::nucomcore::interface::Unsafe>,
243                <T as ::nucomcore::interface::ComDeref>::Target: ::nucomcore::interface::ComHierarchy<#name>,
244        {}
245    }
246}
247
248fn make_ffi_functions(info: &InterfaceInfo) -> TokenStream {
249    let fns = info.all_fns.iter().map(|(f, src)| make_ffi_function(info.itf, f, src));
250
251    quote! {
252        #(#fns,)*
253    }
254}
255
256fn make_ffi_function(itf: &Interface, f: &Function, source: &Interface) -> TokenStream {
257    fn make_fn_param(p: &FnParam) -> TokenStream {
258        let name = p.name.iter().map(|n| Ident::new(n, Span::call_site()));
259        let ty = make_ctype(&p.ty);
260
261        quote! {
262            #(#name:)* #ty
263        }
264    }
265
266    let name = Ident::new(&f.name, Span::call_site());
267    let itf_name = Ident::new(&itf.name, Span::call_site());
268    let args = f.params.iter().map(|par| make_fn_param(par));
269    let ret = make_ctype(&f.ret);
270
271    let fn_doc = {
272        let mut inherits = String::new();
273
274        if source != itf {
275            inherits = format!(" Inherited from [`{}`].", source.name);
276        }
277
278        format!(
279            r#"
280            Function `{itf_name}::{name}`.{inherits}
281            "#
282        )
283    };
284
285    quote! {
286        #[doc = #fn_doc]
287        pub #name: unsafe extern "system" fn(this: *mut #itf_name #(, #args)*) -> #ret
288    }
289}
290
291fn make_impl_functions(info: &InterfaceInfo) -> TokenStream {
292    let fns = info.defn.fns.iter().map(|f| make_impl_function(f));
293
294    quote! {
295        #(#fns;)*
296    }
297}
298
299fn make_impl_function(f: &Function) -> TokenStream {
300    fn make_fn_param(p: &FnParam) -> TokenStream {
301        let name = p.name.iter().map(|n| Ident::new(n, Span::call_site()));
302        let ty = make_ctype(&p.ty);
303
304        quote! {
305            #(#name:)* #ty
306        }
307    }
308
309    let name = make_ident(&f.name);
310    let args = f.params.iter().map(|par| make_fn_param(par));
311    let ret = make_ctype(&f.ret);
312
313    quote! {
314        unsafe extern "system" fn #name(this: *mut Self #(, #args)*) -> #ret
315    }
316}
317
318#[derive(Copy, Clone, Debug, Eq, PartialEq)]
319enum ExtVariant {
320    Safe,
321    Unsafe,
322}
323
324fn make_ext_functions(info: &InterfaceInfo, variant: ExtVariant) -> TokenStream {
325    let fns = info
326        .defn
327        .fns
328        .iter()
329        .map(|f| make_ext_function(info, f, variant));
330
331    quote! {
332        #(#fns)*
333    }
334}
335
336fn make_ext_function(info: &InterfaceInfo, f: &Function, mut variant: ExtVariant) -> TokenStream {
337    // TODO: mark this in IDL as "always unsafe" instead of hardcoding here
338    if f.name == "Release" {
339        variant = ExtVariant::Unsafe;
340    }
341
342    let itf_name = make_ident(&info.itf.name);
343    let name = make_ident(&f.name);
344    let ret = make_ctype(&f.ret);
345    let args_in = make_params_in(f);
346    let args_out = make_params_out(f);
347    let unsafe_t = match variant {
348        ExtVariant::Safe => None,
349        ExtVariant::Unsafe => Some(token::Unsafe::default()),
350    };
351
352    quote! {
353        #[inline]
354        #unsafe_t fn #name(&self, #args_in) -> #ret {
355            let ptr = ::nucomcore::interface::ComDeref::com_deref(self).upcast::<#itf_name>();
356            let ptr = ptr.into_raw();
357            unsafe { ((*ptr.as_ref().vtbl).#name)(ptr.as_ptr(), #args_out) }
358        }
359    }
360}
361
362fn make_param_name(idx: usize, p: &FnParam) -> TokenStream {
363    let name = p
364        .name
365        .as_deref()
366        .cloned()
367        .unwrap_or_else(|| format!("param_{}", idx));
368    let name = Ident::new(&name, Span::call_site());
369
370    quote! {
371        #name
372    }
373}
374
375fn make_params_in(f: &Function) -> TokenStream {
376    let args = f
377        .params
378        .iter()
379        .enumerate()
380        .map(|(idx, p)| make_param_name(idx, p));
381
382    let arg_types = f.params.iter().map(|p| make_ctype(&p.ty));
383
384    quote! {
385        #(#args: #arg_types,)*
386    }
387}
388
389fn make_params_out(f: &Function) -> TokenStream {
390    let args = f
391        .params
392        .iter()
393        .enumerate()
394        .map(|(idx, p)| make_param_name(idx, p));
395
396    quote! {
397        #(#args),*
398    }
399}
400
401fn make_dispatch(info: &InterfaceInfo) -> TokenStream {
402    fn make_function_wrapper(info: &InterfaceInfo, f: &Function) -> TokenStream {
403        let name = make_ident(&f.name);
404        let itf_impl = make_ident(&format!("{}Impl", info.itf.name));
405        let itf = make_ident(&info.itf.name);
406        let ret = make_ctype(&f.ret);
407        let params_in = make_params_in(f);
408        let params_out = make_params_out(f);
409
410        quote! {
411            unsafe extern "system" fn #name<const OFFSET: usize, T: #itf_impl>(this: *mut #itf, #params_in) -> #ret {
412                let this = (this as *mut *mut ()).offset(-(OFFSET as isize)) as *mut T;
413                T::#name(this, #params_out)
414            }
415        }
416    }
417
418    let name = make_ident(&info.itf.name);
419    let vtbl_name = make_ident(&format!("{}Vtbl", name));
420    let impl_name = make_ident(&format!("{}Impl", name));
421
422    let wrappers = info.all_fns.iter().map(|(f, _)| make_function_wrapper(info, f));
423
424    let vtable_elems = info.all_fns.iter().map(|(f, _)| {
425        let name = make_ident(&f.name);
426        quote! { #name: #name::<OFFSET, T> }
427    });
428
429    let dispatch_doc = format!(
430        r#"
431        Creates a [`{name}`] vtable for `T` with the specified `OFFSET`.
432
433        COM interface methods in a vtable receive the pointer to the vtable as
434        their first parameter. This pointer may not necessarily correspond with
435        the pointer to the actual `T` struct (for example, when it contains more
436        than one vtable). This is why the functions in the vtable returned by
437        this function may need to offset the pointer to get the real address for
438        `self` to call the [`{impl_name}`] functions with.
439
440        This is what the `OFFSET` parameter is for: it specifies the offset of
441        the generated vtable in struct `T` in multiples of the platform pointer
442        size. Effectively, this means for a struct with multiple vtables, the
443        first field would have `OFFSET = 0`, the second `OFFSET = 1`, and so on.
444        This is why, when using this function, vtable pointers must only be
445        preceded by other vtable pointers in the struct definition.
446
447        Also see [`nucomcore::decl_class_vtable`] for generating a class vtable
448        without having to manually specify the offset.
449        "#
450    );
451
452    quote! {
453        #[doc = #dispatch_doc]
454        pub const fn dispatch<const OFFSET: usize, T: #impl_name>() -> &'static #vtbl_name {
455            #(#wrappers)*
456
457            & #vtbl_name {
458                #(#vtable_elems,)*
459            }
460        }
461    }
462}
463
464fn make_ctype(ty: &Ctype) -> TokenStream {
465    if ty.is_void() {
466        return quote! {
467            ()
468        };
469    }
470
471    let ptrs = [ty.is_const]
472        .into_iter()
473        .chain(ty.indirection.iter().map(|v| v.is_const))
474        .rev()
475        .skip(1)
476        .map(|is_const| {
477            if is_const {
478                quote! { *const }
479            } else {
480                quote! { *mut }
481            }
482        });
483
484    let inner = Ident::new(&ty.typename, Span::call_site());
485
486    quote! {
487        #(#ptrs)* #inner
488    }
489}
490
491fn make_ident(id: &str) -> Ident {
492    Ident::new(id, Span::call_site())
493}
494
495#[cfg(test)]
496mod test {
497    use super::make_ctype;
498    use crate::parser::{Ctype, Ptr};
499    use quote::quote;
500
501    #[test]
502    fn ptr() {
503        // char const *const * should become *const *const char
504        // (reverse order and the outer pointer's mutability ignored)
505
506        let ty = Ctype {
507            typename: "char".to_string(),
508            is_const: true,
509            indirection: vec![Ptr { is_const: true }, Ptr { is_const: false }],
510        };
511
512        let expected = quote! { *const *const char };
513
514        assert_eq!(expected.to_string(), make_ctype(&ty).to_string());
515    }
516}