Skip to main content

openapi_model_generator/
generator.rs

1use std::sync::OnceLock;
2
3use crate::{
4    models::{
5        CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
6        UnionModel, UnionType,
7    },
8    Result,
9};
10
11bitflags::bitflags! {
12    struct RequiredUses: u8 {
13        const UUID = 0b00000001;
14        const DATETIME = 0b00000010;
15        const DATE = 0b00000100;
16    }
17
18    /// Choose what type of structs you want to generate:
19    ///  - Models (generated always)
20    ///  - Requests (optional)
21    ///  - Responses (optional)
22    pub struct GenerateMode: u8 {
23        /// Models will be always include to output
24        const MODELS = 0;
25        /// Additional includes request structs to output
26        const REQUESTS = 1 << 0;
27        /// Additional includes response structs to output
28        const RESPONSES = 1 << 1;
29        /// Outputs all possible structs: models, request and response structs
30        const ALL = Self::REQUESTS.bits() | Self::RESPONSES.bits();
31    }
32}
33
34impl Default for GenerateMode {
35    fn default() -> Self {
36        Self::ALL
37    }
38}
39
40static HDR: OnceLock<String> = OnceLock::new();
41
42fn create_header() -> String {
43    HDR.get_or_init(|| {
44        format!(
45            r#"
46//!
47//! Generated from an OAS specification by {}(v{})
48//!
49
50"#,
51            option_env!("CARGO_PKG_NAME").unwrap_or("openapi-model-generator"),
52            option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
53        )
54    })
55    .clone()
56}
57
58const RUST_RESERVED_KEYWORDS: &[&str] = &[
59    "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
60    "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
61    "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
62    "while", "abstract", "become", "box", "do", "final", "gen", "macro", "override", "priv", "try",
63    "typeof", "unsized", "virtual", "yield",
64];
65
66const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
67const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
68
69fn is_reserved_word(string_to_check: &str) -> bool {
70    RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
71}
72
73fn generate_description_docs(
74    description: &Option<String>,
75    fallback_str: &str,
76    indent: &str,
77) -> String {
78    let mut output = String::new();
79    if let Some(desc) = description {
80        for line in desc.lines() {
81            output.push_str(&format!("{}/// {}\n", indent, line.trim()));
82        }
83    } else if !fallback_str.is_empty() {
84        output.push_str(&format!("{}/// {}\n", indent, fallback_str));
85    }
86
87    output
88}
89
90fn to_snake_case(name: &str) -> String {
91    let cleaned: String = name
92        .chars()
93        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
94        .collect();
95
96    let mut snake = String::new();
97
98    for (i, c) in cleaned.chars().enumerate() {
99        if c.is_ascii_uppercase() {
100            if i != 0 {
101                snake.push('_');
102            }
103            snake.push(c.to_ascii_lowercase());
104        } else {
105            snake.push(c);
106        }
107    }
108    snake = snake.replace("__", "_");
109
110    if snake == "self" {
111        snake.push('_');
112    }
113
114    if snake
115        .chars()
116        .next()
117        .map(|c| c.is_ascii_digit())
118        .unwrap_or(false)
119    {
120        snake = format!("_{snake}");
121    }
122
123    snake
124}
125
126/// Checks if custom attributes contain a derive attribute
127fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
128    if let Some(attrs) = custom_attrs {
129        attrs
130            .iter()
131            .any(|attr| attr.trim().starts_with("#[derive("))
132    } else {
133        false
134    }
135}
136
137/// Checks if custom attributes contain a serde attribute
138fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
139    if let Some(attrs) = custom_attrs {
140        attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
141    } else {
142        false
143    }
144}
145
146/// Generates custom attributes from x-rust-attrs
147fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
148    if let Some(attrs) = custom_attrs {
149        attrs
150            .iter()
151            .map(|attr| format!("{attr}\n"))
152            .collect::<String>()
153    } else {
154        String::new()
155    }
156}
157
158pub fn generate_models(
159    models: &[ModelType],
160    requests: &[RequestModel],
161    responses: &[ResponseModel],
162    mode: GenerateMode,
163) -> Result<String> {
164    // First, generate all model code to determine which imports are needed
165    let mut models_code = String::new();
166    let mut required_uses = RequiredUses::empty();
167    let mut needs_validator = false;
168
169    for model_type in models {
170        match model_type {
171            ModelType::Struct(model) => {
172                models_code.push_str(&generate_model(
173                    model,
174                    &mut required_uses,
175                    &mut needs_validator,
176                )?);
177            }
178            ModelType::Union(union) => {
179                models_code.push_str(&generate_union(union)?);
180            }
181            ModelType::Composition(comp) => {
182                models_code.push_str(&generate_composition(comp, &mut required_uses)?);
183            }
184            ModelType::Enum(enum_model) => {
185                models_code.push_str(&generate_enum(enum_model)?);
186            }
187            ModelType::TypeAlias(type_alias) => {
188                models_code.push_str(&generate_type_alias(type_alias)?);
189            }
190        }
191    }
192
193    if mode.contains(GenerateMode::REQUESTS) {
194        for request in requests {
195            models_code.push_str(&generate_request_model(request)?);
196        }
197    }
198
199    if mode.contains(GenerateMode::RESPONSES) {
200        for response in responses {
201            models_code.push_str(&generate_response_model(response)?);
202        }
203    }
204
205    // Determine which imports are actually needed
206    let needs_uuid = required_uses.contains(RequiredUses::UUID);
207    let needs_datetime = required_uses.contains(RequiredUses::DATETIME);
208    let needs_date = required_uses.contains(RequiredUses::DATE);
209
210    // Build final output with only necessary imports
211    let mut output = create_header();
212    output.push_str("use serde::{Serialize, Deserialize};\n");
213
214    if needs_uuid {
215        output.push_str("use uuid::Uuid;\n");
216    }
217
218    if needs_validator {
219        output.push_str("use validator::Validator;\n");
220    }
221
222    if needs_datetime || needs_date {
223        output.push_str("use chrono::{");
224        let mut chrono_imports = Vec::new();
225        if needs_datetime {
226            chrono_imports.push("DateTime");
227        }
228        if needs_date {
229            chrono_imports.push("NaiveDate");
230        }
231        if needs_datetime {
232            chrono_imports.push("Utc");
233        }
234        output.push_str(&chrono_imports.join(", "));
235        output.push_str("};\n");
236    }
237
238    output.push('\n');
239    output.push_str(&models_code);
240
241    Ok(output)
242}
243
244/// Generate validator attributes based on validation rules
245fn generate_validator_attrs(rules: &crate::models::ValidationRules, field_type: &str) -> String {
246    let mut attrs = String::new();
247
248    match field_type {
249        "String" | "str" | "Option<String>" | "Option<str>" => {
250            let mut length_attrs = Vec::new();
251            if let Some(min) = rules.min_length {
252                length_attrs.push(format!("min = {}", min));
253            }
254            if let Some(max) = rules.max_length {
255                length_attrs.push(format!("max = {}", max));
256            }
257            if !length_attrs.is_empty() {
258                attrs.push_str(&format!(
259                    "    #[validate(length({}))]\n",
260                    length_attrs.join(", ")
261                ));
262            }
263
264            if rules.email {
265                attrs.push_str("    #[validate(email)]\n");
266            }
267
268            if rules.url {
269                attrs.push_str("    #[validate(url)]\n");
270            }
271
272            if let Some(pattern) = &rules.pattern {
273                attrs.push_str(&format!("    #[regex(pattern = r\"{}\")]\n", pattern));
274            }
275        }
276        "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64"
277        | "Option<i8>" | "Option<i16>" | "Option<i32>" | "Option<i64>" | "Option<u8>"
278        | "Option<u16>" | "Option<u32>" | "Option<u64>" | "Option<f32>" | "Option<f64>" => {
279            let mut range_attrs = Vec::new();
280            if let Some(min) = rules.minimum {
281                range_attrs.push(format!("min = {}", min));
282            }
283            if let Some(max) = rules.maximum {
284                range_attrs.push(format!("max = {}", max));
285            }
286            if rules.exclusive_minimum || rules.exclusive_maximum {
287                range_attrs.push("exclusive = true".to_string());
288            }
289            if !range_attrs.is_empty() {
290                attrs.push_str(&format!(
291                    "    #[validate(range({}))]\n",
292                    range_attrs.join(", ")
293                ));
294            }
295        }
296        _ if field_type.contains("Vec<") => {
297            let mut length_attrs = Vec::new();
298            if let Some(min) = rules.min_items {
299                length_attrs.push(format!("min = {}", min));
300            }
301            if let Some(max) = rules.max_items {
302                length_attrs.push(format!("max = {}", max));
303            }
304            if !length_attrs.is_empty() {
305                attrs.push_str(&format!(
306                    "    #[validate(length({}))]\n",
307                    length_attrs.join(", ")
308                ));
309            }
310        }
311        _ => {}
312    }
313
314    attrs
315}
316
317fn generate_model(
318    model: &Model,
319    required_uses: &mut RequiredUses,
320    needs_validator: &mut bool,
321) -> Result<String> {
322    let mut output = String::new();
323
324    output.push_str(&generate_description_docs(
325        &model.description,
326        &model.name,
327        "",
328    ));
329
330    output.push_str(&generate_custom_attrs(&model.custom_attrs));
331
332    // Check if any fields have validation rules
333    let has_validation = model.fields.iter().any(|f| f.validation_rules.is_some());
334
335    // Mark that we need validator import if any field has validation
336    if has_validation {
337        *needs_validator = true;
338    }
339
340    // Only add default derive if custom_attrs doesn't already contain a derive directive
341    if !has_custom_derive(&model.custom_attrs) {
342        if has_validation {
343            output.push_str("#[derive(Debug, Clone, Serialize, Deserialize, Validator)]\n");
344        } else {
345            output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
346        }
347    }
348
349    output.push_str(&format!("pub struct {} {{\n", model.name));
350
351    for field in &model.fields {
352        let field_type = match field.field_type.as_str() {
353            "DateTime" | "DateTime<Utc>" => {
354                *required_uses |= RequiredUses::DATETIME;
355                "DateTime<Utc>"
356            }
357            "Date" => {
358                *required_uses |= RequiredUses::DATE;
359                "NaiveDate"
360            }
361            "Uuid" => {
362                *required_uses |= RequiredUses::UUID;
363                "Uuid"
364            }
365            _ => &field.field_type,
366        };
367
368        let mut lowercased_name = to_snake_case(field.name.as_str());
369        if is_reserved_word(&lowercased_name) {
370            lowercased_name = format!("r#{lowercased_name}")
371        }
372
373        // Add field description if present
374        output.push_str(&generate_description_docs(&field.description, "", "    "));
375
376        // Field-level custom attributes (e.g. #[serde(rename = "...")])
377        if let Some(attrs) = &field.custom_attrs {
378            for attr in attrs {
379                output.push_str(&format!("    {attr}\n"));
380            }
381        }
382
383        // Calculate full field type before generating validator attributes
384        let is_optional = !field.is_required || field.is_nullable;
385
386        let base_type = if field.is_array_ref {
387            format!("Vec<{field_type}>")
388        } else {
389            field_type.to_string()
390        };
391
392        let full_field_type = if is_optional {
393            format!("Option<{base_type}>")
394        } else {
395            base_type
396        };
397
398        // Add validator attributes if the field has validation rules
399        if let Some(rules) = &field.validation_rules {
400            output.push_str(&generate_validator_attrs(rules, &full_field_type));
401        }
402
403        // Only add serde rename if the Rust field name differs from the original field name
404        if lowercased_name != field.name {
405            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
406        }
407
408        if field.should_flatten() {
409            output.push_str("    #[serde(flatten)]\n");
410        }
411
412        output.push_str(&format!("    pub {lowercased_name}: {full_field_type},\n"));
413    }
414
415    output.push_str("}\n\n");
416    Ok(output)
417}
418
419fn generate_request_model(request: &RequestModel) -> Result<String> {
420    let mut output = String::new();
421    tracing::info!("Generating request model");
422    tracing::info!("{:#?}", request);
423
424    if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
425        return Ok(String::new());
426    }
427
428    output.push_str(&format!("/// {}\n", request.name));
429    output.push_str("#[derive(Debug, Clone, Serialize)]\n");
430    output.push_str(&format!("pub struct {} {{\n", request.name));
431    output.push_str(&format!("    pub body: {},\n", request.schema));
432    output.push_str("}\n");
433    Ok(output)
434}
435
436fn generate_response_model(response: &ResponseModel) -> Result<String> {
437    if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
438        return Ok(String::new());
439    }
440
441    let type_name = format!("{}{}", response.name, response.status_code);
442
443    let mut output = String::new();
444
445    output.push_str(&generate_description_docs(
446        &response.description,
447        &type_name,
448        "",
449    ));
450
451    output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
452    output.push_str(&format!("pub struct {type_name} {{\n"));
453    output.push_str(&format!("    pub body: {},\n", response.schema));
454    output.push_str("}\n");
455
456    Ok(output)
457}
458
459fn generate_union(union: &UnionModel) -> Result<String> {
460    let mut output = String::new();
461
462    output.push_str(&format!(
463        "/// {} ({})\n",
464        union.name,
465        match union.union_type {
466            UnionType::OneOf => "oneOf",
467            UnionType::AnyOf => "anyOf",
468        }
469    ));
470    output.push_str(&generate_custom_attrs(&union.custom_attrs));
471
472    // Only add default derive if custom_attrs doesn't already contain a derive
473    if !has_custom_derive(&union.custom_attrs) {
474        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
475    }
476
477    // Only add default serde(untagged) if custom_attrs doesn't already contain a serde attribute
478    if !has_custom_serde(&union.custom_attrs) {
479        output.push_str("#[serde(untagged)]\n");
480    }
481
482    output.push_str(&format!("pub enum {} {{\n", union.name));
483
484    for variant in &union.variants {
485        match &variant.primitive_type {
486            Some(t) => output.push_str(&format!("    {}({}),\n", variant.name, t)),
487            None => output.push_str(&format!("    {}({}),\n", variant.name, variant.name)),
488        }
489    }
490
491    output.push_str("}\n");
492    Ok(output)
493}
494
495fn generate_composition(
496    comp: &CompositionModel,
497    required_uses: &mut RequiredUses,
498) -> Result<String> {
499    let mut output = String::new();
500
501    output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
502    output.push_str(&generate_custom_attrs(&comp.custom_attrs));
503
504    // Only add default derive if custom_attrs doesn't already contain a derive
505    if !has_custom_derive(&comp.custom_attrs) {
506        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
507    }
508
509    output.push_str(&format!("pub struct {} {{\n", comp.name));
510
511    for field in &comp.all_fields {
512        let field_type = match field.field_type.as_str() {
513            "String" => "String",
514            "f64" => "f64",
515            "i64" => "i64",
516            "bool" => "bool",
517            "DateTime" => {
518                *required_uses |= RequiredUses::DATETIME;
519                "DateTime<Utc>"
520            }
521            "Date" => {
522                *required_uses |= RequiredUses::DATE;
523                "NaiveDate"
524            }
525            "Uuid" => {
526                *required_uses |= RequiredUses::UUID;
527                "Uuid"
528            }
529            _ => &field.field_type,
530        };
531
532        let mut lowercased_name = to_snake_case(field.name.as_str());
533        if is_reserved_word(&lowercased_name) {
534            lowercased_name = format!("r#{lowercased_name}");
535        }
536
537        // Only add serde rename if the Rust field name differs from the original field name
538        if lowercased_name != field.name {
539            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
540        }
541
542        // Field-level custom attributes (e.g. #[serde(rename = "...")])
543        if let Some(attrs) = &field.custom_attrs {
544            for attr in attrs {
545                output.push_str(&format!("    {attr}\n"));
546            }
547        }
548
549        // If field references an array, wrap it in Vec<>
550        if field.is_array_ref {
551            if field.is_required && !field.is_nullable {
552                output.push_str(&format!("    pub {lowercased_name}: Vec<{field_type}>,\n",));
553            } else {
554                output.push_str(&format!(
555                    "    pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
556                ));
557            }
558        } else if field.is_required && !field.is_nullable {
559            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
560        } else {
561            output.push_str(&format!(
562                "    pub {lowercased_name}: Option<{field_type}>,\n",
563            ));
564        }
565    }
566
567    output.push_str("}\n");
568    Ok(output)
569}
570
571fn generate_enum(enum_model: &EnumModel) -> Result<String> {
572    let mut output = String::new();
573
574    output.push_str(&generate_description_docs(
575        &enum_model.description,
576        &enum_model.name,
577        "",
578    ));
579
580    output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
581
582    // Only add default derive if custom_attrs doesn't already contain a derive
583    if !has_custom_derive(&enum_model.custom_attrs) {
584        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
585    }
586
587    output.push_str(&format!("pub enum {} {{\n", enum_model.name));
588
589    for (i, variant) in enum_model.variants.iter().enumerate() {
590        let original = variant.clone();
591
592        let mut rust_name = crate::parser::to_pascal_case(variant);
593
594        let serde_rename = if is_reserved_word(&rust_name) {
595            rust_name.push_str("Value");
596            Some(original)
597        } else if rust_name != original {
598            Some(original)
599        } else {
600            None
601        };
602
603        if let Some(rename) = serde_rename {
604            output.push_str(&format!("    #[serde(rename = \"{rename}\")]\n"));
605        }
606
607        if i + 1 == enum_model.variants.len() {
608            output.push_str(&format!("    {rust_name}\n"));
609        } else {
610            output.push_str(&format!("    {rust_name},\n"));
611        }
612    }
613
614    output.push_str("}\n");
615    Ok(output)
616}
617
618fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
619    let mut output = String::new();
620
621    output.push_str(&generate_description_docs(
622        &type_alias.description,
623        &type_alias.name,
624        "",
625    ));
626
627    output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
628    output.push_str(&format!(
629        "pub type {} = {};\n\n",
630        type_alias.name, type_alias.target_type
631    ));
632
633    Ok(output)
634}
635
636pub fn generate_rust_code(models: &[Model]) -> Result<String> {
637    let mut code = create_header();
638
639    code.push_str("use serde::{Serialize, Deserialize};\n");
640    code.push_str("use uuid::Uuid;\n");
641    code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
642
643    for model in models {
644        code.push_str(&format!("/// {}\n", model.name));
645        code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
646        code.push_str(&format!("pub struct {} {{\n", model.name));
647
648        for field in &model.fields {
649            let field_type = match field.field_type.as_str() {
650                "String" => "String",
651                "f64" => "f64",
652                "i64" => "i64",
653                "bool" => "bool",
654                "DateTime" => "DateTime<Utc>",
655                "Date" => "NaiveDate",
656                "Uuid" => "Uuid",
657                _ => &field.field_type,
658            };
659
660            let mut lowercased_name = to_snake_case(field.name.as_str());
661            if is_reserved_word(&lowercased_name) {
662                lowercased_name = format!("r#{lowercased_name}")
663            }
664
665            // Only add serde rename if the Rust field name differs from the original field name
666            if lowercased_name != field.name {
667                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
668            }
669
670            if field.is_required {
671                code.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
672            } else {
673                code.push_str(&format!(
674                    "    pub {lowercased_name}: Option<{field_type}>,\n",
675                ));
676            }
677        }
678
679        code.push_str("}\n\n");
680    }
681
682    Ok(code)
683}
684
685pub fn generate_lib() -> Result<String> {
686    let mut code = create_header();
687    code.push_str("pub mod models;\n");
688
689    Ok(code)
690}