1use linked_hash_map::LinkedHashMap;
2use std::collections::HashSet;
3
4use crate::options::{ImportStyle, Options, StringTransform};
5use crate::shape::{self, Shape};
6use crate::word_case::to_singular;
7use crate::word_case::{kebab_case, lower_camel_case, snake_case, type_case};
8
9#[derive(PartialEq, PartialOrd, Ord, Eq, Hash, Clone, Copy)]
10enum Import {
11 Any,
12 Optional,
13 BaseModel,
14 Field,
15}
16
17impl Import {
18 fn pair(&self) -> (&'static str, &'static str) {
19 match self {
20 Import::Any => ("typing", "Any"),
21 Import::Optional => ("typing", "Optional"),
22 Import::BaseModel => ("pydantic", "BaseModel"),
23 Import::Field => ("pydantic", "Field"),
24 }
25 }
26 fn module(&self) -> &'static str {
27 self.pair().0
28 }
29 fn identifier(&self) -> &'static str {
30 self.pair().1
31 }
32 fn qualified(&self) -> String {
33 let (module, identifier) = self.pair();
34 format!("{}.{}", module, identifier)
35 }
36}
37
38struct Ctxt {
39 options: Options,
40 type_names: HashSet<String>,
41 imports: HashSet<Import>,
42 created_classes: Vec<(Shape, Ident)>,
43}
44
45pub type Ident = String;
46pub type Code = String;
47
48pub fn to(name: &str, shape: &Shape, options: Options) -> Code {
49 let mut ctxt = Ctxt {
50 options,
51 type_names: HashSet::new(),
52 imports: HashSet::new(),
53 created_classes: Vec::new(),
54 };
55
56 if !matches!(shape, Shape::Struct { .. }) {
57 ctxt.type_names.insert(name.to_string());
59 }
60
61 let (ident, code) = type_from_shape(&mut ctxt, name, shape);
62 let mut code = code.unwrap_or_default();
63
64 if !ctxt.imports.is_empty() {
65 let mut imports: Vec<_> = ctxt.imports.drain().collect();
66 imports.sort();
67 let mut import_code = String::new();
68 match ctxt.options.import_style {
69 ImportStyle::AssumeExisting => {}
70 ImportStyle::AddImports => {
71 for import in imports {
72 let (module, identifier) = import.pair();
73 import_code += &format!("from {} import {}\n", module, identifier);
74 }
75 }
76 ImportStyle::QualifiedPaths => {
77 let mut seen = HashSet::new();
78 for import in imports {
79 let module = import.module();
80 if seen.insert(module) {
81 import_code += &format!("import {}\n", module);
82 }
83 }
84 }
85 }
86 if !import_code.is_empty(){
87 import_code += "\n\n";
88 code = import_code + &code;
89 }
90 }
91
92 if ident != name {
93 if !code.is_empty() {
94 code += "\n\n";
95 }
96 code += &format!("{} = {}", name, ident);
97 }
98 code
99}
100
101fn type_from_shape(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
102 use crate::shape::Shape::*;
103 match shape {
104 Null | Any | Bottom => (import(ctxt, Import::Any), None),
105 Bool => ("bool".into(), None),
106 StringT => ("str".into(), None),
107 Integer => ("int".into(), None),
108 Floating => ("float".into(), None),
109 Tuple(shapes, _n) => {
110 let folded = shape::fold_shapes(shapes.clone());
111 if folded == Any && shapes.iter().any(|s| s != &Any) {
112 generate_tuple_type(ctxt, path, shapes)
113 } else {
114 generate_vec_type(ctxt, path, &folded)
115 }
116 }
117 VecT { elem_type: e } => generate_vec_type(ctxt, path, e),
118 Struct { fields } => generate_data_class(ctxt, path, fields, shape),
119 MapT { val_type: v } => generate_map_type(ctxt, path, v),
120 Opaque(t) => (t.clone(), None),
121 Optional(e) => {
122 let (inner, defs) = type_from_shape(ctxt, path, e);
123 if ctxt.options.use_default_for_missing_fields {
124 (inner, defs)
125 } else {
126 let optional = import(ctxt, Import::Optional);
127 (format!("{}[{}]", optional, inner), defs)
128 }
129 }
130 }
131}
132
133fn generate_vec_type(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
134 let singular = to_singular(path);
135 let (inner, defs) = type_from_shape(ctxt, &singular, shape);
136 (format!("list[{}]", inner), defs)
137}
138
139fn generate_map_type(ctxt: &mut Ctxt, path: &str, shape: &Shape) -> (Ident, Option<Code>) {
140 let singular = to_singular(path);
141 let (inner, defs) = type_from_shape(ctxt, &singular, shape);
142 (format!("dict[str, {}]", inner), defs)
143}
144
145fn generate_tuple_type(ctxt: &mut Ctxt, path: &str, shapes: &[Shape]) -> (Ident, Option<Code>) {
146 let mut types = Vec::new();
147 let mut defs = Vec::new();
148
149 for shape in shapes {
150 let (typ, def) = type_from_shape(ctxt, path, shape);
151 types.push(typ);
152 if let Some(code) = def {
153 if !code.is_empty() {
154 defs.push(code)
155 }
156 }
157 }
158
159 (
160 format!("tuple[{}]", types.join(", ")),
161 Some(defs.join("\n\n")),
162 )
163}
164
165fn field_name(name: &str, used_names: &HashSet<String>) -> Ident {
166 type_or_field_name(name, used_names, "field", snake_case)
167}
168
169fn type_name(name: &str, used_names: &HashSet<String>) -> Ident {
170 type_or_field_name(name, used_names, "GeneratedType", type_case)
171}
172
173#[rustfmt::skip]
175const PYTHON_KEYWORDS: &[&str] = &[
176 "False", "None", "True",
177 "and", "as", "assert", "async", "await", "break", "class", "continue",
178 "def", "del", "elif", "else", "except", "finally", "for", "from", "global",
179 "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass",
180 "raise", "return", "try", "while", "with", "yield",
181];
182
183fn type_or_field_name(
184 name: &str,
185 used_names: &HashSet<String>,
186 default_name: &str,
187 case_fn: fn(&str) -> String,
188) -> Ident {
189 let name = name.trim();
190 let mut output_name = case_fn(name);
191 if PYTHON_KEYWORDS.contains(&&*output_name) {
192 output_name.push_str("_field");
193 }
194 if output_name.is_empty() {
195 output_name.push_str(default_name);
196 }
197 if let Some(c) = output_name.chars().next() {
198 if c.is_ascii_digit() {
199 output_name = String::from("n") + &output_name;
200 }
201 }
202 if !used_names.contains(&output_name) {
203 return output_name;
204 }
205 for n in 2.. {
206 let temp = format!("{}{}", output_name, n);
207 if !used_names.contains(&temp) {
208 return temp;
209 }
210 }
211 unreachable!()
212}
213
214fn import(ctxt: &mut Ctxt, import: Import) -> String {
215 ctxt.imports.insert(import);
216 match ctxt.options.import_style {
217 ImportStyle::QualifiedPaths => import.qualified(),
218 _ => import.identifier().into(),
219 }
220}
221
222fn generate_data_class(
223 ctxt: &mut Ctxt,
224 path: &str,
225 field_shapes: &LinkedHashMap<String, Shape>,
226 containing_shape: &Shape,
227) -> (Ident, Option<Code>) {
228 for (created_for_shape, ident) in ctxt.created_classes.iter() {
229 if created_for_shape.is_acceptable_substitution_for(containing_shape) {
230 return (ident.into(), None);
231 }
232 }
233
234 let type_name = type_name(path, &ctxt.type_names);
235 ctxt.type_names.insert(type_name.clone());
236 ctxt.created_classes
237 .push((containing_shape.clone(), type_name.clone()));
238
239 let mut field_names = HashSet::new();
240 let mut defs = Vec::new();
241
242 let fields: Vec<Code> = field_shapes
243 .iter()
244 .map(|(name, typ)| {
245 let field_name = field_name(name, &field_names);
246 field_names.insert(field_name.clone());
247
248 let (field_type, child_defs) = type_from_shape(ctxt, name, typ);
249
250 if let Some(code) = child_defs {
251 if !code.is_empty() {
252 defs.push(code);
253 }
254 }
255
256 let mut field_code = String::new();
257 let transformed = apply_transform(ctxt, &field_name, name);
258 if transformed != field_name {
259 field_code += &format!(" = {}(alias=\"{}\")", import(ctxt, Import::Field), transformed)
260 }
261
262 format!(" {}: {}{}", field_name, field_type, field_code)
263 })
264 .collect();
265
266 let mut code = String::new();
267
268 code += &format!(
269 "class {}({}):\n",
270 type_name,
271 import(ctxt, Import::BaseModel)
272 );
273
274 if fields.is_empty() {
275 code += " pass\n";
276 } else {
277 code += &fields.join("\n");
278 code += "\n";
279 }
280
281 if !defs.is_empty() {
282 let mut d = defs.join("\n\n");
283 d += "\n\n";
284 d += &code;
285 code = d;
286 }
287
288 (type_name, Some(code))
289}
290
291fn apply_transform(ctxt: &Ctxt, field_name: &str, name: &str) -> String {
292 match ctxt.options.property_name_format {
293 Some(StringTransform::LowerCase) => field_name.to_ascii_lowercase(),
294 Some(StringTransform::PascalCase) => type_case(field_name),
295 Some(StringTransform::SnakeCase) => snake_case(field_name),
296 Some(StringTransform::KebabCase) => kebab_case(field_name),
297 Some(StringTransform::UpperCase) => field_name.to_ascii_uppercase(),
298 Some(StringTransform::CamelCase) => lower_camel_case(field_name),
299 Some(StringTransform::ScreamingSnakeCase) => snake_case(field_name).to_ascii_uppercase(),
300 Some(StringTransform::ScreamingKebabCase) => kebab_case(field_name).to_ascii_uppercase(),
301 None => name.to_string(),
302 }
303}
304
305#[cfg(test)]
306mod python_codegen_tests {
307 use super::*;
308
309 #[test]
310 fn field_names_test() {
311 fn field_name_test(from: &str, to: &str) {
312 assert_eq!(
313 field_name(from, &HashSet::new()),
314 to.to_string(),
315 r#"From "{}" to "{}""#,
316 from,
317 to
318 );
319 }
320
321 field_name_test("valid", "valid");
322 field_name_test("1", "n1");
323 field_name_test("+1", "n1");
324 field_name_test("", "field");
325 field_name_test("def", "def_field");
326 }
327}