com_scrape_types/
class.rs

1use std::ops::Deref;
2use std::ptr::addr_of;
3use std::sync::Arc;
4
5use super::{ComPtr, ComRef, Guid, Interface};
6
7macro_rules! offset_of {
8    ($struct:ty, $field:tt) => {{
9        use ::std::ffi::c_void;
10        use ::std::mem::MaybeUninit;
11        use ::std::ptr::addr_of;
12
13        let dummy = MaybeUninit::<$struct>::uninit();
14        let base = dummy.as_ptr();
15        let field = addr_of!((*base).$field);
16
17        (field as *const c_void).offset_from(base as *const c_void)
18    }};
19}
20
21/// Helper functionality used in generated virtual tables for Rust types.
22///
23/// The purpose of this trait is to allow the [`Construct`] implementations generated by
24/// `com-scrape` to avoid hard-coding any particular reference counting logic or class layout, and
25/// to instead allow this logic to be plugged in at the point where the COM object is constructed.
26///
27/// Currently, the only type implementing this trait is [`ComWrapper`], but it would be possible
28/// for additional wrapper types to implement it in the future.
29pub trait Wrapper<C: Class + ?Sized> {
30    /// Given a pointer to an object's header, returns a pointer to the object itself.
31    unsafe fn data_from_header(ptr: *mut Header<C>) -> *mut C;
32
33    /// Given a pointer to an object, returns a pointer to the object's header.
34    unsafe fn header_from_data(ptr: *mut C) -> *mut Header<C>;
35
36    /// Increments the reference count of an object and returns the resulting count.
37    unsafe fn add_ref(ptr: *mut C) -> usize;
38
39    /// Decrements the reference count of an object and returns the resulting count.
40    unsafe fn release(ptr: *mut C) -> usize;
41}
42
43/// Generates the virtual table and base class object for a given class and interface.
44///
45/// The `W` parameter is the wrapper type which provides helper functionality for reference
46/// counting and class layout. The `OFFSET` parameter is the offset, within the header, at which
47/// the constructed base class object will be located.
48///
49/// This trait is used in the implementations of [`MakeHeader`] for tuples of interfaces. Interface
50/// types generated by `com-scrape` will implement this trait.
51///
52/// # Safety
53///
54/// When `I::OBJ` is reinterpreted as `*const I::Vtbl` (see the
55/// [safety documentation](Interface#safety) for [`Interface`]), it must be a valid pointer to an
56/// instance of `I::Vtbl`.
57pub unsafe trait Construct<C: Class, W: Wrapper<C>, const OFFSET: isize>: Interface {
58    /// The generated base class object.
59    const OBJ: Self;
60}
61
62/// A list of COM interfaces implemented by a Rust type.
63///
64/// Provides a header type containing base class objects for each interface in the list, as well as
65/// a method for querying whether an interface is a member of the list and, if so, at what offset
66/// the corresponding base class object is located in the header.
67///
68/// This trait is implemented for tuples of interface types.
69///
70/// # Safety
71///
72/// If `L::query(I::IID)` returns `Some(offset)` for an [`Interface`] `I`, then whenever
73/// `ptr: *mut L::Header` points to a valid instance of `L::Header`,
74/// `(ptr as *mut u8).offset(offset) as *mut I` must point to a valid instance of `I`.
75pub unsafe trait InterfaceList {
76    /// Header type containing a base class object for each of the interfaces in the list.
77    type Header;
78
79    /// If there is an interface in the list whose GUID equals `iid`, or which transitively derives
80    /// from an interface whose GUID equals `iid`, `query` returns the offset of the corresponding
81    /// base class object within `Self::Header`.
82    fn query(iid: &Guid) -> Option<isize>;
83}
84
85/// Generates the object header for a given class and list of interfaces.
86///
87/// This trait is implemented for tuples of interface types, and it is used by [`ComWrapper`] to
88/// construct the object header for a given Rust value.
89///
90/// # Safety
91///
92/// If `L::query(I::IID)` returns `Some(offset)` for an [`Interface`] `I` (see [`InterfaceList`]),
93/// the `I` value located at offset `offset` within `L::Header` must (when reinterpreted as
94/// `*const I::Vtbl`) be a valid pointer to an instance of `I::Vtbl`.
95pub unsafe trait MakeHeader<C, W>: InterfaceList
96where
97    C: Class,
98    W: Wrapper<C>,
99{
100    const HEADER: Self::Header;
101}
102
103/// A Rust type that defines a COM class.
104///
105/// Must be implemented for a type to be used with [`ComWrapper`].
106pub trait Class {
107    /// The list of interfaces implemented by this Rust type.
108    ///
109    /// Should be given as a tuple, e.g.:
110    ///
111    /// ```ignore
112    /// struct MyClass;
113    ///
114    /// impl Class for MyClass {
115    ///     type Interfaces = (ISomeInterface, IAnotherInterface);
116    /// }
117    /// ```
118    type Interfaces: InterfaceList;
119}
120
121/// Convenience alias for getting the object header of a [`Class`].
122pub type Header<C> = <<C as Class>::Interfaces as InterfaceList>::Header;
123
124macro_rules! interface_list {
125    ($header:ident, $($interface:ident $index:tt),*) => {
126        #[repr(C)]
127        pub struct $header<$($interface),*>($($interface),*);
128
129        unsafe impl<$($interface: Interface),*> InterfaceList for ($($interface,)*) {
130            type Header = $header<$($interface),*>;
131
132            fn query(iid: &Guid) -> Option<isize> {
133                $(
134                    if $interface::inherits(iid) {
135                        return Some($index * std::mem::size_of::<*mut ()>() as isize);
136                    }
137                )*
138
139                None
140            }
141        }
142
143        unsafe impl<C, W $(, $interface)*> MakeHeader<C, W> for ($($interface,)*)
144        where
145            C: Class,
146            W: Wrapper<C>,
147            $($interface: Construct<C, W, { $index * std::mem::size_of::<*mut ()>() as isize }>,)*
148        {
149            const HEADER: Self::Header = $header($($interface::OBJ),*);
150        }
151    }
152}
153
154interface_list!(Header1, I0 0);
155interface_list!(Header2, I0 0, I1 1);
156interface_list!(Header3, I0 0, I1 1, I2 2);
157interface_list!(Header4, I0 0, I1 1, I2 2, I3 3);
158interface_list!(Header5, I0 0, I1 1, I2 2, I3 3, I4 4);
159interface_list!(Header6, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5);
160interface_list!(Header7, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6);
161interface_list!(Header8, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7);
162interface_list!(Header9, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8);
163interface_list!(Header10, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9);
164interface_list!(Header11, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10);
165interface_list!(Header12, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10, I11 11);
166interface_list!(Header13, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10, I11 11, I12 12);
167interface_list!(Header14, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10, I11 11, I12 12, I13 13);
168interface_list!(Header15, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10, I11 11, I12 12, I13 13, I14 14);
169
170#[repr(C)]
171struct ComWrapperInner<C: Class> {
172    header: Header<C>,
173    data: C,
174}
175
176/// A wrapper for constructing a reference-counted COM object from a Rust value.
177///
178/// `ComWrapper` represents an owning reference to the COM object, i.e. it will decrement the
179/// object's reference count when it goes out of scope.
180pub struct ComWrapper<C: Class> {
181    inner: Arc<ComWrapperInner<C>>,
182}
183
184impl<C: Class> Clone for ComWrapper<C> {
185    fn clone(&self) -> ComWrapper<C> {
186        ComWrapper {
187            inner: self.inner.clone(),
188        }
189    }
190}
191
192unsafe impl<C: Class> Send for ComWrapper<C> where C: Send + Sync {}
193unsafe impl<C: Class> Sync for ComWrapper<C> where C: Send + Sync {}
194
195impl<C: Class> Deref for ComWrapper<C> {
196    type Target = C;
197
198    #[inline]
199    fn deref(&self) -> &Self::Target {
200        &self.inner.data
201    }
202}
203
204impl<C: Class> Wrapper<C> for ComWrapper<C> {
205    #[inline]
206    unsafe fn data_from_header(ptr: *mut Header<C>) -> *mut C {
207        (ptr as *mut u8)
208            .offset(-offset_of!(ComWrapperInner<C>, header))
209            .offset(offset_of!(ComWrapperInner<C>, data)) as *mut C
210    }
211
212    #[inline]
213    unsafe fn header_from_data(ptr: *mut C) -> *mut Header<C> {
214        (ptr as *mut u8)
215            .offset(-offset_of!(ComWrapperInner<C>, data))
216            .offset(offset_of!(ComWrapperInner<C>, header)) as *mut Header<C>
217    }
218
219    #[inline]
220    unsafe fn add_ref(ptr: *mut C) -> usize {
221        let wrapper_ptr = (ptr as *mut u8).offset(-offset_of!(ComWrapperInner<C>, data))
222            as *mut ComWrapperInner<C>;
223
224        let arc = Arc::from_raw(wrapper_ptr);
225        let result = Arc::strong_count(&arc) + 1;
226        let _ = Arc::into_raw(arc);
227
228        Arc::increment_strong_count(wrapper_ptr);
229
230        result
231    }
232
233    #[inline]
234    unsafe fn release(ptr: *mut C) -> usize {
235        let wrapper_ptr = (ptr as *mut u8).offset(-offset_of!(ComWrapperInner<C>, data))
236            as *mut ComWrapperInner<C>;
237
238        let arc = Arc::from_raw(wrapper_ptr);
239        let result = Arc::strong_count(&arc) - 1;
240        let _ = Arc::into_raw(arc);
241
242        Arc::decrement_strong_count(wrapper_ptr);
243
244        result
245    }
246}
247
248impl<C: Class> ComWrapper<C> {
249    /// Allocates memory for an object and its header and places `data` into it.
250    #[inline]
251    pub fn new(data: C) -> ComWrapper<C>
252    where
253        C: 'static,
254        C::Interfaces: MakeHeader<C, Self>,
255    {
256        ComWrapper {
257            inner: Arc::new(ComWrapperInner {
258                header: C::Interfaces::HEADER,
259                data,
260            }),
261        }
262    }
263
264    /// If `I` is in `C`'s interface list, returns a [`ComRef<I>`] pointing to the object.
265    ///
266    /// Does not perform any reference counting operations.
267    #[inline]
268    pub fn as_com_ref<'a, I: Interface>(&'a self) -> Option<ComRef<'a, I>> {
269        if let Some(offset) = C::Interfaces::query(&I::IID) {
270            unsafe {
271                let wrapper_ptr = Arc::as_ptr(&self.inner) as *mut ComWrapperInner<C>;
272                let interface_ptr = (wrapper_ptr as *mut u8)
273                    .offset(offset_of!(ComWrapperInner<C>, header))
274                    .offset(offset) as *mut I;
275                Some(ComRef::from_raw_unchecked(interface_ptr))
276            }
277        } else {
278            None
279        }
280    }
281
282    /// If `I` is in `C`'s interface list, returns a [`ComPtr<I>`] pointing to the object.
283    ///
284    /// If a [`ComPtr`] is returned, the object's reference count will be incremented.
285    #[inline]
286    pub fn to_com_ptr<I: Interface>(&self) -> Option<ComPtr<I>> {
287        if let Some(offset) = C::Interfaces::query(&I::IID) {
288            unsafe {
289                let wrapper_ptr = Arc::into_raw(self.inner.clone()) as *mut ComWrapperInner<C>;
290                let interface_ptr = (wrapper_ptr as *mut u8)
291                    .offset(offset_of!(ComWrapperInner<C>, header))
292                    .offset(offset) as *mut I;
293                Some(ComPtr::from_raw_unchecked(interface_ptr))
294            }
295        } else {
296            None
297        }
298    }
299}