py_rs_macros/
lib.rs

1#![macro_use]
2#![deny(unused)]
3
4use std::collections::{HashMap, HashSet};
5
6use proc_macro2::{Ident, TokenStream};
7use quote::{format_ident, quote};
8use syn::{
9    parse_quote, spanned::Spanned, ConstParam, GenericParam, Generics, Item, LifetimeParam, Path,
10    Result, Type, TypeArray, TypeParam, TypeParen, TypePath, TypeReference, TypeSlice, TypeTuple,
11    WhereClause, WherePredicate,
12};
13
14use crate::{deps::Dependencies, utils::format_generics};
15
16#[macro_use]
17mod utils;
18mod attr;
19mod deps;
20mod types;
21
22#[derive(Default, Clone)]
23struct EnumDef {
24    pub variant_names: Vec<String>,
25    pub test_str: TokenStream,
26    pub num_variant_classes: usize,
27}
28
29struct DerivedPY {
30    crate_rename: Path,
31    py_name: String,
32    docs: String,
33    inline: TokenStream,
34    inline_flattened: Option<TokenStream>,
35    dependencies: Dependencies,
36    concrete: HashMap<Ident, Type>,
37    bound: Option<Vec<WherePredicate>>,
38    enum_def: Option<EnumDef>,
39    export: bool,
40    export_to: Option<String>,
41}
42
43impl DerivedPY {
44    fn into_impl(mut self, rust_ty: Ident, generics: Generics) -> TokenStream {
45        let export = self
46            .export
47            .then(|| self.generate_export_test(&rust_ty, &generics));
48
49        let output_path_fn = {
50            let path = match self.export_to.as_deref() {
51                Some(dirname) if dirname.ends_with('/') => {
52                    format!("{}{}.py", dirname, self.py_name)
53                }
54                Some(filename) => filename.to_owned(),
55                None => format!("{}.py", self.py_name),
56            };
57
58            quote! {
59                fn output_path() -> Option<&'static std::path::Path> {
60                    Some(std::path::Path::new(#path))
61                }
62            }
63        };
64
65        let docs = match &*self.docs {
66            "" => None,
67            docs => Some(quote!(const DOCS: Option<&'static str> = Some(#docs);)),
68        };
69
70        let crate_rename = self.crate_rename.clone();
71
72        let ident = self.py_name.clone();
73        let impl_start = generate_impl_block_header(
74            &crate_rename,
75            &rust_ty,
76            &generics,
77            self.bound.as_deref(),
78            &self.dependencies,
79        );
80        let assoc_type = generate_assoc_type(&rust_ty, &crate_rename, &generics, &self.concrete);
81        let name = self.generate_name_fn(&generics);
82        let inline = self.generate_inline_fn();
83        let decl = self.generate_decl_fn(&rust_ty, &generics);
84        let dependencies = &self.dependencies;
85        let generics_fn = self.generate_generics_fn(&generics);
86        let enum_decl = self.generate_variant_classes_decl();
87
88        quote! {
89            #impl_start {
90                #assoc_type
91
92                fn ident() -> String {
93                    #ident.to_owned()
94                }
95                #enum_decl
96                #docs
97                #name
98                #decl
99                #inline
100                #generics_fn
101                #output_path_fn
102
103                fn visit_dependencies(v: &mut impl #crate_rename::TypeVisitor)
104                where
105                    Self: 'static,
106                {
107                    #dependencies
108                }
109            }
110
111            #export
112        }
113    }
114
115    /// Returns an expression which evaluates to the TypeScript name of the type, including generic
116    /// parameters.
117    fn name_with_generics(&self, generics: &Generics) -> TokenStream {
118        let name = &self.py_name;
119        let crate_rename = &self.crate_rename;
120        let mut generics_py_names = generics
121            .type_params()
122            .filter(|ty| !self.concrete.contains_key(&ty.ident))
123            .map(|ty| &ty.ident)
124            .map(|generic| quote!(<#generic as #crate_rename::PY>::name()))
125            .peekable();
126
127        if generics_py_names.peek().is_some() {
128            quote! {
129                format!("{}<{}>", #name, vec![#(#generics_py_names),*].join(", "))
130            }
131        } else {
132            quote!(#name.to_owned())
133        }
134    }
135
136    /// Generate a dummy unit struct for every generic type parameter of this type.
137    /// # Example:
138    /// ```compile_fail
139    /// struct Generic<A, B, const C: usize> { /* ... */ }
140    /// ```
141    /// has two generic type parameters, `A` and `B`. This function will therefore generate
142    /// ```compile_fail
143    /// struct A;
144    /// impl py_rs::PY for A { /* .. */ }
145    ///
146    /// struct B;
147    /// impl py_rs::PY for B { /* .. */ }
148    /// ```
149    fn generate_generic_types(&self, generics: &Generics) -> TokenStream {
150        let crate_rename = &self.crate_rename;
151        let generics = generics
152            .type_params()
153            .filter(|ty| !self.concrete.contains_key(&ty.ident))
154            .map(|ty| ty.ident.clone());
155        let name = quote![<Self as #crate_rename::PY>::name()];
156        quote! {
157            #(
158                #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
159                struct #generics;
160                impl std::fmt::Display for #generics {
161                    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162                        write!(f, "{:?}", self)
163                    }
164                }
165                impl #crate_rename::PY for #generics {
166                    type WithoutGenerics = #generics;
167                    fn name() -> String { stringify!(#generics).to_owned() }
168                    fn inline() -> String { panic!("{} cannot be inlined", #name) }
169                    fn inline_flattened() -> String { stringify!(#generics).to_owned() }
170                    fn decl() -> String { panic!("{} cannot be declared", #name) }
171                    fn decl_concrete() -> String { panic!("{} cannot be declared", #name) }
172                    fn variant_classes_decl() -> String {
173                        panic!("{} cannot be declared", #name)
174                    }
175                }
176            )*
177        }
178    }
179
180    /// generate the variant class decl function
181    fn generate_variant_classes_decl(&self) -> TokenStream {
182        if let Some(enum_def) = self.enum_def.clone() {
183            let name = &self.py_name;
184            let variant_text = enum_def
185                .variant_names
186                .iter()
187                .map(|i| format!("{} = \"{}\"", i, i))
188                .collect::<Vec<String>>()
189                .join("\n\t");
190            if enum_def.num_variant_classes > 0 {
191                let variant_classes = enum_def.test_str;
192                return quote! {
193                    fn variant_classes_decl() -> String {
194                        let variant_classes = #variant_classes;
195                        let variants = format!("{}", #variant_text); // TODO get the variants and put them here
196                        let enum_str = format!("class {}Identifier(StrEnum):\n\t{variants}\n\n{variant_classes}", #name);
197                        enum_str
198                    }
199                };
200            }
201
202            quote! {
203                fn variant_classes_decl() -> String {
204                    let variants = format!("{}", #variant_text); // TODO get the variants and put them here
205                    let enum_str = format!("class {}Identifier(StrEnum):\n\t{variants}\n", #name);
206                    enum_str
207                }
208            }
209        } else {
210            return quote! {
211                fn variant_classes_decl() -> String {
212                    String::new()
213                }
214            };
215        }
216    }
217
218    fn generate_export_test(&self, rust_ty: &Ident, generics: &Generics) -> TokenStream {
219        let test_fn = format_ident!(
220            "export_bindings_{}_py",
221            rust_ty.to_string().to_lowercase().replace("r#", "")
222        );
223        let crate_rename = &self.crate_rename;
224        let generic_params = generics
225            .type_params()
226            .map(|ty| match self.concrete.get(&ty.ident) {
227                None => quote! { #crate_rename::Dummy },
228                Some(ty) => quote! { #ty },
229            });
230        let ty = quote!(<#rust_ty<#(#generic_params),*> as #crate_rename::PY>);
231
232        quote! {
233            #[cfg(test)]
234            #[test]
235            fn #test_fn() {
236                #ty::export_all().expect("could not export type");
237            }
238        }
239    }
240
241    fn generate_generics_fn(&self, generics: &Generics) -> TokenStream {
242        let crate_rename = &self.crate_rename;
243        let generics = generics
244            .type_params()
245            .filter(|ty| !self.concrete.contains_key(&ty.ident))
246            .map(|TypeParam { ident, .. }| {
247                quote![
248                    v.visit::<#ident>();
249                    <#ident as #crate_rename::PY>::visit_generics(v);
250                ]
251            });
252        quote! {
253            fn visit_generics(v: &mut impl #crate_rename::TypeVisitor)
254            where
255                Self: 'static,
256            {
257                #(#generics)*
258            }
259        }
260    }
261
262    fn generate_name_fn(&self, generics: &Generics) -> TokenStream {
263        let name = self.name_with_generics(generics);
264        quote! {
265            fn name() -> String {
266                #name
267            }
268        }
269    }
270
271    fn generate_inline_fn(&self) -> TokenStream {
272        let inline = &self.inline;
273        let crate_rename = &self.crate_rename;
274
275        let inline_flattened = self.inline_flattened.as_ref().map_or_else(
276            || {
277                quote! {
278                    fn inline_flattened() -> String {
279                        panic!("{} cannot be flattened", <Self as #crate_rename::PY>::name())
280                    }
281                }
282            },
283            |inline_flattened| {
284                quote! {
285                    fn inline_flattened() -> String {
286                        #inline_flattened
287                    }
288                }
289            },
290        );
291        let inline = quote! {
292            fn inline() -> String {
293                #inline
294            }
295        };
296        quote! {
297            #inline
298            #inline_flattened
299        }
300    }
301
302    /// Generates the `decl()` and `decl_concrete()` methods.
303    /// `decl_concrete()` is simple, and simply defers to `inline()`.
304    /// For `decl()`, however, we need to change out the generic parameters of the type, replacing
305    /// them with the dummy types generated by `generate_generic_types()`.
306    fn generate_decl_fn(&mut self, rust_ty: &Ident, generics: &Generics) -> TokenStream {
307        let name = &self.py_name;
308        let crate_rename = &self.crate_rename;
309        let generic_types = self.generate_generic_types(generics);
310        let py_generics = format_generics(
311            &mut self.dependencies,
312            crate_rename,
313            generics,
314            &self.concrete,
315        );
316
317        use GenericParam as G;
318        // These are the generic parameters we'll be using.
319        let generic_idents = generics.params.iter().filter_map(|p| match p {
320            G::Lifetime(_) => None,
321            G::Type(TypeParam { ident, .. }) => match self.concrete.get(ident) {
322                // Since we named our dummy types the same as the generic parameters, we can just keep
323                // the identifier of the generic parameter - its name is shadowed by the dummy struct.
324                None => Some(quote!(#ident)),
325                // If the type parameter is concrete, we use the type the user provided using
326                // `#[py(concrete)]`
327                Some(concrete) => Some(quote!(#concrete)),
328            },
329            // We keep const parameters as they are, since there's no sensible default value we can
330            // use instead. This might be something to change in the future.
331            G::Const(ConstParam { ident, .. }) => Some(quote!(#ident)),
332        });
333
334        if let Some(_) = self.enum_def.clone() {
335            quote! {
336                    fn decl_concrete() -> String {
337                        format!("{}\n\n{} = {}", <Self as #crate_rename::PY>::variant_classes_decl(), #name, <Self as #crate_rename::PY>::inline())
338                }
339                fn decl() -> String { // TODO we need to handle the case where the type is a enum or a struct differently
340                    #generic_types
341                    let inline = <#rust_ty<#(#generic_idents,)*> as #crate_rename::PY>::inline();
342                    let generics = #py_generics;
343                    format!("{}\n\n{}{generics} = {inline}", <Self as #crate_rename::PY>::variant_classes_decl(), #name)
344                }
345            }
346        } else {
347            let docs = self.docs.clone();
348            quote! {
349                    fn decl_concrete() -> String {
350                        format!("\nclass {}(BaseModel):\n\t{}\n\t{}", #name, #docs, <Self as #crate_rename::PY>::inline())
351                }
352                fn decl() -> String { // TODO we need to handle the case where the type is a enum or a struct differently
353                    #generic_types
354                    let inline = <#rust_ty<#(#generic_idents,)*> as #crate_rename::PY>::inline();
355                    let generics = #py_generics;
356                    format!("\nclass {}{generics}(BaseModel):\n{}\n\t{inline}", #name, #docs)
357                }
358            }
359        }
360    }
361}
362
363fn generate_assoc_type(
364    rust_ty: &Ident,
365    crate_rename: &Path,
366    generics: &Generics,
367    concrete: &HashMap<Ident, Type>,
368) -> TokenStream {
369    use GenericParam as G;
370
371    let generics_params = generics.params.iter().map(|x| match x {
372        G::Type(ty) => match concrete.get(&ty.ident) {
373            None => quote! { #crate_rename::Dummy },
374            Some(ty) => quote! { #ty },
375        },
376        G::Const(ConstParam { ident, .. }) => quote! { #ident },
377        G::Lifetime(LifetimeParam { lifetime, .. }) => quote! { #lifetime },
378    });
379
380    quote! { type WithoutGenerics = #rust_ty<#(#generics_params),*>; }
381    // This error is not actually breaking the build
382}
383
384// generate start of the `impl PY for #ty` block, up to (excluding) the open brace
385fn generate_impl_block_header(
386    crate_rename: &Path,
387    ty: &Ident,
388    generics: &Generics,
389    bounds: Option<&[WherePredicate]>,
390    dependencies: &Dependencies,
391) -> TokenStream {
392    use GenericParam as G;
393
394    let params = generics.params.iter().map(|param| match param {
395        G::Type(TypeParam {
396            ident,
397            colon_token,
398            bounds,
399            ..
400        }) => quote!(#ident #colon_token #bounds),
401        G::Lifetime(LifetimeParam {
402            lifetime,
403            colon_token,
404            bounds,
405            ..
406        }) => quote!(#lifetime #colon_token #bounds),
407        G::Const(ConstParam {
408            const_token,
409            ident,
410            colon_token,
411            ty,
412            ..
413        }) => quote!(#const_token #ident #colon_token #ty),
414    });
415    let type_args = generics.params.iter().map(|param| match param {
416        G::Type(TypeParam { ident, .. }) | G::Const(ConstParam { ident, .. }) => quote!(#ident),
417        G::Lifetime(LifetimeParam { lifetime, .. }) => quote!(#lifetime),
418    });
419
420    let where_bound = match bounds {
421        Some(bounds) => quote! { where #(#bounds),* },
422        None => {
423            let bounds = generate_where_clause(crate_rename, generics, dependencies);
424            quote! { #bounds }
425        }
426    };
427
428    quote!(impl <#(#params),*> #crate_rename::PY for #ty <#(#type_args),*> #where_bound)
429}
430
431fn generate_where_clause(
432    crate_rename: &Path,
433    generics: &Generics,
434    dependencies: &Dependencies,
435) -> WhereClause {
436    let used_types = {
437        let is_type_param = |id: &Ident| generics.type_params().any(|p| &p.ident == id);
438
439        let mut used_types = HashSet::new();
440        for ty in dependencies.used_types() {
441            used_type_params(&mut used_types, ty, is_type_param);
442        }
443        used_types.into_iter()
444    };
445
446    let existing = generics.where_clause.iter().flat_map(|w| &w.predicates);
447    parse_quote! {
448        where #(#existing,)* #(#used_types: #crate_rename::PY),*
449    }
450}
451
452// Extracts all type parameters which are used within the given type.
453// Associated types of a type parameter are extracted as well.
454// Note: This will not extract `I` from `I::Item`, but just `I::Item`!
455fn used_type_params<'ty, 'out>(
456    out: &'out mut HashSet<&'ty Type>,
457    ty: &'ty Type,
458    is_type_param: impl Fn(&'ty Ident) -> bool + Copy + 'out,
459) {
460    use syn::{
461        AngleBracketedGenericArguments as GenericArgs, GenericArgument as G, PathArguments as P,
462    };
463
464    match ty {
465        Type::Array(TypeArray { elem, .. })
466        | Type::Paren(TypeParen { elem, .. })
467        | Type::Reference(TypeReference { elem, .. })
468        | Type::Slice(TypeSlice { elem, .. }) => used_type_params(out, elem, is_type_param),
469        Type::Tuple(TypeTuple { elems, .. }) => elems
470            .iter()
471            .for_each(|elem| used_type_params(out, elem, is_type_param)),
472        Type::Path(TypePath { qself: None, path }) => {
473            let first = path.segments.first().unwrap();
474            if is_type_param(&first.ident) {
475                // The type is either a generic parameter (e.g `T`), or an associated type of that
476                // generic parameter (e.g `I::Item`). Either way, we return it.
477                out.insert(ty);
478                return;
479            }
480
481            let last = path.segments.last().unwrap();
482            if let P::AngleBracketed(GenericArgs { ref args, .. }) = last.arguments {
483                for generic in args {
484                    if let G::Type(ty) = generic {
485                        used_type_params(out, ty, is_type_param);
486                    }
487                }
488            }
489        }
490        _ => (),
491    }
492}
493
494/// Derives [PY](./trait.PY.html) for a struct or enum.
495/// Please take a look at [PY](./trait.PY.html) for documentation.
496#[proc_macro_derive(PY, attributes(py))]
497pub fn python(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
498    match entry(input) {
499        Err(err) => err.to_compile_error(),
500        Ok(result) => result,
501    }
502    .into()
503}
504
505fn entry(input: proc_macro::TokenStream) -> Result<TokenStream> {
506    let input = syn::parse::<Item>(input)?;
507    let (py, ident, generics) = match input {
508        Item::Struct(s) => (types::struct_def(&s)?, s.ident, s.generics),
509        Item::Enum(e) => (types::enum_def(&e)?, e.ident, e.generics),
510        _ => syn_err!(input.span(); "unsupported item"),
511    };
512
513    Ok(py.into_impl(ident, generics))
514}