openapi_model_generator/
generator.rs

1use crate::{
2    models::{
3        CompositionModel, EnumModel, Model, ModelType, RequestModel, ResponseModel, UnionModel,
4        UnionType,
5    },
6    Result,
7};
8
9const RUST_RESERVED_KEYWORDS: &[&str] = &[
10    "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", "for",
11    "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", "return",
12    "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", "use", "where",
13    "while", "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try",
14    "typeof", "unsized", "virtual", "yield",
15];
16
17const EMPTY_RESPONSE_NAME: &str = "UnknownResponse";
18const EMPTY_REQUEST_NAME: &str = "UnknownRequest";
19
20fn is_reserved_word(string_to_check: &str) -> bool {
21    RUST_RESERVED_KEYWORDS.contains(&string_to_check.to_lowercase().as_str())
22}
23
24pub fn generate_models(
25    models: &[ModelType],
26    requests: &[RequestModel],
27    responses: &[ResponseModel],
28) -> Result<String> {
29    let mut output = String::new();
30
31    output.push_str("use serde::{Serialize, Deserialize};\n");
32    output.push_str("use uuid::Uuid;\n");
33    output.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
34
35    for model_type in models {
36        match model_type {
37            ModelType::Struct(model) => {
38                output.push_str(&generate_model(model)?);
39                output.push('\n');
40            }
41            ModelType::Union(union) => {
42                output.push_str(&generate_union(union)?);
43                output.push('\n');
44            }
45            ModelType::Composition(comp) => {
46                output.push_str(&generate_composition(comp)?);
47                output.push('\n');
48            }
49            ModelType::Enum(enum_model) => {
50                output.push_str(&generate_enum(enum_model)?);
51                output.push('\n');
52            }
53        }
54    }
55
56    for request in requests {
57        output.push_str(&generate_request_model(request)?);
58        output.push('\n');
59    }
60
61    for response in responses {
62        output.push_str(&generate_response_model(response)?);
63        output.push('\n');
64    }
65
66    Ok(output)
67}
68
69fn generate_model(model: &Model) -> Result<String> {
70    let mut output = String::new();
71
72    if !model.name.is_empty() {
73        output.push_str(&format!("/// {}\n", model.name));
74    }
75
76    output.push_str("#[derive(Debug, Serialize, Deserialize)]\n");
77    output.push_str(&format!("pub struct {} {{\n", model.name));
78
79    for field in &model.fields {
80        let field_type = match field.field_type.as_str() {
81            "String" => "String",
82            "f64" => "f64",
83            "i64" => "i64",
84            "bool" => "bool",
85            "DateTime" => "DateTime<Utc>",
86            "Date" => "NaiveDate",
87            "Uuid" => "Uuid",
88            _ => &field.field_type,
89        };
90
91        let mut lowercased_name = field.name.to_lowercase();
92        if is_reserved_word(&lowercased_name) {
93            lowercased_name = format!("r#{lowercased_name}")
94        }
95
96        // Only add serde rename if the Rust field name differs from the original field name
97        if lowercased_name != field.name {
98            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
99        }
100
101        if field.is_required && !field.is_nullable {
102            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
103        } else {
104            output.push_str(&format!(
105                "    pub {lowercased_name}: Option<{field_type}>,\n",
106            ));
107        }
108    }
109
110    output.push_str("}\n\n");
111    Ok(output)
112}
113
114fn generate_request_model(request: &RequestModel) -> Result<String> {
115    let mut output = String::new();
116    tracing::info!("Generating request model");
117    tracing::info!("{:#?}", request);
118
119    if request.name.is_empty() || request.name == EMPTY_REQUEST_NAME {
120        return Ok(String::new());
121    }
122
123    output.push_str(&format!("/// {}\n", request.name));
124    output.push_str("#[derive(Debug, Serialize)]\n");
125    output.push_str(&format!("pub struct {} {{\n", request.name));
126    output.push_str("    pub content_type: String,\n");
127    output.push_str(&format!("    pub body: {},\n", request.schema));
128    output.push_str("}\n");
129    Ok(output)
130}
131
132fn generate_response_model(response: &ResponseModel) -> Result<String> {
133    let mut output = String::new();
134
135    if response.name.is_empty() || response.name == EMPTY_RESPONSE_NAME {
136        return Ok(String::new());
137    }
138
139    output.push_str(&format!("/// {}\n", response.name));
140    output.push_str("#[derive(Debug, Deserialize)]\n");
141    output.push_str(&format!("pub struct {} {{\n", response.name));
142    output.push_str("    pub status_code: String,\n");
143    output.push_str("    pub content_type: String,\n");
144    output.push_str(&format!("    pub body: {},\n", response.schema));
145    if let Some(desc) = &response.description {
146        output.push_str(&format!("    /// {desc}\n"));
147    }
148    output.push_str("}\n");
149    Ok(output)
150}
151
152fn generate_union(union: &UnionModel) -> Result<String> {
153    let mut output = String::new();
154
155    output.push_str(&format!(
156        "/// {} ({})\n",
157        union.name,
158        match union.union_type {
159            UnionType::OneOf => "oneOf",
160            UnionType::AnyOf => "anyOf",
161        }
162    ));
163    output.push_str("#[derive(Debug, Serialize, Deserialize)]\n");
164    output.push_str("#[serde(untagged)]\n");
165    output.push_str(&format!("pub enum {} {{\n", union.name));
166
167    for variant in &union.variants {
168        output.push_str(&format!("    {}({}),\n", variant.name, variant.name));
169    }
170
171    output.push_str("}\n");
172    Ok(output)
173}
174
175fn generate_composition(comp: &CompositionModel) -> Result<String> {
176    let mut output = String::new();
177
178    output.push_str(&format!("/// {} (allOf composition)\n", comp.name));
179    output.push_str("#[derive(Debug, Serialize, Deserialize)]\n");
180    output.push_str(&format!("pub struct {} {{\n", comp.name));
181
182    for field in &comp.all_fields {
183        let field_type = match field.field_type.as_str() {
184            "String" => "String",
185            "f64" => "f64",
186            "i64" => "i64",
187            "bool" => "bool",
188            "DateTime" => "DateTime<Utc>",
189            "Date" => "NaiveDate",
190            "Uuid" => "Uuid",
191            _ => &field.field_type,
192        };
193
194        let mut lowercased_name = field.name.to_lowercase();
195        if is_reserved_word(&lowercased_name) {
196            lowercased_name = format!("r#{lowercased_name}");
197        }
198
199        // Only add serde rename if the Rust field name differs from the original field name
200        if lowercased_name != field.name {
201            output.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
202        }
203
204        if field.is_required && !field.is_nullable {
205            output.push_str(&format!("    pub {lowercased_name}: {field_type},\n"));
206        } else {
207            output.push_str(&format!(
208                "    pub {lowercased_name}: Option<{field_type}>,\n"
209            ));
210        }
211    }
212
213    output.push_str("}\n");
214    Ok(output)
215}
216
217fn generate_enum(enum_model: &EnumModel) -> Result<String> {
218    let mut output = String::new();
219
220    if let Some(description) = &enum_model.description {
221        output.push_str(&format!("/// {description}\n"));
222    } else {
223        output.push_str(&format!("/// {}\n", enum_model.name));
224    }
225
226    output.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
227    output.push_str(&format!("pub enum {} {{\n", enum_model.name));
228
229    for (i, variant) in enum_model.variants.iter().enumerate() {
230        let original = variant.clone();
231
232        let mut chars = variant.chars();
233        let first_char = chars.next().unwrap().to_ascii_uppercase();
234        let rest: String = chars.collect();
235        let mut rust_name = format!("{first_char}{rest}");
236
237        let serde_rename = if is_reserved_word(&rust_name) {
238            rust_name.push_str("Value");
239            Some(original)
240        } else if rust_name != original {
241            Some(original)
242        } else {
243            None
244        };
245
246        if let Some(rename) = serde_rename {
247            output.push_str(&format!("    #[serde(rename = \"{rename}\")]\n"));
248        }
249
250        if i + 1 == enum_model.variants.len() {
251            output.push_str(&format!("    {rust_name}\n"));
252        } else {
253            output.push_str(&format!("    {rust_name},\n"));
254        }
255    }
256
257    output.push_str("}\n");
258    Ok(output)
259}
260
261pub fn generate_rust_code(models: &[Model]) -> Result<String> {
262    let mut code = String::new();
263
264    code.push_str("use serde::{Serialize, Deserialize};\n");
265    code.push_str("use uuid::Uuid;\n");
266    code.push_str("use chrono::{DateTime, NaiveDate, Utc};\n\n");
267
268    for model in models {
269        code.push_str(&format!("/// {}\n", model.name));
270        code.push_str("#[derive(Debug, Serialize, Deserialize)]\n");
271        code.push_str(&format!("pub struct {} {{\n", model.name));
272
273        for field in &model.fields {
274            let field_type = match field.field_type.as_str() {
275                "String" => "String",
276                "f64" => "f64",
277                "i64" => "i64",
278                "bool" => "bool",
279                "DateTime" => "DateTime<Utc>",
280                "Date" => "NaiveDate",
281                "Uuid" => "Uuid",
282                _ => &field.field_type,
283            };
284
285            let mut lowercased_name = field.name.to_lowercase();
286            if is_reserved_word(&lowercased_name) {
287                lowercased_name = format!("r#{lowercased_name}")
288            }
289
290            // Only add serde rename if the Rust field name differs from the original field name
291            if lowercased_name != field.name {
292                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", field.name));
293            }
294
295            if field.is_required {
296                code.push_str(&format!("    pub {lowercased_name}: {field_type},\n",));
297            } else {
298                code.push_str(&format!(
299                    "    pub {lowercased_name}: Option<{field_type}>,\n",
300                ));
301            }
302        }
303
304        code.push_str("}\n\n");
305    }
306
307    Ok(code)
308}
309
310pub fn generate_lib() -> Result<String> {
311    let mut code = String::new();
312    code.push_str("pub mod models;\n");
313
314    Ok(code)
315}