1use crate::error::Result;
4use std::collections::HashSet;
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{escape_field_name, to_enum_name, to_enum_variant, to_struct_name};
13use super::type_resolver::TypeResolver;
14
15pub fn generate_structs(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
17 let output_dir = &config.output_structs_dir;
18 fs::create_dir_all(output_dir)?;
19
20 let mut mod_content = String::new();
22 mod_content.push_str("// Generated model structs\n\n");
23
24 for table in tables {
25 let file_name = heck::AsSnakeCase(&table.name).to_string();
26 mod_content.push_str(&format!("mod {};\n", file_name));
27 mod_content.push_str(&format!("pub use {}::*;\n", file_name));
28 }
29
30 mod_content.push('\n');
32 mod_content.push_str(&generate_shared_pagination_types());
33
34 let mod_path = output_dir.join("mod.rs");
35 fs::write(&mod_path, mod_content)?;
36 super::format_file(&mod_path);
37
38 for table in tables {
40 generate_struct_file(table, output_dir)?;
41 }
42
43 Ok(())
44}
45
46fn generate_shared_pagination_types() -> String {
48 r#"/// Sort direction for pagination
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SortDirection {
51 Asc,
52 Desc,
53}
54
55impl SortDirection {
56 pub fn as_sql(&self) -> &'static str {
57 match self {
58 Self::Asc => "ASC",
59 Self::Desc => "DESC",
60 }
61 }
62}
63
64/// Paginated result container
65#[derive(Debug, Clone)]
66pub struct PaginatedResult<T> {
67 pub items: Vec<T>,
68 pub total_count: i64,
69 pub current_page: i32,
70 pub total_pages: i32,
71 pub page_size: i32,
72 pub has_next: bool,
73}
74
75impl<T> PaginatedResult<T> {
76 pub fn new(
77 items: Vec<T>,
78 total_count: i64,
79 current_page: i32,
80 page_size: i32,
81 ) -> Self {
82 let total_pages = ((total_count as f64) / (page_size as f64)).ceil() as i32;
83 let has_next = current_page < total_pages;
84 Self {
85 items,
86 total_count,
87 current_page,
88 total_pages,
89 page_size,
90 has_next,
91 }
92 }
93}
94"#
95 .to_string()
96}
97
98fn generate_struct_file(table: &TableMetadata, output_dir: &Path) -> Result<()> {
100 let struct_name = to_struct_name(&table.name);
101 let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
102 debug!("Generating struct {} -> {}", struct_name, file_name);
103
104 let mut code = String::new();
105
106 let mut enum_columns: Vec<&ColumnMetadata> = Vec::new();
108 for col in &table.columns {
109 if col.is_enum() {
110 enum_columns.push(col);
111 }
112 }
113
114 code.push_str("use serde::{Deserialize, Serialize};\n");
116
117 code.push('\n');
118
119 for col in &enum_columns {
121 if let Some(values) = &col.enum_values {
122 code.push_str(&generate_enum(&table.name, col, values));
123 code.push('\n');
124 }
125 }
126
127 code.push_str(&format!("/// Database table: `{}`\n", table.name));
129 if let Some(comment) = &table.comment {
130 if !comment.is_empty() {
131 code.push_str(&format!("///\n/// {}\n", comment));
132 }
133 }
134
135 code.push_str("#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, rdbi::FromRow, rdbi::ToParams)]\n");
137 code.push_str(&format!("pub struct {} {{\n", struct_name));
138
139 for col in &table.columns {
141 let field_name = escape_field_name(&col.name);
142 let rust_type = TypeResolver::resolve(col, &table.name);
143
144 code.push_str(&format!(" /// Column: `{}`", col.name));
146
147 let index_info = get_index_info(table, &col.name);
149 if !index_info.is_empty() {
150 code.push_str(&format!(" ({})", index_info.join(", ")));
151 }
152
153 if let Some(comment) = &col.comment {
154 if !comment.is_empty() {
155 code.push_str(&format!(" - {}", comment));
156 }
157 }
158 code.push('\n');
159
160 let mut attrs = Vec::new();
162
163 if field_name != col.name {
165 attrs.push(format!("rename = \"{}\"", col.name));
166 }
167
168 if col.is_auto_increment {
170 attrs.push("skip_insert".to_string());
171 }
172
173 if !attrs.is_empty() {
174 code.push_str(&format!(" #[rdbi({})]\n", attrs.join(", ")));
175 }
176
177 if field_name != col.name {
180 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", col.name));
181 }
182
183 code.push_str(&format!(
184 " pub {}: {},\n",
185 field_name,
186 rust_type.to_type_string()
187 ));
188 }
189
190 code.push_str("}\n");
191
192 code.push('\n');
194 code.push_str(&generate_sort_by_enum(table));
195
196 let file_path = output_dir.join(&file_name);
197 fs::write(&file_path, code)?;
198 super::format_file(&file_path);
199 Ok(())
200}
201
202fn generate_sort_by_enum(table: &TableMetadata) -> String {
204 let struct_name = to_struct_name(&table.name);
205 let enum_name = format!("{}SortBy", struct_name);
206
207 let mut code = String::new();
208
209 code.push_str(&format!("/// Sort columns for `{}`\n", table.name));
210 code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n");
211 code.push_str(&format!("pub enum {} {{\n", enum_name));
212
213 for col in &table.columns {
214 let variant = heck::AsPascalCase(&col.name).to_string();
215 code.push_str(&format!(" {},\n", variant));
216 }
217
218 code.push_str("}\n\n");
219
220 code.push_str(&format!("impl {} {{\n", enum_name));
222 code.push_str(" pub fn as_sql(&self) -> &'static str {\n");
223 code.push_str(" match self {\n");
224
225 for col in &table.columns {
226 let variant = heck::AsPascalCase(&col.name).to_string();
227 code.push_str(&format!(
228 " Self::{} => \"`{}`\",\n",
229 variant, col.name
230 ));
231 }
232
233 code.push_str(" }\n");
234 code.push_str(" }\n");
235 code.push_str("}\n");
236
237 code
238}
239
240fn generate_enum(table_name: &str, column: &ColumnMetadata, values: &[String]) -> String {
242 let enum_name = to_enum_name(table_name, &column.name);
243 let mut code = String::new();
244
245 code.push_str(&format!("/// Enum for `{}.{}`\n", table_name, column.name));
247
248 code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]\n");
250 code.push_str(&format!("pub enum {} {{\n", enum_name));
251
252 let mut used_variants: HashSet<String> = HashSet::new();
254 let mut variant_mappings: Vec<(String, String)> = Vec::new();
255
256 for value in values {
257 let variant = to_enum_variant(value);
258
259 let final_variant = if used_variants.contains(&variant) {
261 let mut counter = 2;
262 loop {
263 let new_variant = format!("{}{}", variant, counter);
264 if !used_variants.contains(&new_variant) {
265 break new_variant;
266 }
267 counter += 1;
268 }
269 } else {
270 variant
271 };
272
273 used_variants.insert(final_variant.clone());
274
275 let clean_value = value.trim_matches('\'').trim_matches('"');
277
278 if final_variant != clean_value {
280 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", clean_value));
281 }
282
283 code.push_str(&format!(" {},\n", final_variant));
284 variant_mappings.push((final_variant, clean_value.to_string()));
285 }
286
287 code.push_str("}\n\n");
288
289 code.push_str(&format!("impl rdbi::FromValue for {} {{\n", enum_name));
291 code.push_str(" fn from_value(value: rdbi::Value) -> rdbi::Result<Self> {\n");
292 code.push_str(" match value {\n");
293 code.push_str(" rdbi::Value::String(s) => match s.as_str() {\n");
294 for (variant, db_value) in &variant_mappings {
295 code.push_str(&format!(
296 " \"{}\" => Ok(Self::{}),\n",
297 db_value, variant
298 ));
299 }
300 code.push_str(&format!(
301 " _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: s }}),\n",
302 enum_name
303 ));
304 code.push_str(" },\n");
305 code.push_str(&format!(
306 " _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: value.type_name().to_string() }}),\n",
307 enum_name
308 ));
309 code.push_str(" }\n");
310 code.push_str(" }\n");
311 code.push_str("}\n\n");
312
313 code.push_str(&format!("impl rdbi::ToValue for {} {{\n", enum_name));
315 code.push_str(" fn to_value(&self) -> rdbi::Value {\n");
316 code.push_str(" rdbi::Value::String(match self {\n");
317 for (variant, db_value) in &variant_mappings {
318 code.push_str(&format!(
319 " Self::{} => \"{}\".to_string(),\n",
320 variant, db_value
321 ));
322 }
323 code.push_str(" })\n");
324 code.push_str(" }\n");
325 code.push_str("}\n");
326
327 code
328}
329
330fn get_index_info(table: &TableMetadata, column_name: &str) -> Vec<String> {
332 let mut info = Vec::new();
333
334 if let Some(pk) = &table.primary_key {
336 if pk.columns.contains(&column_name.to_string()) {
337 info.push("PRIMARY KEY".to_string());
338 }
339 }
340
341 for index in &table.indexes {
343 if index.columns.contains(&column_name.to_string()) {
344 let label = if index.unique {
345 format!("UNIQUE: {}", index.name)
346 } else {
347 format!("INDEX: {}", index.name)
348 };
349 info.push(label);
350 }
351 }
352
353 info
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::parser::{IndexMetadata, PrimaryKey};
360
361 fn make_table() -> TableMetadata {
362 TableMetadata {
363 name: "users".to_string(),
364 comment: None,
365 columns: vec![
366 ColumnMetadata {
367 name: "id".to_string(),
368 data_type: "BIGINT".to_string(),
369 nullable: false,
370 default_value: None,
371 is_auto_increment: true,
372 is_unsigned: false,
373 enum_values: None,
374 comment: None,
375 },
376 ColumnMetadata {
377 name: "username".to_string(),
378 data_type: "VARCHAR(255)".to_string(),
379 nullable: false,
380 default_value: None,
381 is_auto_increment: false,
382 is_unsigned: false,
383 enum_values: None,
384 comment: None,
385 },
386 ColumnMetadata {
387 name: "status".to_string(),
388 data_type: "ENUM".to_string(),
389 nullable: false,
390 default_value: None,
391 is_auto_increment: false,
392 is_unsigned: false,
393 enum_values: Some(vec![
394 "ACTIVE".to_string(),
395 "INACTIVE".to_string(),
396 "PENDING".to_string(),
397 ]),
398 comment: None,
399 },
400 ],
401 indexes: vec![IndexMetadata {
402 name: "idx_username".to_string(),
403 columns: vec!["username".to_string()],
404 unique: true,
405 }],
406 foreign_keys: vec![],
407 primary_key: Some(PrimaryKey {
408 columns: vec!["id".to_string()],
409 }),
410 }
411 }
412
413 #[test]
414 fn test_get_index_info() {
415 let table = make_table();
416 let info = get_index_info(&table, "id");
417 assert!(info.contains(&"PRIMARY KEY".to_string()));
418
419 let info = get_index_info(&table, "username");
420 assert!(info.iter().any(|i| i.contains("UNIQUE")));
421 }
422
423 #[test]
424 fn test_generate_enum() {
425 let col = ColumnMetadata {
426 name: "status".to_string(),
427 data_type: "ENUM".to_string(),
428 nullable: false,
429 default_value: None,
430 is_auto_increment: false,
431 is_unsigned: false,
432 enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
433 comment: None,
434 };
435
436 let code = generate_enum("users", &col, col.enum_values.as_ref().unwrap());
437 assert!(code.contains("pub enum UsersStatus"));
438 assert!(code.contains("Active"));
439 assert!(code.contains("Inactive"));
440 assert!(code.contains("impl rdbi::FromValue for UsersStatus"));
442 assert!(code.contains("impl rdbi::ToValue for UsersStatus"));
443 }
444}