Skip to main content

openapi_model_generator/
generator.rs

1use std::sync::OnceLock;
2
3use crate::{
4    models::{
5        CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
6        UnionModel, UnionType,
7    },
8    Result,
9};
10
11bitflags::bitflags! {
12    struct RequiredUses: u8 {
13        const UUID = 0b00000001;
14        const DATETIME = 0b00000010;
15        const DATE = 0b00000100;
16    }
17
18    /// Choose what type of structs you want to generate:
19    ///  - Models (generated always)
20    ///  - Requests (optional)
21    ///  - Responses (optional)
22    pub struct GenerateMode: u8 {
23        /// Models will be always include to output
24        const MODELS = 0;
25        /// Additional includes request structs to output
26        const REQUESTS = 1 << 0;
27        /// Additional includes response structs to output
28        const RESPONSES = 1 << 1;
29        /// Outputs all possible structs: models, request and response structs
30        const ALL = Self::REQUESTS.bits() | Self::RESPONSES.bits();
31    }
32}
33
34impl Default for GenerateMode {
35    fn default() -> Self {
36        Self::ALL
37    }
38}
39
40static HDR: OnceLock<String> = OnceLock::new();
41
42fn create_header() -> String {
43    HDR.get_or_init(|| {
44        format!(
45            r#"
46//!
47//! Generated from an OAS specification by {}(v{})
48//!
49
50"#,
51            option_env!("CARGO_PKG_NAME").unwrap_or("openapi-model-generator"),
52            option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
53        )
54    })
55    .clone()
56}
57
58const RUST_RESERVED_KEYWORDS: &[&str] = &[
59    "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
60    "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
61    "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
62    "while", "abstract", "become", "box", "do", "final", "gen", "macro", "override", "priv", "try",
63    "typeof", "unsized", "virtual", "yield",
64];
65
66const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
67const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
68
69fn is_reserved_word(string_to_check: &str) -> bool {
70    RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
71}
72
73fn generate_description_docs(
74    description: &Option<String>,
75    fallback_str: &str,
76    indent: &str,
77) -> String {
78    let mut output = String::new();
79    if let Some(desc) = description {
80        for line in desc.lines() {
81            output.push_str(&format!("{}/// {}\n", indent, line.trim()));
82        }
83    } else if !fallback_str.is_empty() {
84        output.push_str(&format!("{}/// {}\n", indent, fallback_str));
85    }
86
87    output
88}
89
90fn to_snake_case(name: &str) -> String {
91    let cleaned: String = name
92        .chars()
93        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
94        .collect();
95
96    let mut snake = String::new();
97
98    for (i, c) in cleaned.chars().enumerate() {
99        if c.is_ascii_uppercase() {
100            if i != 0 {
101                snake.push('_');
102            }
103            snake.push(c.to_ascii_lowercase());
104        } else {
105            snake.push(c);
106        }
107    }
108    snake = snake.replace("__", "_");
109
110    if snake == "self" {
111        snake.push('_');
112    }
113
114    if snake
115        .chars()
116        .next()
117        .map(|c| c.is_ascii_digit())
118        .unwrap_or(false)
119    {
120        snake = format!("_{snake}");
121    }
122
123    snake
124}
125
126/// Checks if custom attributes contain a derive attribute
127fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
128    if let Some(attrs) = custom_attrs {
129        attrs
130            .iter()
131            .any(|attr| attr.trim().starts_with("#[derive("))
132    } else {
133        false
134    }
135}
136
137/// Checks if custom attributes contain a serde attribute
138fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
139    if let Some(attrs) = custom_attrs {
140        attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
141    } else {
142        false
143    }
144}
145
146/// Generates a `impl std::fmt::Display` block for a named type, unless custom_attrs
147/// already contains a Display derive. Structs use `{:?}` (Debug) as a fallback;
148/// for enums and unions the caller supplies the match body via `match_arms`.
149fn generate_display_impl(name: &str, custom_attrs: &Option<Vec<String>>, body: &str) -> String {
150    // TODO: `.contains("Display")` is a loose heuristic - it correctly catches
151    // `derive_more::Display` and `#[display(...)]` format attrs, but could
152    // false-positive on unrelated attribute strings containing "Display".
153    // The proper fix is a dedicated spec extension (`x-rust-display: false`)
154    // that explicitly opts a type out of Display generation, rather than
155    // inferring intent from x-rust-attrs content.
156    let has_display = custom_attrs
157        .as_ref()
158        .is_some_and(|attrs| attrs.iter().any(|a| a.contains("Display")));
159    if has_display {
160        return String::new();
161    }
162    format!(
163        "impl std::fmt::Display for {name} {{\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{\n{body}    }}\n}}\n"
164    )
165}
166
167/// Generates custom attributes from x-rust-attrs
168fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
169    if let Some(attrs) = custom_attrs {
170        attrs
171            .iter()
172            .map(|attr| format!("{attr}\n"))
173            .collect::<String>()
174    } else {
175        String::new()
176    }
177}
178
179pub fn generate_models(
180    models: &[ModelType],
181    requests: &[RequestModel],
182    responses: &[ResponseModel],
183    mode: GenerateMode,
184    display: bool,
185) -> Result<String> {
186    // First, generate all model code to determine which imports are needed
187    let mut models_code = String::new();
188    let mut required_uses = RequiredUses::empty();
189    let mut needs_validator = false;
190
191    for model_type in models {
192        match model_type {
193            ModelType::Struct(model) => {
194                models_code.push_str(&generate_model(
195                    model,
196                    &mut required_uses,
197                    &mut needs_validator,
198                    display,
199                )?);
200            }
201            ModelType::Union(union) => {
202                models_code.push_str(&generate_union(union, display)?);
203            }
204            ModelType::Composition(comp) => {
205                models_code.push_str(&generate_composition(comp, &mut required_uses, display)?);
206            }
207            ModelType::Enum(enum_model) => {
208                models_code.push_str(&generate_enum(enum_model, display)?);
209            }
210            ModelType::TypeAlias(type_alias) => {
211                models_code.push_str(&generate_type_alias(type_alias)?);
212            }
213        }
214    }
215
216    if mode.contains(GenerateMode::REQUESTS) {
217        for request in requests {
218            models_code.push_str(&generate_request_model(request)?);
219        }
220    }
221
222    if mode.contains(GenerateMode::RESPONSES) {
223        for response in responses {
224            models_code.push_str(&generate_response_model(response)?);
225        }
226    }
227
228    // Determine which imports are actually needed
229    let needs_uuid = required_uses.contains(RequiredUses::UUID);
230    let needs_datetime = required_uses.contains(RequiredUses::DATETIME);
231    let needs_date = required_uses.contains(RequiredUses::DATE);
232
233    // Build final output with only necessary imports
234    let mut output = create_header();
235    output.push_str("use serde::{Serialize, Deserialize};\n");
236
237    if needs_uuid {
238        output.push_str("use uuid::Uuid;\n");
239    }
240
241    if needs_validator {
242        output.push_str("use validator::Validate;\n");
243    }
244
245    if needs_datetime || needs_date {
246        output.push_str("use chrono::{");
247        let mut chrono_imports = Vec::new();
248        if needs_datetime {
249            chrono_imports.push("DateTime");
250        }
251        if needs_date {
252            chrono_imports.push("NaiveDate");
253        }
254        if needs_datetime {
255            chrono_imports.push("Utc");
256        }
257        output.push_str(&chrono_imports.join(", "));
258        output.push_str("};\n");
259    }
260
261    output.push('\n');
262    output.push_str(&models_code);
263
264    Ok(output)
265}
266
267/// Generate validator attributes based on validation rules
268fn generate_validator_attrs(rules: &crate::models::ValidationRules, field_type: &str) -> String {
269    let mut attrs = String::new();
270
271    match field_type {
272        "String" | "str" | "Option<String>" | "Option<str>" => {
273            let mut length_attrs = Vec::new();
274            if let Some(min) = rules.min_length {
275                length_attrs.push(format!("min = {}", min));
276            }
277            if let Some(max) = rules.max_length {
278                length_attrs.push(format!("max = {}", max));
279            }
280            if !length_attrs.is_empty() {
281                attrs.push_str(&format!(
282                    "    #[validate(length({}))]\n",
283                    length_attrs.join(", ")
284                ));
285            }
286
287            if rules.email {
288                attrs.push_str("    #[validate(email)]\n");
289            }
290
291            if rules.url {
292                attrs.push_str("    #[validate(url)]\n");
293            }
294
295            if let Some(pattern) = &rules.pattern {
296                attrs.push_str(&format!("    #[regex(pattern = r\"{}\")]\n", pattern));
297            }
298        }
299        "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64"
300        | "Option<i8>" | "Option<i16>" | "Option<i32>" | "Option<i64>" | "Option<u8>"
301        | "Option<u16>" | "Option<u32>" | "Option<u64>" | "Option<f32>" | "Option<f64>" => {
302            let mut range_attrs = Vec::new();
303            if let Some(min) = rules.minimum {
304                range_attrs.push(format!("min = {}", min));
305            }
306            if let Some(max) = rules.maximum {
307                range_attrs.push(format!("max = {}", max));
308            }
309            if rules.exclusive_minimum || rules.exclusive_maximum {
310                range_attrs.push("exclusive = true".to_string());
311            }
312            if !range_attrs.is_empty() {
313                attrs.push_str(&format!(
314                    "    #[validate(range({}))]\n",
315                    range_attrs.join(", ")
316                ));
317            }
318        }
319        _ if field_type.contains("Vec<") => {
320            let mut length_attrs = Vec::new();
321            if let Some(min) = rules.min_items {
322                length_attrs.push(format!("min = {}", min));
323            }
324            if let Some(max) = rules.max_items {
325                length_attrs.push(format!("max = {}", max));
326            }
327            if !length_attrs.is_empty() {
328                attrs.push_str(&format!(
329                    "    #[validate(length({}))]\n",
330                    length_attrs.join(", ")
331                ));
332            }
333        }
334        _ => {}
335    }
336
337    attrs
338}
339
340fn generate_model(
341    model: &Model,
342    required_uses: &mut RequiredUses,
343    needs_validator: &mut bool,
344    display: bool,
345) -> Result<String> {
346    let mut output = String::new();
347
348    output.push_str(&generate_description_docs(
349        &model.description,
350        &model.name,
351        "",
352    ));
353
354    output.push_str(&generate_custom_attrs(&model.custom_attrs));
355
356    // First pass over fields: resolve types and generate field bodies, tracking
357    // whether any #[validate(...)] attrs are needed. This lets us emit the correct
358    // derive line once without fragile byte-range patching.
359    struct FieldOutput {
360        body: String,
361        needs_validate: bool,
362    }
363    let mut field_outputs: Vec<FieldOutput> = Vec::with_capacity(model.fields.len());
364
365    for field in &model.fields {
366        let field_type = match field.field_type.as_str() {
367            "DateTime" | "DateTime<Utc>" => {
368                *required_uses |= RequiredUses::DATETIME;
369                "DateTime<Utc>"
370            }
371            "Date" => {
372                *required_uses |= RequiredUses::DATE;
373                "NaiveDate"
374            }
375            "Uuid" => {
376                *required_uses |= RequiredUses::UUID;
377                "Uuid"
378            }
379            _ => &field.field_type,
380        };
381
382        let mut lowercased_name = to_snake_case(field.name.as_str());
383        if is_reserved_word(&lowercased_name) {
384            lowercased_name = format!("r#{lowercased_name}")
385        }
386
387        let is_optional = !field.is_required || field.is_nullable;
388        let base_type = if field.is_array_ref {
389            format!("Vec<{field_type}>")
390        } else {
391            field_type.to_string()
392        };
393        let full_field_type = if is_optional {
394            format!("Option<{base_type}>")
395        } else {
396            base_type
397        };
398
399        let mut field_body = String::new();
400        field_body.push_str(&generate_description_docs(&field.description, "", "    "));
401
402        if let Some(attrs) = &field.custom_attrs {
403            for attr in attrs {
404                field_body.push_str(&format!("    {attr}\n"));
405            }
406        }
407
408        let mut needs_validate = false;
409        if let Some(rules) = &field.validation_rules {
410            let attrs = generate_validator_attrs(rules, &full_field_type);
411            if !attrs.is_empty() {
412                needs_validate = true;
413                field_body.push_str(&attrs);
414            }
415        }
416
417        if lowercased_name != field.name {
418            field_body.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
419        }
420        if field.should_flatten() {
421            field_body.push_str("    #[serde(flatten)]\n");
422        }
423        field_body.push_str(&format!("    pub {lowercased_name}: {full_field_type},\n"));
424
425        field_outputs.push(FieldOutput {
426            body: field_body,
427            needs_validate,
428        });
429    }
430
431    let any_validate_attrs = field_outputs.iter().any(|f| f.needs_validate);
432
433    if !has_custom_derive(&model.custom_attrs) {
434        if any_validate_attrs {
435            *needs_validator = true;
436            output.push_str("#[derive(Debug, Clone, Serialize, Deserialize, Validate)]\n");
437        } else {
438            output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
439        }
440    }
441
442    output.push_str(&format!("pub struct {} {{\n", model.name));
443    for fo in field_outputs {
444        output.push_str(&fo.body);
445    }
446
447    output.push_str("}\n");
448    if display {
449        output.push_str(&generate_display_impl(
450            &model.name,
451            &model.custom_attrs,
452            "        write!(f, \"{:?}\", self)\n",
453        ));
454    }
455    output.push('\n');
456    Ok(output)
457}
458
459fn generate_request_model(request: &RequestModel) -> Result<String> {
460    let mut output = String::new();
461
462    if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
463        return Ok(String::new());
464    }
465
466    output.push_str(&format!("/// {}\n", request.name));
467    output.push_str("#[derive(Debug, Clone, Serialize)]\n");
468    output.push_str(&format!("pub struct {} {{\n", request.name));
469    output.push_str(&format!("    pub body: {},\n", request.schema));
470    output.push_str("}\n");
471    Ok(output)
472}
473
474fn generate_response_model(response: &ResponseModel) -> Result<String> {
475    if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
476        return Ok(String::new());
477    }
478
479    let type_name = format!("{}{}", response.name, response.status_code);
480
481    let mut output = String::new();
482
483    output.push_str(&generate_description_docs(
484        &response.description,
485        &type_name,
486        "",
487    ));
488
489    output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
490    output.push_str(&format!("pub struct {type_name} {{\n"));
491    output.push_str(&format!("    pub body: {},\n", response.schema));
492    output.push_str("}\n");
493
494    Ok(output)
495}
496
497fn generate_union(union: &UnionModel, display: bool) -> Result<String> {
498    let mut output = String::new();
499
500    output.push_str(&format!(
501        "/// {} ({})\n",
502        union.name,
503        match union.union_type {
504            UnionType::OneOf => "oneOf",
505            UnionType::AnyOf => "anyOf",
506        }
507    ));
508    output.push_str(&generate_custom_attrs(&union.custom_attrs));
509
510    // Only add default derive if custom_attrs doesn't already contain a derive
511    if !has_custom_derive(&union.custom_attrs) {
512        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
513    }
514
515    // Only add default serde(untagged) if custom_attrs doesn't already contain a serde attribute
516    if !has_custom_serde(&union.custom_attrs) {
517        output.push_str("#[serde(untagged)]\n");
518    }
519
520    output.push_str(&format!("pub enum {} {{\n", union.name));
521
522    for variant in &union.variants {
523        match &variant.primitive_type {
524            Some(t) => output.push_str(&format!("    {}({}),\n", variant.name, t)),
525            None => output.push_str(&format!("    {}({}),\n", variant.name, variant.name)),
526        }
527    }
528
529    output.push_str("}\n");
530
531    if display {
532        let match_arms = union
533            .variants
534            .iter()
535            .map(|v| {
536                format!(
537                    "            Self::{}(inner) => write!(f, \"{{}}\", inner),\n",
538                    v.name
539                )
540            })
541            .collect::<String>();
542        output.push_str(&generate_display_impl(
543            &union.name,
544            &union.custom_attrs,
545            &format!("        match self {{\n{match_arms}        }}\n"),
546        ));
547    }
548
549    Ok(output)
550}
551
552fn generate_composition(
553    comp: &CompositionModel,
554    required_uses: &mut RequiredUses,
555    display: bool,
556) -> Result<String> {
557    let mut output = String::new();
558
559    output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
560    output.push_str(&generate_custom_attrs(&comp.custom_attrs));
561
562    // Only add default derive if custom_attrs doesn't already contain a derive
563    if !has_custom_derive(&comp.custom_attrs) {
564        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
565    }
566
567    output.push_str(&format!("pub struct {} {{\n", comp.name));
568
569    for field in &comp.all_fields {
570        let field_type = match field.field_type.as_str() {
571            "String" => "String",
572            "f64" => "f64",
573            "i64" => "i64",
574            "bool" => "bool",
575            "DateTime" => {
576                *required_uses |= RequiredUses::DATETIME;
577                "DateTime<Utc>"
578            }
579            "Date" => {
580                *required_uses |= RequiredUses::DATE;
581                "NaiveDate"
582            }
583            "Uuid" => {
584                *required_uses |= RequiredUses::UUID;
585                "Uuid"
586            }
587            _ => &field.field_type,
588        };
589
590        let mut lowercased_name = to_snake_case(field.name.as_str());
591        if is_reserved_word(&lowercased_name) {
592            lowercased_name = format!("r#{lowercased_name}");
593        }
594
595        // Only add serde rename if the Rust field name differs from the original field name
596        if lowercased_name != field.name {
597            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
598        }
599
600        // Field-level custom attributes (e.g. #[serde(rename = "...")])
601        if let Some(attrs) = &field.custom_attrs {
602            for attr in attrs {
603                output.push_str(&format!("    {attr}\n"));
604            }
605        }
606
607        // If field references an array, wrap it in Vec<>
608        if field.is_array_ref {
609            if field.is_required && !field.is_nullable {
610                output.push_str(&format!("    pub {lowercased_name}: Vec<{field_type}>,\n",));
611            } else {
612                output.push_str(&format!(
613                    "    pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
614                ));
615            }
616        } else if field.is_required && !field.is_nullable {
617            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
618        } else {
619            output.push_str(&format!(
620                "    pub {lowercased_name}: Option<{field_type}>,\n",
621            ));
622        }
623    }
624
625    output.push_str("}\n");
626    if display {
627        output.push_str(&generate_display_impl(
628            &comp.name,
629            &comp.custom_attrs,
630            "        write!(f, \"{:?}\", self)\n",
631        ));
632    }
633    Ok(output)
634}
635
636fn generate_enum(enum_model: &EnumModel, display: bool) -> Result<String> {
637    let mut output = String::new();
638
639    output.push_str(&generate_description_docs(
640        &enum_model.description,
641        &enum_model.name,
642        "",
643    ));
644
645    output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
646
647    // Only add default derive if custom_attrs doesn't already contain a derive
648    if !has_custom_derive(&enum_model.custom_attrs) {
649        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
650    }
651
652    output.push_str(&format!("pub enum {} {{\n", enum_model.name));
653
654    // Collect (rust_name, display_value) pairs for the Display impl below.
655    let mut variant_display: Vec<(String, String)> = Vec::new();
656
657    for (i, variant) in enum_model.variants.iter().enumerate() {
658        let original = variant.clone();
659
660        let mut rust_name = crate::parser::to_pascal_case(variant);
661
662        let serde_rename = if is_reserved_word(&rust_name) {
663            rust_name.push_str("Value");
664            Some(original.clone())
665        } else if rust_name != original {
666            Some(original.clone())
667        } else {
668            None
669        };
670
671        let display_value = serde_rename
672            .as_deref()
673            .unwrap_or(&original)
674            .replace('\\', "\\\\")
675            .replace('"', "\\\"");
676        variant_display.push((rust_name.clone(), display_value));
677
678        if let Some(rename) = serde_rename {
679            output.push_str(&format!("    #[serde(rename = \"{rename}\")]\n"));
680        }
681
682        if i + 1 == enum_model.variants.len() {
683            output.push_str(&format!("    {rust_name}\n"));
684        } else {
685            output.push_str(&format!("    {rust_name},\n"));
686        }
687    }
688
689    output.push_str("}\n");
690
691    if display {
692        let match_arms = variant_display
693            .iter()
694            .map(|(rust_name, display_value)| {
695                format!("            Self::{rust_name} => write!(f, \"{display_value}\"),\n")
696            })
697            .collect::<String>();
698        output.push_str(&generate_display_impl(
699            &enum_model.name,
700            &enum_model.custom_attrs,
701            &format!("        match self {{\n{match_arms}        }}\n"),
702        ));
703    }
704
705    Ok(output)
706}
707
708fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
709    let mut output = String::new();
710
711    output.push_str(&generate_description_docs(
712        &type_alias.description,
713        &type_alias.name,
714        "",
715    ));
716
717    output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
718    output.push_str(&format!(
719        "pub type {} = {};\n\n",
720        type_alias.name, type_alias.target_type
721    ));
722
723    Ok(output)
724}
725
726pub fn generate_rust_code(models: &[Model]) -> Result<String> {
727    let mut code = create_header();
728
729    code.push_str("use serde::{Serialize, Deserialize};\n");
730    code.push_str("use uuid::Uuid;\n");
731    code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
732
733    for model in models {
734        code.push_str(&format!("/// {}\n", model.name));
735        code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
736        code.push_str(&format!("pub struct {} {{\n", model.name));
737
738        for field in &model.fields {
739            let field_type = match field.field_type.as_str() {
740                "String" => "String",
741                "f64" => "f64",
742                "i64" => "i64",
743                "bool" => "bool",
744                "DateTime" => "DateTime<Utc>",
745                "Date" => "NaiveDate",
746                "Uuid" => "Uuid",
747                _ => &field.field_type,
748            };
749
750            let mut lowercased_name = to_snake_case(field.name.as_str());
751            if is_reserved_word(&lowercased_name) {
752                lowercased_name = format!("r#{lowercased_name}")
753            }
754
755            // Only add serde rename if the Rust field name differs from the original field name
756            if lowercased_name != field.name {
757                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
758            }
759
760            if field.is_required {
761                code.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
762            } else {
763                code.push_str(&format!(
764                    "    pub {lowercased_name}: Option<{field_type}>,\n",
765                ));
766            }
767        }
768
769        code.push_str("}\n\n");
770    }
771
772    Ok(code)
773}
774
775pub fn generate_lib() -> Result<String> {
776    let mut code = create_header();
777    code.push_str("pub mod models;\n");
778
779    Ok(code)
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785    use crate::models::EnumModel;
786
787    fn make_enum(variants: Vec<&str>) -> EnumModel {
788        EnumModel {
789            name: "TestEnum".to_string(),
790            description: None,
791            variants: variants.into_iter().map(String::from).collect(),
792            custom_attrs: None,
793        }
794    }
795
796    #[test]
797    fn test_enum_display_escapes_quotes_and_backslashes() {
798        // Enum values containing " or \ must be escaped in the generated write! string literal.
799        let model = make_enum(vec!["normal", r#"with"quote"#, r"with\backslash"]);
800        let output = generate_enum(&model, true).expect("generate_enum failed");
801
802        assert!(
803            output.contains(r#"write!(f, "with\"quote")"#),
804            "double quote should be escaped in Display impl:\n{output}"
805        );
806        assert!(
807            output.contains(r#"write!(f, "with\\backslash")"#),
808            "backslash should be escaped in Display impl:\n{output}"
809        );
810        assert!(
811            output.contains(r#"write!(f, "normal")"#),
812            "plain value should be unmodified:\n{output}"
813        );
814    }
815
816    #[test]
817    fn test_enum_no_display_when_flag_off() {
818        let model = make_enum(vec!["foo", "bar"]);
819        let output = generate_enum(&model, false).expect("generate_enum failed");
820        assert!(
821            !output.contains("impl std::fmt::Display"),
822            "Display impl should not be generated when display=false:\n{output}"
823        );
824    }
825
826    #[test]
827    fn test_enum_no_display_when_custom_attrs_has_display() {
828        let mut model = make_enum(vec!["foo"]);
829        model.custom_attrs = Some(vec![
830            "#[derive(derive_more::Display, Debug, Clone)]".to_string()
831        ]);
832        let output = generate_enum(&model, true).expect("generate_enum failed");
833        assert!(
834            !output.contains("impl std::fmt::Display"),
835            "Display impl should be skipped when x-rust-attrs already has Display:\n{output}"
836        );
837    }
838}