json_typegen_shared/generation/
rust.rs

1use inflector::Inflector;
2use lazy_static::lazy_static;
3use linked_hash_map::LinkedHashMap;
4use std::collections::HashSet;
5use unindent::unindent;
6
7use crate::generation::serde_case::RenameRule;
8use crate::options::{ImportStyle, Options, StringTransform};
9use crate::shape::{self, Shape};
10use crate::util::{snake_case, type_case};
11
12pub struct Ctxt {
13    options: Options,
14    type_names: HashSet<String>,
15    imports: HashSet<String>,
16}
17
18pub type Ident = String;
19pub type Code = String;
20
21pub fn rust_program(name: &str, shape: &Shape, options: Options) -> Code {
22    let (type_name, defs) = rust_types(name, &shape, options);
23
24    let var_name = snake_case(&type_name);
25
26    let main = unindent(&format!(
27        r#"
28        fn main() {{
29            let {var_name} = {type_name}::default();
30            let serialized = serde_json::to_string(&{var_name}).unwrap();
31            println!("serialized = {{}}", serialized);
32            let deserialized: {type_name} = serde_json::from_str(&serialized).unwrap();
33            println!("deserialized = {{:?}}", deserialized);
34        }}
35        "#,
36        var_name = var_name,
37        type_name = type_name
38    ));
39
40    match defs {
41        Some(code) => code + "\n\n" + &main,
42        None => main,
43    }
44}
45
46pub fn rust_types(name: &str, shape: &Shape, options: Options) -> (Ident, Option<Code>) {
47    let mut ctxt = Ctxt {
48        options,
49        type_names: HashSet::new(),
50        imports: HashSet::new(),
51    };
52
53    if ctxt.options.import_style != ImportStyle::QualifiedPaths {
54        ctxt.options.derives = ctxt.options.derives
55            .clone()
56            .split(',')
57            .map(|s| import(&mut ctxt, s.trim()))
58            .collect::<Vec<_>>()
59            .join(", ");
60    };
61
62    if !matches!(shape, Shape::Struct { .. }) {
63        // reserve the requested name
64        ctxt.type_names.insert(name.to_string());
65    }
66
67    let (ident, code) = type_from_shape(&mut ctxt, name, shape);
68    let mut code = code.unwrap_or(String::new());
69
70    if ident != name {
71        code = format!(
72            "{} type {} = {};\n\n{}",
73            ctxt.options.type_visibility, name, ident, code
74        );
75    }
76
77    if !ctxt.imports.is_empty() {
78        let mut imports: Vec<_> = ctxt.imports.drain().collect();
79        imports.sort();
80        let mut import_code = String::new();
81        for import in imports {
82            import_code += "use ";
83            import_code += &import;
84            import_code += ";\n";
85        }
86        import_code += "\n";
87        code = import_code + &code;
88    }
89
90    (name.to_string(), Some(code))
91}
92
93fn type_from_shape(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
94    use crate::shape::Shape::*;
95    match shape {
96        Null | Any | Bottom => (import(ctxt, "serde_json::Value"), None),
97        Bool => ("bool".into(), None),
98        StringT => ("String".into(), None),
99        Integer => ("i64".into(), None),
100        Floating => ("f64".into(), None),
101        Tuple(shapes, _n) => {
102            let folded = shape::fold_shapes(shapes.clone());
103            if folded == Any && shapes.iter().any(|s| s != &Any) {
104                generate_tuple_type(ctxt, path, &shapes)
105            } else {
106                generate_vec_type(ctxt, path, &folded)
107            }
108        }
109        VecT { elem_type: e } => generate_vec_type(ctxt, path, &e),
110        Struct { fields: map } => generate_struct_from_field_shapes(ctxt, path, &map),
111        MapT { val_type: v } => generate_map_type(ctxt, path, &v),
112        Opaque(t) => (t.clone(), None),
113        Optional(e) => {
114            let (inner, defs) = type_from_shape(ctxt, path, &e);
115            if ctxt.options.use_default_for_missing_fields {
116                (inner, defs)
117            } else {
118                (format!("Option<{}>", inner), defs)
119            }
120        }
121    }
122}
123
124fn generate_vec_type(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
125    let singular = path.to_singular();
126    let (inner, defs) = type_from_shape(ctxt, &singular, shape);
127    (format!("Vec<{}>", inner), defs)
128}
129
130fn generate_map_type(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
131    let singular = path.to_singular();
132    let (inner, defs) = type_from_shape(ctxt, &singular, shape);
133    (
134        format!("{}<String, {}>", import(ctxt, "std::collections::HashMap"), inner),
135        defs,
136    )
137}
138
139fn generate_tuple_type(ctxt: &mut Ctxt, path: &str, shapes: &[Shape]) -> (Ident, Option<Code>) {
140    let mut types = Vec::new();
141    let mut defs = Vec::new();
142
143    for shape in shapes {
144        let (typ, def) = type_from_shape(ctxt, path, shape);
145        types.push(typ);
146        if let Some(code) = def {
147            defs.push(code)
148        }
149    }
150
151    (format!("({})", types.join(", ")), Some(defs.join("\n\n")))
152}
153
154fn field_name(name: &str, used_names: &HashSet<String>) -> Ident {
155    type_or_field_name(name, used_names, "field", snake_case)
156}
157
158fn type_name(name: &str, used_names: &HashSet<String>) -> Ident {
159    type_or_field_name(name, used_names, "GeneratedType", type_case)
160}
161
162const RUST_KEYWORDS_ARR: &[&str] = &[
163    "abstract", "alignof", "as", "become", "box", "break", "const", "continue", "crate", "do",
164    "else", "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop",
165    "macro", "match", "mod", "move", "mut", "offsetof", "override", "priv", "proc", "pub", "pure",
166    "ref", "return", "Self", "self", "sizeof", "static", "struct", "super", "trait", "true",
167    "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield", "async",
168    "await", "try",
169];
170
171lazy_static! {
172    static ref RUST_KEYWORDS: HashSet<&'static str> = RUST_KEYWORDS_ARR.iter().cloned().collect();
173}
174
175fn type_or_field_name(
176    name: &str,
177    used_names: &HashSet<String>,
178    default_name: &str,
179    case_fn: fn(&str) -> String,
180) -> Ident {
181    let name = name.trim();
182    let mut output_name = case_fn(name);
183    if RUST_KEYWORDS.contains::<str>(&output_name) {
184        output_name.push_str("_field");
185    }
186    if output_name == "" {
187        output_name.push_str(default_name);
188    }
189    if let Some(c) = output_name.chars().next() {
190        if c.is_ascii() && c.is_numeric() {
191            output_name = String::from("n") + &output_name;
192        }
193    }
194    if !used_names.contains(&output_name) {
195        return output_name;
196    }
197    for n in 2.. {
198        let temp = format!("{}{}", output_name, n);
199        if !used_names.contains(&temp) {
200            return temp;
201        }
202    }
203    unreachable!()
204}
205
206fn collapse_option_vec<'a>(ctxt: &mut Ctxt, typ: &'a Shape) -> (bool, &'a Shape) {
207    if !(ctxt.options.allow_option_vec || ctxt.options.use_default_for_missing_fields) {
208        if let Shape::Optional(inner) = typ {
209            if let Shape::VecT { .. } = **inner {
210                return (true, &**inner);
211            }
212        }
213    }
214    (false, typ)
215}
216
217fn import(ctxt: &mut Ctxt, qualified: &str) -> String {
218    if !qualified.contains("::") {
219        return qualified.into()
220    }
221    match ctxt.options.import_style {
222        ImportStyle::AddImports => {
223            ctxt.imports.insert(qualified.into());
224            qualified.rsplit("::").next().unwrap().into()
225        }
226        ImportStyle::AssumeExisting => qualified.rsplit("::").next().unwrap().into(),
227        ImportStyle::QualifiedPaths => qualified.into(),
228    }
229}
230
231fn generate_struct_from_field_shapes(
232    ctxt: &mut Ctxt,
233    path: &str,
234    map: &LinkedHashMap<String, Shape>,
235) -> (Ident, Option<Code>) {
236    let type_name = type_name(path, &ctxt.type_names);
237    ctxt.type_names.insert(type_name.clone());
238    let visibility = ctxt.options.type_visibility.clone();
239    let field_visibility = match ctxt.options.field_visibility {
240        None => visibility.clone(),
241        Some(ref v) => v.clone(),
242    };
243
244    let mut field_names = HashSet::new();
245    let mut defs = Vec::new();
246
247    let fields: Vec<Code> = map
248        .iter()
249        .map(|(name, typ)| {
250            let field_name = field_name(name, &field_names);
251            field_names.insert(field_name.clone());
252
253            let needs_rename = if let Some(ref transform) = ctxt.options.property_name_format {
254                &to_rename_rule(transform).apply_to_field(&field_name) != name
255            } else {
256                &field_name != name
257            };
258            let mut field_code = String::new();
259            if needs_rename {
260                field_code += &format!("    #[serde(rename = \"{}\")]\n", name)
261            }
262
263            let (is_collapsed, collapsed) = collapse_option_vec(ctxt, typ);
264            if is_collapsed {
265                field_code += "    #[serde(default)]\n";
266            }
267
268            let (field_type, child_defs) = type_from_shape(ctxt, name, collapsed);
269
270            if let Some(code) = child_defs {
271                defs.push(code);
272            }
273
274            field_code += "    ";
275            if field_visibility != "" {
276                field_code += &field_visibility;
277                field_code += " ";
278            }
279
280            format!("{}{}: {},", field_code, field_name, field_type)
281        })
282        .collect();
283
284    let mut code = format!("#[derive({})]\n", ctxt.options.derives);
285
286    if ctxt.options.deny_unknown_fields {
287        code += "#[serde(deny_unknown_fields)]\n";
288    }
289
290    if ctxt.options.use_default_for_missing_fields {
291        code += "#[serde(default)]\n";
292    }
293
294    if let Some(ref transform) = ctxt.options.property_name_format {
295        if *transform != StringTransform::SnakeCase {
296            code += &format!("#[serde(rename_all = \"{}\")]\n", serde_name(transform))
297        }
298    }
299
300    if visibility != "" {
301        code += &visibility;
302        code += " ";
303    }
304
305    code += &format!("struct {} {{\n", type_name);
306
307    if !fields.is_empty() {
308        code += &fields.join("\n");
309        code += "\n";
310    }
311    if ctxt.options.collect_additional {
312        code += &format!(
313            "    #[serde(flatten)]\n    additional_fields: {}<String, {}>,\n",
314            import(ctxt, "std::collections::HashMap"),
315            import(ctxt, "serde_json::Value"),
316        )
317    }
318    code += "}";
319
320    if !defs.is_empty() {
321        code += "\n\n";
322        code += &defs.join("\n\n");
323    }
324
325    (type_name, Some(code))
326}
327
328fn to_rename_rule(transform: &StringTransform) -> RenameRule {
329    match transform {
330        StringTransform::LowerCase => RenameRule::LowerCase,
331        StringTransform::UpperCase => RenameRule::UPPERCASE,
332        StringTransform::PascalCase => RenameRule::PascalCase,
333        StringTransform::CamelCase => RenameRule::CamelCase,
334        StringTransform::SnakeCase => RenameRule::SnakeCase,
335        StringTransform::ScreamingSnakeCase => RenameRule::ScreamingSnakeCase,
336        StringTransform::KebabCase => RenameRule::KebabCase,
337        StringTransform::ScreamingKebabCase => RenameRule::ScreamingKebabCase,
338    }
339}
340
341fn serde_name(transform: &StringTransform) -> &'static str {
342    match transform {
343        StringTransform::LowerCase => "lowercase",
344        StringTransform::UpperCase => "UPPERCASE",
345        StringTransform::PascalCase => "PascalCase",
346        StringTransform::CamelCase => "camelCase",
347        StringTransform::SnakeCase => "snake_case",
348        StringTransform::ScreamingSnakeCase => "SCREAMING_SNAKE_CASE",
349        StringTransform::KebabCase => "kebab-case",
350        StringTransform::ScreamingKebabCase => "SCREAMING-KEBAB-CASE",
351    }
352}