1use std::sync::OnceLock;
2
3use crate::{
4 models::{
5 CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
6 UnionModel, UnionType,
7 },
8 Result,
9};
10
11bitflags::bitflags! {
12 struct RequiredUses: u8 {
13 const UUID = 0b00000001;
14 const DATETIME = 0b00000010;
15 const DATE = 0b00000100;
16 }
17
18 pub struct GenerateMode: u8 {
23 const MODELS = 0;
25 const REQUESTS = 1 << 0;
27 const RESPONSES = 1 << 1;
29 const ALL = Self::REQUESTS.bits() | Self::RESPONSES.bits();
31 }
32}
33
34impl Default for GenerateMode {
35 fn default() -> Self {
36 Self::ALL
37 }
38}
39
40static HDR: OnceLock<String> = OnceLock::new();
41
42fn create_header() -> String {
43 HDR.get_or_init(|| {
44 format!(
45 r#"
46//!
47//! Generated from an OAS specification by {}(v{})
48//!
49
50"#,
51 option_env!("CARGO_PKG_NAME").unwrap_or("openapi-model-generator"),
52 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
53 )
54 })
55 .clone()
56}
57
58const RUST_RESERVED_KEYWORDS: &[&str] = &[
59 "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
60 "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
61 "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
62 "while", "abstract", "become", "box", "do", "final", "gen", "macro", "override", "priv", "try",
63 "typeof", "unsized", "virtual", "yield",
64];
65
66const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
67const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
68
69fn is_reserved_word(string_to_check: &str) -> bool {
70 RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
71}
72
73fn generate_description_docs(
74 description: &Option<String>,
75 fallback_str: &str,
76 indent: &str,
77) -> String {
78 let mut output = String::new();
79 if let Some(desc) = description {
80 for line in desc.lines() {
81 output.push_str(&format!("{}/// {}\n", indent, line.trim()));
82 }
83 } else if !fallback_str.is_empty() {
84 output.push_str(&format!("{}/// {}\n", indent, fallback_str));
85 }
86
87 output
88}
89
90fn to_snake_case(name: &str) -> String {
91 let cleaned: String = name
92 .chars()
93 .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
94 .collect();
95
96 let mut snake = String::new();
97
98 for (i, c) in cleaned.chars().enumerate() {
99 if c.is_ascii_uppercase() {
100 if i != 0 {
101 snake.push('_');
102 }
103 snake.push(c.to_ascii_lowercase());
104 } else {
105 snake.push(c);
106 }
107 }
108 snake = snake.replace("__", "_");
109
110 if snake == "self" {
111 snake.push('_');
112 }
113
114 if snake
115 .chars()
116 .next()
117 .map(|c| c.is_ascii_digit())
118 .unwrap_or(false)
119 {
120 snake = format!("_{snake}");
121 }
122
123 snake
124}
125
126fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
128 if let Some(attrs) = custom_attrs {
129 attrs
130 .iter()
131 .any(|attr| attr.trim().starts_with("#[derive("))
132 } else {
133 false
134 }
135}
136
137fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
139 if let Some(attrs) = custom_attrs {
140 attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
141 } else {
142 false
143 }
144}
145
146fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
148 if let Some(attrs) = custom_attrs {
149 attrs
150 .iter()
151 .map(|attr| format!("{attr}\n"))
152 .collect::<String>()
153 } else {
154 String::new()
155 }
156}
157
158pub fn generate_models(
159 models: &[ModelType],
160 requests: &[RequestModel],
161 responses: &[ResponseModel],
162 mode: GenerateMode,
163) -> Result<String> {
164 let mut models_code = String::new();
166 let mut required_uses = RequiredUses::empty();
167 let mut needs_validator = false;
168
169 for model_type in models {
170 match model_type {
171 ModelType::Struct(model) => {
172 models_code.push_str(&generate_model(
173 model,
174 &mut required_uses,
175 &mut needs_validator,
176 )?);
177 }
178 ModelType::Union(union) => {
179 models_code.push_str(&generate_union(union)?);
180 }
181 ModelType::Composition(comp) => {
182 models_code.push_str(&generate_composition(comp, &mut required_uses)?);
183 }
184 ModelType::Enum(enum_model) => {
185 models_code.push_str(&generate_enum(enum_model)?);
186 }
187 ModelType::TypeAlias(type_alias) => {
188 models_code.push_str(&generate_type_alias(type_alias)?);
189 }
190 }
191 }
192
193 if mode.contains(GenerateMode::REQUESTS) {
194 for request in requests {
195 models_code.push_str(&generate_request_model(request)?);
196 }
197 }
198
199 if mode.contains(GenerateMode::RESPONSES) {
200 for response in responses {
201 models_code.push_str(&generate_response_model(response)?);
202 }
203 }
204
205 let needs_uuid = required_uses.contains(RequiredUses::UUID);
207 let needs_datetime = required_uses.contains(RequiredUses::DATETIME);
208 let needs_date = required_uses.contains(RequiredUses::DATE);
209
210 let mut output = create_header();
212 output.push_str("use serde::{Serialize, Deserialize};\n");
213
214 if needs_uuid {
215 output.push_str("use uuid::Uuid;\n");
216 }
217
218 if needs_validator {
219 output.push_str("use validator::Validator;\n");
220 }
221
222 if needs_datetime || needs_date {
223 output.push_str("use chrono::{");
224 let mut chrono_imports = Vec::new();
225 if needs_datetime {
226 chrono_imports.push("DateTime");
227 }
228 if needs_date {
229 chrono_imports.push("NaiveDate");
230 }
231 if needs_datetime {
232 chrono_imports.push("Utc");
233 }
234 output.push_str(&chrono_imports.join(", "));
235 output.push_str("};\n");
236 }
237
238 output.push('\n');
239 output.push_str(&models_code);
240
241 Ok(output)
242}
243
244fn generate_validator_attrs(rules: &crate::models::ValidationRules, field_type: &str) -> String {
246 let mut attrs = String::new();
247
248 match field_type {
249 "String" | "str" | "Option<String>" | "Option<str>" => {
250 let mut length_attrs = Vec::new();
251 if let Some(min) = rules.min_length {
252 length_attrs.push(format!("min = {}", min));
253 }
254 if let Some(max) = rules.max_length {
255 length_attrs.push(format!("max = {}", max));
256 }
257 if !length_attrs.is_empty() {
258 attrs.push_str(&format!(
259 " #[validate(length({}))]\n",
260 length_attrs.join(", ")
261 ));
262 }
263
264 if rules.email {
265 attrs.push_str(" #[validate(email)]\n");
266 }
267
268 if rules.url {
269 attrs.push_str(" #[validate(url)]\n");
270 }
271
272 if let Some(pattern) = &rules.pattern {
273 attrs.push_str(&format!(" #[regex(pattern = r\"{}\")]\n", pattern));
274 }
275 }
276 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64"
277 | "Option<i8>" | "Option<i16>" | "Option<i32>" | "Option<i64>" | "Option<u8>"
278 | "Option<u16>" | "Option<u32>" | "Option<u64>" | "Option<f32>" | "Option<f64>" => {
279 let mut range_attrs = Vec::new();
280 if let Some(min) = rules.minimum {
281 range_attrs.push(format!("min = {}", min));
282 }
283 if let Some(max) = rules.maximum {
284 range_attrs.push(format!("max = {}", max));
285 }
286 if rules.exclusive_minimum || rules.exclusive_maximum {
287 range_attrs.push("exclusive = true".to_string());
288 }
289 if !range_attrs.is_empty() {
290 attrs.push_str(&format!(
291 " #[validate(range({}))]\n",
292 range_attrs.join(", ")
293 ));
294 }
295 }
296 _ if field_type.contains("Vec<") => {
297 let mut length_attrs = Vec::new();
298 if let Some(min) = rules.min_items {
299 length_attrs.push(format!("min = {}", min));
300 }
301 if let Some(max) = rules.max_items {
302 length_attrs.push(format!("max = {}", max));
303 }
304 if !length_attrs.is_empty() {
305 attrs.push_str(&format!(
306 " #[validate(length({}))]\n",
307 length_attrs.join(", ")
308 ));
309 }
310 }
311 _ => {}
312 }
313
314 attrs
315}
316
317fn generate_model(
318 model: &Model,
319 required_uses: &mut RequiredUses,
320 needs_validator: &mut bool,
321) -> Result<String> {
322 let mut output = String::new();
323
324 output.push_str(&generate_description_docs(
325 &model.description,
326 &model.name,
327 "",
328 ));
329
330 output.push_str(&generate_custom_attrs(&model.custom_attrs));
331
332 let has_validation = model.fields.iter().any(|f| f.validation_rules.is_some());
334
335 if has_validation {
337 *needs_validator = true;
338 }
339
340 if !has_custom_derive(&model.custom_attrs) {
342 if has_validation {
343 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize, Validator)]\n");
344 } else {
345 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
346 }
347 }
348
349 output.push_str(&format!("pub struct {} {{\n", model.name));
350
351 for field in &model.fields {
352 let field_type = match field.field_type.as_str() {
353 "DateTime" | "DateTime<Utc>" => {
354 *required_uses |= RequiredUses::DATETIME;
355 "DateTime<Utc>"
356 }
357 "Date" => {
358 *required_uses |= RequiredUses::DATE;
359 "NaiveDate"
360 }
361 "Uuid" => {
362 *required_uses |= RequiredUses::UUID;
363 "Uuid"
364 }
365 _ => &field.field_type,
366 };
367
368 let mut lowercased_name = to_snake_case(field.name.as_str());
369 if is_reserved_word(&lowercased_name) {
370 lowercased_name = format!("r#{lowercased_name}")
371 }
372
373 output.push_str(&generate_description_docs(&field.description, "", " "));
375
376 if let Some(attrs) = &field.custom_attrs {
378 for attr in attrs {
379 output.push_str(&format!(" {attr}\n"));
380 }
381 }
382
383 let is_optional = !field.is_required || field.is_nullable;
385
386 let base_type = if field.is_array_ref {
387 format!("Vec<{field_type}>")
388 } else {
389 field_type.to_string()
390 };
391
392 let full_field_type = if is_optional {
393 format!("Option<{base_type}>")
394 } else {
395 base_type
396 };
397
398 if let Some(rules) = &field.validation_rules {
400 output.push_str(&generate_validator_attrs(rules, &full_field_type));
401 }
402
403 if lowercased_name != field.name {
405 output.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
406 }
407
408 if field.should_flatten() {
409 output.push_str(" #[serde(flatten)]\n");
410 }
411
412 output.push_str(&format!(" pub {lowercased_name}: {full_field_type},\n"));
413 }
414
415 output.push_str("}\n\n");
416 Ok(output)
417}
418
419fn generate_request_model(request: &RequestModel) -> Result<String> {
420 let mut output = String::new();
421 tracing::info!("Generating request model");
422 tracing::info!("{:#?}", request);
423
424 if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
425 return Ok(String::new());
426 }
427
428 output.push_str(&format!("/// {}\n", request.name));
429 output.push_str("#[derive(Debug, Clone, Serialize)]\n");
430 output.push_str(&format!("pub struct {} {{\n", request.name));
431 output.push_str(&format!(" pub body: {},\n", request.schema));
432 output.push_str("}\n");
433 Ok(output)
434}
435
436fn generate_response_model(response: &ResponseModel) -> Result<String> {
437 if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
438 return Ok(String::new());
439 }
440
441 let type_name = format!("{}{}", response.name, response.status_code);
442
443 let mut output = String::new();
444
445 output.push_str(&generate_description_docs(
446 &response.description,
447 &type_name,
448 "",
449 ));
450
451 output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
452 output.push_str(&format!("pub struct {type_name} {{\n"));
453 output.push_str(&format!(" pub body: {},\n", response.schema));
454 output.push_str("}\n");
455
456 Ok(output)
457}
458
459fn generate_union(union: &UnionModel) -> Result<String> {
460 let mut output = String::new();
461
462 output.push_str(&format!(
463 "/// {} ({})\n",
464 union.name,
465 match union.union_type {
466 UnionType::OneOf => "oneOf",
467 UnionType::AnyOf => "anyOf",
468 }
469 ));
470 output.push_str(&generate_custom_attrs(&union.custom_attrs));
471
472 if !has_custom_derive(&union.custom_attrs) {
474 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
475 }
476
477 if !has_custom_serde(&union.custom_attrs) {
479 output.push_str("#[serde(untagged)]\n");
480 }
481
482 output.push_str(&format!("pub enum {} {{\n", union.name));
483
484 for variant in &union.variants {
485 match &variant.primitive_type {
486 Some(t) => output.push_str(&format!(" {}({}),\n", variant.name, t)),
487 None => output.push_str(&format!(" {}({}),\n", variant.name, variant.name)),
488 }
489 }
490
491 output.push_str("}\n");
492 Ok(output)
493}
494
495fn generate_composition(
496 comp: &CompositionModel,
497 required_uses: &mut RequiredUses,
498) -> Result<String> {
499 let mut output = String::new();
500
501 output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
502 output.push_str(&generate_custom_attrs(&comp.custom_attrs));
503
504 if !has_custom_derive(&comp.custom_attrs) {
506 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
507 }
508
509 output.push_str(&format!("pub struct {} {{\n", comp.name));
510
511 for field in &comp.all_fields {
512 let field_type = match field.field_type.as_str() {
513 "String" => "String",
514 "f64" => "f64",
515 "i64" => "i64",
516 "bool" => "bool",
517 "DateTime" => {
518 *required_uses |= RequiredUses::DATETIME;
519 "DateTime<Utc>"
520 }
521 "Date" => {
522 *required_uses |= RequiredUses::DATE;
523 "NaiveDate"
524 }
525 "Uuid" => {
526 *required_uses |= RequiredUses::UUID;
527 "Uuid"
528 }
529 _ => &field.field_type,
530 };
531
532 let mut lowercased_name = to_snake_case(field.name.as_str());
533 if is_reserved_word(&lowercased_name) {
534 lowercased_name = format!("r#{lowercased_name}");
535 }
536
537 if lowercased_name != field.name {
539 output.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
540 }
541
542 if let Some(attrs) = &field.custom_attrs {
544 for attr in attrs {
545 output.push_str(&format!(" {attr}\n"));
546 }
547 }
548
549 if field.is_array_ref {
551 if field.is_required && !field.is_nullable {
552 output.push_str(&format!(" pub {lowercased_name}: Vec<{field_type}>,\n",));
553 } else {
554 output.push_str(&format!(
555 " pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
556 ));
557 }
558 } else if field.is_required && !field.is_nullable {
559 output.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
560 } else {
561 output.push_str(&format!(
562 " pub {lowercased_name}: Option<{field_type}>,\n",
563 ));
564 }
565 }
566
567 output.push_str("}\n");
568 Ok(output)
569}
570
571fn generate_enum(enum_model: &EnumModel) -> Result<String> {
572 let mut output = String::new();
573
574 output.push_str(&generate_description_docs(
575 &enum_model.description,
576 &enum_model.name,
577 "",
578 ));
579
580 output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
581
582 if !has_custom_derive(&enum_model.custom_attrs) {
584 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
585 }
586
587 output.push_str(&format!("pub enum {} {{\n", enum_model.name));
588
589 for (i, variant) in enum_model.variants.iter().enumerate() {
590 let original = variant.clone();
591
592 let mut rust_name = crate::parser::to_pascal_case(variant);
593
594 let serde_rename = if is_reserved_word(&rust_name) {
595 rust_name.push_str("Value");
596 Some(original)
597 } else if rust_name != original {
598 Some(original)
599 } else {
600 None
601 };
602
603 if let Some(rename) = serde_rename {
604 output.push_str(&format!(" #[serde(rename = \"{rename}\")]\n"));
605 }
606
607 if i + 1 == enum_model.variants.len() {
608 output.push_str(&format!(" {rust_name}\n"));
609 } else {
610 output.push_str(&format!(" {rust_name},\n"));
611 }
612 }
613
614 output.push_str("}\n");
615 Ok(output)
616}
617
618fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
619 let mut output = String::new();
620
621 output.push_str(&generate_description_docs(
622 &type_alias.description,
623 &type_alias.name,
624 "",
625 ));
626
627 output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
628 output.push_str(&format!(
629 "pub type {} = {};\n\n",
630 type_alias.name, type_alias.target_type
631 ));
632
633 Ok(output)
634}
635
636pub fn generate_rust_code(models: &[Model]) -> Result<String> {
637 let mut code = create_header();
638
639 code.push_str("use serde::{Serialize, Deserialize};\n");
640 code.push_str("use uuid::Uuid;\n");
641 code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
642
643 for model in models {
644 code.push_str(&format!("/// {}\n", model.name));
645 code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
646 code.push_str(&format!("pub struct {} {{\n", model.name));
647
648 for field in &model.fields {
649 let field_type = match field.field_type.as_str() {
650 "String" => "String",
651 "f64" => "f64",
652 "i64" => "i64",
653 "bool" => "bool",
654 "DateTime" => "DateTime<Utc>",
655 "Date" => "NaiveDate",
656 "Uuid" => "Uuid",
657 _ => &field.field_type,
658 };
659
660 let mut lowercased_name = to_snake_case(field.name.as_str());
661 if is_reserved_word(&lowercased_name) {
662 lowercased_name = format!("r#{lowercased_name}")
663 }
664
665 if lowercased_name != field.name {
667 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
668 }
669
670 if field.is_required {
671 code.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
672 } else {
673 code.push_str(&format!(
674 " pub {lowercased_name}: Option<{field_type}>,\n",
675 ));
676 }
677 }
678
679 code.push_str("}\n\n");
680 }
681
682 Ok(code)
683}
684
685pub fn generate_lib() -> Result<String> {
686 let mut code = create_header();
687 code.push_str("pub mod models;\n");
688
689 Ok(code)
690}