openapi_model_generator/
generator.rs

1use crate::{
2    models::{
3        CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
4        UnionModel, UnionType,
5    },
6    Result,
7};
8
9const RUST_RESERVED_KEYWORDS: &[&str] = &[
10    "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
11    "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
12    "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
13    "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try",
14    "typeof", "unsized", "virtual", "yield",
15];
16
17const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
18const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
19
20fn is_reserved_word(string_to_check: &str) -> bool {
21    RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
22}
23
24fn to_snake_case(name: &str) -> String {
25    let cleaned: String = name
26        .chars()
27        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
28        .collect();
29
30    let mut snake = String::new();
31
32    for (i, c) in cleaned.chars().enumerate() {
33        if c.is_ascii_uppercase() {
34            if i != 0 {
35                snake.push('_');
36            }
37            snake.push(c.to_ascii_lowercase());
38        } else {
39            snake.push(c);
40        }
41    }
42    snake = snake.replace("__", "_");
43
44    if snake == "self" {
45        snake.push('_');
46    }
47
48    if snake
49        .chars()
50        .next()
51        .map(|c| c.is_ascii_digit())
52        .unwrap_or(false)
53    {
54        snake = format!("_{snake}");
55    }
56
57    snake
58}
59
60/// Checks if custom attributes contain a derive attribute
61fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
62    if let Some(attrs) = custom_attrs {
63        attrs
64            .iter()
65            .any(|attr| attr.trim().starts_with("#[derive("))
66    } else {
67        false
68    }
69}
70
71/// Checks if custom attributes contain a serde attribute
72fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
73    if let Some(attrs) = custom_attrs {
74        attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
75    } else {
76        false
77    }
78}
79
80/// Generates custom attributes from x-rust-attrs
81fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
82    if let Some(attrs) = custom_attrs {
83        attrs
84            .iter()
85            .map(|attr| format!("{attr}\n"))
86            .collect::<String>()
87    } else {
88        String::new()
89    }
90}
91
92pub fn generate_models(
93    models: &[ModelType],
94    requests: &[RequestModel],
95    responses: &[ResponseModel],
96) -> Result<String> {
97    // First, generate all model code to determine which imports are needed
98    let mut models_code = String::new();
99
100    for model_type in models {
101        match model_type {
102            ModelType::Struct(model) => {
103                models_code.push_str(&generate_model(model)?);
104            }
105            ModelType::Union(union) => {
106                models_code.push_str(&generate_union(union)?);
107            }
108            ModelType::Composition(comp) => {
109                models_code.push_str(&generate_composition(comp)?);
110            }
111            ModelType::Enum(enum_model) => {
112                models_code.push_str(&generate_enum(enum_model)?);
113            }
114            ModelType::TypeAlias(type_alias) => {
115                models_code.push_str(&generate_type_alias(type_alias)?);
116            }
117        }
118    }
119
120    for request in requests {
121        models_code.push_str(&generate_request_model(request)?);
122    }
123
124    for response in responses {
125        models_code.push_str(&generate_response_model(response)?);
126    }
127
128    // Determine which imports are actually needed
129    let needs_uuid = models_code.contains("Uuid");
130    let needs_datetime = models_code.contains("DateTime<Utc>");
131    let needs_date = models_code.contains("NaiveDate");
132
133    // Build final output with only necessary imports
134    let mut output = String::new();
135    output.push_str("use serde::{Serialize, Deserialize};\n");
136
137    if needs_uuid {
138        output.push_str("use uuid::Uuid;\n");
139    }
140
141    if needs_datetime || needs_date {
142        output.push_str("use chrono::{");
143        let mut chrono_imports = Vec::new();
144        if needs_datetime {
145            chrono_imports.push("DateTime");
146        }
147        if needs_date {
148            chrono_imports.push("NaiveDate");
149        }
150        if needs_datetime {
151            chrono_imports.push("Utc");
152        }
153        output.push_str(&chrono_imports.join(", "));
154        output.push_str("};\n");
155    }
156
157    output.push('\n');
158    output.push_str(&models_code);
159
160    Ok(output)
161}
162
163fn generate_model(model: &Model) -> Result<String> {
164    let mut output = String::new();
165
166    if !model.name.is_empty() {
167        output.push_str(&format!("/// {}\n", model.name));
168    }
169
170    output.push_str(&generate_custom_attrs(&model.custom_attrs));
171
172    // Only add default derive if custom_attrs doesn't already contain a derive directive
173    if !has_custom_derive(&model.custom_attrs) {
174        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
175    }
176
177    output.push_str(&format!("pub struct {} {{\n", model.name));
178
179    for field in &model.fields {
180        let field_type = match field.field_type.as_str() {
181            "String" => "String",
182            "f64" => "f64",
183            "i64" => "i64",
184            "bool" => "bool",
185            "DateTime" => "DateTime<Utc>",
186            "Date" => "NaiveDate",
187            "Uuid" => "Uuid",
188            _ => &field.field_type,
189        };
190
191        let mut lowercased_name = to_snake_case(field.name.as_str());
192        if is_reserved_word(&lowercased_name) {
193            lowercased_name = format!("r#{lowercased_name}")
194        }
195
196        // Only add serde rename if the Rust field name differs from the original field name
197        if lowercased_name != field.name {
198            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
199        }
200
201        if field.should_flatten() {
202            output.push_str("    #[serde(flatten)]\n");
203        }
204
205        if field.is_required && !field.is_nullable {
206            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
207        } else {
208            output.push_str(&format!(
209                "    pub {lowercased_name}: Option<{field_type}>,\n",
210            ));
211        }
212    }
213
214    output.push_str("}\n\n");
215    Ok(output)
216}
217
218fn generate_request_model(request: &RequestModel) -> Result<String> {
219    let mut output = String::new();
220    tracing::info!("Generating request model");
221    tracing::info!("{:#?}", request);
222
223    if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
224        return Ok(String::new());
225    }
226
227    output.push_str(&format!("/// {}\n", request.name));
228    output.push_str("#[derive(Debug, Clone, Serialize)]\n");
229    output.push_str(&format!("pub struct {} {{\n", request.name));
230    output.push_str(&format!("    pub body: {},\n", request.schema));
231    output.push_str("}\n");
232    Ok(output)
233}
234
235fn generate_response_model(response: &ResponseModel) -> Result<String> {
236    if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
237        return Ok(String::new());
238    }
239
240    let type_name = format!("{}{}", response.name, response.status_code);
241
242    let mut output = String::new();
243
244    if let Some(desc) = &response.description {
245        for line in desc.lines() {
246            output.push_str(&format!("/// {}\n", line.trim()));
247        }
248    } else {
249        output.push_str(&format!("/// {type_name}\n"));
250    }
251
252    output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
253    output.push_str(&format!("pub struct {type_name} {{\n"));
254    output.push_str(&format!("    pub body: {},\n", response.schema));
255    output.push_str("}\n");
256
257    Ok(output)
258}
259
260fn generate_union(union: &UnionModel) -> Result<String> {
261    let mut output = String::new();
262
263    output.push_str(&format!(
264        "/// {} ({})\n",
265        union.name,
266        match union.union_type {
267            UnionType::OneOf => "oneOf",
268            UnionType::AnyOf => "anyOf",
269        }
270    ));
271    output.push_str(&generate_custom_attrs(&union.custom_attrs));
272
273    // Only add default derive if custom_attrs doesn't already contain a derive
274    if !has_custom_derive(&union.custom_attrs) {
275        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
276    }
277
278    // Only add default serde(untagged) if custom_attrs doesn't already contain a serde attribute
279    if !has_custom_serde(&union.custom_attrs) {
280        output.push_str("#[serde(untagged)]\n");
281    }
282
283    output.push_str(&format!("pub enum {} {{\n", union.name));
284
285    for variant in &union.variants {
286        output.push_str(&format!("    {}({}),\n", variant.name, variant.name));
287    }
288
289    output.push_str("}\n");
290    Ok(output)
291}
292
293fn generate_composition(comp: &CompositionModel) -> Result<String> {
294    let mut output = String::new();
295
296    output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
297    output.push_str(&generate_custom_attrs(&comp.custom_attrs));
298
299    // Only add default derive if custom_attrs doesn't already contain a derive
300    if !has_custom_derive(&comp.custom_attrs) {
301        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
302    }
303
304    output.push_str(&format!("pub struct {} {{\n", comp.name));
305
306    for field in &comp.all_fields {
307        let field_type = match field.field_type.as_str() {
308            "String" => "String",
309            "f64" => "f64",
310            "i64" => "i64",
311            "bool" => "bool",
312            "DateTime" => "DateTime<Utc>",
313            "Date" => "NaiveDate",
314            "Uuid" => "Uuid",
315            _ => &field.field_type,
316        };
317
318        let mut lowercased_name = to_snake_case(field.name.as_str());
319        if is_reserved_word(&lowercased_name) {
320            lowercased_name = format!("r#{lowercased_name}");
321        }
322
323        // Only add serde rename if the Rust field name differs from the original field name
324        if lowercased_name != field.name {
325            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
326        }
327
328        if field.is_required && !field.is_nullable {
329            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n"));
330        } else {
331            output.push_str(&format!(
332                "    pub {lowercased_name}: Option<{field_type}>,\n"
333            ));
334        }
335    }
336
337    output.push_str("}\n");
338    Ok(output)
339}
340
341fn generate_enum(enum_model: &EnumModel) -> Result<String> {
342    let mut output = String::new();
343
344    if let Some(description) = &enum_model.description {
345        output.push_str(&format!("/// {description}\n"));
346    } else {
347        output.push_str(&format!("/// {}\n", enum_model.name));
348    }
349
350    output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
351
352    // Only add default derive if custom_attrs doesn't already contain a derive
353    if !has_custom_derive(&enum_model.custom_attrs) {
354        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
355    }
356
357    output.push_str(&format!("pub enum {} {{\n", enum_model.name));
358
359    for (i, variant) in enum_model.variants.iter().enumerate() {
360        let original = variant.clone();
361
362        let mut chars = variant.chars();
363        let first_char = chars.next().unwrap().to_ascii_uppercase();
364        let rest: String = chars.collect();
365        let mut rust_name = format!("{first_char}{rest}");
366
367        let serde_rename = if is_reserved_word(&rust_name) {
368            rust_name.push_str("Value");
369            Some(original)
370        } else if rust_name != original {
371            Some(original)
372        } else {
373            None
374        };
375
376        if let Some(rename) = serde_rename {
377            output.push_str(&format!("    #[serde(rename = \"{rename}\")]\n"));
378        }
379
380        if i + 1 == enum_model.variants.len() {
381            output.push_str(&format!("    {rust_name}\n"));
382        } else {
383            output.push_str(&format!("    {rust_name},\n"));
384        }
385    }
386
387    output.push_str("}\n");
388    Ok(output)
389}
390
391fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
392    let mut output = String::new();
393
394    if let Some(description) = &type_alias.description {
395        output.push_str(&format!("/// {description}\n"));
396    } else {
397        output.push_str(&format!("/// {}\n", type_alias.name));
398    }
399
400    output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
401    output.push_str(&format!(
402        "pub type {} = {};\n\n",
403        type_alias.name, type_alias.target_type
404    ));
405
406    Ok(output)
407}
408
409pub fn generate_rust_code(models: &[Model]) -> Result<String> {
410    let mut code = String::new();
411
412    code.push_str("use serde::{Serialize, Deserialize};\n");
413    code.push_str("use uuid::Uuid;\n");
414    code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
415
416    for model in models {
417        code.push_str(&format!("/// {}\n", model.name));
418        code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
419        code.push_str(&format!("pub struct {} {{\n", model.name));
420
421        for field in &model.fields {
422            let field_type = match field.field_type.as_str() {
423                "String" => "String",
424                "f64" => "f64",
425                "i64" => "i64",
426                "bool" => "bool",
427                "DateTime" => "DateTime<Utc>",
428                "Date" => "NaiveDate",
429                "Uuid" => "Uuid",
430                _ => &field.field_type,
431            };
432
433            let mut lowercased_name = to_snake_case(field.name.as_str());
434            if is_reserved_word(&lowercased_name) {
435                lowercased_name = format!("r#{lowercased_name}")
436            }
437
438            // Only add serde rename if the Rust field name differs from the original field name
439            if lowercased_name != field.name {
440                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
441            }
442
443            if field.is_required {
444                code.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
445            } else {
446                code.push_str(&format!(
447                    "    pub {lowercased_name}: Option<{field_type}>,\n",
448                ));
449            }
450        }
451
452        code.push_str("}\n\n");
453    }
454
455    Ok(code)
456}
457
458pub fn generate_lib() -> Result<String> {
459    let mut code = String::new();
460    code.push_str("pub mod models;\n");
461
462    Ok(code)
463}