postcard_bindgen_core/code_gen/python/
mod.rs

1mod des;
2mod general;
3mod generateable;
4mod ser;
5mod type_checks;
6
7use core::borrow::Borrow;
8
9use convert_case::{Case, Casing};
10use des::{gen_des_functions, gen_deserialize_func, gen_deserializer_code};
11use genco::{lang::python::Python, quote, quote_in, tokens::FormatInto};
12use general::gen_util;
13use generateable::{gen_basic_typings, gen_typings};
14use ser::{gen_ser_functions, gen_serialize_func, gen_serializer_code};
15use type_checks::gen_type_checks;
16
17use crate::{
18    code_gen::import_registry::ImportMode, path::PathBuf, registry::ContainerCollection, Exports,
19};
20
21use super::{
22    import_registry::{ImportItem, Package},
23    utils::{IfBranchedTemplate, TokensBranchedIterExt, TokensIterExt},
24};
25
26const PYTHON_OBJECT_VARIABLE: &str = "v";
27const PYTHON_LOGIC_AND: &str = "and";
28const PYTHON_LOGIC_OR: &str = "or";
29
30type Tokens = genco::lang::python::Tokens;
31
32type VariablePath = super::variable_path::VariablePath<Python>;
33type VariableAccess = super::variable_path::VariableAccess;
34type FieldAccessor<'a> = super::field_accessor::FieldAccessor<'a>;
35type AvailableCheck = super::available_check::AvailableCheck<Python>;
36type ImportRegistry = super::import_registry::ImportRegistry;
37type ExportFile = crate::ExportFile<Python>;
38type FunctionArg = super::function::FunctionArg<Python>;
39type Function = super::function::Function<Python>;
40
41/// Settings for bindings generation.
42///
43/// This enables the possibility to enable or disable serialization, deserialization, runtime type checks
44/// or type script types.
45/// Less code will be generated if an option is off.
46///
47/// By default, only deserialization is enabled. Serialization can be enabled by using [`GenerationSettings::serialization()`].
48/// Deserialization can be disabled with [`GenerationSettings::deserialization()`].
49/// To enable all at once use [`GenerationSettings::enable_all()`].
50#[derive(Debug)]
51pub struct GenerationSettings {
52    ser: bool,
53    des: bool,
54    runtime_type_checks: bool,
55    module_structure: bool,
56}
57
58impl GenerationSettings {
59    /// Constructs [`GenerationSettings`] and enables all options at once.
60    pub fn enable_all() -> Self {
61        Self {
62            ser: true,
63            des: true,
64            runtime_type_checks: true,
65            module_structure: true,
66        }
67    }
68
69    /// Enabling or disabling of serialization code generation.
70    pub fn serialization(mut self, enabled: bool) -> Self {
71        self.ser = enabled;
72        self
73    }
74
75    /// Enabling or disabling of deserialization code generation.
76    pub fn deserialization(mut self, enabled: bool) -> Self {
77        self.des = enabled;
78        self
79    }
80
81    /// Enabling or disabling of runtime type checks code generation.
82    ///
83    /// Disabling this should lead to a speed increase at serialization.
84    pub fn runtime_type_checks(mut self, enabled: bool) -> Self {
85        self.runtime_type_checks = enabled;
86        self
87    }
88
89    /// Enabling or disabling of module structure code generation.
90    ///
91    /// Enabling this will generate the types in the same module structure
92    /// as in rust. Root level types will be in the root of the generated
93    /// module (subpackage types). Types nested in modules will be in subpackages
94    /// (e.g. types.<mod_name>.<type_name>). This avoids name clashes.
95    ///
96    /// Disabling this will generate all types in the root module.
97    pub fn module_structure(mut self, enabled: bool) -> Self {
98        self.module_structure = enabled;
99        self
100    }
101}
102
103impl Default for GenerationSettings {
104    fn default() -> Self {
105        Self {
106            ser: false,
107            des: true,
108            runtime_type_checks: false,
109            module_structure: true,
110        }
111    }
112}
113
114pub fn generate(
115    mut containers: ContainerCollection,
116    gen_settings: impl Borrow<GenerationSettings>,
117    generate_package_name: String,
118) -> Exports<Python> {
119    let generate_package_name = generate_package_name.to_case(Case::Snake);
120    let gen_settings = gen_settings.borrow();
121
122    if !gen_settings.module_structure {
123        containers.flatten();
124    }
125
126    let mut files = Vec::new();
127
128    files.push(ExportFile {
129        content_type: "util".to_owned(),
130        content: gen_util(),
131    });
132
133    files.push(ExportFile {
134        content_type: "basic_types".to_owned(),
135        content: gen_basic_typings(),
136    });
137
138    files.extend(gen_typings(&containers, generate_package_name.clone()));
139
140    if gen_settings.runtime_type_checks {
141        let type_checks = gen_type_checks(containers.all_containers());
142
143        let type_checks = quote! {
144            from .util import *
145            from .types import *
146
147            $type_checks
148        };
149
150        files.push(ExportFile {
151            content_type: "runtime_checks".to_owned(),
152            content: type_checks,
153        });
154    }
155
156    if gen_settings.ser {
157        let serializer_code = gen_serializer_code();
158        let ser_code = quote! {
159            from typing import Union
160
161            from .types import *
162            from .util import *
163            from .serializer import Serializer
164
165            $(gen_ser_functions(containers.all_containers()))
166
167            $(gen_serialize_func(containers.all_containers(), gen_settings.runtime_type_checks))
168        };
169
170        files.push(ExportFile {
171            content_type: "serializer".to_owned(),
172            content: serializer_code,
173        });
174
175        files.push(ExportFile {
176            content_type: "ser".to_owned(),
177            content: ser_code,
178        });
179    }
180
181    if gen_settings.des {
182        let deserializer_code = gen_deserializer_code();
183        let des_code = quote! {
184            from typing import TypeVar, Type, cast, Tuple
185
186            from .types import *
187            from .util import *
188            from .deserializer import Deserializer
189
190            $(gen_des_functions(containers.all_containers()))
191
192            $(gen_deserialize_func(containers.all_containers()))
193        };
194
195        files.push(ExportFile {
196            content_type: "deserializer".to_owned(),
197            content: deserializer_code,
198        });
199
200        files.push(ExportFile {
201            content_type: "des".to_owned(),
202            content: des_code,
203        });
204    }
205
206    let mut import_registry = ImportRegistry::new(generate_package_name);
207    import_registry.push(Package::Relative("types".into()), ImportItem::All);
208    import_registry.push(Package::Relative("basic_types".into()), ImportItem::All);
209
210    if gen_settings.des {
211        import_registry.push(
212            Package::Relative("des".into()),
213            ImportItem::Single("deserialize".into()),
214        );
215    }
216
217    if gen_settings.ser {
218        import_registry.push(
219            Package::Relative("ser".into()),
220            ImportItem::Single("serialize".into()),
221        );
222    }
223
224    files.push(ExportFile {
225        content_type: "__init__".to_owned(),
226        content: quote!($import_registry),
227    });
228
229    Exports { files }
230}
231
232impl<I, F> TokensIterExt<Python, F> for I
233where
234    I: Iterator<Item = F>,
235    F: FormatInto<Python>,
236{
237    const LOGICAL_AND: &'static str = PYTHON_LOGIC_AND;
238    const LOGICAL_OR: &'static str = PYTHON_LOGIC_OR;
239}
240
241pub(super) struct BranchedTemplate;
242
243impl IfBranchedTemplate<Python> for BranchedTemplate {
244    const IF_BRANCH: &'static str = "if";
245    const IF_ELSE_BRANCH: &'static str = "elif";
246    const ELSE_BRANCH: &'static str = "else";
247
248    fn push_condition(tokens: &mut Tokens, condition: impl FormatInto<Python>) {
249        tokens.append(condition)
250    }
251
252    fn push_condition_block(tokens: &mut Tokens, body: impl FormatInto<Python>) {
253        tokens.append(":");
254        tokens.indent();
255        tokens.append(body);
256        tokens.unindent();
257    }
258}
259
260impl<I> TokensBranchedIterExt<Python> for I
261where
262    I: Iterator<Item = (Option<Tokens>, Tokens)>,
263{
264    type Template = BranchedTemplate;
265}
266
267impl FormatInto<Python> for FieldAccessor<'_> {
268    fn format_into(self, tokens: &mut Tokens) {
269        quote_in! { *tokens =>
270            $(match self {
271                Self::Array | Self::None => (),
272                Self::Object(n) => $n = $[' '],
273            })
274        }
275    }
276}
277
278impl FormatInto<Python> for VariablePath {
279    fn format_into(self, tokens: &mut genco::Tokens<Python>) {
280        quote_in! { *tokens =>
281            $(self.start_variable)
282        }
283        self.parts
284            .into_iter()
285            .for_each(|part| part.format_into(tokens))
286    }
287}
288
289impl Default for VariablePath {
290    fn default() -> Self {
291        Self::new(PYTHON_OBJECT_VARIABLE.to_owned())
292    }
293}
294
295impl FormatInto<Python> for VariableAccess {
296    fn format_into(self, tokens: &mut genco::Tokens<Python>) {
297        quote_in! { *tokens =>
298            $(match self {
299                Self::Indexed(index) => [$index],
300                Self::Field(name) => .$name,
301            })
302        }
303    }
304}
305
306impl FormatInto<Python> for AvailableCheck {
307    fn format_into(self, tokens: &mut Tokens) {
308        quote_in! { *tokens =>
309            $(match self {
310                AvailableCheck::Object(..) => (),
311                AvailableCheck::None => ()
312            })
313        }
314    }
315}
316
317impl FormatInto<Python> for ImportRegistry {
318    fn format_into(self, tokens: &mut Tokens) {
319        let (base_path, items) = self.into_items_sorted();
320        for (package, imports) in items {
321            let joiner = ".";
322            let package = match package {
323                Package::Relative(path) => format!(".{}", path.into_path(joiner)),
324                Package::Extern(path) => path.into_path(joiner).to_string(),
325                Package::Intern(mut path) => {
326                    if !path.is_empty() {
327                        path.push_front(base_path.as_str());
328                        path.into_path(joiner).to_string()
329                    } else {
330                        PathBuf::new()
331                            .join(base_path.as_str())
332                            .into_path(joiner)
333                            .to_string()
334                    }
335                }
336                Package::Root => base_path.to_owned(),
337            };
338
339            quote_in!(*tokens=> from $(package) import);
340            tokens.space();
341
342            match imports {
343                ImportMode::All => quote_in!(*tokens=> *),
344                ImportMode::Single(items) => {
345                    let items = items.iter().map(|i| {
346                        if let Some(alias) = &i.alias {
347                            quote!($(&i.name) as $alias)
348                        } else {
349                            quote!($(&i.name))
350                        }
351                    });
352                    quote_in!(*tokens=> $(for part in items join (, ) => $part))
353                }
354            }
355
356            tokens.push();
357        }
358    }
359}
360
361impl FormatInto<Python> for FunctionArg {
362    fn format_into(self, tokens: &mut Tokens) {
363        if let Some(r#type) = self.r#type {
364            quote_in! { *tokens =>
365                $(self.name): $r#type
366            }
367        } else {
368            quote_in! { *tokens =>
369                $(self.name)
370            }
371        }
372    }
373}
374
375impl FormatInto<Python> for Function {
376    fn format_into(self, tokens: &mut Tokens) {
377        let doc_string = self.doc_string.map(|doc_string| {
378            let mut tokens = Tokens::new();
379
380            tokens.append("\"\"\"");
381            tokens.append(
382                doc_string
383                    .lines()
384                    .enumerate()
385                    .map(|(i, f)| {
386                        if i > 0 {
387                            format!("    {}", f)
388                        } else {
389                            f.to_string()
390                        }
391                    })
392                    .collect::<Vec<_>>()
393                    .join("\n"),
394            );
395            tokens.append("\"\"\"");
396            tokens
397        });
398
399        let return_type = self.return_type.map(|r| quote!($(" ")-> $r));
400        quote_in! { *tokens =>
401            def $(self.name)($(for arg in self.args join (, ) => $arg))$return_type:
402                $(doc_string)
403                $(self.body)
404        }
405    }
406}
407
408#[cfg(test)]
409mod test {
410    use genco::tokens::FormatInto;
411
412    use super::Tokens;
413
414    #[test]
415    fn test_import_registry_format() {
416        use super::{ImportItem, ImportRegistry, Package};
417
418        let mut import_registry = ImportRegistry::new("package".to_owned());
419        import_registry.push(
420            Package::Relative("basic_types".into()),
421            ImportItem::Aliased {
422                item_name: "A".into(),
423                alias: "A__A".into(),
424            },
425        );
426        import_registry.push(
427            Package::Intern("des".into()),
428            ImportItem::Single("deserialize".into()),
429        );
430        import_registry.push(
431            Package::Extern("ser".into()),
432            ImportItem::Single("serialize".into()),
433        );
434        import_registry.push(Package::Relative("types".into()), ImportItem::All);
435
436        let mut tokens = Tokens::new();
437        import_registry.format_into(&mut tokens);
438
439        assert_eq!(
440            tokens.to_file_string().unwrap(),
441            format!(
442                r#"from ser import serialize
443from package.des import deserialize
444from .basic_types import A as A__A
445from .types import *
446"#
447            )
448        );
449    }
450}