1use std::sync::OnceLock;
2
3use crate::{
4 models::{
5 CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
6 UnionModel, UnionType,
7 },
8 Result,
9};
10
11bitflags::bitflags! {
12 struct RequiredUses: u8 {
13 const UUID = 0b00000001;
14 const DATETIME = 0b00000010;
15 const DATE = 0b00000100;
16 }
17}
18
19static HDR: OnceLock<String> = OnceLock::new();
20
21fn create_header() -> String {
22 HDR.get_or_init(|| {
23 format!(
24 r#"
25//!
26//! Generated from an OAS specification by {}(v{})
27//!
28
29"#,
30 option_env!("CARGO_PKG_NAME").unwrap_or("openapi-model-generator"),
31 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
32 )
33 })
34 .clone()
35}
36
37const RUST_RESERVED_KEYWORDS: &[&str] = &[
38 "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
39 "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
40 "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
41 "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try",
42 "typeof", "unsized", "virtual", "yield",
43];
44
45const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
46const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
47
48fn is_reserved_word(string_to_check: &str) -> bool {
49 RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
50}
51
52fn generate_description_docs(
53 description: &Option<String>,
54 fallback_str: &str,
55 indent: &str,
56) -> String {
57 let mut output = String::new();
58 if let Some(desc) = description {
59 for line in desc.lines() {
60 output.push_str(&format!("{}/// {}\n", indent, line.trim()));
61 }
62 } else if !fallback_str.is_empty() {
63 output.push_str(&format!("{}/// {}\n", indent, fallback_str));
64 }
65
66 output
67}
68
69fn to_snake_case(name: &str) -> String {
70 let cleaned: String = name
71 .chars()
72 .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
73 .collect();
74
75 let mut snake = String::new();
76
77 for (i, c) in cleaned.chars().enumerate() {
78 if c.is_ascii_uppercase() {
79 if i != 0 {
80 snake.push('_');
81 }
82 snake.push(c.to_ascii_lowercase());
83 } else {
84 snake.push(c);
85 }
86 }
87 snake = snake.replace("__", "_");
88
89 if snake == "self" {
90 snake.push('_');
91 }
92
93 if snake
94 .chars()
95 .next()
96 .map(|c| c.is_ascii_digit())
97 .unwrap_or(false)
98 {
99 snake = format!("_{snake}");
100 }
101
102 snake
103}
104
105fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
107 if let Some(attrs) = custom_attrs {
108 attrs
109 .iter()
110 .any(|attr| attr.trim().starts_with("#[derive("))
111 } else {
112 false
113 }
114}
115
116fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
118 if let Some(attrs) = custom_attrs {
119 attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
120 } else {
121 false
122 }
123}
124
125fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
127 if let Some(attrs) = custom_attrs {
128 attrs
129 .iter()
130 .map(|attr| format!("{attr}\n"))
131 .collect::<String>()
132 } else {
133 String::new()
134 }
135}
136
137pub fn generate_models(
138 models: &[ModelType],
139 requests: &[RequestModel],
140 responses: &[ResponseModel],
141) -> Result<String> {
142 let mut models_code = String::new();
144 let mut required_uses = RequiredUses::empty();
145
146 for model_type in models {
147 match model_type {
148 ModelType::Struct(model) => {
149 models_code.push_str(&generate_model(model, &mut required_uses)?);
150 }
151 ModelType::Union(union) => {
152 models_code.push_str(&generate_union(union)?);
153 }
154 ModelType::Composition(comp) => {
155 models_code.push_str(&generate_composition(comp, &mut required_uses)?);
156 }
157 ModelType::Enum(enum_model) => {
158 models_code.push_str(&generate_enum(enum_model)?);
159 }
160 ModelType::TypeAlias(type_alias) => {
161 models_code.push_str(&generate_type_alias(type_alias)?);
162 }
163 }
164 }
165
166 for request in requests {
167 models_code.push_str(&generate_request_model(request)?);
168 }
169
170 for response in responses {
171 models_code.push_str(&generate_response_model(response)?);
172 }
173
174 let needs_uuid = required_uses.contains(RequiredUses::UUID);
176 let needs_datetime = required_uses.contains(RequiredUses::DATETIME);
177 let needs_date = required_uses.contains(RequiredUses::DATE);
178
179 let mut output = create_header();
181 output.push_str("use serde::{Serialize, Deserialize};\n");
182
183 if needs_uuid {
184 output.push_str("use uuid::Uuid;\n");
185 }
186
187 if needs_datetime || needs_date {
188 output.push_str("use chrono::{");
189 let mut chrono_imports = Vec::new();
190 if needs_datetime {
191 chrono_imports.push("DateTime");
192 }
193 if needs_date {
194 chrono_imports.push("NaiveDate");
195 }
196 if needs_datetime {
197 chrono_imports.push("Utc");
198 }
199 output.push_str(&chrono_imports.join(", "));
200 output.push_str("};\n");
201 }
202
203 output.push('\n');
204 output.push_str(&models_code);
205
206 Ok(output)
207}
208
209fn generate_model(model: &Model, required_uses: &mut RequiredUses) -> Result<String> {
210 let mut output = String::new();
211
212 output.push_str(&generate_description_docs(
213 &model.description,
214 &model.name,
215 "",
216 ));
217
218 output.push_str(&generate_custom_attrs(&model.custom_attrs));
219
220 if !has_custom_derive(&model.custom_attrs) {
222 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
223 }
224
225 output.push_str(&format!("pub struct {} {{\n", model.name));
226
227 for field in &model.fields {
228 let field_type = match field.field_type.as_str() {
229 "String" => "String",
230 "f64" => "f64",
231 "i64" => "i64",
232 "bool" => "bool",
233 "DateTime" => {
234 *required_uses |= RequiredUses::DATETIME;
235 "DateTime<Utc>"
236 }
237 "Date" => {
238 *required_uses |= RequiredUses::DATE;
239 "NaiveDate"
240 }
241 "Uuid" => {
242 *required_uses |= RequiredUses::UUID;
243 "Uuid"
244 }
245 _ => &field.field_type,
246 };
247
248 let mut lowercased_name = to_snake_case(field.name.as_str());
249 if is_reserved_word(&lowercased_name) {
250 lowercased_name = format!("r#{lowercased_name}")
251 }
252
253 output.push_str(&generate_description_docs(&field.description, "", " "));
255
256 if lowercased_name != field.name {
258 output.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
259 }
260
261 if field.should_flatten() {
262 output.push_str(" #[serde(flatten)]\n");
263 }
264
265 if field.is_array_ref {
267 if field.is_required && !field.is_nullable {
268 output.push_str(&format!(" pub {lowercased_name}: Vec<{field_type}>,\n",));
269 } else {
270 output.push_str(&format!(
271 " pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
272 ));
273 }
274 } else if field.is_required && !field.is_nullable {
275 output.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
276 } else {
277 output.push_str(&format!(
278 " pub {lowercased_name}: Option<{field_type}>,\n",
279 ));
280 }
281 }
282
283 output.push_str("}\n\n");
284 Ok(output)
285}
286
287fn generate_request_model(request: &RequestModel) -> Result<String> {
288 let mut output = String::new();
289 tracing::info!("Generating request model");
290 tracing::info!("{:#?}", request);
291
292 if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
293 return Ok(String::new());
294 }
295
296 output.push_str(&format!("/// {}\n", request.name));
297 output.push_str("#[derive(Debug, Clone, Serialize)]\n");
298 output.push_str(&format!("pub struct {} {{\n", request.name));
299 output.push_str(&format!(" pub body: {},\n", request.schema));
300 output.push_str("}\n");
301 Ok(output)
302}
303
304fn generate_response_model(response: &ResponseModel) -> Result<String> {
305 if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
306 return Ok(String::new());
307 }
308
309 let type_name = format!("{}{}", response.name, response.status_code);
310
311 let mut output = String::new();
312
313 output.push_str(&generate_description_docs(
314 &response.description,
315 &type_name,
316 "",
317 ));
318
319 output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
320 output.push_str(&format!("pub struct {type_name} {{\n"));
321 output.push_str(&format!(" pub body: {},\n", response.schema));
322 output.push_str("}\n");
323
324 Ok(output)
325}
326
327fn generate_union(union: &UnionModel) -> Result<String> {
328 let mut output = String::new();
329
330 output.push_str(&format!(
331 "/// {} ({})\n",
332 union.name,
333 match union.union_type {
334 UnionType::OneOf => "oneOf",
335 UnionType::AnyOf => "anyOf",
336 }
337 ));
338 output.push_str(&generate_custom_attrs(&union.custom_attrs));
339
340 if !has_custom_derive(&union.custom_attrs) {
342 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
343 }
344
345 if !has_custom_serde(&union.custom_attrs) {
347 output.push_str("#[serde(untagged)]\n");
348 }
349
350 output.push_str(&format!("pub enum {} {{\n", union.name));
351
352 for variant in &union.variants {
353 match &variant.primitive_type {
354 Some(t) => output.push_str(&format!(" {}({}),\n", variant.name, t)),
355 None => output.push_str(&format!(" {}({}),\n", variant.name, variant.name)),
356 }
357 }
358
359 output.push_str("}\n");
360 Ok(output)
361}
362
363fn generate_composition(
364 comp: &CompositionModel,
365 required_uses: &mut RequiredUses,
366) -> Result<String> {
367 let mut output = String::new();
368
369 output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
370 output.push_str(&generate_custom_attrs(&comp.custom_attrs));
371
372 if !has_custom_derive(&comp.custom_attrs) {
374 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
375 }
376
377 output.push_str(&format!("pub struct {} {{\n", comp.name));
378
379 for field in &comp.all_fields {
380 let field_type = match field.field_type.as_str() {
381 "String" => "String",
382 "f64" => "f64",
383 "i64" => "i64",
384 "bool" => "bool",
385 "DateTime" => {
386 *required_uses |= RequiredUses::DATETIME;
387 "DateTime<Utc>"
388 }
389 "Date" => {
390 *required_uses |= RequiredUses::DATE;
391 "NaiveDate"
392 }
393 "Uuid" => {
394 *required_uses |= RequiredUses::UUID;
395 "Uuid"
396 }
397 _ => &field.field_type,
398 };
399
400 let mut lowercased_name = to_snake_case(field.name.as_str());
401 if is_reserved_word(&lowercased_name) {
402 lowercased_name = format!("r#{lowercased_name}");
403 }
404
405 if lowercased_name != field.name {
407 output.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
408 }
409
410 if field.is_array_ref {
412 if field.is_required && !field.is_nullable {
413 output.push_str(&format!(" pub {lowercased_name}: Vec<{field_type}>,\n",));
414 } else {
415 output.push_str(&format!(
416 " pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
417 ));
418 }
419 } else if field.is_required && !field.is_nullable {
420 output.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
421 } else {
422 output.push_str(&format!(
423 " pub {lowercased_name}: Option<{field_type}>,\n",
424 ));
425 }
426 }
427
428 output.push_str("}\n");
429 Ok(output)
430}
431
432fn generate_enum(enum_model: &EnumModel) -> Result<String> {
433 let mut output = String::new();
434
435 output.push_str(&generate_description_docs(
436 &enum_model.description,
437 &enum_model.name,
438 "",
439 ));
440
441 output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
442
443 if !has_custom_derive(&enum_model.custom_attrs) {
445 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
446 }
447
448 output.push_str(&format!("pub enum {} {{\n", enum_model.name));
449
450 for (i, variant) in enum_model.variants.iter().enumerate() {
451 let original = variant.clone();
452
453 let mut rust_name = crate::parser::to_pascal_case(variant);
454
455 let serde_rename = if is_reserved_word(&rust_name) {
456 rust_name.push_str("Value");
457 Some(original)
458 } else if rust_name != original {
459 Some(original)
460 } else {
461 None
462 };
463
464 if let Some(rename) = serde_rename {
465 output.push_str(&format!(" #[serde(rename = \"{rename}\")]\n"));
466 }
467
468 if i + 1 == enum_model.variants.len() {
469 output.push_str(&format!(" {rust_name}\n"));
470 } else {
471 output.push_str(&format!(" {rust_name},\n"));
472 }
473 }
474
475 output.push_str("}\n");
476 Ok(output)
477}
478
479fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
480 let mut output = String::new();
481
482 output.push_str(&generate_description_docs(
483 &type_alias.description,
484 &type_alias.name,
485 "",
486 ));
487
488 output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
489 output.push_str(&format!(
490 "pub type {} = {};\n\n",
491 type_alias.name, type_alias.target_type
492 ));
493
494 Ok(output)
495}
496
497pub fn generate_rust_code(models: &[Model]) -> Result<String> {
498 let mut code = create_header();
499
500 code.push_str("use serde::{Serialize, Deserialize};\n");
501 code.push_str("use uuid::Uuid;\n");
502 code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
503
504 for model in models {
505 code.push_str(&format!("/// {}\n", model.name));
506 code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
507 code.push_str(&format!("pub struct {} {{\n", model.name));
508
509 for field in &model.fields {
510 let field_type = match field.field_type.as_str() {
511 "String" => "String",
512 "f64" => "f64",
513 "i64" => "i64",
514 "bool" => "bool",
515 "DateTime" => "DateTime<Utc>",
516 "Date" => "NaiveDate",
517 "Uuid" => "Uuid",
518 _ => &field.field_type,
519 };
520
521 let mut lowercased_name = to_snake_case(field.name.as_str());
522 if is_reserved_word(&lowercased_name) {
523 lowercased_name = format!("r#{lowercased_name}")
524 }
525
526 if lowercased_name != field.name {
528 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
529 }
530
531 if field.is_required {
532 code.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
533 } else {
534 code.push_str(&format!(
535 " pub {lowercased_name}: Option<{field_type}>,\n",
536 ));
537 }
538 }
539
540 code.push_str("}\n\n");
541 }
542
543 Ok(code)
544}
545
546pub fn generate_lib() -> Result<String> {
547 let mut code = create_header();
548 code.push_str("pub mod models;\n");
549
550 Ok(code)
551}