1use inflector::Inflector;
2use lazy_static::lazy_static;
3use linked_hash_map::LinkedHashMap;
4use std::collections::HashSet;
5use unindent::unindent;
6
7use crate::generation::serde_case::RenameRule;
8use crate::options::{ImportStyle, Options, StringTransform};
9use crate::shape::{self, Shape};
10use crate::util::{snake_case, type_case};
11
12pub struct Ctxt {
13 options: Options,
14 type_names: HashSet<String>,
15 imports: HashSet<String>,
16}
17
18pub type Ident = String;
19pub type Code = String;
20
21pub fn rust_program(name: &str, shape: &Shape, options: Options) -> Code {
22 let (type_name, defs) = rust_types(name, &shape, options);
23
24 let var_name = snake_case(&type_name);
25
26 let main = unindent(&format!(
27 r#"
28 fn main() {{
29 let {var_name} = {type_name}::default();
30 let serialized = serde_json::to_string(&{var_name}).unwrap();
31 println!("serialized = {{}}", serialized);
32 let deserialized: {type_name} = serde_json::from_str(&serialized).unwrap();
33 println!("deserialized = {{:?}}", deserialized);
34 }}
35 "#,
36 var_name = var_name,
37 type_name = type_name
38 ));
39
40 match defs {
41 Some(code) => code + "\n\n" + &main,
42 None => main,
43 }
44}
45
46pub fn rust_types(name: &str, shape: &Shape, options: Options) -> (Ident, Option<Code>) {
47 let mut ctxt = Ctxt {
48 options,
49 type_names: HashSet::new(),
50 imports: HashSet::new(),
51 };
52
53 if ctxt.options.import_style != ImportStyle::QualifiedPaths {
54 ctxt.options.derives = ctxt.options.derives
55 .clone()
56 .split(',')
57 .map(|s| import(&mut ctxt, s.trim()))
58 .collect::<Vec<_>>()
59 .join(", ");
60 };
61
62 if !matches!(shape, Shape::Struct { .. }) {
63 ctxt.type_names.insert(name.to_string());
65 }
66
67 let (ident, code) = type_from_shape(&mut ctxt, name, shape);
68 let mut code = code.unwrap_or(String::new());
69
70 if ident != name {
71 code = format!(
72 "{} type {} = {};\n\n{}",
73 ctxt.options.type_visibility, name, ident, code
74 );
75 }
76
77 if !ctxt.imports.is_empty() {
78 let mut imports: Vec<_> = ctxt.imports.drain().collect();
79 imports.sort();
80 let mut import_code = String::new();
81 for import in imports {
82 import_code += "use ";
83 import_code += &import;
84 import_code += ";\n";
85 }
86 import_code += "\n";
87 code = import_code + &code;
88 }
89
90 (name.to_string(), Some(code))
91}
92
93fn type_from_shape(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
94 use crate::shape::Shape::*;
95 match shape {
96 Null | Any | Bottom => (import(ctxt, "serde_json::Value"), None),
97 Bool => ("bool".into(), None),
98 StringT => ("String".into(), None),
99 Integer => ("i64".into(), None),
100 Floating => ("f64".into(), None),
101 Tuple(shapes, _n) => {
102 let folded = shape::fold_shapes(shapes.clone());
103 if folded == Any && shapes.iter().any(|s| s != &Any) {
104 generate_tuple_type(ctxt, path, &shapes)
105 } else {
106 generate_vec_type(ctxt, path, &folded)
107 }
108 }
109 VecT { elem_type: e } => generate_vec_type(ctxt, path, &e),
110 Struct { fields: map } => generate_struct_from_field_shapes(ctxt, path, &map),
111 MapT { val_type: v } => generate_map_type(ctxt, path, &v),
112 Opaque(t) => (t.clone(), None),
113 Optional(e) => {
114 let (inner, defs) = type_from_shape(ctxt, path, &e);
115 if ctxt.options.use_default_for_missing_fields {
116 (inner, defs)
117 } else {
118 (format!("Option<{}>", inner), defs)
119 }
120 }
121 }
122}
123
124fn generate_vec_type(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
125 let singular = path.to_singular();
126 let (inner, defs) = type_from_shape(ctxt, &singular, shape);
127 (format!("Vec<{}>", inner), defs)
128}
129
130fn generate_map_type(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
131 let singular = path.to_singular();
132 let (inner, defs) = type_from_shape(ctxt, &singular, shape);
133 (
134 format!("{}<String, {}>", import(ctxt, "std::collections::HashMap"), inner),
135 defs,
136 )
137}
138
139fn generate_tuple_type(ctxt: &mut Ctxt, path: &str, shapes: &[Shape]) -> (Ident, Option<Code>) {
140 let mut types = Vec::new();
141 let mut defs = Vec::new();
142
143 for shape in shapes {
144 let (typ, def) = type_from_shape(ctxt, path, shape);
145 types.push(typ);
146 if let Some(code) = def {
147 defs.push(code)
148 }
149 }
150
151 (format!("({})", types.join(", ")), Some(defs.join("\n\n")))
152}
153
154fn field_name(name: &str, used_names: &HashSet<String>) -> Ident {
155 type_or_field_name(name, used_names, "field", snake_case)
156}
157
158fn type_name(name: &str, used_names: &HashSet<String>) -> Ident {
159 type_or_field_name(name, used_names, "GeneratedType", type_case)
160}
161
162const RUST_KEYWORDS_ARR: &[&str] = &[
163 "abstract", "alignof", "as", "become", "box", "break", "const", "continue", "crate", "do",
164 "else", "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop",
165 "macro", "match", "mod", "move", "mut", "offsetof", "override", "priv", "proc", "pub", "pure",
166 "ref", "return", "Self", "self", "sizeof", "static", "struct", "super", "trait", "true",
167 "type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield", "async",
168 "await", "try",
169];
170
171lazy_static! {
172 static ref RUST_KEYWORDS: HashSet<&'static str> = RUST_KEYWORDS_ARR.iter().cloned().collect();
173}
174
175fn type_or_field_name(
176 name: &str,
177 used_names: &HashSet<String>,
178 default_name: &str,
179 case_fn: fn(&str) -> String,
180) -> Ident {
181 let name = name.trim();
182 let mut output_name = case_fn(name);
183 if RUST_KEYWORDS.contains::<str>(&output_name) {
184 output_name.push_str("_field");
185 }
186 if output_name == "" {
187 output_name.push_str(default_name);
188 }
189 if let Some(c) = output_name.chars().next() {
190 if c.is_ascii() && c.is_numeric() {
191 output_name = String::from("n") + &output_name;
192 }
193 }
194 if !used_names.contains(&output_name) {
195 return output_name;
196 }
197 for n in 2.. {
198 let temp = format!("{}{}", output_name, n);
199 if !used_names.contains(&temp) {
200 return temp;
201 }
202 }
203 unreachable!()
204}
205
206fn collapse_option_vec<'a>(ctxt: &mut Ctxt, typ: &'a Shape) -> (bool, &'a Shape) {
207 if !(ctxt.options.allow_option_vec || ctxt.options.use_default_for_missing_fields) {
208 if let Shape::Optional(inner) = typ {
209 if let Shape::VecT { .. } = **inner {
210 return (true, &**inner);
211 }
212 }
213 }
214 (false, typ)
215}
216
217fn import(ctxt: &mut Ctxt, qualified: &str) -> String {
218 if !qualified.contains("::") {
219 return qualified.into()
220 }
221 match ctxt.options.import_style {
222 ImportStyle::AddImports => {
223 ctxt.imports.insert(qualified.into());
224 qualified.rsplit("::").next().unwrap().into()
225 }
226 ImportStyle::AssumeExisting => qualified.rsplit("::").next().unwrap().into(),
227 ImportStyle::QualifiedPaths => qualified.into(),
228 }
229}
230
231fn generate_struct_from_field_shapes(
232 ctxt: &mut Ctxt,
233 path: &str,
234 map: &LinkedHashMap<String, Shape>,
235) -> (Ident, Option<Code>) {
236 let type_name = type_name(path, &ctxt.type_names);
237 ctxt.type_names.insert(type_name.clone());
238 let visibility = ctxt.options.type_visibility.clone();
239 let field_visibility = match ctxt.options.field_visibility {
240 None => visibility.clone(),
241 Some(ref v) => v.clone(),
242 };
243
244 let mut field_names = HashSet::new();
245 let mut defs = Vec::new();
246
247 let fields: Vec<Code> = map
248 .iter()
249 .map(|(name, typ)| {
250 let field_name = field_name(name, &field_names);
251 field_names.insert(field_name.clone());
252
253 let needs_rename = if let Some(ref transform) = ctxt.options.property_name_format {
254 &to_rename_rule(transform).apply_to_field(&field_name) != name
255 } else {
256 &field_name != name
257 };
258 let mut field_code = String::new();
259 if needs_rename {
260 field_code += &format!(" #[serde(rename = \"{}\")]\n", name)
261 }
262
263 let (is_collapsed, collapsed) = collapse_option_vec(ctxt, typ);
264 if is_collapsed {
265 field_code += " #[serde(default)]\n";
266 }
267
268 let (field_type, child_defs) = type_from_shape(ctxt, name, collapsed);
269
270 if let Some(code) = child_defs {
271 defs.push(code);
272 }
273
274 field_code += " ";
275 if field_visibility != "" {
276 field_code += &field_visibility;
277 field_code += " ";
278 }
279
280 format!("{}{}: {},", field_code, field_name, field_type)
281 })
282 .collect();
283
284 let mut code = format!("#[derive({})]\n", ctxt.options.derives);
285
286 if ctxt.options.deny_unknown_fields {
287 code += "#[serde(deny_unknown_fields)]\n";
288 }
289
290 if ctxt.options.use_default_for_missing_fields {
291 code += "#[serde(default)]\n";
292 }
293
294 if let Some(ref transform) = ctxt.options.property_name_format {
295 if *transform != StringTransform::SnakeCase {
296 code += &format!("#[serde(rename_all = \"{}\")]\n", serde_name(transform))
297 }
298 }
299
300 if visibility != "" {
301 code += &visibility;
302 code += " ";
303 }
304
305 code += &format!("struct {} {{\n", type_name);
306
307 if !fields.is_empty() {
308 code += &fields.join("\n");
309 code += "\n";
310 }
311 if ctxt.options.collect_additional {
312 code += &format!(
313 " #[serde(flatten)]\n additional_fields: {}<String, {}>,\n",
314 import(ctxt, "std::collections::HashMap"),
315 import(ctxt, "serde_json::Value"),
316 )
317 }
318 code += "}";
319
320 if !defs.is_empty() {
321 code += "\n\n";
322 code += &defs.join("\n\n");
323 }
324
325 (type_name, Some(code))
326}
327
328fn to_rename_rule(transform: &StringTransform) -> RenameRule {
329 match transform {
330 StringTransform::LowerCase => RenameRule::LowerCase,
331 StringTransform::UpperCase => RenameRule::UPPERCASE,
332 StringTransform::PascalCase => RenameRule::PascalCase,
333 StringTransform::CamelCase => RenameRule::CamelCase,
334 StringTransform::SnakeCase => RenameRule::SnakeCase,
335 StringTransform::ScreamingSnakeCase => RenameRule::ScreamingSnakeCase,
336 StringTransform::KebabCase => RenameRule::KebabCase,
337 StringTransform::ScreamingKebabCase => RenameRule::ScreamingKebabCase,
338 }
339}
340
341fn serde_name(transform: &StringTransform) -> &'static str {
342 match transform {
343 StringTransform::LowerCase => "lowercase",
344 StringTransform::UpperCase => "UPPERCASE",
345 StringTransform::PascalCase => "PascalCase",
346 StringTransform::CamelCase => "camelCase",
347 StringTransform::SnakeCase => "snake_case",
348 StringTransform::ScreamingSnakeCase => "SCREAMING_SNAKE_CASE",
349 StringTransform::KebabCase => "kebab-case",
350 StringTransform::ScreamingKebabCase => "SCREAMING-KEBAB-CASE",
351 }
352}