intercom_common/
methodinfo.rs

1use crate::prelude::*;
2use proc_macro2::Span;
3use std::rc::Rc;
4use syn::{spanned::Spanned, FnArg, PathArguments, Receiver, ReturnType, Signature, Type};
5
6use crate::ast_converters::*;
7use crate::returnhandlers::{get_return_handler, ReturnHandler};
8use crate::tyhandlers::{get_ty_handler, Direction, ModelTypeSystem, TypeContext, TypeHandler};
9use crate::utils;
10
11#[derive(Debug, PartialEq, Eq)]
12pub enum ComMethodInfoError
13{
14    TooFewArguments,
15    BadSelfArg,
16    BadArg(Box<FnArg>),
17    BadReturnType,
18}
19
20#[derive(Clone)]
21pub struct RustArg
22{
23    /// Name of the Rust argument.
24    pub name: Ident,
25
26    /// Rust type of the COM argument.
27    pub ty: Type,
28
29    /// Rust type span.
30    pub span: Span,
31
32    /// Type handler.
33    pub handler: Rc<TypeHandler>,
34}
35
36impl PartialEq for RustArg
37{
38    fn eq(&self, other: &RustArg) -> bool
39    {
40        self.name == other.name && self.ty == other.ty
41    }
42}
43
44impl ::std::fmt::Debug for RustArg
45{
46    fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result
47    {
48        write!(f, "{}: {:?}", self.name, self.ty)
49    }
50}
51
52impl RustArg
53{
54    pub fn new(name: Ident, ty: Type, span: Span, type_system: ModelTypeSystem) -> RustArg
55    {
56        let tyhandler = get_ty_handler(&ty, TypeContext::new(type_system));
57        RustArg {
58            name,
59            ty,
60            span,
61            handler: tyhandler,
62        }
63    }
64}
65
66pub struct ComArg
67{
68    /// Name of the argument.
69    pub name: Ident,
70
71    /// Rust type of the raw COM argument.
72    pub ty: Type,
73
74    /// Type handler.
75    pub handler: Rc<TypeHandler>,
76
77    // Rust span that sources this COM argument.
78    pub span: Span,
79
80    /// Argument direction. COM uses OUT params while Rust uses return values.
81    pub dir: Direction,
82}
83
84impl ComArg
85{
86    pub fn new(
87        name: Ident,
88        ty: Type,
89        span: Span,
90        dir: Direction,
91        type_system: ModelTypeSystem,
92    ) -> ComArg
93    {
94        let tyhandler = get_ty_handler(&ty, TypeContext::new(type_system));
95        ComArg {
96            name,
97            ty,
98            dir,
99            span,
100            handler: tyhandler,
101        }
102    }
103
104    pub fn from_rustarg(rustarg: RustArg, dir: Direction, type_system: ModelTypeSystem) -> ComArg
105    {
106        let tyhandler = get_ty_handler(&rustarg.ty, TypeContext::new(type_system));
107        ComArg {
108            name: rustarg.name,
109            ty: rustarg.ty,
110            dir,
111            span: rustarg.span,
112            handler: tyhandler,
113        }
114    }
115}
116
117impl PartialEq for ComArg
118{
119    fn eq(&self, other: &ComArg) -> bool
120    {
121        self.name == other.name && self.ty == other.ty && self.dir == other.dir
122    }
123}
124
125impl ::std::fmt::Debug for ComArg
126{
127    fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result
128    {
129        write!(f, "{}: {:?} {:?}", self.name, self.dir, self.ty)
130    }
131}
132
133#[derive(Debug, Clone)]
134pub struct ComMethodInfo
135{
136    /// The display name used in public places that do not require an unique name.
137    pub name: Ident,
138
139    /// True if the self parameter is not mutable.
140    pub is_const: bool,
141
142    /// Rust self argument.
143    pub rust_self_arg: Receiver,
144
145    /// Rust return type.
146    pub rust_return_ty: Type,
147
148    /// COM retval out parameter type, such as the value of Result<...>.
149    pub retval_type: Option<Type>,
150
151    /// COM return type, such as the error value of Result<...>.
152    pub return_type: Option<Type>,
153
154    /// Span for the method signature.
155    pub signature_span: Span,
156
157    /// Return value handler.
158    pub returnhandler: Rc<dyn ReturnHandler>,
159
160    /// Method arguments.
161    pub args: Vec<RustArg>,
162
163    /// True if the Rust method is unsafe.
164    pub is_unsafe: bool,
165
166    /// Type system.
167    pub type_system: ModelTypeSystem,
168
169    /// Is the method infallible.
170    pub infallible: bool,
171}
172
173impl PartialEq for ComMethodInfo
174{
175    fn eq(&self, other: &ComMethodInfo) -> bool
176    {
177        self.name == other.name
178            && self.is_const == other.is_const
179            && self.rust_self_arg == other.rust_self_arg
180            && self.rust_return_ty == other.rust_return_ty
181            && self.retval_type == other.retval_type
182            && self.return_type == other.return_type
183            && self.args == other.args
184    }
185}
186
187impl ComMethodInfo
188{
189    /// Constructs new COM method info from a Rust method signature.
190    pub fn new(
191        decl: &Signature,
192        type_system: ModelTypeSystem,
193    ) -> Result<ComMethodInfo, ComMethodInfoError>
194    {
195        // Process all the function arguments.
196        // In Rust this includes the 'self' argument and the actual function
197        // arguments. For COM the self is implicit so we'll handle it
198        // separately.
199        let n = decl.ident.clone();
200        let unsafety = decl.unsafety.is_some();
201        let mut iter = decl.inputs.iter();
202        let rust_self_arg = iter.next().ok_or(ComMethodInfoError::TooFewArguments)?;
203
204        let (is_const, rust_self_arg) = match *rust_self_arg {
205            FnArg::Receiver(ref self_arg) => (self_arg.mutability.is_none(), self_arg.clone()),
206            _ => return Err(ComMethodInfoError::BadSelfArg),
207        };
208
209        // Process other arguments.
210        let args = iter
211            .map(|arg| {
212                let ty = arg
213                    .get_ty()
214                    .map_err(|_| ComMethodInfoError::BadArg(Box::new(arg.clone())))?;
215                let ident = arg
216                    .get_ident()
217                    .map_err(|_| ComMethodInfoError::BadArg(Box::new(arg.clone())))?;
218
219                Ok(RustArg::new(ident, ty, arg.span(), type_system))
220            })
221            .collect::<Result<_, _>>()?;
222
223        // Get the output.
224        let rust_return_ty = match decl.output {
225            ReturnType::Default => syn::parse2(quote_spanned!(decl.span() => ())).unwrap(),
226            ReturnType::Type(_, ref ty) => (**ty).clone(),
227        };
228
229        // Resolve the return type and retval type.
230        let (retval_type, return_type, retval_span) = if utils::is_unit(&rust_return_ty) {
231            (None, None, decl.span())
232        } else if let Some((retval, ret)) = try_parse_result(&rust_return_ty) {
233            (Some(retval), Some(ret), decl.output.span())
234        } else {
235            (None, Some(rust_return_ty.clone()), decl.output.span())
236        };
237
238        let returnhandler =
239            get_return_handler(&retval_type, &return_type, retval_span, type_system)
240                .or(Err(ComMethodInfoError::BadReturnType))?;
241        Ok(ComMethodInfo {
242            name: n,
243            infallible: returnhandler.is_infallible(),
244            returnhandler: returnhandler.into(),
245            signature_span: decl.span(),
246            is_const,
247            rust_self_arg,
248            rust_return_ty,
249            retval_type,
250            return_type,
251            args,
252            is_unsafe: unsafety,
253            type_system,
254        })
255    }
256
257    pub fn raw_com_args(&self) -> Vec<ComArg>
258    {
259        let in_args = self
260            .args
261            .iter()
262            .map(|ca| ComArg::from_rustarg(ca.clone(), Direction::In, self.type_system));
263        let out_args = self.returnhandler.com_out_args();
264
265        in_args.chain(out_args).collect()
266    }
267
268    pub fn get_parameters_tokenstream(&self) -> TokenStream
269    {
270        let in_out_args = self.raw_com_args().into_iter().map(|com_arg| {
271            let name = &com_arg.name;
272            let com_ty = &com_arg.handler.com_ty(com_arg.span);
273            let dir = match com_arg.dir {
274                Direction::In => quote!(),
275                Direction::Out | Direction::Retval => quote_spanned!(com_arg.span => *mut ),
276            };
277            quote_spanned!(com_arg.span => #name : #dir #com_ty )
278        });
279        let self_arg = quote_spanned!(self.rust_self_arg.span()=>
280            self_vtable: intercom::raw::RawComPtr);
281        let args = std::iter::once(self_arg).chain(in_out_args);
282        quote!(#(#args),*)
283    }
284}
285
286fn try_parse_result(ty: &Type) -> Option<(Type, Type)>
287{
288    let path = match *ty {
289        Type::Path(ref p) => &p.path,
290        _ => return None,
291    };
292
293    // Ensure the type name contains 'Result'. We don't really have
294    // good ways to ensure it is an actual Result type but at least we can
295    // use this to discount things like Option<>, etc.
296    let last_segment = path.segments.last()?;
297    if !last_segment.ident.to_string().contains("Result") {
298        return None;
299    }
300
301    // Ensure the Result has angle bracket arguments.
302    if let PathArguments::AngleBracketed(ref data) = last_segment.arguments {
303        // The returned types depend on how many arguments the Result has.
304        return Some(match data.args.len() {
305            1 => (data.args[0].get_ty().ok()?, hresult_ty(ty.span())),
306            2 => (data.args[0].get_ty().ok()?, data.args[1].get_ty().ok()?),
307            _ => return None,
308        });
309    }
310
311    // We couldn't find a valid type. Return nothing.
312    None
313}
314
315fn hresult_ty(span: Span) -> Type
316{
317    syn::parse2(quote_spanned!(span => intercom::raw::HRESULT)).unwrap()
318}
319
320#[cfg(test)]
321mod tests
322{
323
324    use syn::Item;
325
326    use super::*;
327    use crate::tyhandlers::ModelTypeSystem::*;
328
329    #[test]
330    fn no_args_or_return_value()
331    {
332        let info = test_info("fn foo( &self ) {}", Automation);
333
334        assert_eq!(info.is_const, true);
335        assert_eq!(info.name, "foo");
336        assert_eq!(info.args.len(), 0);
337        assert_eq!(info.retval_type.is_none(), true);
338        assert_eq!(info.return_type.is_none(), true);
339    }
340
341    #[test]
342    fn basic_return_value()
343    {
344        let info = test_info("fn foo( &self ) -> bool {}", Raw);
345
346        assert_eq!(info.is_const, true);
347        assert_eq!(info.name, "foo");
348        assert_eq!(info.args.len(), 0);
349        assert_eq!(info.retval_type.is_none(), true);
350        assert_eq!(info.return_type, Some(parse_quote!(bool)));
351    }
352
353    #[test]
354    fn result_return_value()
355    {
356        let info = test_info("fn foo( &self ) -> Result<String, f32> {}", Automation);
357
358        assert_eq!(info.is_const, true);
359        assert_eq!(info.name, "foo");
360        assert_eq!(info.args.len(), 0);
361        assert_eq!(info.retval_type, Some(parse_quote!(String)));
362        assert_eq!(info.return_type, Some(parse_quote!(f32)));
363    }
364
365    #[test]
366    fn comresult_return_value()
367    {
368        let info = test_info("fn foo( &self ) -> ComResult<String> {}", Automation);
369
370        assert_eq!(info.is_const, true);
371        assert_eq!(info.name, "foo");
372        assert_eq!(info.args.len(), 0);
373        assert_eq!(info.retval_type, Some(parse_quote!(String)));
374        assert_eq!(info.return_type, Some(parse_quote!(intercom::raw::HRESULT)));
375    }
376
377    #[test]
378    fn basic_arguments()
379    {
380        let info = test_info("fn foo( &self, a : u32, b : f32 ) {}", Raw);
381
382        assert_eq!(info.is_const, true);
383        assert_eq!(info.name, "foo");
384        assert_eq!(info.retval_type.is_none(), true);
385        assert_eq!(info.return_type.is_none(), true);
386
387        assert_eq!(info.args.len(), 2);
388
389        assert_eq!(info.args[0].name, Ident::new("a", Span::call_site()));
390        assert_eq!(info.args[0].ty, parse_quote!(u32));
391
392        assert_eq!(info.args[1].name, Ident::new("b", Span::call_site()));
393        assert_eq!(info.args[1].ty, parse_quote!(f32));
394    }
395
396    fn test_info(code: &str, ts: ModelTypeSystem) -> ComMethodInfo
397    {
398        let item = syn::parse_str(code).unwrap();
399        let sig = match item {
400            Item::Fn(ref f) => &f.sig,
401            _ => panic!("Code isn't function"),
402        };
403        ComMethodInfo::new(sig, ts).unwrap()
404    }
405}