jsonto/generation/
python.rs

1use linked_hash_map::LinkedHashMap;
2use std::collections::HashSet;
3
4use crate::options::{ImportStyle, Options, StringTransform};
5use crate::shape::{self, Shape};
6use crate::word_case::to_singular;
7use crate::word_case::{kebab_case, lower_camel_case, snake_case, type_case};
8
9#[derive(PartialEq, PartialOrd, Ord, Eq, Hash, Clone, Copy)]
10enum Import {
11    Any,
12    Optional,
13    BaseModel,
14    Field,
15}
16
17impl Import {
18    fn pair(&self) -> (&'static str, &'static str) {
19        match self {
20            Import::Any => ("typing", "Any"),
21            Import::Optional => ("typing", "Optional"),
22            Import::BaseModel => ("pydantic", "BaseModel"),
23            Import::Field => ("pydantic", "Field"),
24        }
25    }
26    fn module(&self) -> &'static str {
27        self.pair().0
28    }
29    fn identifier(&self) -> &'static str {
30        self.pair().1
31    }
32    fn qualified(&self) -> String {
33        let (module, identifier) = self.pair();
34        format!("{}.{}", module, identifier)
35    }
36}
37
38struct Ctxt {
39    options: Options,
40    type_names: HashSet<String>,
41    imports: HashSet<Import>,
42    created_classes: Vec<(Shape, Ident)>,
43}
44
45pub type Ident = String;
46pub type Code = String;
47
48pub fn to(name: &str, shape: &Shape, options: Options) -> Code {
49    let mut ctxt = Ctxt {
50        options,
51        type_names: HashSet::new(),
52        imports: HashSet::new(),
53        created_classes: Vec::new(),
54    };
55
56    if !matches!(shape, Shape::Struct { .. }) {
57        // reserve the requested name
58        ctxt.type_names.insert(name.to_string());
59    }
60
61    let (ident, code) = type_from_shape(&mut ctxt, name, shape);
62    let mut code = code.unwrap_or_default();
63
64    if !ctxt.imports.is_empty() {
65        let mut imports: Vec<_> = ctxt.imports.drain().collect();
66        imports.sort();
67        let mut import_code = String::new();
68        match ctxt.options.import_style {
69            ImportStyle::AssumeExisting => {}
70            ImportStyle::AddImports => {
71                for import in imports {
72                    let (module, identifier) = import.pair();
73                    import_code += &format!("from {} import {}\n", module, identifier);
74                }
75            }
76            ImportStyle::QualifiedPaths => {
77                let mut seen = HashSet::new();
78                for import in imports {
79                    let module = import.module();
80                    if seen.insert(module) {
81                        import_code += &format!("import {}\n", module);
82                    }
83                }
84            }
85        }
86        if !import_code.is_empty(){
87            import_code += "\n\n";
88            code = import_code + &code;
89        }
90    }
91
92    if ident != name {
93        if !code.is_empty() {
94            code += "\n\n";
95        }
96        code += &format!("{} = {}", name, ident);
97    }
98    code
99}
100
101fn type_from_shape(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
102    use crate::shape::Shape::*;
103    match shape {
104        Null | Any | Bottom => (import(ctxt, Import::Any), None),
105        Bool => ("bool".into(), None),
106        StringT => ("str".into(), None),
107        Integer => ("int".into(), None),
108        Floating => ("float".into(), None),
109        Tuple(shapes, _n) => {
110            let folded = shape::fold_shapes(shapes.clone());
111            if folded == Any && shapes.iter().any(|s| s != &Any) {
112                generate_tuple_type(ctxt, path, shapes)
113            } else {
114                generate_vec_type(ctxt, path, &folded)
115            }
116        }
117        VecT { elem_type: e } => generate_vec_type(ctxt, path, e),
118        Struct { fields } => generate_data_class(ctxt, path, fields, shape),
119        MapT { val_type: v } => generate_map_type(ctxt, path, v),
120        Opaque(t) => (t.clone(), None),
121        Optional(e) => {
122            let (inner, defs) = type_from_shape(ctxt, path, e);
123            if ctxt.options.use_default_for_missing_fields {
124                (inner, defs)
125            } else {
126                let optional = import(ctxt, Import::Optional);
127                (format!("{}[{}]", optional, inner), defs)
128            }
129        }
130    }
131}
132
133fn generate_vec_type(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
134    let singular = to_singular(path);
135    let (inner, defs) = type_from_shape(ctxt, &singular, shape);
136    (format!("list[{}]", inner), defs)
137}
138
139fn generate_map_type(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
140    let singular = to_singular(path);
141    let (inner, defs) = type_from_shape(ctxt, &singular, shape);
142    (format!("dict[str, {}]", inner), defs)
143}
144
145fn generate_tuple_type(ctxt: &mut Ctxt, path: &str, shapes: &[Shape]) -> (Ident, Option<Code>) {
146    let mut types = Vec::new();
147    let mut defs = Vec::new();
148
149    for shape in shapes {
150        let (typ, def) = type_from_shape(ctxt, path, shape);
151        types.push(typ);
152        if let Some(code) = def {
153            if !code.is_empty() {
154                defs.push(code)
155            }
156        }
157    }
158
159    (
160        format!("tuple[{}]", types.join(", ")),
161        Some(defs.join("\n\n")),
162    )
163}
164
165fn field_name(name: &str, used_names: &HashSet<String>) -> Ident {
166    type_or_field_name(name, used_names, "field", snake_case)
167}
168
169fn type_name(name: &str, used_names: &HashSet<String>) -> Ident {
170    type_or_field_name(name, used_names, "GeneratedType", type_case)
171}
172
173// https://docs.python.org/3/reference/lexical_analysis.html#keywords
174#[rustfmt::skip]
175const PYTHON_KEYWORDS: &[&str] = &[
176    "False", "None", "True",
177    "and", "as", "assert", "async", "await", "break", "class", "continue",
178    "def", "del", "elif", "else", "except", "finally", "for", "from", "global",
179    "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass",
180    "raise", "return", "try", "while", "with", "yield",
181];
182
183fn type_or_field_name(
184    name: &str,
185    used_names: &HashSet<String>,
186    default_name: &str,
187    case_fn: fn(&str) -> String,
188) -> Ident {
189    let name = name.trim();
190    let mut output_name = case_fn(name);
191    if PYTHON_KEYWORDS.contains(&&*output_name) {
192        output_name.push_str("_field");
193    }
194    if output_name.is_empty() {
195        output_name.push_str(default_name);
196    }
197    if let Some(c) = output_name.chars().next() {
198        if c.is_ascii_digit() {
199            output_name = String::from("n") + &output_name;
200        }
201    }
202    if !used_names.contains(&output_name) {
203        return output_name;
204    }
205    for n in 2.. {
206        let temp = format!("{}{}", output_name, n);
207        if !used_names.contains(&temp) {
208            return temp;
209        }
210    }
211    unreachable!()
212}
213
214fn import(ctxt: &mut Ctxt, import: Import) -> String {
215    ctxt.imports.insert(import);
216    match ctxt.options.import_style {
217        ImportStyle::QualifiedPaths => import.qualified(),
218        _ => import.identifier().into(),
219    }
220}
221
222fn generate_data_class(
223    ctxt: &mut Ctxt,
224    path: &str,
225    field_shapes: &LinkedHashMap<String, Shape>,
226    containing_shape: &Shape,
227) -> (Ident, Option<Code>) {
228    for (created_for_shape, ident) in ctxt.created_classes.iter() {
229        if created_for_shape.is_acceptable_substitution_for(containing_shape) {
230            return (ident.into(), None);
231        }
232    }
233
234    let type_name = type_name(path, &ctxt.type_names);
235    ctxt.type_names.insert(type_name.clone());
236    ctxt.created_classes
237        .push((containing_shape.clone(), type_name.clone()));
238
239    let mut field_names = HashSet::new();
240    let mut defs = Vec::new();
241
242    let fields: Vec<Code> = field_shapes
243        .iter()
244        .map(|(name, typ)| {
245            let field_name = field_name(name, &field_names);
246            field_names.insert(field_name.clone());
247
248            let (field_type, child_defs) = type_from_shape(ctxt, name, typ);
249
250            if let Some(code) = child_defs {
251                if !code.is_empty() {
252                    defs.push(code);
253                }
254            }
255
256            let mut field_code = String::new();
257            let transformed = apply_transform(ctxt, &field_name, name);
258            if transformed != field_name {
259                field_code += &format!(" = {}(alias=\"{}\")", import(ctxt, Import::Field), transformed)
260            }
261
262            format!("    {}: {}{}", field_name, field_type, field_code)
263        })
264        .collect();
265
266    let mut code = String::new();
267
268    code += &format!(
269        "class {}({}):\n",
270        type_name,
271        import(ctxt, Import::BaseModel)
272    );
273
274    if fields.is_empty() {
275        code += "    pass\n";
276    } else {
277        code += &fields.join("\n");
278        code += "\n";
279    }
280
281    if !defs.is_empty() {
282        let mut d = defs.join("\n\n");
283        d += "\n\n";
284        d += &code;
285        code = d;
286    }
287
288    (type_name, Some(code))
289}
290
291fn apply_transform(ctxt: &Ctxt, field_name: &str, name: &str) -> String {
292    match ctxt.options.property_name_format {
293        Some(StringTransform::LowerCase) => field_name.to_ascii_lowercase(),
294        Some(StringTransform::PascalCase) => type_case(field_name),
295        Some(StringTransform::SnakeCase) => snake_case(field_name),
296        Some(StringTransform::KebabCase) => kebab_case(field_name),
297        Some(StringTransform::UpperCase) => field_name.to_ascii_uppercase(),
298        Some(StringTransform::CamelCase) => lower_camel_case(field_name),
299        Some(StringTransform::ScreamingSnakeCase) => snake_case(field_name).to_ascii_uppercase(),
300        Some(StringTransform::ScreamingKebabCase) => kebab_case(field_name).to_ascii_uppercase(),
301        None => name.to_string(),
302    }
303}
304
305#[cfg(test)]
306mod python_codegen_tests {
307    use super::*;
308
309    #[test]
310    fn field_names_test() {
311        fn field_name_test(from: &str, to: &str) {
312            assert_eq!(
313                field_name(from, &HashSet::new()),
314                to.to_string(),
315                r#"From "{}" to "{}""#,
316                from,
317                to
318            );
319        }
320
321        field_name_test("valid", "valid");
322        field_name_test("1", "n1");
323        field_name_test("+1", "n1");
324        field_name_test("", "field");
325        field_name_test("def", "def_field");
326    }
327}