1use crate::error::Result;
4use std::collections::HashSet;
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{escape_field_name, to_enum_name, to_enum_variant, to_struct_name};
13use super::type_resolver::TypeResolver;
14
15pub fn generate_structs(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
17 let output_dir = &config.output_structs_dir;
18 fs::create_dir_all(output_dir)?;
19
20 let mut mod_content = String::new();
22 mod_content.push_str("// Generated model structs\n\n");
23
24 for table in tables {
25 let file_name = heck::AsSnakeCase(&table.name).to_string();
26 mod_content.push_str(&format!("mod {};\n", file_name));
27 mod_content.push_str(&format!("pub use {}::*;\n", file_name));
28 }
29
30 mod_content.push('\n');
32 mod_content.push_str(&generate_shared_pagination_types());
33
34 let mod_path = output_dir.join("mod.rs");
35 fs::write(&mod_path, mod_content)?;
36
37 for table in tables {
39 generate_struct_file(table, output_dir)?;
40 }
41
42 Ok(())
43}
44
45fn generate_shared_pagination_types() -> String {
47 r#"/// Sort direction for pagination
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum SortDirection {
50 Asc,
51 Desc,
52}
53
54impl SortDirection {
55 pub fn as_sql(&self) -> &'static str {
56 match self {
57 Self::Asc => "ASC",
58 Self::Desc => "DESC",
59 }
60 }
61}
62
63/// Paginated result container
64#[derive(Debug, Clone)]
65pub struct PaginatedResult<T> {
66 pub items: Vec<T>,
67 pub total_count: i64,
68 pub current_page: i32,
69 pub total_pages: i32,
70 pub page_size: i32,
71 pub has_next: bool,
72}
73
74impl<T> PaginatedResult<T> {
75 pub fn new(
76 items: Vec<T>,
77 total_count: i64,
78 current_page: i32,
79 page_size: i32,
80 ) -> Self {
81 let total_pages = ((total_count as f64) / (page_size as f64)).ceil() as i32;
82 let has_next = current_page < total_pages;
83 Self {
84 items,
85 total_count,
86 current_page,
87 total_pages,
88 page_size,
89 has_next,
90 }
91 }
92}
93"#
94 .to_string()
95}
96
97fn generate_struct_file(table: &TableMetadata, output_dir: &Path) -> Result<()> {
99 let struct_name = to_struct_name(&table.name);
100 let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
101 debug!("Generating struct {} -> {}", struct_name, file_name);
102
103 let mut code = String::new();
104
105 let mut enum_columns: Vec<&ColumnMetadata> = Vec::new();
107 for col in &table.columns {
108 if col.is_enum() {
109 enum_columns.push(col);
110 }
111 }
112
113 code.push_str("use serde::{Deserialize, Serialize};\n");
115
116 code.push('\n');
117
118 for col in &enum_columns {
120 if let Some(values) = &col.enum_values {
121 code.push_str(&generate_enum(&table.name, col, values));
122 code.push('\n');
123 }
124 }
125
126 code.push_str(&format!("/// Database table: `{}`\n", table.name));
128 if let Some(comment) = &table.comment {
129 if !comment.is_empty() {
130 code.push_str(&format!("///\n/// {}\n", comment));
131 }
132 }
133
134 code.push_str("#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, rdbi::FromRow, rdbi::ToParams)]\n");
136 code.push_str(&format!("pub struct {} {{\n", struct_name));
137
138 for col in &table.columns {
140 let field_name = escape_field_name(&col.name);
141 let rust_type = TypeResolver::resolve(col, &table.name);
142
143 code.push_str(&format!(" /// Column: `{}`", col.name));
145
146 let index_info = get_index_info(table, &col.name);
148 if !index_info.is_empty() {
149 code.push_str(&format!(" ({})", index_info.join(", ")));
150 }
151
152 if let Some(comment) = &col.comment {
153 if !comment.is_empty() {
154 code.push_str(&format!(" - {}", comment));
155 }
156 }
157 code.push('\n');
158
159 let mut attrs = Vec::new();
161
162 if field_name != col.name {
164 attrs.push(format!("rename = \"{}\"", col.name));
165 }
166
167 if col.is_auto_increment {
169 attrs.push("skip_insert".to_string());
170 }
171
172 if !attrs.is_empty() {
173 code.push_str(&format!(" #[rdbi({})]\n", attrs.join(", ")));
174 }
175
176 if field_name != col.name {
179 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", col.name));
180 }
181
182 code.push_str(&format!(
183 " pub {}: {},\n",
184 field_name,
185 rust_type.to_type_string()
186 ));
187 }
188
189 code.push_str("}\n");
190
191 code.push('\n');
193 code.push_str(&generate_sort_by_enum(table));
194
195 let file_path = output_dir.join(&file_name);
196 fs::write(&file_path, code)?;
197 Ok(())
198}
199
200fn generate_sort_by_enum(table: &TableMetadata) -> String {
202 let struct_name = to_struct_name(&table.name);
203 let enum_name = format!("{}SortBy", struct_name);
204
205 let mut code = String::new();
206
207 code.push_str(&format!("/// Sort columns for `{}`\n", table.name));
208 code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n");
209 code.push_str(&format!("pub enum {} {{\n", enum_name));
210
211 for col in &table.columns {
212 let variant = heck::AsPascalCase(&col.name).to_string();
213 code.push_str(&format!(" {},\n", variant));
214 }
215
216 code.push_str("}\n\n");
217
218 code.push_str(&format!("impl {} {{\n", enum_name));
220 code.push_str(" pub fn as_sql(&self) -> &'static str {\n");
221 code.push_str(" match self {\n");
222
223 for col in &table.columns {
224 let variant = heck::AsPascalCase(&col.name).to_string();
225 code.push_str(&format!(
226 " Self::{} => \"`{}`\",\n",
227 variant, col.name
228 ));
229 }
230
231 code.push_str(" }\n");
232 code.push_str(" }\n");
233 code.push_str("}\n");
234
235 code
236}
237
238fn generate_enum(table_name: &str, column: &ColumnMetadata, values: &[String]) -> String {
240 let enum_name = to_enum_name(table_name, &column.name);
241 let mut code = String::new();
242
243 code.push_str(&format!("/// Enum for `{}.{}`\n", table_name, column.name));
245
246 code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]\n");
248 code.push_str(&format!("pub enum {} {{\n", enum_name));
249
250 let mut used_variants: HashSet<String> = HashSet::new();
252 let mut variant_mappings: Vec<(String, String)> = Vec::new();
253
254 for value in values {
255 let variant = to_enum_variant(value);
256
257 let final_variant = if used_variants.contains(&variant) {
259 let mut counter = 2;
260 loop {
261 let new_variant = format!("{}{}", variant, counter);
262 if !used_variants.contains(&new_variant) {
263 break new_variant;
264 }
265 counter += 1;
266 }
267 } else {
268 variant
269 };
270
271 used_variants.insert(final_variant.clone());
272
273 let clean_value = value.trim_matches('\'').trim_matches('"');
275
276 if final_variant != clean_value {
278 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", clean_value));
279 }
280
281 code.push_str(&format!(" {},\n", final_variant));
282 variant_mappings.push((final_variant, clean_value.to_string()));
283 }
284
285 code.push_str("}\n\n");
286
287 code.push_str(&format!("impl rdbi::FromValue for {} {{\n", enum_name));
289 code.push_str(" fn from_value(value: rdbi::Value) -> rdbi::Result<Self> {\n");
290 code.push_str(" match value {\n");
291 code.push_str(" rdbi::Value::String(s) => match s.as_str() {\n");
292 for (variant, db_value) in &variant_mappings {
293 code.push_str(&format!(
294 " \"{}\" => Ok(Self::{}),\n",
295 db_value, variant
296 ));
297 }
298 code.push_str(&format!(
299 " _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: s }}),\n",
300 enum_name
301 ));
302 code.push_str(" },\n");
303 code.push_str(&format!(
304 " _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: value.type_name().to_string() }}),\n",
305 enum_name
306 ));
307 code.push_str(" }\n");
308 code.push_str(" }\n");
309 code.push_str("}\n\n");
310
311 code.push_str(&format!("impl rdbi::ToValue for {} {{\n", enum_name));
313 code.push_str(" fn to_value(&self) -> rdbi::Value {\n");
314 code.push_str(" rdbi::Value::String(match self {\n");
315 for (variant, db_value) in &variant_mappings {
316 code.push_str(&format!(
317 " Self::{} => \"{}\".to_string(),\n",
318 variant, db_value
319 ));
320 }
321 code.push_str(" })\n");
322 code.push_str(" }\n");
323 code.push_str("}\n");
324
325 code
326}
327
328fn get_index_info(table: &TableMetadata, column_name: &str) -> Vec<String> {
330 let mut info = Vec::new();
331
332 if let Some(pk) = &table.primary_key {
334 if pk.columns.contains(&column_name.to_string()) {
335 info.push("PRIMARY KEY".to_string());
336 }
337 }
338
339 for index in &table.indexes {
341 if index.columns.contains(&column_name.to_string()) {
342 let label = if index.unique {
343 format!("UNIQUE: {}", index.name)
344 } else {
345 format!("INDEX: {}", index.name)
346 };
347 info.push(label);
348 }
349 }
350
351 info
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::parser::{IndexMetadata, PrimaryKey};
358
359 fn make_table() -> TableMetadata {
360 TableMetadata {
361 name: "users".to_string(),
362 comment: None,
363 columns: vec![
364 ColumnMetadata {
365 name: "id".to_string(),
366 data_type: "BIGINT".to_string(),
367 nullable: false,
368 default_value: None,
369 is_auto_increment: true,
370 is_unsigned: false,
371 enum_values: None,
372 comment: None,
373 },
374 ColumnMetadata {
375 name: "username".to_string(),
376 data_type: "VARCHAR(255)".to_string(),
377 nullable: false,
378 default_value: None,
379 is_auto_increment: false,
380 is_unsigned: false,
381 enum_values: None,
382 comment: None,
383 },
384 ColumnMetadata {
385 name: "status".to_string(),
386 data_type: "ENUM".to_string(),
387 nullable: false,
388 default_value: None,
389 is_auto_increment: false,
390 is_unsigned: false,
391 enum_values: Some(vec![
392 "ACTIVE".to_string(),
393 "INACTIVE".to_string(),
394 "PENDING".to_string(),
395 ]),
396 comment: None,
397 },
398 ],
399 indexes: vec![IndexMetadata {
400 name: "idx_username".to_string(),
401 columns: vec!["username".to_string()],
402 unique: true,
403 }],
404 foreign_keys: vec![],
405 primary_key: Some(PrimaryKey {
406 columns: vec!["id".to_string()],
407 }),
408 }
409 }
410
411 #[test]
412 fn test_get_index_info() {
413 let table = make_table();
414 let info = get_index_info(&table, "id");
415 assert!(info.contains(&"PRIMARY KEY".to_string()));
416
417 let info = get_index_info(&table, "username");
418 assert!(info.iter().any(|i| i.contains("UNIQUE")));
419 }
420
421 #[test]
422 fn test_generate_enum() {
423 let col = ColumnMetadata {
424 name: "status".to_string(),
425 data_type: "ENUM".to_string(),
426 nullable: false,
427 default_value: None,
428 is_auto_increment: false,
429 is_unsigned: false,
430 enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
431 comment: None,
432 };
433
434 let code = generate_enum("users", &col, col.enum_values.as_ref().unwrap());
435 assert!(code.contains("pub enum UsersStatus"));
436 assert!(code.contains("Active"));
437 assert!(code.contains("Inactive"));
438 assert!(code.contains("impl rdbi::FromValue for UsersStatus"));
440 assert!(code.contains("impl rdbi::ToValue for UsersStatus"));
441 }
442}