1use std::sync::OnceLock;
2
3use crate::{
4 models::{
5 CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
6 UnionModel, UnionType,
7 },
8 Result,
9};
10
11static HDR: OnceLock<String> = OnceLock::new();
12
13fn create_header() -> String {
14 HDR.get_or_init(|| {
15 format!(
16 r#"
17//!
18//! Generated from an OAS specification by {}(v{})
19//!
20
21"#,
22 option_env!("CARGO_PKG_NAME").unwrap_or("openapi-model-generator"),
23 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
24 )
25 })
26 .clone()
27}
28
29const RUST_RESERVED_KEYWORDS: &[&str] = &[
30 "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
31 "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
32 "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
33 "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try",
34 "typeof", "unsized", "virtual", "yield",
35];
36
37const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
38const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
39
40fn is_reserved_word(string_to_check: &str) -> bool {
41 RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
42}
43
44fn generate_description_docs(
45 description: &Option<String>,
46 fallback_str: &str,
47 indent: &str,
48) -> String {
49 let mut output = String::new();
50 if let Some(desc) = description {
51 for line in desc.lines() {
52 output.push_str(&format!("{}/// {}\n", indent, line.trim()));
53 }
54 } else if !fallback_str.is_empty() {
55 output.push_str(&format!("{}/// {}\n", indent, fallback_str));
56 }
57
58 output
59}
60
61fn to_snake_case(name: &str) -> String {
62 let cleaned: String = name
63 .chars()
64 .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
65 .collect();
66
67 let mut snake = String::new();
68
69 for (i, c) in cleaned.chars().enumerate() {
70 if c.is_ascii_uppercase() {
71 if i != 0 {
72 snake.push('_');
73 }
74 snake.push(c.to_ascii_lowercase());
75 } else {
76 snake.push(c);
77 }
78 }
79 snake = snake.replace("__", "_");
80
81 if snake == "self" {
82 snake.push('_');
83 }
84
85 if snake
86 .chars()
87 .next()
88 .map(|c| c.is_ascii_digit())
89 .unwrap_or(false)
90 {
91 snake = format!("_{snake}");
92 }
93
94 snake
95}
96
97fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
99 if let Some(attrs) = custom_attrs {
100 attrs
101 .iter()
102 .any(|attr| attr.trim().starts_with("#[derive("))
103 } else {
104 false
105 }
106}
107
108fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
110 if let Some(attrs) = custom_attrs {
111 attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
112 } else {
113 false
114 }
115}
116
117fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
119 if let Some(attrs) = custom_attrs {
120 attrs
121 .iter()
122 .map(|attr| format!("{attr}\n"))
123 .collect::<String>()
124 } else {
125 String::new()
126 }
127}
128
129pub fn generate_models(
130 models: &[ModelType],
131 requests: &[RequestModel],
132 responses: &[ResponseModel],
133) -> Result<String> {
134 let mut models_code = String::new();
136
137 for model_type in models {
138 match model_type {
139 ModelType::Struct(model) => {
140 models_code.push_str(&generate_model(model)?);
141 }
142 ModelType::Union(union) => {
143 models_code.push_str(&generate_union(union)?);
144 }
145 ModelType::Composition(comp) => {
146 models_code.push_str(&generate_composition(comp)?);
147 }
148 ModelType::Enum(enum_model) => {
149 models_code.push_str(&generate_enum(enum_model)?);
150 }
151 ModelType::TypeAlias(type_alias) => {
152 models_code.push_str(&generate_type_alias(type_alias)?);
153 }
154 }
155 }
156
157 for request in requests {
158 models_code.push_str(&generate_request_model(request)?);
159 }
160
161 for response in responses {
162 models_code.push_str(&generate_response_model(response)?);
163 }
164
165 let needs_uuid = models_code.contains("Uuid");
167 let needs_datetime = models_code.contains("DateTime<Utc>");
168 let needs_date = models_code.contains("NaiveDate");
169
170 let mut output = create_header();
172 output.push_str("use serde::{Serialize, Deserialize};\n");
173
174 if needs_uuid {
175 output.push_str("use uuid::Uuid;\n");
176 }
177
178 if needs_datetime || needs_date {
179 output.push_str("use chrono::{");
180 let mut chrono_imports = Vec::new();
181 if needs_datetime {
182 chrono_imports.push("DateTime");
183 }
184 if needs_date {
185 chrono_imports.push("NaiveDate");
186 }
187 if needs_datetime {
188 chrono_imports.push("Utc");
189 }
190 output.push_str(&chrono_imports.join(", "));
191 output.push_str("};\n");
192 }
193
194 output.push('\n');
195 output.push_str(&models_code);
196
197 Ok(output)
198}
199
200fn generate_model(model: &Model) -> Result<String> {
201 let mut output = String::new();
202
203 output.push_str(&generate_description_docs(
204 &model.description,
205 &model.name,
206 "",
207 ));
208
209 output.push_str(&generate_custom_attrs(&model.custom_attrs));
210
211 if !has_custom_derive(&model.custom_attrs) {
213 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
214 }
215
216 output.push_str(&format!("pub struct {} {{\n", model.name));
217
218 for field in &model.fields {
219 let field_type = match field.field_type.as_str() {
220 "String" => "String",
221 "f64" => "f64",
222 "i64" => "i64",
223 "bool" => "bool",
224 "DateTime" => "DateTime<Utc>",
225 "Date" => "NaiveDate",
226 "Uuid" => "Uuid",
227 _ => &field.field_type,
228 };
229
230 let mut lowercased_name = to_snake_case(field.name.as_str());
231 if is_reserved_word(&lowercased_name) {
232 lowercased_name = format!("r#{lowercased_name}")
233 }
234
235 output.push_str(&generate_description_docs(&field.description, "", " "));
237
238 if lowercased_name != field.name {
240 output.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
241 }
242
243 if field.should_flatten() {
244 output.push_str(" #[serde(flatten)]\n");
245 }
246
247 if field.is_array_ref {
249 if field.is_required && !field.is_nullable {
250 output.push_str(&format!(" pub {lowercased_name}: Vec<{field_type}>,\n",));
251 } else {
252 output.push_str(&format!(
253 " pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
254 ));
255 }
256 } else if field.is_required && !field.is_nullable {
257 output.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
258 } else {
259 output.push_str(&format!(
260 " pub {lowercased_name}: Option<{field_type}>,\n",
261 ));
262 }
263 }
264
265 output.push_str("}\n\n");
266 Ok(output)
267}
268
269fn generate_request_model(request: &RequestModel) -> Result<String> {
270 let mut output = String::new();
271 tracing::info!("Generating request model");
272 tracing::info!("{:#?}", request);
273
274 if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
275 return Ok(String::new());
276 }
277
278 output.push_str(&format!("/// {}\n", request.name));
279 output.push_str("#[derive(Debug, Clone, Serialize)]\n");
280 output.push_str(&format!("pub struct {} {{\n", request.name));
281 output.push_str(&format!(" pub body: {},\n", request.schema));
282 output.push_str("}\n");
283 Ok(output)
284}
285
286fn generate_response_model(response: &ResponseModel) -> Result<String> {
287 if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
288 return Ok(String::new());
289 }
290
291 let type_name = format!("{}{}", response.name, response.status_code);
292
293 let mut output = String::new();
294
295 output.push_str(&generate_description_docs(
296 &response.description,
297 &type_name,
298 "",
299 ));
300
301 output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
302 output.push_str(&format!("pub struct {type_name} {{\n"));
303 output.push_str(&format!(" pub body: {},\n", response.schema));
304 output.push_str("}\n");
305
306 Ok(output)
307}
308
309fn generate_union(union: &UnionModel) -> Result<String> {
310 let mut output = String::new();
311
312 output.push_str(&format!(
313 "/// {} ({})\n",
314 union.name,
315 match union.union_type {
316 UnionType::OneOf => "oneOf",
317 UnionType::AnyOf => "anyOf",
318 }
319 ));
320 output.push_str(&generate_custom_attrs(&union.custom_attrs));
321
322 if !has_custom_derive(&union.custom_attrs) {
324 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
325 }
326
327 if !has_custom_serde(&union.custom_attrs) {
329 output.push_str("#[serde(untagged)]\n");
330 }
331
332 output.push_str(&format!("pub enum {} {{\n", union.name));
333
334 for variant in &union.variants {
335 match &variant.primitive_type {
336 Some(t) => output.push_str(&format!(" {}({}),\n", variant.name, t)),
337 None => output.push_str(&format!(" {}({}),\n", variant.name, variant.name)),
338 }
339 }
340
341 output.push_str("}\n");
342 Ok(output)
343}
344
345fn generate_composition(comp: &CompositionModel) -> Result<String> {
346 let mut output = String::new();
347
348 output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
349 output.push_str(&generate_custom_attrs(&comp.custom_attrs));
350
351 if !has_custom_derive(&comp.custom_attrs) {
353 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
354 }
355
356 output.push_str(&format!("pub struct {} {{\n", comp.name));
357
358 for field in &comp.all_fields {
359 let field_type = match field.field_type.as_str() {
360 "String" => "String",
361 "f64" => "f64",
362 "i64" => "i64",
363 "bool" => "bool",
364 "DateTime" => "DateTime<Utc>",
365 "Date" => "NaiveDate",
366 "Uuid" => "Uuid",
367 _ => &field.field_type,
368 };
369
370 let mut lowercased_name = to_snake_case(field.name.as_str());
371 if is_reserved_word(&lowercased_name) {
372 lowercased_name = format!("r#{lowercased_name}");
373 }
374
375 if lowercased_name != field.name {
377 output.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
378 }
379
380 if field.is_array_ref {
382 if field.is_required && !field.is_nullable {
383 output.push_str(&format!(" pub {lowercased_name}: Vec<{field_type}>,\n",));
384 } else {
385 output.push_str(&format!(
386 " pub {lowercased_name}: Option<Vec<{field_type}>>,\n",
387 ));
388 }
389 } else if field.is_required && !field.is_nullable {
390 output.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
391 } else {
392 output.push_str(&format!(
393 " pub {lowercased_name}: Option<{field_type}>,\n",
394 ));
395 }
396 }
397
398 output.push_str("}\n");
399 Ok(output)
400}
401
402fn generate_enum(enum_model: &EnumModel) -> Result<String> {
403 let mut output = String::new();
404
405 output.push_str(&generate_description_docs(
406 &enum_model.description,
407 &enum_model.name,
408 "",
409 ));
410
411 output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
412
413 if !has_custom_derive(&enum_model.custom_attrs) {
415 output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
416 }
417
418 output.push_str(&format!("pub enum {} {{\n", enum_model.name));
419
420 for (i, variant) in enum_model.variants.iter().enumerate() {
421 let original = variant.clone();
422
423 let mut rust_name = crate::parser::to_pascal_case(variant);
424
425 let serde_rename = if is_reserved_word(&rust_name) {
426 rust_name.push_str("Value");
427 Some(original)
428 } else if rust_name != original {
429 Some(original)
430 } else {
431 None
432 };
433
434 if let Some(rename) = serde_rename {
435 output.push_str(&format!(" #[serde(rename = \"{rename}\")]\n"));
436 }
437
438 if i + 1 == enum_model.variants.len() {
439 output.push_str(&format!(" {rust_name}\n"));
440 } else {
441 output.push_str(&format!(" {rust_name},\n"));
442 }
443 }
444
445 output.push_str("}\n");
446 Ok(output)
447}
448
449fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
450 let mut output = String::new();
451
452 output.push_str(&generate_description_docs(
453 &type_alias.description,
454 &type_alias.name,
455 "",
456 ));
457
458 output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
459 output.push_str(&format!(
460 "pub type {} = {};\n\n",
461 type_alias.name, type_alias.target_type
462 ));
463
464 Ok(output)
465}
466
467pub fn generate_rust_code(models: &[Model]) -> Result<String> {
468 let mut code = create_header();
469
470 code.push_str("use serde::{Serialize, Deserialize};\n");
471 code.push_str("use uuid::Uuid;\n");
472 code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
473
474 for model in models {
475 code.push_str(&format!("/// {}\n", model.name));
476 code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
477 code.push_str(&format!("pub struct {} {{\n", model.name));
478
479 for field in &model.fields {
480 let field_type = match field.field_type.as_str() {
481 "String" => "String",
482 "f64" => "f64",
483 "i64" => "i64",
484 "bool" => "bool",
485 "DateTime" => "DateTime<Utc>",
486 "Date" => "NaiveDate",
487 "Uuid" => "Uuid",
488 _ => &field.field_type,
489 };
490
491 let mut lowercased_name = to_snake_case(field.name.as_str());
492 if is_reserved_word(&lowercased_name) {
493 lowercased_name = format!("r#{lowercased_name}")
494 }
495
496 if lowercased_name != field.name {
498 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.name));
499 }
500
501 if field.is_required {
502 code.push_str(&format!(" pub {lowercased_name}: {field_type},\n",));
503 } else {
504 code.push_str(&format!(
505 " pub {lowercased_name}: Option<{field_type}>,\n",
506 ));
507 }
508 }
509
510 code.push_str("}\n\n");
511 }
512
513 Ok(code)
514}
515
516pub fn generate_lib() -> Result<String> {
517 let mut code = create_header();
518 code.push_str("pub mod models;\n");
519
520 Ok(code)
521}