Skip to main content

openapi_model_generator/
generator.rs

1use std::sync::OnceLock;
2
3use crate::{
4    models::{
5        CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
6        UnionModel, UnionType,
7    },
8    Result,
9};
10
11static HDR: OnceLock<String> = OnceLock::new();
12
13fn create_header() -> String {
14    HDR.get_or_init(|| {
15        format!(
16            r#"
17//!
18//! Generated from an OAS specification by {}(v{})
19//!
20
21"#,
22            option_env!("CARGO_PKG_NAME").unwrap_or("openapi-model-generator"),
23            option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
24        )
25    })
26    .clone()
27}
28
29const RUST_RESERVED_KEYWORDS: &[&str] = &[
30    "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
31    "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
32    "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
33    "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try",
34    "typeof", "unsized", "virtual", "yield",
35];
36
37const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
38const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
39
40fn is_reserved_word(string_to_check: &str) -> bool {
41    RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
42}
43
44fn generate_description_docs(
45    description: &Option<String>,
46    fallback_str: &str,
47    indent: &str,
48) -> String {
49    let mut output = String::new();
50    if let Some(desc) = description {
51        for line in desc.lines() {
52            output.push_str(&format!("{}/// {}\n", indent, line.trim()));
53        }
54    } else if !fallback_str.is_empty() {
55        output.push_str(&format!("{}/// {}\n", indent, fallback_str));
56    }
57
58    output
59}
60
61fn to_snake_case(name: &str) -> String {
62    let cleaned: String = name
63        .chars()
64        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
65        .collect();
66
67    let mut snake = String::new();
68
69    for (i, c) in cleaned.chars().enumerate() {
70        if c.is_ascii_uppercase() {
71            if i != 0 {
72                snake.push('_');
73            }
74            snake.push(c.to_ascii_lowercase());
75        } else {
76            snake.push(c);
77        }
78    }
79    snake = snake.replace("__", "_");
80
81    if snake == "self" {
82        snake.push('_');
83    }
84
85    if snake
86        .chars()
87        .next()
88        .map(|c| c.is_ascii_digit())
89        .unwrap_or(false)
90    {
91        snake = format!("_{snake}");
92    }
93
94    snake
95}
96
97/// Checks if custom attributes contain a derive attribute
98fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
99    if let Some(attrs) = custom_attrs {
100        attrs
101            .iter()
102            .any(|attr| attr.trim().starts_with("#[derive("))
103    } else {
104        false
105    }
106}
107
108/// Checks if custom attributes contain a serde attribute
109fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
110    if let Some(attrs) = custom_attrs {
111        attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
112    } else {
113        false
114    }
115}
116
117/// Generates custom attributes from x-rust-attrs
118fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
119    if let Some(attrs) = custom_attrs {
120        attrs
121            .iter()
122            .map(|attr| format!("{attr}\n"))
123            .collect::<String>()
124    } else {
125        String::new()
126    }
127}
128
129pub fn generate_models(
130    models: &[ModelType],
131    requests: &[RequestModel],
132    responses: &[ResponseModel],
133) -> Result<String> {
134    // First, generate all model code to determine which imports are needed
135    let mut models_code = String::new();
136
137    for model_type in models {
138        match model_type {
139            ModelType::Struct(model) => {
140                models_code.push_str(&generate_model(model)?);
141            }
142            ModelType::Union(union) => {
143                models_code.push_str(&generate_union(union)?);
144            }
145            ModelType::Composition(comp) => {
146                models_code.push_str(&generate_composition(comp)?);
147            }
148            ModelType::Enum(enum_model) => {
149                models_code.push_str(&generate_enum(enum_model)?);
150            }
151            ModelType::TypeAlias(type_alias) => {
152                models_code.push_str(&generate_type_alias(type_alias)?);
153            }
154        }
155    }
156
157    for request in requests {
158        models_code.push_str(&generate_request_model(request)?);
159    }
160
161    for response in responses {
162        models_code.push_str(&generate_response_model(response)?);
163    }
164
165    // Determine which imports are actually needed
166    let needs_uuid = models_code.contains("Uuid");
167    let needs_datetime = models_code.contains("DateTime<Utc>");
168    let needs_date = models_code.contains("NaiveDate");
169
170    // Build final output with only necessary imports
171    let mut output = create_header();
172    output.push_str("use serde::{Serialize, Deserialize};\n");
173
174    if needs_uuid {
175        output.push_str("use uuid::Uuid;\n");
176    }
177
178    if needs_datetime || needs_date {
179        output.push_str("use chrono::{");
180        let mut chrono_imports = Vec::new();
181        if needs_datetime {
182            chrono_imports.push("DateTime");
183        }
184        if needs_date {
185            chrono_imports.push("NaiveDate");
186        }
187        if needs_datetime {
188            chrono_imports.push("Utc");
189        }
190        output.push_str(&chrono_imports.join(", "));
191        output.push_str("};\n");
192    }
193
194    output.push('\n');
195    output.push_str(&models_code);
196
197    Ok(output)
198}
199
200fn generate_model(model: &Model) -> Result<String> {
201    let mut output = String::new();
202
203    output.push_str(&generate_description_docs(
204        &model.description,
205        &model.name,
206        "",
207    ));
208
209    output.push_str(&generate_custom_attrs(&model.custom_attrs));
210
211    // Only add default derive if custom_attrs doesn't already contain a derive directive
212    if !has_custom_derive(&model.custom_attrs) {
213        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
214    }
215
216    output.push_str(&format!("pub struct {} {{\n", model.name));
217
218    for field in &model.fields {
219        let field_type = match field.field_type.as_str() {
220            "String" => "String",
221            "f64" => "f64",
222            "i64" => "i64",
223            "bool" => "bool",
224            "DateTime" => "DateTime<Utc>",
225            "Date" => "NaiveDate",
226            "Uuid" => "Uuid",
227            _ => &field.field_type,
228        };
229
230        let mut lowercased_name = to_snake_case(field.name.as_str());
231        if is_reserved_word(&lowercased_name) {
232            lowercased_name = format!("r#{lowercased_name}")
233        }
234
235        // Add field description if present
236        output.push_str(&generate_description_docs(&field.description, "", "    "));
237
238        // Only add serde rename if the Rust field name differs from the original field name
239        if lowercased_name != field.name {
240            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
241        }
242
243        if field.should_flatten() {
244            output.push_str("    #[serde(flatten)]\n");
245        }
246
247        // If field references an array, wrap it in Vec<>
248        if field.is_array_ref {
249            if field.is_required && !field.is_nullable {
250                output.push_str(&format!("    pub {lowercased_name}: Vec<{field_type}>,\n",));
251            } else {
252                output.push_str(&format!(
253                    "    pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
254                ));
255            }
256        } else if field.is_required && !field.is_nullable {
257            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
258        } else {
259            output.push_str(&format!(
260                "    pub {lowercased_name}: Option<{field_type}>,\n",
261            ));
262        }
263    }
264
265    output.push_str("}\n\n");
266    Ok(output)
267}
268
269fn generate_request_model(request: &RequestModel) -> Result<String> {
270    let mut output = String::new();
271    tracing::info!("Generating request model");
272    tracing::info!("{:#?}", request);
273
274    if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
275        return Ok(String::new());
276    }
277
278    output.push_str(&format!("/// {}\n", request.name));
279    output.push_str("#[derive(Debug, Clone, Serialize)]\n");
280    output.push_str(&format!("pub struct {} {{\n", request.name));
281    output.push_str(&format!("    pub body: {},\n", request.schema));
282    output.push_str("}\n");
283    Ok(output)
284}
285
286fn generate_response_model(response: &ResponseModel) -> Result<String> {
287    if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
288        return Ok(String::new());
289    }
290
291    let type_name = format!("{}{}", response.name, response.status_code);
292
293    let mut output = String::new();
294
295    output.push_str(&generate_description_docs(
296        &response.description,
297        &type_name,
298        "",
299    ));
300
301    output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
302    output.push_str(&format!("pub struct {type_name} {{\n"));
303    output.push_str(&format!("    pub body: {},\n", response.schema));
304    output.push_str("}\n");
305
306    Ok(output)
307}
308
309fn generate_union(union: &UnionModel) -> Result<String> {
310    let mut output = String::new();
311
312    output.push_str(&format!(
313        "/// {} ({})\n",
314        union.name,
315        match union.union_type {
316            UnionType::OneOf => "oneOf",
317            UnionType::AnyOf => "anyOf",
318        }
319    ));
320    output.push_str(&generate_custom_attrs(&union.custom_attrs));
321
322    // Only add default derive if custom_attrs doesn't already contain a derive
323    if !has_custom_derive(&union.custom_attrs) {
324        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
325    }
326
327    // Only add default serde(untagged) if custom_attrs doesn't already contain a serde attribute
328    if !has_custom_serde(&union.custom_attrs) {
329        output.push_str("#[serde(untagged)]\n");
330    }
331
332    output.push_str(&format!("pub enum {} {{\n", union.name));
333
334    for variant in &union.variants {
335        match &variant.primitive_type {
336            Some(t) => output.push_str(&format!("    {}({}),\n", variant.name, t)),
337            None => output.push_str(&format!("    {}({}),\n", variant.name, variant.name)),
338        }
339    }
340
341    output.push_str("}\n");
342    Ok(output)
343}
344
345fn generate_composition(comp: &CompositionModel) -> Result<String> {
346    let mut output = String::new();
347
348    output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
349    output.push_str(&generate_custom_attrs(&comp.custom_attrs));
350
351    // Only add default derive if custom_attrs doesn't already contain a derive
352    if !has_custom_derive(&comp.custom_attrs) {
353        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
354    }
355
356    output.push_str(&format!("pub struct {} {{\n", comp.name));
357
358    for field in &comp.all_fields {
359        let field_type = match field.field_type.as_str() {
360            "String" => "String",
361            "f64" => "f64",
362            "i64" => "i64",
363            "bool" => "bool",
364            "DateTime" => "DateTime<Utc>",
365            "Date" => "NaiveDate",
366            "Uuid" => "Uuid",
367            _ => &field.field_type,
368        };
369
370        let mut lowercased_name = to_snake_case(field.name.as_str());
371        if is_reserved_word(&lowercased_name) {
372            lowercased_name = format!("r#{lowercased_name}");
373        }
374
375        // Only add serde rename if the Rust field name differs from the original field name
376        if lowercased_name != field.name {
377            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
378        }
379
380        // If field references an array, wrap it in Vec<>
381        if field.is_array_ref {
382            if field.is_required && !field.is_nullable {
383                output.push_str(&format!("    pub {lowercased_name}: Vec<{field_type}>,\n",));
384            } else {
385                output.push_str(&format!(
386                    "    pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
387                ));
388            }
389        } else if field.is_required && !field.is_nullable {
390            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
391        } else {
392            output.push_str(&format!(
393                "    pub {lowercased_name}: Option<{field_type}>,\n",
394            ));
395        }
396    }
397
398    output.push_str("}\n");
399    Ok(output)
400}
401
402fn generate_enum(enum_model: &EnumModel) -> Result<String> {
403    let mut output = String::new();
404
405    output.push_str(&generate_description_docs(
406        &enum_model.description,
407        &enum_model.name,
408        "",
409    ));
410
411    output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
412
413    // Only add default derive if custom_attrs doesn't already contain a derive
414    if !has_custom_derive(&enum_model.custom_attrs) {
415        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
416    }
417
418    output.push_str(&format!("pub enum {} {{\n", enum_model.name));
419
420    for (i, variant) in enum_model.variants.iter().enumerate() {
421        let original = variant.clone();
422
423        let mut rust_name = crate::parser::to_pascal_case(variant);
424
425        let serde_rename = if is_reserved_word(&rust_name) {
426            rust_name.push_str("Value");
427            Some(original)
428        } else if rust_name != original {
429            Some(original)
430        } else {
431            None
432        };
433
434        if let Some(rename) = serde_rename {
435            output.push_str(&format!("    #[serde(rename = \"{rename}\")]\n"));
436        }
437
438        if i + 1 == enum_model.variants.len() {
439            output.push_str(&format!("    {rust_name}\n"));
440        } else {
441            output.push_str(&format!("    {rust_name},\n"));
442        }
443    }
444
445    output.push_str("}\n");
446    Ok(output)
447}
448
449fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
450    let mut output = String::new();
451
452    output.push_str(&generate_description_docs(
453        &type_alias.description,
454        &type_alias.name,
455        "",
456    ));
457
458    output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
459    output.push_str(&format!(
460        "pub type {} = {};\n\n",
461        type_alias.name, type_alias.target_type
462    ));
463
464    Ok(output)
465}
466
467pub fn generate_rust_code(models: &[Model]) -> Result<String> {
468    let mut code = create_header();
469
470    code.push_str("use serde::{Serialize, Deserialize};\n");
471    code.push_str("use uuid::Uuid;\n");
472    code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
473
474    for model in models {
475        code.push_str(&format!("/// {}\n", model.name));
476        code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
477        code.push_str(&format!("pub struct {} {{\n", model.name));
478
479        for field in &model.fields {
480            let field_type = match field.field_type.as_str() {
481                "String" => "String",
482                "f64" => "f64",
483                "i64" => "i64",
484                "bool" => "bool",
485                "DateTime" => "DateTime<Utc>",
486                "Date" => "NaiveDate",
487                "Uuid" => "Uuid",
488                _ => &field.field_type,
489            };
490
491            let mut lowercased_name = to_snake_case(field.name.as_str());
492            if is_reserved_word(&lowercased_name) {
493                lowercased_name = format!("r#{lowercased_name}")
494            }
495
496            // Only add serde rename if the Rust field name differs from the original field name
497            if lowercased_name != field.name {
498                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
499            }
500
501            if field.is_required {
502                code.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
503            } else {
504                code.push_str(&format!(
505                    "    pub {lowercased_name}: Option<{field_type}>,\n",
506                ));
507            }
508        }
509
510        code.push_str("}\n\n");
511    }
512
513    Ok(code)
514}
515
516pub fn generate_lib() -> Result<String> {
517    let mut code = create_header();
518    code.push_str("pub mod models;\n");
519
520    Ok(code)
521}