openapi_model_generator/
generator.rs

1use crate::{
2    models::{
3        CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, UnionModel,
4        UnionType,
5    },
6    Result,
7};
8
9const RUST_RESERVED_KEYWORDS: &[&str] = &[
10    "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
11    "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
12    "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
13    "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try",
14    "typeof", "unsized", "virtual", "yield",
15];
16
17const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
18const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
19
20fn is_reserved_word(string_to_check: &str) -> bool {
21    RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
22}
23
24pub fn generate_models(
25    models: &[ModelType],
26    requests: &[RequestModel],
27    responses: &[ResponseModel],
28) -> Result<String> {
29    let mut output = String::new();
30
31    output.push_str("use serde::{Serialize, Deserialize};\n");
32    output.push_str("use uuid::Uuid;\n");
33    output.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
34
35    for model_type in models {
36        match model_type {
37            ModelType::Struct(model) => {
38                output.push_str(&generate_model(model)?);
39                output.push('\n');
40            }
41            ModelType::Union(union) => {
42                output.push_str(&generate_union(union)?);
43                output.push('\n');
44            }
45            ModelType::Composition(comp) => {
46                output.push_str(&generate_composition(comp)?);
47                output.push('\n');
48            }
49            ModelType::Enum(enum_model) => {
50                output.push_str(&generate_enum(enum_model)?);
51                output.push('\n');
52            }
53        }
54    }
55
56    for request in requests {
57        output.push_str(&generate_request_model(request)?);
58        output.push('\n');
59    }
60
61    for response in responses {
62        output.push_str(&generate_response_model(response)?);
63        output.push('\n');
64    }
65
66    Ok(output)
67}
68
69fn generate_model(model: &Model) -> Result<String> {
70    let mut output = String::new();
71
72    if !model.name.is_empty() {
73        output.push_str(&format!("/// {}\n", model.name));
74    }
75
76    output.push_str("#[derive(Debug, Serialize, Deserialize)]\n");
77    output.push_str(&format!("pub struct {} {{\n", model.name));
78
79    for field in &model.fields {
80        let field_type = match field.field_type.as_str() {
81            "String" => "String",
82            "f64" => "f64",
83            "i64" => "i64",
84            "bool" => "bool",
85            "DateTime" => "DateTime<Utc>",
86            "Date" => "NaiveDate",
87            "Uuid" => "Uuid",
88            _ => &field.field_type,
89        };
90
91        let mut lowercased_name = field.name.to_lowercase();
92        if is_reserved_word(&lowercased_name) {
93            lowercased_name = format!("r#{lowercased_name}")
94        }
95
96        // Only add serde rename if the Rust field name differs from the original field name
97        if lowercased_name != field.name {
98            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
99        }
100
101        if field.is_required && !field.is_nullable {
102            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
103        } else {
104            output.push_str(&format!(
105                "    pub {lowercased_name}: Option<{field_type}>,\n",
106            ));
107        }
108    }
109
110    output.push_str("}\n\n");
111    Ok(output)
112}
113
114fn generate_request_model(request: &RequestModel) -> Result<String> {
115    let mut output = String::new();
116    tracing::info!("Generating request model");
117    tracing::info!("{:#?}", request);
118
119    if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
120        return Ok(String::new());
121    }
122
123    output.push_str(&format!("/// {}\n", request.name));
124    output.push_str("#[derive(Debug, Serialize)]\n");
125    output.push_str(&format!("pub struct {} {{\n", request.name));
126    output.push_str("    pub content_type: String,\n");
127    output.push_str(&format!("    pub body: {},\n", request.schema));
128    output.push_str("}\n");
129    Ok(output)
130}
131
132fn generate_response_model(response: &ResponseModel) -> Result<String> {
133    if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
134        return Ok(String::new());
135    }
136
137    let type_name = format!("{}{}", response.name, response.status_code);
138
139    let mut output = String::new();
140
141    if let Some(desc) = &response.description {
142        for line in desc.lines() {
143            output.push_str(&format!("/// {}\n", line.trim()));
144        }
145    } else {
146        output.push_str(&format!("/// {type_name}\n"));
147    }
148
149    output.push_str("#[derive(Debug, Deserialize)]\n");
150    output.push_str(&format!("pub struct {type_name} {{\n"));
151    output.push_str(&format!("    pub body: {},\n", response.schema));
152    output.push_str("}\n");
153
154    Ok(output)
155}
156
157fn generate_union(union: &UnionModel) -> Result<String> {
158    let mut output = String::new();
159
160    output.push_str(&format!(
161        "/// {} ({})\n",
162        union.name,
163        match union.union_type {
164            UnionType::OneOf => "oneOf",
165            UnionType::AnyOf => "anyOf",
166        }
167    ));
168    output.push_str("#[derive(Debug, Serialize, Deserialize)]\n");
169    output.push_str("#[serde(untagged)]\n");
170    output.push_str(&format!("pub enum {} {{\n", union.name));
171
172    for variant in &union.variants {
173        output.push_str(&format!("    {}({}),\n", variant.name, variant.name));
174    }
175
176    output.push_str("}\n");
177    Ok(output)
178}
179
180fn generate_composition(comp: &CompositionModel) -> Result<String> {
181    let mut output = String::new();
182
183    output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
184    output.push_str("#[derive(Debug, Serialize, Deserialize)]\n");
185    output.push_str(&format!("pub struct {} {{\n", comp.name));
186
187    for field in &comp.all_fields {
188        let field_type = match field.field_type.as_str() {
189            "String" => "String",
190            "f64" => "f64",
191            "i64" => "i64",
192            "bool" => "bool",
193            "DateTime" => "DateTime<Utc>",
194            "Date" => "NaiveDate",
195            "Uuid" => "Uuid",
196            _ => &field.field_type,
197        };
198
199        let mut lowercased_name = field.name.to_lowercase();
200        if is_reserved_word(&lowercased_name) {
201            lowercased_name = format!("r#{lowercased_name}");
202        }
203
204        // Only add serde rename if the Rust field name differs from the original field name
205        if lowercased_name != field.name {
206            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
207        }
208
209        if field.is_required && !field.is_nullable {
210            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n"));
211        } else {
212            output.push_str(&format!(
213                "    pub {lowercased_name}: Option<{field_type}>,\n"
214            ));
215        }
216    }
217
218    output.push_str("}\n");
219    Ok(output)
220}
221
222fn generate_enum(enum_model: &EnumModel) -> Result<String> {
223    let mut output = String::new();
224
225    if let Some(description) = &enum_model.description {
226        output.push_str(&format!("/// {description}\n"));
227    } else {
228        output.push_str(&format!("/// {}\n", enum_model.name));
229    }
230
231    output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
232    output.push_str(&format!("pub enum {} {{\n", enum_model.name));
233
234    for (i, variant) in enum_model.variants.iter().enumerate() {
235        let original = variant.clone();
236
237        let mut chars = variant.chars();
238        let first_char = chars.next().unwrap().to_ascii_uppercase();
239        let rest: String = chars.collect();
240        let mut rust_name = format!("{first_char}{rest}");
241
242        let serde_rename = if is_reserved_word(&rust_name) {
243            rust_name.push_str("Value");
244            Some(original)
245        } else if rust_name != original {
246            Some(original)
247        } else {
248            None
249        };
250
251        if let Some(rename) = serde_rename {
252            output.push_str(&format!("    #[serde(rename = \"{rename}\")]\n"));
253        }
254
255        if i + 1 == enum_model.variants.len() {
256            output.push_str(&format!("    {rust_name}\n"));
257        } else {
258            output.push_str(&format!("    {rust_name},\n"));
259        }
260    }
261
262    output.push_str("}\n");
263    Ok(output)
264}
265
266pub fn generate_rust_code(models: &[Model]) -> Result<String> {
267    let mut code = String::new();
268
269    code.push_str("use serde::{Serialize, Deserialize};\n");
270    code.push_str("use uuid::Uuid;\n");
271    code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
272
273    for model in models {
274        code.push_str(&format!("/// {}\n", model.name));
275        code.push_str("#[derive(Debug, Serialize, Deserialize)]\n");
276        code.push_str(&format!("pub struct {} {{\n", model.name));
277
278        for field in &model.fields {
279            let field_type = match field.field_type.as_str() {
280                "String" => "String",
281                "f64" => "f64",
282                "i64" => "i64",
283                "bool" => "bool",
284                "DateTime" => "DateTime<Utc>",
285                "Date" => "NaiveDate",
286                "Uuid" => "Uuid",
287                _ => &field.field_type,
288            };
289
290            let mut lowercased_name = field.name.to_lowercase();
291            if is_reserved_word(&lowercased_name) {
292                lowercased_name = format!("r#{lowercased_name}")
293            }
294
295            // Only add serde rename if the Rust field name differs from the original field name
296            if lowercased_name != field.name {
297                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
298            }
299
300            if field.is_required {
301                code.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
302            } else {
303                code.push_str(&format!(
304                    "    pub {lowercased_name}: Option<{field_type}>,\n",
305                ));
306            }
307        }
308
309        code.push_str("}\n\n");
310    }
311
312    Ok(code)
313}
314
315pub fn generate_lib() -> Result<String> {
316    let mut code = String::new();
317    code.push_str("pub mod models;\n");
318
319    Ok(code)
320}