1use std::sync::OnceLock;
2
3use crate::{
4 models::{
5 CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
6 UnionModel, UnionType,
7 },
8 Result,
9};
10
11bitflags::bitflags! {
12 struct RequiredUses: u8 {
13 const UUID = 0b00000001;
14 const DATETIME = 0b00000010;
15 const DATE = 0b00000100;
16 }
17
18 pub struct GenerateMode: u8 {
23 const MODELS = 0;
25 const REQUESTS = 1 << 0;
27 const RESPONSES = 1 << 1;
29 const ALL = Self::REQUESTS.bits() | Self::RESPONSES.bits();
31 }
32}
33
34impl Default for GenerateMode {
35 fn default() -> Self {
36 Self::ALL
37 }
38}
39
40static HDR: OnceLock<String> = OnceLock::new();
41
42fn create_header() -> String {
43 HDR.get_or_init(|| {
44 format!(
45 r#"
46//!
47//! Generated from an OAS specification by {}(v{})
48//!
49
50"#,
51 option_env!("CARGO_PKG_NAME").unwrap_or("openapi-model-generator"),
52 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
53 )
54 })
55 .clone()
56}
57
58const RUST_RESERVED_KEYWORDS: &[&str] = &[
59 "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
60 "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
61 "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
62 "while", "abstract", "become", "box", "do", "final", "gen", "macro", "override", "priv", "try",
63 "typeof", "unsized", "virtual", "yield",
64];
65
66const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
67const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
68
69fn is_reserved_word(string_to_check: &str) -> bool {
70 RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
71}
72
73fn generate_description_docs(
74 description: &Option<String>,
75 fallback_str: &str,
76 indent: &str,
77) -> String {
78 let mut output = String::new();
79 if let Some(desc) = description {
80 for line in desc.lines() {
81 output.push_str(&format!("{}/// {}\n", indent, line.trim()));
82 }
83 } else if !fallback_str.is_empty() {
84 output.push_str(&format!("{}/// {}\n", indent, fallback_str));
85 }
86
87 output
88}
89
90fn to_snake_case(name: &str) -> String {
91 let cleaned: String = name
92 .chars()
93 .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
94 .collect();
95
96 let mut snake = String::new();
97
98 for (i, c) in cleaned.chars().enumerate() {
99 if c.is_ascii_uppercase() {
100 if i != 0 {
101 snake.push('_');
102 }
103 snake.push(c.to_ascii_lowercase());
104 } else {
105 snake.push(c);
106 }
107 }
108 snake = snake.replace("__", "_");
109
110 if snake == "self" {
111 snake.push('_');
112 }
113
114 if snake
115 .chars()
116 .next()
117 .map(|c| c.is_ascii_digit())
118 .unwrap_or(false)
119 {
120 snake = format!("_{snake}");
121 }
122
123 snake
124}
125
126fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
128 if let Some(attrs) = custom_attrs {
129 attrs
130 .iter()
131 .any(|attr| attr.trim().starts_with("#[derive("))
132 } else {
133 false
134 }
135}
136
137fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
139 if let Some(attrs) = custom_attrs {
140 attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
141 } else {
142 false
143 }
144}
145
146fn generate_display_impl(name: &str, custom_attrs: &Option<Vec<String>>, body: &str) -> String {
150 let has_display = custom_attrs
157 .as_ref()
158 .is_some_and(|attrs| attrs.iter().any(|a| a.contains("Display")));
159 if has_display {
160 return String::new();
161 }
162 format!(
163 "impl std::fmt::Display for {name} {{\n fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{\n{body} }}\n}}\n"
164 )
165}
166
167fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
169 if let Some(attrs) = custom_attrs {
170 attrs
171 .iter()
172 .map(|attr| format!("{attr}\n"))
173 .collect::<String>()
174 } else {
175 String::new()
176 }
177}
178
179pub fn generate_models(
180 models: &[ModelType],
181 requests: &[RequestModel],
182 responses: &[ResponseModel],
183 mode: GenerateMode,
184 display: bool,
185) -> Result<String> {
186 let mut models_code = String::new();
188 let mut required_uses = RequiredUses::empty();
189 let mut needs_validator = false;
190
191 for model_type in models {
192 match model_type {
193 ModelType::Struct(model) => {
194 models_code.push_str(&generate_model(
195 model,
196 &mut required_uses,
197 &mut needs_validator,
198 display,
199 )?);
200 }
201 ModelType::Union(union) => {
202 models_code.push_str(&generate_union(union, display)?);
203 }
204 ModelType::Composition(comp) => {
205 models_code.push_str(&generate_composition(comp, &mut required_uses, display)?);
206 }
207 ModelType::Enum(enum_model) => {
208 models_code.push_str(&generate_enum(enum_model, display)?);
209 }
210 ModelType::TypeAlias(type_alias) => {
211 models_code.push_str(&generate_type_alias(type_alias)?);
212 }
213 }
214 }
215
216 if mode.contains(GenerateMode::REQUESTS) {
217 for request in requests {
218 models_code.push_str(&generate_request_model(request)?);
219 }
220 }
221
222 if mode.contains(GenerateMode::RESPONSES) {
223 for response in responses {
224 models_code.push_str(&generate_response_model(response)?);
225 }
226 }
227
228 let needs_uuid = required_uses.contains(RequiredUses::UUID);
230 let needs_datetime = required_uses.contains(RequiredUses::DATETIME);
231 let needs_date = required_uses.contains(RequiredUses::DATE);
232
233 let mut output = create_header();
235 output.push_str("use serde::{Serialize, Deserialize};\n");
236
237 if needs_uuid {
238 output.push_str("use uuid::Uuid;\n");
239 }
240
241 if needs_validator {
242 output.push_str("use validator::Validate;\n");
243 }
244
245 if needs_datetime || needs_date {
246 output.push_str("use chrono::{");
247 let mut chrono_imports = Vec::new();
248 if needs_datetime {
249 chrono_imports.push("DateTime");
250 }
251 if needs_date {
252 chrono_imports.push("NaiveDate");
253 }
254 if needs_datetime {
255 chrono_imports.push("Utc");
256 }
257 output.push_str(&chrono_imports.join(", "));
258 output.push_str("};\n");
259 }
260
261 output.push('\n');
262 output.push_str(&models_code);
263
264 Ok(output)
265}
266
267fn generate_validator_attrs(rules: &crate::models::ValidationRules, field_type: &str) -> String {
269 let mut attrs = String::new();
270
271 match field_type {
272 "String" | "str" | "Option<String>" | "Option<str>" => {
273 let mut length_attrs = Vec::new();
274 if let Some(min) = rules.min_length {
275 length_attrs.push(format!("min = {}", min));
276 }
277 if let Some(max) = rules.max_length {
278 length_attrs.push(format!("max = {}", max));
279 }
280 if !length_attrs.is_empty() {
281 attrs.push_str(&format!(
282 " #[validate(length({}))]\n",
283 length_attrs.join(", ")
284 ));
285 }
286
287 if rules.email {
288 attrs.push_str(" #[validate(email)]\n");
289 }
290
291 if rules.url {
292 attrs.push_str(" #[validate(url)]\n");
293 }
294
295 if let Some(pattern) = &rules.pattern {
296 attrs.push_str(&format!(" #[regex(pattern = r\"{}\")]\n", pattern));
297 }
298 }
299 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64"
300 | "Option<i8>" | "Option<i16>" | "Option<i32>" | "Option<i64>" | "Option<u8>"
301 | "Option<u16>" | "Option<u32>" | "Option<u64>" | "Option<f32>" | "Option<f64>" => {
302 let mut range_attrs = Vec::new();
303 if let Some(min) = rules.minimum {
304 range_attrs.push(format!("min = {}", min));
305 }
306 if let Some(max) = rules.maximum {
307 range_attrs.push(format!("max = {}", max));
308 }
309 if rules.exclusive_minimum || rules.exclusive_maximum {
310 range_attrs.push("exclusive = true".to_string());
311 }
312 if !range_attrs.is_empty() {
313 attrs.push_str(&format!(
314 " #[validate(range({}))]\n",
315 range_attrs.join(", ")
316 ));
317 }
318 }
319 _ if field_type.contains("Vec<") => {
320 let mut length_attrs = Vec::new();
321 if let Some(min) = rules.min_items {
322 length_attrs.push(format!("min = {}", min));
323 }
324 if let Some(max) = rules.max_items {
325 length_attrs.push(format!("max = {}", max));
326 }
327 if !length_attrs.is_empty() {
328 attrs.push_str(&format!(
329 " #[validate(length({}))]\n",
330 length_attrs.join(", ")
331 ));
332 }
333 }
334 _ => {}
335 }
336
337 attrs
338}
339
340fn generate_model(
341 model: &Model,
342 required_uses: &mut RequiredUses,
343 needs_validator: &mut bool,
344 display: bool,
345) -> Result<String> {
346 let mut output = String::new();
347
348 output.push_str(&generate_description_docs(
349 &model.description,
350 &model.name,
351 "",
352 ));
353
354 output.push_str(&generate_custom_attrs(&model.custom_attrs));
355
356 struct FieldOutput {
360 body: String,
361 needs_validate: bool,
362 }
363 let mut field_outputs: Vec<FieldOutput> = Vec::with_capacity(model.fields.len());
364
365 for field in &model.fields {
366 let field_type = match field.field_type.as_str() {
367 "DateTime" | "DateTime<Utc>" => {
368 *required_uses |= RequiredUses::DATETIME;
369 "DateTime<Utc>"
370 }
371 "Date" => {
372 *required_uses |= RequiredUses::DATE;
373 "NaiveDate"
374 }
375 "Uuid" => {
376 *required_uses |= RequiredUses::UUID;
377 "Uuid"
378 }
379 _ => &field.field_type,
380 };
381
382 let mut lowercased_name = to_snake_case(field.name.as_str());
383 if is_reserved_word(&lowercased_name) {
384 lowercased_name = format!("r#{lowercased_name}")
385 }
386
387 let is_optional = !field.is_required || field.is_nullable;
388 let base_type = if field.is_array_ref {
389 format!("Vec<{field_type}>")
390 } else {
391 field_type.to_string()
392 };
393 let full_field_type = if is_optional {
394 format!("Option<{base_type}>")
395 } else {
396 base_type
397 };
398
399 let mut field_body = String::new();
400 field_body.push_str(&generate_description_docs(&field.description, "", " "));
401
402 if let Some(attrs) = &field.custom_attrs {
403 for attr in attrs {
404 field_body.push_str(&format!(" {attr}\n"));
405 }
406 }
407
408 let mut needs_validate = false;
409 if let Some(rules) = &field.validation_rules {
410 let attrs = generate_validator_attrs(rules, &full_field_type);
411 if !attrs.is_empty() {
412 needs_validate = true;
413 field_body.push_str(&attrs);
414 }
415 }
416
417 if lowercased_name != field.name {
418 field_body.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
419 }
420 if field.should_flatten() {
421 field_body.push_str(" #[serde(flatten)]\n");
422 }
423 field_body.push_str(&format!(" pub {lowercased_name}: {full_field_type},\n"));
424
425 field_outputs.push(FieldOutput {
426 body: field_body,
427 needs_validate,
428 });
429 }
430
431 let any_validate_attrs = field_outputs.iter().any(|f| f.needs_validate);
432
433 if !has_custom_derive(&model.custom_attrs) {
434 if any_validate_attrs {
435 *needs_validator = true;
436 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize, Validate)]\n");
437 } else {
438 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
439 }
440 }
441
442 output.push_str(&format!("pub struct {} {{\n", model.name));
443 for fo in field_outputs {
444 output.push_str(&fo.body);
445 }
446
447 output.push_str("}\n");
448 if display {
449 output.push_str(&generate_display_impl(
450 &model.name,
451 &model.custom_attrs,
452 " write!(f, \"{:?}\", self)\n",
453 ));
454 }
455 output.push('\n');
456 Ok(output)
457}
458
459fn generate_request_model(request: &RequestModel) -> Result<String> {
460 let mut output = String::new();
461
462 if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
463 return Ok(String::new());
464 }
465
466 output.push_str(&format!("/// {}\n", request.name));
467 output.push_str("#[derive(Debug, Clone, Serialize)]\n");
468 output.push_str(&format!("pub struct {} {{\n", request.name));
469 output.push_str(&format!(" pub body: {},\n", request.schema));
470 output.push_str("}\n");
471 Ok(output)
472}
473
474fn generate_response_model(response: &ResponseModel) -> Result<String> {
475 if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
476 return Ok(String::new());
477 }
478
479 let type_name = format!("{}{}", response.name, response.status_code);
480
481 let mut output = String::new();
482
483 output.push_str(&generate_description_docs(
484 &response.description,
485 &type_name,
486 "",
487 ));
488
489 output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
490 output.push_str(&format!("pub struct {type_name} {{\n"));
491 output.push_str(&format!(" pub body: {},\n", response.schema));
492 output.push_str("}\n");
493
494 Ok(output)
495}
496
497fn generate_union(union: &UnionModel, display: bool) -> Result<String> {
498 let mut output = String::new();
499
500 output.push_str(&format!(
501 "/// {} ({})\n",
502 union.name,
503 match union.union_type {
504 UnionType::OneOf => "oneOf",
505 UnionType::AnyOf => "anyOf",
506 }
507 ));
508 output.push_str(&generate_custom_attrs(&union.custom_attrs));
509
510 if !has_custom_derive(&union.custom_attrs) {
512 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
513 }
514
515 if !has_custom_serde(&union.custom_attrs) {
517 output.push_str("#[serde(untagged)]\n");
518 }
519
520 output.push_str(&format!("pub enum {} {{\n", union.name));
521
522 for variant in &union.variants {
523 match &variant.primitive_type {
524 Some(t) => output.push_str(&format!(" {}({}),\n", variant.name, t)),
525 None => output.push_str(&format!(" {}({}),\n", variant.name, variant.name)),
526 }
527 }
528
529 output.push_str("}\n");
530
531 if display {
532 let match_arms = union
533 .variants
534 .iter()
535 .map(|v| {
536 format!(
537 " Self::{}(inner) => write!(f, \"{{}}\", inner),\n",
538 v.name
539 )
540 })
541 .collect::<String>();
542 output.push_str(&generate_display_impl(
543 &union.name,
544 &union.custom_attrs,
545 &format!(" match self {{\n{match_arms} }}\n"),
546 ));
547 }
548
549 Ok(output)
550}
551
552fn generate_composition(
553 comp: &CompositionModel,
554 required_uses: &mut RequiredUses,
555 display: bool,
556) -> Result<String> {
557 let mut output = String::new();
558
559 output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
560 output.push_str(&generate_custom_attrs(&comp.custom_attrs));
561
562 if !has_custom_derive(&comp.custom_attrs) {
564 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
565 }
566
567 output.push_str(&format!("pub struct {} {{\n", comp.name));
568
569 for field in &comp.all_fields {
570 let field_type = match field.field_type.as_str() {
571 "String" => "String",
572 "f64" => "f64",
573 "i64" => "i64",
574 "bool" => "bool",
575 "DateTime" => {
576 *required_uses |= RequiredUses::DATETIME;
577 "DateTime<Utc>"
578 }
579 "Date" => {
580 *required_uses |= RequiredUses::DATE;
581 "NaiveDate"
582 }
583 "Uuid" => {
584 *required_uses |= RequiredUses::UUID;
585 "Uuid"
586 }
587 _ => &field.field_type,
588 };
589
590 let mut lowercased_name = to_snake_case(field.name.as_str());
591 if is_reserved_word(&lowercased_name) {
592 lowercased_name = format!("r#{lowercased_name}");
593 }
594
595 if lowercased_name != field.name {
597 output.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
598 }
599
600 if let Some(attrs) = &field.custom_attrs {
602 for attr in attrs {
603 output.push_str(&format!(" {attr}\n"));
604 }
605 }
606
607 if field.is_array_ref {
609 if field.is_required && !field.is_nullable {
610 output.push_str(&format!(" pub {lowercased_name}: Vec<{field_type}>,\n",));
611 } else {
612 output.push_str(&format!(
613 " pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
614 ));
615 }
616 } else if field.is_required && !field.is_nullable {
617 output.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
618 } else {
619 output.push_str(&format!(
620 " pub {lowercased_name}: Option<{field_type}>,\n",
621 ));
622 }
623 }
624
625 output.push_str("}\n");
626 if display {
627 output.push_str(&generate_display_impl(
628 &comp.name,
629 &comp.custom_attrs,
630 " write!(f, \"{:?}\", self)\n",
631 ));
632 }
633 Ok(output)
634}
635
636fn generate_enum(enum_model: &EnumModel, display: bool) -> Result<String> {
637 let mut output = String::new();
638
639 output.push_str(&generate_description_docs(
640 &enum_model.description,
641 &enum_model.name,
642 "",
643 ));
644
645 output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
646
647 if !has_custom_derive(&enum_model.custom_attrs) {
649 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
650 }
651
652 output.push_str(&format!("pub enum {} {{\n", enum_model.name));
653
654 let mut variant_display: Vec<(String, String)> = Vec::new();
656
657 for (i, variant) in enum_model.variants.iter().enumerate() {
658 let original = variant.clone();
659
660 let mut rust_name = crate::parser::to_pascal_case(variant);
661
662 let serde_rename = if is_reserved_word(&rust_name) {
663 rust_name.push_str("Value");
664 Some(original.clone())
665 } else if rust_name != original {
666 Some(original.clone())
667 } else {
668 None
669 };
670
671 let display_value = serde_rename
672 .as_deref()
673 .unwrap_or(&original)
674 .replace('\\', "\\\\")
675 .replace('"', "\\\"");
676 variant_display.push((rust_name.clone(), display_value));
677
678 if let Some(rename) = serde_rename {
679 output.push_str(&format!(" #[serde(rename = \"{rename}\")]\n"));
680 }
681
682 if i + 1 == enum_model.variants.len() {
683 output.push_str(&format!(" {rust_name}\n"));
684 } else {
685 output.push_str(&format!(" {rust_name},\n"));
686 }
687 }
688
689 output.push_str("}\n");
690
691 if display {
692 let match_arms = variant_display
693 .iter()
694 .map(|(rust_name, display_value)| {
695 format!(" Self::{rust_name} => write!(f, \"{display_value}\"),\n")
696 })
697 .collect::<String>();
698 output.push_str(&generate_display_impl(
699 &enum_model.name,
700 &enum_model.custom_attrs,
701 &format!(" match self {{\n{match_arms} }}\n"),
702 ));
703 }
704
705 Ok(output)
706}
707
708fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
709 let mut output = String::new();
710
711 output.push_str(&generate_description_docs(
712 &type_alias.description,
713 &type_alias.name,
714 "",
715 ));
716
717 output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
718 output.push_str(&format!(
719 "pub type {} = {};\n\n",
720 type_alias.name, type_alias.target_type
721 ));
722
723 Ok(output)
724}
725
726pub fn generate_rust_code(models: &[Model]) -> Result<String> {
727 let mut code = create_header();
728
729 code.push_str("use serde::{Serialize, Deserialize};\n");
730 code.push_str("use uuid::Uuid;\n");
731 code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
732
733 for model in models {
734 code.push_str(&format!("/// {}\n", model.name));
735 code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
736 code.push_str(&format!("pub struct {} {{\n", model.name));
737
738 for field in &model.fields {
739 let field_type = match field.field_type.as_str() {
740 "String" => "String",
741 "f64" => "f64",
742 "i64" => "i64",
743 "bool" => "bool",
744 "DateTime" => "DateTime<Utc>",
745 "Date" => "NaiveDate",
746 "Uuid" => "Uuid",
747 _ => &field.field_type,
748 };
749
750 let mut lowercased_name = to_snake_case(field.name.as_str());
751 if is_reserved_word(&lowercased_name) {
752 lowercased_name = format!("r#{lowercased_name}")
753 }
754
755 if lowercased_name != field.name {
757 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
758 }
759
760 if field.is_required {
761 code.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
762 } else {
763 code.push_str(&format!(
764 " pub {lowercased_name}: Option<{field_type}>,\n",
765 ));
766 }
767 }
768
769 code.push_str("}\n\n");
770 }
771
772 Ok(code)
773}
774
775pub fn generate_lib() -> Result<String> {
776 let mut code = create_header();
777 code.push_str("pub mod models;\n");
778
779 Ok(code)
780}
781
782#[cfg(test)]
783mod tests {
784 use super::*;
785 use crate::models::EnumModel;
786
787 fn make_enum(variants: Vec<&str>) -> EnumModel {
788 EnumModel {
789 name: "TestEnum".to_string(),
790 description: None,
791 variants: variants.into_iter().map(String::from).collect(),
792 custom_attrs: None,
793 }
794 }
795
796 #[test]
797 fn test_enum_display_escapes_quotes_and_backslashes() {
798 let model = make_enum(vec!["normal", r#"with"quote"#, r"with\backslash"]);
800 let output = generate_enum(&model, true).expect("generate_enum failed");
801
802 assert!(
803 output.contains(r#"write!(f, "with\"quote")"#),
804 "double quote should be escaped in Display impl:\n{output}"
805 );
806 assert!(
807 output.contains(r#"write!(f, "with\\backslash")"#),
808 "backslash should be escaped in Display impl:\n{output}"
809 );
810 assert!(
811 output.contains(r#"write!(f, "normal")"#),
812 "plain value should be unmodified:\n{output}"
813 );
814 }
815
816 #[test]
817 fn test_enum_no_display_when_flag_off() {
818 let model = make_enum(vec!["foo", "bar"]);
819 let output = generate_enum(&model, false).expect("generate_enum failed");
820 assert!(
821 !output.contains("impl std::fmt::Display"),
822 "Display impl should not be generated when display=false:\n{output}"
823 );
824 }
825
826 #[test]
827 fn test_enum_no_display_when_custom_attrs_has_display() {
828 let mut model = make_enum(vec!["foo"]);
829 model.custom_attrs = Some(vec![
830 "#[derive(derive_more::Display, Debug, Clone)]".to_string()
831 ]);
832 let output = generate_enum(&model, true).expect("generate_enum failed");
833 assert!(
834 !output.contains("impl std::fmt::Display"),
835 "Display impl should be skipped when x-rust-attrs already has Display:\n{output}"
836 );
837 }
838}