flapigen/python/
mod.rs

1use crate::typemap::ty::RustType;
2use crate::{
3    error::Result,
4    extension::{ClassExtHandlers, MethodExtHandlers},
5    source_registry::SourceId,
6    typemap::{
7        ast::{ForeignTypeName, GenericTypeConv},
8        ty::ForeignTypeS,
9        TypeConvCode,
10    },
11    types::{
12        ForeignClassInfo, ForeignEnumInfo, ForeignInterface, ForeignMethod, ItemToExpand,
13        MethodVariant, SelfTypeVariant,
14    },
15    DiagnosticError, LanguageGenerator, PythonConfig, SourceCode, TypeMap,
16};
17use crate::{extension::ExtHandlers, typemap::ast};
18use heck::ToSnakeCase;
19use proc_macro2::{Span, TokenStream};
20use quote::quote;
21use quote::ToTokens;
22use std::ops::Deref;
23use syn::parse_quote;
24use syn::{Ident, Type};
25
26const ENUM_TRAIT_NAME: &str = "SwigForeignEnum";
27
28impl LanguageGenerator for PythonConfig {
29    fn expand_items(
30        &self,
31        conv_map: &mut TypeMap,
32        _pointer_target_width: usize,
33        _code: &[SourceCode],
34        items: Vec<ItemToExpand>,
35        _remove_not_generated_files: bool,
36        ext_handlers: ExtHandlers,
37    ) -> Result<Vec<TokenStream>> {
38        for item in &items {
39            if let ItemToExpand::Class(ref fclass) = item {
40                self.register_class(conv_map, fclass)?;
41            }
42        }
43        let mut code = Vec::with_capacity(items.len());
44        let mut module_initialization = Vec::with_capacity(items.len());
45        for item in items {
46            let (class_code, initialization) = match item {
47                ItemToExpand::Class(fclass) => self.generate_class(
48                    conv_map,
49                    &fclass,
50                    ext_handlers.class_ext_handlers,
51                    ext_handlers.method_ext_handlers,
52                )?,
53                ItemToExpand::Enum(fenum) => self.generate_enum(conv_map, &fenum)?,
54                ItemToExpand::Interface(finterface) => {
55                    self.generate_interface(conv_map, &finterface)?
56                }
57            };
58            code.push(class_code);
59            module_initialization.push(initialization);
60        }
61        code.push(self.generate_module_initialization(&module_initialization)?);
62        Ok(code)
63    }
64}
65
66impl PythonConfig {
67    fn register_class(&self, conv_map: &mut TypeMap, class: &ForeignClassInfo) -> Result<()> {
68        if let Some(ref self_desc) = class.self_desc {
69            conv_map.find_or_alloc_rust_type(&self_desc.self_type, class.src_id);
70        }
71        Ok(())
72    }
73
74    /// Generate class code and module initialization code for this class.
75    fn generate_class(
76        &self,
77        conv_map: &mut TypeMap,
78        class: &ForeignClassInfo,
79        class_ext_handlers: &ClassExtHandlers,
80        method_ext_handlers: &MethodExtHandlers,
81    ) -> Result<(TokenStream, TokenStream)> {
82        if !class_ext_handlers.is_empty() || !method_ext_handlers.is_empty() {
83            return Err(DiagnosticError::new(
84                class.src_id,
85                class.span(),
86                format!(
87                    "class {}: has attributes, this is not supported for python",
88                    class.name
89                ),
90            ));
91        }
92
93        let class_name = &class.name;
94        let wrapper_mod_name =
95            parse::<Ident>(&py_wrapper_mod_name(&class_name.to_string()), class.src_id)?;
96        let (rust_instance_field, rust_instance_getter) =
97            generate_rust_instance_field_and_methods(class, conv_map)?;
98        let methods_code = class
99            .methods
100            .iter()
101            .map(|m| generate_method_code(class, m, conv_map))
102            .collect::<Result<Vec<_>>>()?;
103        let mut doc_comments = class.doc_comments.clone();
104        if let Some(constructor) = class
105            .methods
106            .iter()
107            .find(|m| m.variant == MethodVariant::Constructor)
108        {
109            // Python API doesn't allow to add docstring to the special methods (slots),
110            // including __new__ and __init__.
111            // The convention is, to document the constructor in class's docstring.
112            doc_comments.push("".to_owned());
113            doc_comments.extend_from_slice(&constructor.doc_comments);
114        }
115        let docstring = doc_comments.as_slice().join("\n");
116        let class_code = quote! {
117            mod #wrapper_mod_name {
118                use super::*;
119                #[allow(unused)]
120                cpython::py_class!(pub class #class_name |py| {
121                    static __doc__  = #docstring;
122
123                    #rust_instance_field
124
125                    #( #methods_code )*
126                });
127
128                #rust_instance_getter
129            }
130        };
131
132        let module_initialization_code = quote! {
133            {
134                m.add_class::<#wrapper_mod_name::#class_name>(py)?;
135            }
136        };
137        Ok((class_code, module_initialization_code))
138    }
139
140    fn generate_enum(
141        &self,
142        conv_map: &mut TypeMap,
143        enum_info: &ForeignEnumInfo,
144    ) -> Result<(TokenStream, TokenStream)> {
145        let enum_name = &enum_info.name;
146        let wrapper_mod_name = parse::<Ident>(
147            &py_wrapper_mod_name(&enum_name.to_string()),
148            enum_info.src_id,
149        )?;
150        let foreign_variants = enum_info.items.iter().map(|item| &item.name);
151        let rust_variants = enum_info
152            .items
153            .iter()
154            .map(|item| &item.rust_name)
155            .collect::<Vec<_>>();
156        let rust_variants_ref_1 = &rust_variants;
157        let rust_variants_ref_2 = &rust_variants;
158        let enum_name_str = enum_name.to_string();
159        let docstring = enum_info.doc_comments.as_slice().join("\n");
160        let class_code = quote! {
161            mod #wrapper_mod_name {
162                cpython::py_class!(pub class #enum_name |py| {
163                    static __doc__  = #docstring;
164                    #( static #foreign_variants = super::#rust_variants_ref_1 as u32; )*
165                });
166
167                pub fn from_u32(py: cpython::Python, value: u32) -> cpython::PyResult<super::#enum_name> {
168                    #( if value == super::#rust_variants_ref_1 as u32 { return Ok(super::#rust_variants_ref_2); } )*
169                    Err(cpython::PyErr::new::<cpython::exc::ValueError, _>(
170                        py, format!("{} is not valid value for enum {}", value, #enum_name_str)
171                    ))
172                }
173            }
174        };
175        let enum_ti: Type =
176            ast::parse_ty_with_given_span(&enum_name.to_string(), enum_info.name.span())
177                .map_err(|err| DiagnosticError::from_syn_err(enum_info.src_id, err))?;
178        conv_map.find_or_alloc_rust_type_that_implements(
179            &enum_ti,
180            &[ENUM_TRAIT_NAME],
181            enum_info.src_id,
182        );
183        let enum_ftype = ForeignTypeS {
184            name: ForeignTypeName::new(
185                enum_info.name.to_string(),
186                (enum_info.src_id, enum_info.name.span()),
187            ),
188            provided_by_module: vec![],
189            into_from_rust: None,
190            from_into_rust: None,
191        };
192        conv_map.alloc_foreign_type(enum_ftype)?;
193
194        let module_initialization_code = quote! {
195            {
196                m.add_class::<#wrapper_mod_name::#enum_name>(py)?;
197            }
198        };
199        Ok((class_code, module_initialization_code))
200    }
201
202    fn generate_interface(
203        &self,
204        _conv_map: &mut TypeMap,
205        _interface: &ForeignInterface,
206    ) -> Result<(TokenStream, TokenStream)> {
207        unimplemented!("Interfaces are currently unsupported for Python.")
208    }
209
210    fn generate_module_initialization(
211        &self,
212        module_initialization_code: &[TokenStream],
213    ) -> Result<TokenStream> {
214        let module_name = parse::<syn::Ident>(&self.module_name, SourceId::none())?;
215        let module_init =
216            parse::<syn::Ident>(&format!("init{}", &self.module_name), SourceId::none())?;
217        let module_py_init =
218            parse::<syn::Ident>(&format!("PyInit_{}", &self.module_name), SourceId::none())?;
219        let registration_code = quote! {
220            mod py_error {
221                cpython::py_exception!(#module_name, Error);
222            }
223
224            cpython::py_module_initializer!(#module_name, #module_init, #module_py_init, |py, m| {
225                m.add(py, "Error", py_error::Error::type_object(py))?;
226                #(#module_initialization_code)*
227                Ok(())
228            });
229        };
230        Ok(registration_code)
231    }
232}
233
234fn generate_rust_instance_field_and_methods(
235    class: &ForeignClassInfo,
236    conv_map: &mut TypeMap,
237) -> Result<(TokenStream, TokenStream)> {
238    if let Some(ref self_desc) = class.self_desc {
239        let rust_self_type = &self_desc.self_type;
240        let storage_smart_pointer = storage_smart_pointer_for_class(class, conv_map)?;
241        if storage_smart_pointer.inner_ty.normalized_name
242            != conv_map
243                .find_or_alloc_rust_type(rust_self_type, class.src_id)
244                .normalized_name
245        {
246            return Err(DiagnosticError::new(
247                class.src_id,
248                class.span(),
249                "Self type and (inner) type returned from constructor doesn't match",
250            ));
251        }
252        let storage_type = wrap_type_for_class(rust_self_type, storage_smart_pointer.pointer_type);
253        let storage_type_ref = &storage_type;
254        let class_name = &class.name;
255        Ok((
256            quote! {
257                data rust_instance: #storage_type_ref;
258            },
259            // For some reason, rust-cpython generates private `rust_instance` getter method.
260            // As a workaround, we add public function in the same module, that gets `rust_instance`.
261            // The same goes for `create_instance
262            quote! {
263                pub fn rust_instance<'a>(class: &'a #class_name, py: cpython::Python<'a>) -> &'a #storage_type_ref {
264                    class.rust_instance(py)
265                }
266
267                pub fn create_instance(py: cpython::Python, instance: #storage_type_ref) -> cpython::PyResult<#class_name> {
268                    #class_name::create_instance(py, instance)
269                }
270            },
271        ))
272    } else if !has_any_methods(class) {
273        Ok((TokenStream::new(), TokenStream::new()))
274    } else {
275        Err(DiagnosticError::new(
276            class.src_id,
277            class.span(),
278            format!(
279                "Class {} has non-static methods, but no self_type",
280                class.name
281            ),
282        ))
283    }
284}
285
286fn generate_method_code(
287    class: &ForeignClassInfo,
288    method: &ForeignMethod,
289    conv_map: &mut TypeMap,
290) -> Result<TokenStream> {
291    if method.is_dummy_constructor() {
292        return Ok(TokenStream::new());
293    }
294    let method_name = method_name(method, class.src_id)?;
295    let method_rust_path = &method.rust_id;
296    let skip_args_count = if let MethodVariant::Method(_) = method.variant {
297        1
298    } else {
299        0
300    };
301    let (args_list, mut args_conversions): (Vec<_>, Vec<_>) = method
302        .fn_decl
303        .inputs
304        .iter()
305        .skip(skip_args_count)
306        .map(|a| {
307            let named_arg = a
308                .as_named_arg()
309                .map_err(|err| DiagnosticError::from_syn_err(class.src_id, err))?;
310            let (arg_type, arg_conversion) = generate_conversion_for_argument(
311                &conv_map.find_or_alloc_rust_type(&named_arg.ty, class.src_id),
312                method.span(),
313                class.src_id,
314                conv_map,
315                &named_arg.name,
316                true,
317            )?;
318            Ok(((&named_arg.name, arg_type), arg_conversion))
319        })
320        .collect::<Result<Vec<_>>>()?
321        .into_iter()
322        .unzip();
323    if let Some(self_conversion) = self_type_conversion(class, method, conv_map)? {
324        args_conversions.insert(0, self_conversion);
325    }
326    let mut args_list_tokens = args_list
327        .into_iter()
328        .map(|(name, t)| {
329            parse(
330                &format!("{}: {}", name, t.into_token_stream()),
331                class.src_id,
332            )
333        })
334        .collect::<std::result::Result<Vec<TokenStream>, _>>()?;
335    if let MethodVariant::Method(_) = method.variant {
336        args_list_tokens.insert(0, parse("&self", class.src_id)?);
337    } else if method.variant == MethodVariant::Constructor {
338        args_list_tokens.insert(0, parse("_cls", class.src_id)?);
339    }
340    let attribute = if method.variant == MethodVariant::StaticMethod {
341        parse("@staticmethod", class.src_id)?
342    } else {
343        TokenStream::new()
344    };
345    let (return_type, rust_call_with_return_conversion) = generate_conversion_for_return(
346        &conv_map
347            .find_or_alloc_rust_type(&extract_return_type(&method.fn_decl.output), class.src_id),
348        method.span(),
349        class.src_id,
350        conv_map,
351        quote! {
352            #method_rust_path(#( #args_conversions ),*)
353        },
354    )?;
355    let docstring = if !method_name.to_string().starts_with("__") {
356        parse::<TokenStream>(
357            &("/// ".to_owned() + &method.doc_comments.as_slice().join("\n/// ")),
358            class.src_id,
359        )?
360    } else {
361        // Python API doesn't support defining docstrings on the special methods (slots)
362        quote! {}
363    };
364    Ok(quote! {
365        #docstring #attribute def #method_name(
366            #( #args_list_tokens ),*
367        ) -> cpython::PyResult<#return_type> {
368            #[allow(unused)]
369            use super::*;
370            Ok(#rust_call_with_return_conversion)
371        }
372    })
373}
374
375fn standard_method_name(method: &ForeignMethod, src_id: SourceId) -> Result<syn::Ident> {
376    Ok(method
377        .name_alias
378        .as_ref()
379        .or_else(|| method.rust_id.segments.last().map(|p| &p.ident))
380        .ok_or_else(|| DiagnosticError::new(src_id, method.span(), "Method has no name"))?
381        .clone())
382}
383
384fn method_name(method: &ForeignMethod, src_id: SourceId) -> Result<syn::Ident> {
385    if method.variant == MethodVariant::Constructor {
386        parse("__new__", src_id)
387    } else {
388        let name = standard_method_name(method, src_id)?;
389        let name_str = name.to_string();
390        match name_str.as_ref() {
391            "to_string" => parse("__repr__", src_id),
392            _ => Ok(name),
393        }
394    }
395}
396
397fn self_type_conversion(
398    class: &ForeignClassInfo,
399    method: &ForeignMethod,
400    conv_map: &mut TypeMap,
401) -> Result<Option<TokenStream>> {
402    if let MethodVariant::Method(self_variant) = method.variant {
403        let self_type = &class
404            .self_desc
405            .as_ref()
406            .ok_or_else(|| {
407                DiagnosticError::new(
408                    class.src_id,
409                    class.span(),
410                    "Class have non-static methods, but no self_type",
411                )
412            })?
413            .self_type;
414        let self_type_ty = match self_variant {
415            SelfTypeVariant::Rptr => parse_type! {&#self_type},
416            SelfTypeVariant::RptrMut => parse_type! {&mut #self_type},
417            _ => parse_type! {#self_type},
418        };
419        Ok(Some(
420            generate_conversion_for_argument(
421                &conv_map.find_or_alloc_rust_type(&self_type_ty, class.src_id),
422                method.span(),
423                class.src_id,
424                conv_map,
425                "self",
426                true,
427            )?
428            .1,
429        ))
430    } else {
431        Ok(None)
432    }
433}
434
435fn has_any_methods(class: &ForeignClassInfo) -> bool {
436    class
437        .methods
438        .iter()
439        .any(|m| matches!(m.variant, MethodVariant::Method(_)))
440}
441
442fn generate_conversion_for_argument(
443    rust_type: &RustType,
444    method_span: Span,
445    src_id: SourceId,
446    conv_map: &mut TypeMap,
447    arg_name: &str,
448    reference_allowed: bool,
449) -> Result<(Type, TokenStream)> {
450    let arg_name_ident: TokenStream = parse(arg_name, src_id)?;
451    if is_cpython_supported_type(rust_type) {
452        Ok((rust_type.ty.clone(), arg_name_ident))
453    } else if let Some((ty, conversion)) = if_exported_class_generate_argument_conversion(
454        rust_type,
455        conv_map,
456        &arg_name_ident,
457        method_span,
458        src_id,
459        reference_allowed,
460    )? {
461        Ok((ty, conversion))
462    } else if rust_type
463        .implements
464        .contains_path(&parse(ENUM_TRAIT_NAME, src_id)?)
465    {
466        let enum_py_mod: Ident = parse(&py_wrapper_mod_name(&rust_type.normalized_name), src_id)?;
467        Ok((
468            parse_type!(u32),
469            quote! {
470                super::#enum_py_mod::from_u32(py, #arg_name_ident)?
471            },
472        ))
473    } else if let Some(inner) = ast::if_option_return_some_type(rust_type) {
474        let (inner_py_type, inner_conversion) = generate_conversion_for_argument(
475            &conv_map.find_or_alloc_rust_type(&inner, src_id),
476            method_span,
477            src_id,
478            conv_map,
479            "inner",
480            false,
481        )?;
482        Ok((
483            parse_type!(Option<#inner_py_type>),
484            quote! {
485                match #arg_name_ident {
486                    Some(inner) => Some(#inner_conversion),
487                    None => None,
488                }
489            },
490        ))
491    } else if let Some(inner) = if_type_slice_return_elem_type(&rust_type.ty, false) {
492        let (inner_py_type, inner_conversion) = generate_conversion_for_argument(
493            &conv_map.find_or_alloc_rust_type(inner, src_id),
494            method_span,
495            src_id,
496            conv_map,
497            "inner",
498            false,
499        )?;
500        Ok((
501            parse_type!(Vec<#inner_py_type>),
502            quote! {
503                &#arg_name_ident.into_iter().map(|inner| Ok(#inner_conversion)).collect::<cpython::PyResult<Vec<_>>>()?
504            },
505        ))
506    } else if let Some(inner) = if_vec_return_elem_type(rust_type) {
507        let (inner_py_type, inner_conversion) = generate_conversion_for_argument(
508            &conv_map.find_or_alloc_rust_type(&inner, src_id),
509            method_span,
510            src_id,
511            conv_map,
512            "inner",
513            false,
514        )?;
515        Ok((
516            parse_type!(Vec<#inner_py_type>),
517            quote! {
518                #arg_name_ident.into_iter().map(|inner| Ok(#inner_conversion)).collect::<cpython::PyResult<Vec<_>>>()?
519            },
520        ))
521    } else if let Type::Reference(ref inner) = rust_type.ty {
522        if inner.mutability.is_some() {
523            return Err(DiagnosticError::new(
524                src_id,
525                method_span,
526                "mutable reference is only supported for exported class types",
527            ));
528        }
529        let (inner_py_type, inner_conversion) = generate_conversion_for_argument(
530            &conv_map.find_or_alloc_rust_type(inner.elem.deref(), src_id),
531            method_span,
532            src_id,
533            conv_map,
534            arg_name,
535            false,
536        )?;
537        Ok((
538            parse_type!(#inner_py_type),
539            quote! {
540                &#inner_conversion
541            },
542        ))
543    } else {
544        Err(DiagnosticError::new(
545            src_id,
546            method_span,
547            format!("Unsupported argument type: {rust_type}"),
548        ))
549    }
550}
551
552fn generate_conversion_for_return(
553    rust_type: &RustType,
554    method_span: Span,
555    src_id: SourceId,
556    conv_map: &mut TypeMap,
557    rust_call: TokenStream,
558) -> Result<(Type, TokenStream)> {
559    if rust_type.ty == parse_type! { () } {
560        Ok((
561            parse_type!(cpython::PyObject),
562            quote! {
563                {#rust_call; py.None()}
564            },
565        ))
566    } else if is_cpython_supported_type(rust_type) {
567        Ok((rust_type.ty.clone(), rust_call))
568    } else if let Some((ty, conversion)) = if_exported_class_generate_return_conversion(
569        rust_type,
570        conv_map,
571        &rust_call,
572        method_span,
573        src_id,
574    )? {
575        Ok((ty, conversion))
576    } else if rust_type
577        .implements
578        .contains_path(&parse(ENUM_TRAIT_NAME, src_id)?)
579    {
580        Ok((
581            parse_type!(u32),
582            quote! {
583                #rust_call as u32
584            },
585        ))
586    } else if let Some(inner) = ast::if_option_return_some_type(rust_type) {
587        let (inner_py_type, inner_conversion) = generate_conversion_for_return(
588            &conv_map.find_or_alloc_rust_type(&inner, src_id),
589            method_span,
590            src_id,
591            conv_map,
592            quote! {inner},
593        )?;
594        Ok((
595            parse_type!(Option<#inner_py_type>),
596            quote! {
597                match #rust_call {
598                    Some(inner) => Some(#inner_conversion),
599                    None => None
600                }
601            },
602        ))
603    } else if let Some(inner) = if_type_slice_return_elem_type(&rust_type.ty, false) {
604        let (inner_py_type, inner_conversion) = generate_conversion_for_return(
605            &conv_map.find_or_alloc_rust_type(inner, src_id),
606            method_span,
607            src_id,
608            conv_map,
609            quote! {inner},
610        )?;
611        Ok((
612            parse_type!(Vec<#inner_py_type>),
613            quote! {
614                #rust_call.iter().cloned().map(|inner| Ok(#inner_conversion)).collect::<cpython::PyResult<Vec<_>>>()?
615            },
616        ))
617    } else if let Some(inner) = if_vec_return_elem_type(rust_type) {
618        let (inner_py_type, inner_conversion) = generate_conversion_for_return(
619            &conv_map.find_or_alloc_rust_type(&inner, src_id),
620            method_span,
621            src_id,
622            conv_map,
623            quote! {inner},
624        )?;
625        Ok((
626            parse_type!(Vec<#inner_py_type>),
627            quote! {
628                #rust_call.into_iter().map(|inner| Ok(#inner_conversion)).collect::<cpython::PyResult<Vec<_>>>()?
629            },
630        ))
631    } else if let Some((inner_ok, _inner_err)) = ast::if_result_return_ok_err_types(rust_type) {
632        let (inner_py_type, inner_conversion) = generate_conversion_for_return(
633            &conv_map.find_or_alloc_rust_type(&inner_ok, src_id),
634            method_span,
635            src_id,
636            conv_map,
637            quote! {ok_inner},
638        )?;
639        Ok((
640            parse_type!(#inner_py_type),
641            quote! {
642                match #rust_call {
643                    Ok(ok_inner) => #inner_conversion,
644                    Err(err_inner) => return Err(cpython::PyErr::new::<super::py_error::Error, _>(
645                        py,
646                        swig_collect_error_message(&err_inner)
647                    )),
648                }
649            },
650        ))
651    } else if let Type::Reference(ref inner) = rust_type.ty {
652        generate_conversion_for_return(
653            &conv_map.find_or_alloc_rust_type(inner.elem.deref(), src_id),
654            method_span,
655            src_id,
656            conv_map,
657            quote! {(#rust_call).clone()},
658        )
659    } else if let Type::Tuple(ref tuple) = rust_type.ty {
660        let (types, conversions): (Vec<_>, Vec<_>) = tuple
661            .elems
662            .iter()
663            .enumerate()
664            .map(|(i, ty)| {
665                let i_ident = syn::Index {
666                    index: i as u32,
667                    span: Span::call_site(),
668                };
669                generate_conversion_for_return(
670                    &conv_map.find_or_alloc_rust_type(ty, src_id),
671                    method_span,
672                    src_id,
673                    conv_map,
674                    quote! {tuple.#i_ident},
675                )
676            })
677            .collect::<Result<Vec<_>>>()?
678            .into_iter()
679            .unzip();
680        Ok((
681            parse_type! {( #( #types, )* )},
682            quote! {
683                {
684                    let tuple = #rust_call;
685                    (
686                        #( #conversions, )*
687                    )
688                }
689            },
690        ))
691    } else {
692        Err(DiagnosticError::new(
693            src_id,
694            method_span,
695            format!("Unsupported return type: {rust_type}"),
696        ))
697    }
698}
699
700fn is_cpython_supported_type(rust_type: &RustType) -> bool {
701    let primitive_types = [
702        "bool", "i8", "i16", "i32", "i64", "isize", "u8", "u16", "u32", "u64", "usize", "f32",
703        "f64", "String", "& str",
704    ];
705    primitive_types.contains(&rust_type.normalized_name.as_str())
706}
707
708fn extract_return_type(syn_return_type: &syn::ReturnType) -> Type {
709    match syn_return_type {
710        syn::ReturnType::Default => {
711            parse_type! { () }
712        }
713        syn::ReturnType::Type(_, ref ty) => ty.deref().clone(),
714    }
715}
716
717fn py_wrapper_mod_name(type_name: &str) -> String {
718    format!("py_{}", type_name.to_snake_case())
719}
720
721// `rust_cpython` provides access only to non-mutable reference of the wrapped Rust object.
722// What's more `rust_cpython` requires the object to be `Send + 'static`, because Python VM
723// can move it between threads without any control from Rust.
724// As a result, we need to wrap the object in `Mutex`, to provide mutability.
725// By default, `Mutex` is used.
726// This could be overriden by the smart pointer returned from the constructor.
727// Following smart pointers are supported:
728// - `Arc<Mutex<T>>`: wrapped Rust object is mutable and can be shared between Rust and Python,
729// - `Mutex<T>`: wrapped Rust object is mutable, but only Python owns it,
730// - `Arc<T>`: wrapped Rust object is immutable and can be shared between Rust and Python,
731// - `Box<T>`: wrapped Rust object is immutable and only Python owns it,
732// Note, that `Rc` is NOT supported. This is because it is not `Send`.
733// `RefCell` theoretically could be supported, but generated Python API would be thread unsafe
734// (it is `Send`, but no `Sync`), so it is intentionally omitted.
735#[derive(Debug, Clone, Copy, PartialEq)]
736enum PointerType {
737    ArcMutex,
738    Mutex,
739    Arc,
740    Box,
741    None,
742}
743
744impl PointerType {
745    fn is_shared(self) -> bool {
746        self == PointerType::Arc || self == PointerType::ArcMutex
747    }
748}
749
750struct SmartPointerInfo {
751    pointer_type: PointerType,
752    inner_ty: RustType,
753}
754
755impl SmartPointerInfo {
756    fn new(pointer_type: PointerType, inner_ty: RustType) -> SmartPointerInfo {
757        SmartPointerInfo {
758            pointer_type,
759            inner_ty,
760        }
761    }
762}
763
764fn smart_pointer(
765    rust_type: &RustType,
766    conv_map: &mut TypeMap,
767    src_id: SourceId,
768) -> SmartPointerInfo {
769    if let Some(inner_ty) = ast::check_if_smart_pointer_return_inner_type(rust_type, "Arc") {
770        let rust_inner_ty = conv_map.find_or_alloc_rust_type(&inner_ty, src_id);
771        if let Some(inner_inner_ty) =
772            ast::check_if_smart_pointer_return_inner_type(&rust_inner_ty, "Mutex")
773        {
774            SmartPointerInfo::new(
775                PointerType::ArcMutex,
776                conv_map.find_or_alloc_rust_type(&inner_inner_ty, src_id),
777            )
778        } else {
779            SmartPointerInfo::new(PointerType::Arc, rust_inner_ty)
780        }
781    } else if let Some(inner_ty) = ast::check_if_smart_pointer_return_inner_type(rust_type, "Mutex")
782    {
783        SmartPointerInfo::new(
784            PointerType::Mutex,
785            conv_map.find_or_alloc_rust_type(&inner_ty, src_id),
786        )
787    } else if let Some(inner_ty) = ast::check_if_smart_pointer_return_inner_type(rust_type, "Box") {
788        SmartPointerInfo::new(
789            PointerType::Box,
790            conv_map.find_or_alloc_rust_type(&inner_ty, src_id),
791        )
792    } else {
793        SmartPointerInfo::new(PointerType::None, rust_type.clone())
794    }
795}
796
797#[derive(Debug, Clone, Copy, PartialEq)]
798enum Reference {
799    None,
800    Ref,
801    MutRef,
802}
803
804fn get_reference_info_and_inner_type(
805    rust_type: &RustType,
806    conv_map: &mut TypeMap,
807    src_id: SourceId,
808) -> (Reference, RustType) {
809    if let Type::Reference(ref reference) = rust_type.ty {
810        if reference.mutability.is_some() {
811            (
812                Reference::MutRef,
813                conv_map.find_or_alloc_rust_type(&reference.elem, src_id),
814            )
815        } else {
816            (
817                Reference::Ref,
818                conv_map.find_or_alloc_rust_type(&reference.elem, src_id),
819            )
820        }
821    } else {
822        (Reference::None, rust_type.clone())
823    }
824}
825
826fn if_exported_class_generate_return_conversion(
827    rust_type: &RustType,
828    conv_map: &mut TypeMap,
829    rust_call: &TokenStream,
830    method_span: Span,
831    src_id: SourceId,
832) -> Result<Option<(Type, TokenStream)>> {
833    let (reference_type, rust_type_unref) =
834        get_reference_info_and_inner_type(rust_type, conv_map, src_id);
835    let smart_pointer_info = smart_pointer(&rust_type_unref, conv_map, src_id);
836    let class = match conv_map
837        .find_foreigner_class_with_such_this_type(&smart_pointer_info.inner_ty.ty, |_, ft| {
838            ft.self_desc.as_ref().map(|x| x.self_type.clone())
839        }) {
840        Some(fc) => fc.clone(),
841        None => return Ok(None),
842    };
843    let class_smart_pointer = storage_smart_pointer_for_class(&class, conv_map)?;
844    let rust_call_with_deref = if reference_type != Reference::None {
845        if smart_pointer_info.pointer_type == PointerType::Mutex {
846            return Err(DiagnosticError::new(
847                src_id,
848                method_span,
849               "Returning a rust object into python by reference is not safe, so the clone of the object needs to be make.\
850However, `Mutex` doesn't implement `Clone`, so it can't be returned by reference."
851            ));
852        } else if class.clone_derived()
853            || class.copy_derived()
854            || smart_pointer_info.pointer_type.is_shared()
855        {
856            quote! {
857                ((#rust_call).clone())
858            }
859        } else {
860            return Err(DiagnosticError::new(
861                src_id,
862                method_span,
863                "Returning a rust object into python by reference is not safe, so the clone of the object needs to be make. \
864Thus, the returned type must marked with `#[derive(Clone)]` or `#[derive(Copy)]` inside its `foreigner_class` macro."
865            ));
866        }
867    } else {
868        rust_call.clone()
869    };
870    let class_name = &class.name;
871    let py_mod: Ident = parse(&py_wrapper_mod_name(&class_name.to_string()), src_id)?;
872    let rust_call_with_wrapper = match class_smart_pointer.pointer_type {
873        PointerType::Mutex => generate_wrapper_constructor_for_mutex(
874            &class,
875            &smart_pointer_info,
876            rust_call_with_deref,
877            method_span,
878            src_id,
879        )?,
880        PointerType::ArcMutex => generate_wrapper_constructor_for_arc_mutex(
881            &class,
882            &smart_pointer_info,
883            rust_call_with_deref,
884            method_span,
885            src_id,
886        )?,
887        PointerType::Arc => generate_wrapper_constructor_for_arc(
888            &class,
889            &smart_pointer_info,
890            rust_call_with_deref,
891            method_span,
892            src_id,
893        )?,
894        PointerType::Box => generate_wrapper_constructor_for_box(
895            &class,
896            &smart_pointer_info,
897            rust_call_with_deref,
898            method_span,
899            src_id,
900        )?,
901        _ => unreachable!("`PointerType::None` as class storage pointer"),
902    };
903    let conversion = quote! {
904        super::#py_mod::create_instance(py, #rust_call_with_wrapper)?
905    };
906    Ok(Some((parse_type!(super::#py_mod::#class_name), conversion)))
907}
908
909fn generate_wrapper_constructor_for_mutex(
910    class: &ForeignClassInfo,
911    returned_smart_pointer: &SmartPointerInfo,
912    rust_call: TokenStream,
913    method_span: Span,
914    src_id: SourceId,
915) -> Result<TokenStream> {
916    match returned_smart_pointer.pointer_type {
917        PointerType::Mutex => Ok(rust_call),
918        PointerType::None => Ok(quote! {std::sync::Mutex::new(#rust_call)}),
919        _ => Err(DiagnosticError::new(
920            src_id,
921            method_span,
922            format!(
923                "Unsupported conversion for smart pointer. \
924Foreigner class {} is stored as `Mutex` and can be returned eiter as `Mutex` or bare type",
925                class.name
926            ),
927        )),
928    }
929}
930
931fn generate_wrapper_constructor_for_arc_mutex(
932    class: &ForeignClassInfo,
933    returned_smart_pointer: &SmartPointerInfo,
934    rust_call: TokenStream,
935    method_span: Span,
936    src_id: SourceId,
937) -> Result<TokenStream> {
938    match returned_smart_pointer.pointer_type {
939        PointerType::ArcMutex => Ok(rust_call),
940        _ => Err(DiagnosticError::new(
941            src_id,
942            method_span,
943            format!(
944                "Unsupported conversion for smart pointer. \
945Foreigner class {} is stored as `Arc<Mutex<T>>`, so it is intended for sharing between Rust and Python.\
946Thus, it must always be returned from Rust literally by `Arc<Mutex<T>>` \
947(or reference to `Arc<Mutex<T>>`) for any sharing to occur.",
948                class.name
949            ),
950        ))
951    }
952}
953
954fn generate_wrapper_constructor_for_arc(
955    class: &ForeignClassInfo,
956    returned_smart_pointer: &SmartPointerInfo,
957    rust_call: TokenStream,
958    method_span: Span,
959    src_id: SourceId,
960) -> Result<TokenStream> {
961    match returned_smart_pointer.pointer_type {
962        PointerType::Arc => Ok(rust_call),
963        _ => Err(DiagnosticError::new(
964            src_id,
965            method_span,
966            format!(
967                "Unsupported conversion for smart pointer. \
968Foreigner class {} is stored as `Arc<T>`, so it is intended for sharing between Rust and Python.\
969Thus, it must always be returned Rust literally by `Arc<T>` \
970(or reference to `Arc<T>`) for any sharing to occur.",
971                class.name
972            ),
973        )),
974    }
975}
976
977fn generate_wrapper_constructor_for_box(
978    class: &ForeignClassInfo,
979    returned_smart_pointer: &SmartPointerInfo,
980    rust_call: TokenStream,
981    method_span: Span,
982    src_id: SourceId,
983) -> Result<TokenStream> {
984    match returned_smart_pointer.pointer_type {
985        PointerType::None => Ok(rust_call),
986        PointerType::Box => Ok(quote! {(*#rust_call)}),
987        _ => Err(DiagnosticError::new(
988            src_id,
989            method_span,
990            format!(
991                "Unsupported conversion for smart pointer. \
992Foreigner class {} is stored as `Box`, and can be returned eiter as `Box` or bare type",
993                class.name
994            ),
995        )),
996    }
997}
998
999fn if_exported_class_generate_argument_conversion(
1000    rust_type: &RustType,
1001    conv_map: &mut TypeMap,
1002    arg_name_ident: &TokenStream,
1003    method_span: Span,
1004    src_id: SourceId,
1005    reference_allowed: bool,
1006) -> Result<Option<(Type, TokenStream)>> {
1007    let (reference_type, rust_type_unref) =
1008        get_reference_info_and_inner_type(rust_type, conv_map, src_id);
1009    let smart_pointer_info = smart_pointer(&rust_type_unref, conv_map, src_id);
1010    let class = match conv_map
1011        .find_foreigner_class_with_such_this_type(&smart_pointer_info.inner_ty.ty, |_, ft| {
1012            ft.self_desc.as_ref().map(|x| x.self_type.clone())
1013        }) {
1014        Some(fc) => fc.clone(),
1015        None => return Ok(None),
1016    };
1017    let class_smart_pointer = storage_smart_pointer_for_class(&class, conv_map)?;
1018    let class_name = class.name.to_string();
1019    let py_mod_str = py_wrapper_mod_name(&class_name);
1020    let py_mod: Ident = parse(&py_mod_str, src_id)?;
1021    let py_type: Type = if reference_allowed {
1022        parse(&format!("&super::{}::{}", &py_mod_str, &class_name), src_id)?
1023    } else {
1024        parse(&format!("super::{}::{}", &py_mod_str, &class_name), src_id)?
1025    };
1026
1027    let rust_instance_code = if reference_allowed {
1028        quote! {
1029            super::#py_mod::rust_instance(#arg_name_ident, py)
1030        }
1031    } else {
1032        quote! {
1033            super::#py_mod::rust_instance(&#arg_name_ident, py)
1034        }
1035    };
1036    let deref_code = match class_smart_pointer.pointer_type {
1037        PointerType::ArcMutex => generate_deref_for_arc_mutex(
1038            &class,
1039            smart_pointer_info.pointer_type,
1040            reference_type,
1041            rust_instance_code,
1042            method_span,
1043            src_id,
1044        )?,
1045        PointerType::Arc => generate_deref_for_arc(
1046            &class,
1047            smart_pointer_info.pointer_type,
1048            reference_type,
1049            rust_instance_code,
1050            method_span,
1051            src_id,
1052        )?,
1053        PointerType::Mutex => generate_deref_for_mutex(
1054            &class,
1055            smart_pointer_info.pointer_type,
1056            reference_type,
1057            rust_instance_code,
1058            method_span,
1059            src_id,
1060        )?,
1061        PointerType::Box => generate_deref_for_box(
1062            &class,
1063            smart_pointer_info.pointer_type,
1064            reference_type,
1065            rust_instance_code,
1066            method_span,
1067            src_id,
1068        )?,
1069        _ => unreachable!("Class stored as None"),
1070    };
1071
1072    Ok(Some((py_type, deref_code)))
1073}
1074
1075fn generate_deref_for_mutex(
1076    class: &ForeignClassInfo,
1077    arg_smart_pointer: PointerType,
1078    arg_reference: Reference,
1079    rust_instance_code: TokenStream,
1080    method_span: Span,
1081    src_id: SourceId,
1082) -> Result<TokenStream> {
1083    match arg_smart_pointer {
1084        PointerType::None => match arg_reference {
1085            Reference::MutRef => Ok(quote!{(&mut *#rust_instance_code.lock().unwrap())}),
1086            Reference::Ref => Ok(quote!{(&*#rust_instance_code.lock().unwrap())}),
1087            Reference::None => append_clone_if_supported(class, quote!{*#rust_instance_code.lock().unwrap()}, method_span),
1088        },
1089        PointerType::Mutex => match arg_reference {
1090            Reference::Ref => Ok(rust_instance_code),
1091            _ => Err(DiagnosticError::new(
1092                src_id,
1093                method_span,
1094                "Mutex can be passed to function only by const reference `Mutex`",
1095            ))
1096        }
1097        _ => Err(DiagnosticError::new(
1098            src_id,
1099            method_span,
1100            format!(
1101                "Unsupported conversion for smart pointer. \
1102Foreigner class {} is stored as `Mutex` and can be passed to function either as `Mutex` reference or bare type",
1103                class.name
1104            ),
1105        ))
1106    }
1107}
1108
1109fn generate_deref_for_arc_mutex(
1110    class: &ForeignClassInfo,
1111    arg_smart_pointer: PointerType,
1112    arg_reference: Reference,
1113    rust_instance_code: TokenStream,
1114    method_span: Span,
1115    src_id: SourceId,
1116) -> Result<TokenStream> {
1117    match arg_smart_pointer {
1118        PointerType::None => match arg_reference {
1119            Reference::MutRef => Ok(quote!{(&mut *#rust_instance_code.lock().unwrap())}),
1120            Reference::Ref => Ok(quote!{(&*#rust_instance_code.lock().unwrap())}),
1121            Reference::None => append_clone_if_supported(class, quote!{*#rust_instance_code.lock().unwrap()}, method_span),
1122        },
1123        PointerType::ArcMutex => match arg_reference {
1124            Reference::Ref => Ok(rust_instance_code),
1125            Reference::None => Ok(quote!{#rust_instance_code.clone()}),
1126            _ => Err(DiagnosticError::new(
1127                src_id,
1128                method_span,
1129                "Arc<Mutex<T>> can't be passed to function by mut reference. It doesn't make sense anyway.",
1130            ))
1131        }
1132        _ => Err(DiagnosticError::new(
1133            src_id,
1134            method_span,
1135            format!(
1136                "Unsupported conversion for smart pointer. \
1137Foreigner class {} is stored as `Arc<Mutex<T>>` and can be passed to function either as `Arc<Mutex<T>>` or bare type",
1138                class.name
1139            ),
1140        ))
1141    }
1142}
1143
1144fn generate_deref_for_arc(
1145    class: &ForeignClassInfo,
1146    arg_smart_pointer: PointerType,
1147    arg_reference: Reference,
1148    rust_instance_code: TokenStream,
1149    method_span: Span,
1150    src_id: SourceId,
1151) -> Result<TokenStream> {
1152    match arg_smart_pointer {
1153        PointerType::None => match arg_reference {
1154            Reference::Ref => Ok(quote!{(&*#rust_instance_code)}),
1155            Reference::None => append_clone_if_supported(class, quote!{*#rust_instance_code}, method_span),
1156            Reference::MutRef => Err(DiagnosticError::new(
1157                src_id,
1158                method_span,
1159                "Object is stored in `Arc`, so it is immutable. If you need mutability, use Arc<Mutex<T>> for constructor type",
1160            ))
1161        },
1162        PointerType::Arc => match arg_reference {
1163            Reference::Ref => Ok(rust_instance_code),
1164            Reference::None => Ok(quote!{#rust_instance_code.clone()}),
1165            _ => Err(DiagnosticError::new(
1166                src_id,
1167                method_span,
1168                "`Arc` can't be passed to function by mut reference. It's immutable",
1169            ))
1170        }
1171        _ => Err(DiagnosticError::new(
1172            src_id,
1173            method_span,
1174            format!(
1175                "Unsupported conversion for smart pointer. \
1176Foreigner class {} is stored as `Arc` and can be passed to function either as `Arc` or a bare type",
1177                class.name
1178            ),
1179        ))
1180    }
1181}
1182
1183fn generate_deref_for_box(
1184    class: &ForeignClassInfo,
1185    arg_smart_pointer: PointerType,
1186    arg_reference: Reference,
1187    rust_instance_code: TokenStream,
1188    method_span: Span,
1189    src_id: SourceId,
1190) -> Result<TokenStream> {
1191    match arg_smart_pointer {
1192        PointerType::None => match arg_reference {
1193            Reference::Ref => Ok(rust_instance_code),
1194            Reference::None => append_clone_if_supported(class, rust_instance_code, method_span),
1195            Reference::MutRef => Err(DiagnosticError::new(
1196                src_id,
1197                method_span,
1198                "Object is stored in `Box`, so it is immutable. If you need mutability, use `Mutex` for constructor type",
1199            ))
1200        },
1201        _ => Err(DiagnosticError::new(
1202            src_id,
1203            method_span,
1204            format!(
1205                "Unsupported conversion for smart pointer. \
1206Foreigner class {} is stored as `Box` and can be passed to function anly as a bare type",
1207                class.name
1208            ),
1209        ))
1210    }
1211}
1212
1213fn append_clone_if_supported(
1214    class: &ForeignClassInfo,
1215    rust_instance_code: TokenStream,
1216    method_span: Span,
1217) -> Result<TokenStream> {
1218    if class.clone_derived() || class.copy_derived() {
1219        Ok(quote!((#rust_instance_code).clone()))
1220    } else {
1221        Err(DiagnosticError::new(
1222            class.src_id,
1223            method_span,
1224            "Passing object by value requires that it is marked with `#[derive(Clone)]` or `#[derive(Copy)]`\
1225inside its `foreigner_class` macro."
1226        ))
1227    }
1228}
1229
1230fn storage_smart_pointer_for_class(
1231    class: &ForeignClassInfo,
1232    conv_map: &mut TypeMap,
1233) -> Result<SmartPointerInfo> {
1234    if let Some(ref self_desc) = class.self_desc {
1235        let constructor_ret_rust_type =
1236            conv_map.find_or_alloc_rust_type(&self_desc.constructor_ret_type, class.src_id);
1237        let pointer = smart_pointer(&constructor_ret_rust_type, conv_map, class.src_id);
1238        match pointer.pointer_type {
1239            // Default wrapper type for storage is `Mutex`.
1240            PointerType::None => Ok(SmartPointerInfo::new(PointerType::Mutex, pointer.inner_ty)),
1241            _ => Ok(pointer),
1242        }
1243    } else {
1244        Err(DiagnosticError::new(
1245            class.src_id,
1246            class.span(),
1247            "Class doesn't define a type returned from constructor, nor self_type, but is not static"
1248        ))
1249    }
1250}
1251
1252fn wrap_type_for_class(self_type: &Type, storage_pointer: PointerType) -> TokenStream {
1253    match storage_pointer {
1254        PointerType::ArcMutex => quote! {std::sync::Arc<std::sync::Mutex<super::#self_type>>},
1255        PointerType::Arc => quote! {std::sync::Arc<super::#self_type>},
1256        PointerType::Mutex => quote! {std::sync::Mutex<super::#self_type>},
1257        PointerType::Box => quote! {super::#self_type},
1258        PointerType::None => unreachable!("None pointer for object storage"),
1259    }
1260}
1261
1262fn parse<T: syn::parse::Parse>(ident_str: &str, src_id: SourceId) -> Result<T> {
1263    syn::parse_str::<T>(ident_str).map_err(|err| DiagnosticError::from_syn_err(src_id, err))
1264}
1265
1266fn if_type_slice_return_elem_type(ty: &Type, accept_mutbl_slice: bool) -> Option<&Type> {
1267    if let syn::Type::Reference(syn::TypeReference {
1268        ref elem,
1269        mutability,
1270        ..
1271    }) = ty
1272    {
1273        if mutability.is_some() && !accept_mutbl_slice {
1274            return None;
1275        }
1276        if let syn::Type::Slice(syn::TypeSlice { ref elem, .. }) = **elem {
1277            Some(elem)
1278        } else {
1279            None
1280        }
1281    } else {
1282        None
1283    }
1284}
1285
1286fn if_vec_return_elem_type(ty: &RustType) -> Option<Type> {
1287    let from_ty: Type = parse_quote! { Vec<T> };
1288    let to_ty: Type = parse_quote! { T };
1289    let generic_params: syn::Generics = parse_quote! { <T> };
1290
1291    GenericTypeConv::new(from_ty, to_ty, generic_params, TypeConvCode::invalid())
1292        .is_conv_possible(ty, None, |_| None)
1293        .map(|x| x.to_ty)
1294}