Skip to main content

openapi_model_generator/
generator.rs

1use std::sync::OnceLock;
2
3use crate::{
4    models::{
5        CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
6        UnionModel, UnionType,
7    },
8    Result,
9};
10
11bitflags::bitflags! {
12    struct RequiredUses: u8 {
13        const UUID = 0b00000001;
14        const DATETIME = 0b00000010;
15        const DATE = 0b00000100;
16    }
17}
18
19static HDR: OnceLock<String> = OnceLock::new();
20
21fn create_header() -> String {
22    HDR.get_or_init(|| {
23        format!(
24            r#"
25//!
26//! Generated from an OAS specification by {}(v{})
27//!
28
29"#,
30            option_env!("CARGO_PKG_NAME").unwrap_or("openapi-model-generator"),
31            option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
32        )
33    })
34    .clone()
35}
36
37const RUST_RESERVED_KEYWORDS: &[&str] = &[
38    "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
39    "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
40    "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
41    "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try",
42    "typeof", "unsized", "virtual", "yield",
43];
44
45const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
46const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
47
48fn is_reserved_word(string_to_check: &str) -> bool {
49    RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
50}
51
52fn generate_description_docs(
53    description: &Option<String>,
54    fallback_str: &str,
55    indent: &str,
56) -> String {
57    let mut output = String::new();
58    if let Some(desc) = description {
59        for line in desc.lines() {
60            output.push_str(&format!("{}/// {}\n", indent, line.trim()));
61        }
62    } else if !fallback_str.is_empty() {
63        output.push_str(&format!("{}/// {}\n", indent, fallback_str));
64    }
65
66    output
67}
68
69fn to_snake_case(name: &str) -> String {
70    let cleaned: String = name
71        .chars()
72        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
73        .collect();
74
75    let mut snake = String::new();
76
77    for (i, c) in cleaned.chars().enumerate() {
78        if c.is_ascii_uppercase() {
79            if i != 0 {
80                snake.push('_');
81            }
82            snake.push(c.to_ascii_lowercase());
83        } else {
84            snake.push(c);
85        }
86    }
87    snake = snake.replace("__", "_");
88
89    if snake == "self" {
90        snake.push('_');
91    }
92
93    if snake
94        .chars()
95        .next()
96        .map(|c| c.is_ascii_digit())
97        .unwrap_or(false)
98    {
99        snake = format!("_{snake}");
100    }
101
102    snake
103}
104
105/// Checks if custom attributes contain a derive attribute
106fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
107    if let Some(attrs) = custom_attrs {
108        attrs
109            .iter()
110            .any(|attr| attr.trim().starts_with("#[derive("))
111    } else {
112        false
113    }
114}
115
116/// Checks if custom attributes contain a serde attribute
117fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
118    if let Some(attrs) = custom_attrs {
119        attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
120    } else {
121        false
122    }
123}
124
125/// Generates custom attributes from x-rust-attrs
126fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
127    if let Some(attrs) = custom_attrs {
128        attrs
129            .iter()
130            .map(|attr| format!("{attr}\n"))
131            .collect::<String>()
132    } else {
133        String::new()
134    }
135}
136
137pub fn generate_models(
138    models: &[ModelType],
139    requests: &[RequestModel],
140    responses: &[ResponseModel],
141) -> Result<String> {
142    // First, generate all model code to determine which imports are needed
143    let mut models_code = String::new();
144    let mut required_uses = RequiredUses::empty();
145
146    for model_type in models {
147        match model_type {
148            ModelType::Struct(model) => {
149                models_code.push_str(&generate_model(model, &mut required_uses)?);
150            }
151            ModelType::Union(union) => {
152                models_code.push_str(&generate_union(union)?);
153            }
154            ModelType::Composition(comp) => {
155                models_code.push_str(&generate_composition(comp, &mut required_uses)?);
156            }
157            ModelType::Enum(enum_model) => {
158                models_code.push_str(&generate_enum(enum_model)?);
159            }
160            ModelType::TypeAlias(type_alias) => {
161                models_code.push_str(&generate_type_alias(type_alias)?);
162            }
163        }
164    }
165
166    for request in requests {
167        models_code.push_str(&generate_request_model(request)?);
168    }
169
170    for response in responses {
171        models_code.push_str(&generate_response_model(response)?);
172    }
173
174    // Determine which imports are actually needed
175    let needs_uuid = required_uses.contains(RequiredUses::UUID);
176    let needs_datetime = required_uses.contains(RequiredUses::DATETIME);
177    let needs_date = required_uses.contains(RequiredUses::DATE);
178
179    // Build final output with only necessary imports
180    let mut output = create_header();
181    output.push_str("use serde::{Serialize, Deserialize};\n");
182
183    if needs_uuid {
184        output.push_str("use uuid::Uuid;\n");
185    }
186
187    if needs_datetime || needs_date {
188        output.push_str("use chrono::{");
189        let mut chrono_imports = Vec::new();
190        if needs_datetime {
191            chrono_imports.push("DateTime");
192        }
193        if needs_date {
194            chrono_imports.push("NaiveDate");
195        }
196        if needs_datetime {
197            chrono_imports.push("Utc");
198        }
199        output.push_str(&chrono_imports.join(", "));
200        output.push_str("};\n");
201    }
202
203    output.push('\n');
204    output.push_str(&models_code);
205
206    Ok(output)
207}
208
209fn generate_model(model: &Model, required_uses: &mut RequiredUses) -> Result<String> {
210    let mut output = String::new();
211
212    output.push_str(&generate_description_docs(
213        &model.description,
214        &model.name,
215        "",
216    ));
217
218    output.push_str(&generate_custom_attrs(&model.custom_attrs));
219
220    // Only add default derive if custom_attrs doesn't already contain a derive directive
221    if !has_custom_derive(&model.custom_attrs) {
222        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
223    }
224
225    output.push_str(&format!("pub struct {} {{\n", model.name));
226
227    for field in &model.fields {
228        let field_type = match field.field_type.as_str() {
229            "String" => "String",
230            "f64" => "f64",
231            "i64" => "i64",
232            "bool" => "bool",
233            "DateTime" => {
234                *required_uses |= RequiredUses::DATETIME;
235                "DateTime<Utc>"
236            }
237            "Date" => {
238                *required_uses |= RequiredUses::DATE;
239                "NaiveDate"
240            }
241            "Uuid" => {
242                *required_uses |= RequiredUses::UUID;
243                "Uuid"
244            }
245            _ => &field.field_type,
246        };
247
248        let mut lowercased_name = to_snake_case(field.name.as_str());
249        if is_reserved_word(&lowercased_name) {
250            lowercased_name = format!("r#{lowercased_name}")
251        }
252
253        // Add field description if present
254        output.push_str(&generate_description_docs(&field.description, "", "    "));
255
256        // Only add serde rename if the Rust field name differs from the original field name
257        if lowercased_name != field.name {
258            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
259        }
260
261        if field.should_flatten() {
262            output.push_str("    #[serde(flatten)]\n");
263        }
264
265        // If field references an array, wrap it in Vec<>
266        if field.is_array_ref {
267            if field.is_required && !field.is_nullable {
268                output.push_str(&format!("    pub {lowercased_name}: Vec<{field_type}>,\n",));
269            } else {
270                output.push_str(&format!(
271                    "    pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
272                ));
273            }
274        } else if field.is_required && !field.is_nullable {
275            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
276        } else {
277            output.push_str(&format!(
278                "    pub {lowercased_name}: Option<{field_type}>,\n",
279            ));
280        }
281    }
282
283    output.push_str("}\n\n");
284    Ok(output)
285}
286
287fn generate_request_model(request: &RequestModel) -> Result<String> {
288    let mut output = String::new();
289    tracing::info!("Generating request model");
290    tracing::info!("{:#?}", request);
291
292    if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
293        return Ok(String::new());
294    }
295
296    output.push_str(&format!("/// {}\n", request.name));
297    output.push_str("#[derive(Debug, Clone, Serialize)]\n");
298    output.push_str(&format!("pub struct {} {{\n", request.name));
299    output.push_str(&format!("    pub body: {},\n", request.schema));
300    output.push_str("}\n");
301    Ok(output)
302}
303
304fn generate_response_model(response: &ResponseModel) -> Result<String> {
305    if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
306        return Ok(String::new());
307    }
308
309    let type_name = format!("{}{}", response.name, response.status_code);
310
311    let mut output = String::new();
312
313    output.push_str(&generate_description_docs(
314        &response.description,
315        &type_name,
316        "",
317    ));
318
319    output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
320    output.push_str(&format!("pub struct {type_name} {{\n"));
321    output.push_str(&format!("    pub body: {},\n", response.schema));
322    output.push_str("}\n");
323
324    Ok(output)
325}
326
327fn generate_union(union: &UnionModel) -> Result<String> {
328    let mut output = String::new();
329
330    output.push_str(&format!(
331        "/// {} ({})\n",
332        union.name,
333        match union.union_type {
334            UnionType::OneOf => "oneOf",
335            UnionType::AnyOf => "anyOf",
336        }
337    ));
338    output.push_str(&generate_custom_attrs(&union.custom_attrs));
339
340    // Only add default derive if custom_attrs doesn't already contain a derive
341    if !has_custom_derive(&union.custom_attrs) {
342        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
343    }
344
345    // Only add default serde(untagged) if custom_attrs doesn't already contain a serde attribute
346    if !has_custom_serde(&union.custom_attrs) {
347        output.push_str("#[serde(untagged)]\n");
348    }
349
350    output.push_str(&format!("pub enum {} {{\n", union.name));
351
352    for variant in &union.variants {
353        match &variant.primitive_type {
354            Some(t) => output.push_str(&format!("    {}({}),\n", variant.name, t)),
355            None => output.push_str(&format!("    {}({}),\n", variant.name, variant.name)),
356        }
357    }
358
359    output.push_str("}\n");
360    Ok(output)
361}
362
363fn generate_composition(
364    comp: &CompositionModel,
365    required_uses: &mut RequiredUses,
366) -> Result<String> {
367    let mut output = String::new();
368
369    output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
370    output.push_str(&generate_custom_attrs(&comp.custom_attrs));
371
372    // Only add default derive if custom_attrs doesn't already contain a derive
373    if !has_custom_derive(&comp.custom_attrs) {
374        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
375    }
376
377    output.push_str(&format!("pub struct {} {{\n", comp.name));
378
379    for field in &comp.all_fields {
380        let field_type = match field.field_type.as_str() {
381            "String" => "String",
382            "f64" => "f64",
383            "i64" => "i64",
384            "bool" => "bool",
385            "DateTime" => {
386                *required_uses |= RequiredUses::DATETIME;
387                "DateTime<Utc>"
388            }
389            "Date" => {
390                *required_uses |= RequiredUses::DATE;
391                "NaiveDate"
392            }
393            "Uuid" => {
394                *required_uses |= RequiredUses::UUID;
395                "Uuid"
396            }
397            _ => &field.field_type,
398        };
399
400        let mut lowercased_name = to_snake_case(field.name.as_str());
401        if is_reserved_word(&lowercased_name) {
402            lowercased_name = format!("r#{lowercased_name}");
403        }
404
405        // Only add serde rename if the Rust field name differs from the original field name
406        if lowercased_name != field.name {
407            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
408        }
409
410        // If field references an array, wrap it in Vec<>
411        if field.is_array_ref {
412            if field.is_required && !field.is_nullable {
413                output.push_str(&format!("    pub {lowercased_name}: Vec<{field_type}>,\n",));
414            } else {
415                output.push_str(&format!(
416                    "    pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
417                ));
418            }
419        } else if field.is_required && !field.is_nullable {
420            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
421        } else {
422            output.push_str(&format!(
423                "    pub {lowercased_name}: Option<{field_type}>,\n",
424            ));
425        }
426    }
427
428    output.push_str("}\n");
429    Ok(output)
430}
431
432fn generate_enum(enum_model: &EnumModel) -> Result<String> {
433    let mut output = String::new();
434
435    output.push_str(&generate_description_docs(
436        &enum_model.description,
437        &enum_model.name,
438        "",
439    ));
440
441    output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
442
443    // Only add default derive if custom_attrs doesn't already contain a derive
444    if !has_custom_derive(&enum_model.custom_attrs) {
445        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
446    }
447
448    output.push_str(&format!("pub enum {} {{\n", enum_model.name));
449
450    for (i, variant) in enum_model.variants.iter().enumerate() {
451        let original = variant.clone();
452
453        let mut rust_name = crate::parser::to_pascal_case(variant);
454
455        let serde_rename = if is_reserved_word(&rust_name) {
456            rust_name.push_str("Value");
457            Some(original)
458        } else if rust_name != original {
459            Some(original)
460        } else {
461            None
462        };
463
464        if let Some(rename) = serde_rename {
465            output.push_str(&format!("    #[serde(rename = \"{rename}\")]\n"));
466        }
467
468        if i + 1 == enum_model.variants.len() {
469            output.push_str(&format!("    {rust_name}\n"));
470        } else {
471            output.push_str(&format!("    {rust_name},\n"));
472        }
473    }
474
475    output.push_str("}\n");
476    Ok(output)
477}
478
479fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
480    let mut output = String::new();
481
482    output.push_str(&generate_description_docs(
483        &type_alias.description,
484        &type_alias.name,
485        "",
486    ));
487
488    output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
489    output.push_str(&format!(
490        "pub type {} = {};\n\n",
491        type_alias.name, type_alias.target_type
492    ));
493
494    Ok(output)
495}
496
497pub fn generate_rust_code(models: &[Model]) -> Result<String> {
498    let mut code = create_header();
499
500    code.push_str("use serde::{Serialize, Deserialize};\n");
501    code.push_str("use uuid::Uuid;\n");
502    code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
503
504    for model in models {
505        code.push_str(&format!("/// {}\n", model.name));
506        code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
507        code.push_str(&format!("pub struct {} {{\n", model.name));
508
509        for field in &model.fields {
510            let field_type = match field.field_type.as_str() {
511                "String" => "String",
512                "f64" => "f64",
513                "i64" => "i64",
514                "bool" => "bool",
515                "DateTime" => "DateTime<Utc>",
516                "Date" => "NaiveDate",
517                "Uuid" => "Uuid",
518                _ => &field.field_type,
519            };
520
521            let mut lowercased_name = to_snake_case(field.name.as_str());
522            if is_reserved_word(&lowercased_name) {
523                lowercased_name = format!("r#{lowercased_name}")
524            }
525
526            // Only add serde rename if the Rust field name differs from the original field name
527            if lowercased_name != field.name {
528                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
529            }
530
531            if field.is_required {
532                code.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
533            } else {
534                code.push_str(&format!(
535                    "    pub {lowercased_name}: Option<{field_type}>,\n",
536                ));
537            }
538        }
539
540        code.push_str("}\n\n");
541    }
542
543    Ok(code)
544}
545
546pub fn generate_lib() -> Result<String> {
547    let mut code = create_header();
548    code.push_str("pub mod models;\n");
549
550    Ok(code)
551}