openapi_model_generator/
generator.rs

1use crate::{
2    models::{
3        CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, TypeAliasModel,
4        UnionModel, UnionType,
5    },
6    Result,
7};
8
9const RUST_RESERVED_KEYWORDS: &[&str] = &[
10    "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
11    "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
12    "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
13    "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try",
14    "typeof", "unsized", "virtual", "yield",
15];
16
17const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
18const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
19
20fn is_reserved_word(string_to_check: &str) -> bool {
21    RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
22}
23
24/// Checks if custom attributes contain a derive attribute
25fn has_custom_derive(custom_attrs: &Option<Vec<String>>) -> bool {
26    if let Some(attrs) = custom_attrs {
27        attrs
28            .iter()
29            .any(|attr| attr.trim().starts_with("#[derive("))
30    } else {
31        false
32    }
33}
34
35/// Checks if custom attributes contain a serde attribute
36fn has_custom_serde(custom_attrs: &Option<Vec<String>>) -> bool {
37    if let Some(attrs) = custom_attrs {
38        attrs.iter().any(|attr| attr.trim().starts_with("#[serde("))
39    } else {
40        false
41    }
42}
43
44/// Generates custom attributes from x-rust-attrs
45fn generate_custom_attrs(custom_attrs: &Option<Vec<String>>) -> String {
46    if let Some(attrs) = custom_attrs {
47        attrs
48            .iter()
49            .map(|attr| format!("{attr}\n"))
50            .collect::<String>()
51    } else {
52        String::new()
53    }
54}
55
56pub fn generate_models(
57    models: &[ModelType],
58    requests: &[RequestModel],
59    responses: &[ResponseModel],
60) -> Result<String> {
61    // First, generate all model code to determine which imports are needed
62    let mut models_code = String::new();
63
64    for model_type in models {
65        match model_type {
66            ModelType::Struct(model) => {
67                models_code.push_str(&generate_model(model)?);
68            }
69            ModelType::Union(union) => {
70                models_code.push_str(&generate_union(union)?);
71            }
72            ModelType::Composition(comp) => {
73                models_code.push_str(&generate_composition(comp)?);
74            }
75            ModelType::Enum(enum_model) => {
76                models_code.push_str(&generate_enum(enum_model)?);
77            }
78            ModelType::TypeAlias(type_alias) => {
79                models_code.push_str(&generate_type_alias(type_alias)?);
80            }
81        }
82    }
83
84    for request in requests {
85        models_code.push_str(&generate_request_model(request)?);
86    }
87
88    for response in responses {
89        models_code.push_str(&generate_response_model(response)?);
90    }
91
92    // Determine which imports are actually needed
93    let needs_uuid = models_code.contains("Uuid");
94    let needs_datetime = models_code.contains("DateTime<Utc>");
95    let needs_date = models_code.contains("NaiveDate");
96
97    // Build final output with only necessary imports
98    let mut output = String::new();
99    output.push_str("use serde::{Serialize, Deserialize};\n");
100
101    if needs_uuid {
102        output.push_str("use uuid::Uuid;\n");
103    }
104
105    if needs_datetime || needs_date {
106        output.push_str("use chrono::{");
107        let mut chrono_imports = Vec::new();
108        if needs_datetime {
109            chrono_imports.push("DateTime");
110        }
111        if needs_date {
112            chrono_imports.push("NaiveDate");
113        }
114        if needs_datetime {
115            chrono_imports.push("Utc");
116        }
117        output.push_str(&chrono_imports.join(", "));
118        output.push_str("};\n");
119    }
120
121    output.push('\n');
122    output.push_str(&models_code);
123
124    Ok(output)
125}
126
127fn generate_model(model: &Model) -> Result<String> {
128    let mut output = String::new();
129
130    if !model.name.is_empty() {
131        output.push_str(&format!("/// {}\n", model.name));
132    }
133
134    output.push_str(&generate_custom_attrs(&model.custom_attrs));
135
136    // Only add default derive if custom_attrs doesn't already contain a derive directive
137    if !has_custom_derive(&model.custom_attrs) {
138        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
139    }
140
141    output.push_str(&format!("pub struct {} {{\n", model.name));
142
143    for field in &model.fields {
144        let field_type = match field.field_type.as_str() {
145            "String" => "String",
146            "f64" => "f64",
147            "i64" => "i64",
148            "bool" => "bool",
149            "DateTime" => "DateTime<Utc>",
150            "Date" => "NaiveDate",
151            "Uuid" => "Uuid",
152            _ => &field.field_type,
153        };
154
155        let mut lowercased_name = field.name.to_lowercase();
156        if is_reserved_word(&lowercased_name) {
157            lowercased_name = format!("r#{lowercased_name}")
158        }
159
160        // Only add serde rename if the Rust field name differs from the original field name
161        if lowercased_name != field.name {
162            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
163        }
164
165        if field.is_required && !field.is_nullable {
166            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
167        } else {
168            output.push_str(&format!(
169                "    pub {lowercased_name}: Option<{field_type}>,\n",
170            ));
171        }
172    }
173
174    output.push_str("}\n\n");
175    Ok(output)
176}
177
178fn generate_request_model(request: &RequestModel) -> Result<String> {
179    let mut output = String::new();
180    tracing::info!("Generating request model");
181    tracing::info!("{:#?}", request);
182
183    if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
184        return Ok(String::new());
185    }
186
187    output.push_str(&format!("/// {}\n", request.name));
188    output.push_str("#[derive(Debug, Clone, Serialize)]\n");
189    output.push_str(&format!("pub struct {} {{\n", request.name));
190    output.push_str(&format!("    pub body: {},\n", request.schema));
191    output.push_str("}\n");
192    Ok(output)
193}
194
195fn generate_response_model(response: &ResponseModel) -> Result<String> {
196    if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
197        return Ok(String::new());
198    }
199
200    let type_name = format!("{}{}", response.name, response.status_code);
201
202    let mut output = String::new();
203
204    if let Some(desc) = &response.description {
205        for line in desc.lines() {
206            output.push_str(&format!("/// {}\n", line.trim()));
207        }
208    } else {
209        output.push_str(&format!("/// {type_name}\n"));
210    }
211
212    output.push_str("#[derive(Debug, Clone, Deserialize)]\n");
213    output.push_str(&format!("pub struct {type_name} {{\n"));
214    output.push_str(&format!("    pub body: {},\n", response.schema));
215    output.push_str("}\n");
216
217    Ok(output)
218}
219
220fn generate_union(union: &UnionModel) -> Result<String> {
221    let mut output = String::new();
222
223    output.push_str(&format!(
224        "/// {} ({})\n",
225        union.name,
226        match union.union_type {
227            UnionType::OneOf => "oneOf",
228            UnionType::AnyOf => "anyOf",
229        }
230    ));
231    output.push_str(&generate_custom_attrs(&union.custom_attrs));
232
233    // Only add default derive if custom_attrs doesn't already contain a derive
234    if !has_custom_derive(&union.custom_attrs) {
235        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
236    }
237
238    // Only add default serde(untagged) if custom_attrs doesn't already contain a serde attribute
239    if !has_custom_serde(&union.custom_attrs) {
240        output.push_str("#[serde(untagged)]\n");
241    }
242
243    output.push_str(&format!("pub enum {} {{\n", union.name));
244
245    for variant in &union.variants {
246        output.push_str(&format!("    {}({}),\n", variant.name, variant.name));
247    }
248
249    output.push_str("}\n");
250    Ok(output)
251}
252
253fn generate_composition(comp: &CompositionModel) -> Result<String> {
254    let mut output = String::new();
255
256    output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
257    output.push_str(&generate_custom_attrs(&comp.custom_attrs));
258
259    // Only add default derive if custom_attrs doesn't already contain a derive
260    if !has_custom_derive(&comp.custom_attrs) {
261        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
262    }
263
264    output.push_str(&format!("pub struct {} {{\n", comp.name));
265
266    for field in &comp.all_fields {
267        let field_type = match field.field_type.as_str() {
268            "String" => "String",
269            "f64" => "f64",
270            "i64" => "i64",
271            "bool" => "bool",
272            "DateTime" => "DateTime<Utc>",
273            "Date" => "NaiveDate",
274            "Uuid" => "Uuid",
275            _ => &field.field_type,
276        };
277
278        let mut lowercased_name = field.name.to_lowercase();
279        if is_reserved_word(&lowercased_name) {
280            lowercased_name = format!("r#{lowercased_name}");
281        }
282
283        // Only add serde rename if the Rust field name differs from the original field name
284        if lowercased_name != field.name {
285            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
286        }
287
288        if field.is_required && !field.is_nullable {
289            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n"));
290        } else {
291            output.push_str(&format!(
292                "    pub {lowercased_name}: Option<{field_type}>,\n"
293            ));
294        }
295    }
296
297    output.push_str("}\n");
298    Ok(output)
299}
300
301fn generate_enum(enum_model: &EnumModel) -> Result<String> {
302    let mut output = String::new();
303
304    if let Some(description) = &enum_model.description {
305        output.push_str(&format!("/// {description}\n"));
306    } else {
307        output.push_str(&format!("/// {}\n", enum_model.name));
308    }
309
310    output.push_str(&generate_custom_attrs(&enum_model.custom_attrs));
311
312    // Only add default derive if custom_attrs doesn't already contain a derive
313    if !has_custom_derive(&enum_model.custom_attrs) {
314        output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
315    }
316
317    output.push_str(&format!("pub enum {} {{\n", enum_model.name));
318
319    for (i, variant) in enum_model.variants.iter().enumerate() {
320        let original = variant.clone();
321
322        let mut chars = variant.chars();
323        let first_char = chars.next().unwrap().to_ascii_uppercase();
324        let rest: String = chars.collect();
325        let mut rust_name = format!("{first_char}{rest}");
326
327        let serde_rename = if is_reserved_word(&rust_name) {
328            rust_name.push_str("Value");
329            Some(original)
330        } else if rust_name != original {
331            Some(original)
332        } else {
333            None
334        };
335
336        if let Some(rename) = serde_rename {
337            output.push_str(&format!("    #[serde(rename = \"{rename}\")]\n"));
338        }
339
340        if i + 1 == enum_model.variants.len() {
341            output.push_str(&format!("    {rust_name}\n"));
342        } else {
343            output.push_str(&format!("    {rust_name},\n"));
344        }
345    }
346
347    output.push_str("}\n");
348    Ok(output)
349}
350
351fn generate_type_alias(type_alias: &TypeAliasModel) -> Result<String> {
352    let mut output = String::new();
353
354    if let Some(description) = &type_alias.description {
355        output.push_str(&format!("/// {description}\n"));
356    } else {
357        output.push_str(&format!("/// {}\n", type_alias.name));
358    }
359
360    output.push_str(&generate_custom_attrs(&type_alias.custom_attrs));
361    output.push_str(&format!(
362        "pub type {} = {};\n\n",
363        type_alias.name, type_alias.target_type
364    ));
365
366    Ok(output)
367}
368
369pub fn generate_rust_code(models: &[Model]) -> Result<String> {
370    let mut code = String::new();
371
372    code.push_str("use serde::{Serialize, Deserialize};\n");
373    code.push_str("use uuid::Uuid;\n");
374    code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
375
376    for model in models {
377        code.push_str(&format!("/// {}\n", model.name));
378        code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
379        code.push_str(&format!("pub struct {} {{\n", model.name));
380
381        for field in &model.fields {
382            let field_type = match field.field_type.as_str() {
383                "String" => "String",
384                "f64" => "f64",
385                "i64" => "i64",
386                "bool" => "bool",
387                "DateTime" => "DateTime<Utc>",
388                "Date" => "NaiveDate",
389                "Uuid" => "Uuid",
390                _ => &field.field_type,
391            };
392
393            let mut lowercased_name = field.name.to_lowercase();
394            if is_reserved_word(&lowercased_name) {
395                lowercased_name = format!("r#{lowercased_name}")
396            }
397
398            // Only add serde rename if the Rust field name differs from the original field name
399            if lowercased_name != field.name {
400                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
401            }
402
403            if field.is_required {
404                code.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
405            } else {
406                code.push_str(&format!(
407                    "    pub {lowercased_name}: Option<{field_type}>,\n",
408                ));
409            }
410        }
411
412        code.push_str("}\n\n");
413    }
414
415    Ok(code)
416}
417
418pub fn generate_lib() -> Result<String> {
419    let mut code = String::new();
420    code.push_str("pub mod models;\n");
421
422    Ok(code)
423}