intercom_common/attributes/
com_library.rs

1use super::common::*;
2use crate::prelude::*;
3
4use crate::idents;
5use crate::model;
6use crate::utils;
7use std::iter::FromIterator;
8
9use syn::spanned::Spanned;
10
11/// Expands the `com_library` macro.
12///
13/// The macro expansion results in the following items:
14///
15/// - `DllGetClassObject` extern function implementation.
16/// - `IntercomListClassObjects` extern function implementation.
17pub fn expand_com_module(
18    arg_tokens: TokenStreamNightly,
19    com_library: bool,
20) -> Result<TokenStreamNightly, model::ParseError>
21{
22    let mut output = vec![];
23    let lib = model::ComLibrary::parse(&lib_name(), arg_tokens.into())?;
24
25    // Create the match-statmeent patterns for each supposedly visible COM class.
26    let mut match_arms = vec![];
27    for struct_path in &lib.coclasses {
28        // Construct the match pattern.
29        let clsid_path = idents::clsid_path(struct_path);
30        match_arms.push(quote_spanned!(struct_path.span() =>
31            #clsid_path =>
32                return Some(intercom::ClassFactory::<#struct_path>::create(riid, pout))
33        ));
34    }
35
36    let try_submodule_class_factory = lib.submodules.iter().map(|submod| {
37        quote!(
38            if let Some(hr) = #submod::__get_module_class_factory(rclsid, riid, pout) {
39                return Some(hr);
40            }
41        )
42    });
43
44    // Implement the __get_module_class_factory function.
45    output.push(quote!(
46        #[allow(dead_code)]
47        #[doc(hidden)]
48        pub unsafe fn __get_module_class_factory(
49            rclsid : intercom::REFCLSID,
50            riid : intercom::REFIID,
51            pout : *mut intercom::raw::RawComPtr
52        ) -> Option<intercom::raw::HRESULT>
53        {
54            // Create new class factory.
55            // Specify a create function that is able to create all the
56            // contained coclasses.
57            match *rclsid {
58                #( #match_arms, )*
59                _ => {},
60            };
61
62            // Try each sub-module
63            #( #try_submodule_class_factory )*
64
65            None
66        }
67    ));
68
69    // Figure the on_load() function to invoke in DllMain.
70    let on_load = if let Some(ref on_load) = &lib.on_load {
71        quote!(#on_load();)
72    } else {
73        quote!()
74    };
75
76    // Implement DllGetClassObject and DllMain.
77    //
78    // This is more or less the only symbolic entry point that the COM
79    // infrastructure uses. The COM client uses this method to acquire
80    // the IClassFactory interfaces that are then used to construct the
81    // actual coclasses.
82    if com_library {
83        let dll_get_class_object = get_dll_get_class_object_function();
84        output.push(dll_get_class_object);
85        output.push(quote!(
86            #[doc(hidden)]
87            static mut __INTERCOM_DLL_INSTANCE: *mut std::os::raw::c_void = 0 as _;
88
89            #[no_mangle]
90            #[allow(non_camel_case_types)]
91            #[deprecated]
92            #[doc(hidden)]
93            pub extern "system" fn DllMain(
94                dll_instance: *mut std::os::raw::c_void,
95                reason: u32,
96                _reserved: *mut std::os::raw::c_void,
97            ) -> bool
98            {
99                match reason {
100                    // DLL_PROCESS_ATTACH
101                    1 => unsafe {
102                        __INTERCOM_DLL_INSTANCE = dll_instance;
103                        #on_load
104                    },
105                    _ => {}
106                }
107                true
108            }
109        ));
110    }
111
112    // Implement get_intercom_typelib()
113    output.push(create_gather_module_types(&lib));
114    if com_library {
115        output.push(create_get_typelib_function(&lib));
116    }
117
118    // Implement the global DLL entry points
119    if com_library {
120        // DllListClassObjects returns all CLSIDs implemented in the crate.
121        let list_class_objects = get_intercom_list_class_objects_function();
122        output.push(list_class_objects);
123
124        // DllListClassObjects returns all CLSIDs implemented in the crate.
125        let dll_register_server = get_register_server_function(&lib);
126        output.push(dll_register_server);
127    }
128
129    Ok(TokenStream::from_iter(output.into_iter()).into())
130}
131
132fn get_dll_get_class_object_function() -> TokenStream
133{
134    quote!(
135        #[no_mangle]
136        #[allow(non_snake_case)]
137        #[allow(dead_code)]
138        #[doc(hidden)]
139        pub unsafe extern "system" fn DllGetClassObject(
140            rclsid: intercom::REFCLSID,
141            riid: intercom::REFIID,
142            pout: *mut intercom::raw::RawComPtr,
143        ) -> intercom::raw::HRESULT
144        {
145            // Delegate to the module implementation.
146            if let Some(hr) = __get_module_class_factory(rclsid, riid, pout) {
147                return hr;
148            }
149
150            // Try intercom built in types.
151            if let Some(hr) = intercom::__get_module_class_factory(rclsid, riid, pout) {
152                return hr;
153            }
154
155            intercom::raw::E_CLASSNOTAVAILABLE
156        }
157    )
158}
159
160fn create_gather_module_types(lib: &model::ComLibrary) -> TokenStream
161{
162    let create_class_typeinfo = lib.coclasses.iter().map(|path| {
163        quote!(
164            <#path as intercom::attributes::ComClassTypeInfo>::gather_type_info()
165        )
166    });
167    let create_interface_typeinfo = lib.interfaces.iter().map(|path| {
168        quote!(
169            <dyn #path as intercom::attributes::ComInterfaceTypeInfo>::gather_type_info()
170        )
171    });
172    let gather_submodule_types = lib
173        .submodules
174        .iter()
175        .map(|path| quote!( #path::__gather_module_types()));
176    quote!(
177        pub fn __gather_module_types() -> Vec<intercom::typelib::TypeInfo>
178        {
179            vec![
180                #( #create_class_typeinfo, )*
181                #( #gather_submodule_types, )*
182                #( #create_interface_typeinfo, )*
183            ]
184            .into_iter()
185            .flatten()
186            .collect()
187        }
188    )
189}
190
191fn create_get_typelib_function(lib: &model::ComLibrary) -> TokenStream
192{
193    let lib_name = lib_name();
194    let libid = utils::get_guid_tokens(&lib.libid, Span::call_site());
195    quote!(
196        #[no_mangle]
197        pub unsafe extern "system" fn IntercomTypeLib(
198            type_system: intercom::type_system::TypeSystemName,
199            out: *mut intercom::raw::RawComPtr,
200        ) -> intercom::raw::HRESULT
201        {
202            let mut tlib = intercom::ComBox::new(intercom::typelib::TypeLib::__new(
203                    #lib_name.into(),
204                    #libid,
205                    "0.1".into(),
206                    intercom::__gather_module_types()
207                        .into_iter().chain(__gather_module_types())
208                        .collect()
209            ));
210            let rc = intercom::ComRc::<intercom::typelib::IIntercomTypeLib>::from( &tlib );
211            let itf = intercom::ComRc::detach(rc);
212            *out = type_system.get_ptr(&itf);
213
214            intercom::raw::S_OK
215        }
216    )
217}
218
219fn get_intercom_list_class_objects_function() -> TokenStream
220{
221    quote!(
222        #[no_mangle]
223        #[allow(non_snake_case)]
224        #[allow(dead_code)]
225        #[doc(hidden)]
226        pub unsafe extern "system" fn IntercomListClassObjects(
227            pcount: *mut usize,
228            pclsids: *mut *const intercom::CLSID,
229        ) -> intercom::raw::HRESULT
230        {
231            // Do not crash due to invalid parameters.
232            if pcount.is_null() {
233                return intercom::raw::E_POINTER;
234            }
235            if pclsids.is_null() {
236                return intercom::raw::E_POINTER;
237            }
238
239            // Store the available CLSID in a static variable so that we can
240            // pass them as-is to the caller.
241            static mut AVAILABLE_CLASSES: Option<Vec<intercom::CLSID>> = None;
242            static INIT_AVAILABLE_CLASSES: std::sync::Once = std::sync::Once::new();
243            INIT_AVAILABLE_CLASSES.call_once(|| unsafe {
244                AVAILABLE_CLASSES = Some(
245                    __gather_module_types()
246                        .into_iter()
247                        .chain(intercom::__gather_module_types())
248                        .filter_map(|ty| match ty {
249                            intercom::typelib::TypeInfo::Class(cls) => Some(cls.clsid.clone()),
250                            _ => None,
251                        })
252                        .collect(),
253                );
254            });
255
256            // com_struct will drop here and decrement the referenc ecount.
257            // This is okay, as the query_interface incremented it, leaving
258            // it at two at this point.
259            let available_classes = AVAILABLE_CLASSES
260                .as_ref()
261                .expect("AVAILABLE_CLASSES was not initialized");
262            *pcount = available_classes.len();
263            *pclsids = available_classes.as_ptr();
264
265            intercom::raw::S_OK
266        }
267    )
268}
269
270fn get_register_server_function(lib: &model::ComLibrary) -> TokenStream
271{
272    // We'll need token streams to for the hook functions to use in the quote macros. If the user
273    // did not specify hook functions to use, we'll define the token streams as empty.
274    let on_register = if let Some(ref on_register) = &lib.on_register {
275        quote!(if let Err(hr) = #on_register() { return hr; })
276    } else {
277        quote!()
278    };
279    let on_unregister = if let Some(ref on_unregister) = &lib.on_unregister {
280        quote!(if let Err(hr) = #on_unregister() { return hr; })
281    } else {
282        quote!()
283    };
284
285    let lib_name = lib_name();
286    let libid = utils::get_guid_tokens(&lib.libid, Span::call_site());
287    quote!(
288        #[no_mangle]
289        #[allow(non_snake_case)]
290        #[allow(dead_code)]
291        #[doc(hidden)]
292        pub unsafe extern "system" fn DllRegisterServer() -> intercom::raw::HRESULT
293        {
294            let mut tlib = intercom::typelib::TypeLib::__new(
295                    #lib_name.into(),
296                    #libid,
297                    "0.1".into(),
298                    intercom::__gather_module_types()
299                        .into_iter().chain(__gather_module_types())
300                        .collect()
301            );
302
303            if let Err(hr) = intercom::registry::register(__INTERCOM_DLL_INSTANCE, tlib) {
304                return hr;
305            }
306
307            #on_register
308
309            intercom::raw::S_OK
310        }
311
312        #[no_mangle]
313        #[allow(non_snake_case)]
314        #[allow(dead_code)]
315        #[doc(hidden)]
316        pub unsafe extern "system" fn DllUnregisterServer() -> intercom::raw::HRESULT
317        {
318            let mut tlib = intercom::typelib::TypeLib::__new(
319                    #lib_name.into(),
320                    #libid,
321                    "0.1".into(),
322                    intercom::__gather_module_types()
323                        .into_iter().chain(__gather_module_types())
324                        .collect()
325            );
326
327            if let Err(hr) = intercom::registry::unregister(__INTERCOM_DLL_INSTANCE, tlib) {
328                return hr;
329            }
330
331            #on_unregister
332
333            intercom::raw::S_OK
334        }
335    )
336}