1use crate::error::Result;
4use std::collections::HashSet;
5use std::fs;
6use std::path::Path;
7use tracing::debug;
8
9use crate::config::CodegenConfig;
10use crate::parser::{ColumnMetadata, TableMetadata};
11
12use super::naming::{escape_field_name, to_enum_name, to_enum_variant, to_struct_name};
13use super::type_resolver::TypeResolver;
14
15pub fn generate_structs(tables: &[TableMetadata], config: &CodegenConfig) -> Result<()> {
17 let output_dir = &config.output_structs_dir;
18 fs::create_dir_all(output_dir)?;
19
20 let mut mod_content = String::new();
22 mod_content.push_str("// Generated model structs\n\n");
23
24 for table in tables {
25 let file_name = heck::AsSnakeCase(&table.name).to_string();
26 mod_content.push_str(&format!("mod {};\n", file_name));
27 mod_content.push_str(&format!("pub use {}::*;\n", file_name));
28 }
29
30 mod_content.push('\n');
32 mod_content.push_str(&generate_shared_pagination_types());
33
34 fs::write(output_dir.join("mod.rs"), mod_content)?;
35
36 for table in tables {
38 generate_struct_file(table, output_dir)?;
39 }
40
41 Ok(())
42}
43
44fn generate_shared_pagination_types() -> String {
46 r#"/// Sort direction for pagination
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum SortDirection {
49 Asc,
50 Desc,
51}
52
53impl SortDirection {
54 pub fn as_sql(&self) -> &'static str {
55 match self {
56 Self::Asc => "ASC",
57 Self::Desc => "DESC",
58 }
59 }
60}
61
62/// Paginated result container
63#[derive(Debug, Clone)]
64pub struct PaginatedResult<T> {
65 pub items: Vec<T>,
66 pub total_count: i64,
67 pub current_page: i32,
68 pub total_pages: i32,
69 pub page_size: i32,
70 pub has_next: bool,
71}
72
73impl<T> PaginatedResult<T> {
74 pub fn new(
75 items: Vec<T>,
76 total_count: i64,
77 current_page: i32,
78 page_size: i32,
79 ) -> Self {
80 let total_pages = ((total_count as f64) / (page_size as f64)).ceil() as i32;
81 let has_next = current_page < total_pages;
82 Self {
83 items,
84 total_count,
85 current_page,
86 total_pages,
87 page_size,
88 has_next,
89 }
90 }
91}
92"#
93 .to_string()
94}
95
96fn generate_struct_file(table: &TableMetadata, output_dir: &Path) -> Result<()> {
98 let struct_name = to_struct_name(&table.name);
99 let file_name = format!("{}.rs", heck::AsSnakeCase(&table.name));
100 debug!("Generating struct {} -> {}", struct_name, file_name);
101
102 let mut code = String::new();
103
104 let mut enum_columns: Vec<&ColumnMetadata> = Vec::new();
106 for col in &table.columns {
107 if col.is_enum() {
108 enum_columns.push(col);
109 }
110 }
111
112 code.push_str("use serde::{Deserialize, Serialize};\n");
114
115 code.push('\n');
116
117 for col in &enum_columns {
119 if let Some(values) = &col.enum_values {
120 code.push_str(&generate_enum(&table.name, col, values));
121 code.push('\n');
122 }
123 }
124
125 code.push_str(&format!("/// Database table: `{}`\n", table.name));
127 if let Some(comment) = &table.comment {
128 if !comment.is_empty() {
129 code.push_str(&format!("///\n/// {}\n", comment));
130 }
131 }
132
133 code.push_str("#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, rdbi::FromRow, rdbi::ToParams)]\n");
135 code.push_str(&format!("pub struct {} {{\n", struct_name));
136
137 for col in &table.columns {
139 let field_name = escape_field_name(&col.name);
140 let rust_type = TypeResolver::resolve(col, &table.name);
141
142 code.push_str(&format!(" /// Column: `{}`", col.name));
144
145 let index_info = get_index_info(table, &col.name);
147 if !index_info.is_empty() {
148 code.push_str(&format!(" ({})", index_info.join(", ")));
149 }
150
151 if let Some(comment) = &col.comment {
152 if !comment.is_empty() {
153 code.push_str(&format!(" - {}", comment));
154 }
155 }
156 code.push('\n');
157
158 let mut attrs = Vec::new();
160
161 if field_name != col.name {
163 attrs.push(format!("rename = \"{}\"", col.name));
164 }
165
166 if col.is_auto_increment {
168 attrs.push("skip_insert".to_string());
169 }
170
171 if !attrs.is_empty() {
172 code.push_str(&format!(" #[rdbi({})]\n", attrs.join(", ")));
173 }
174
175 if field_name != col.name {
178 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", col.name));
179 }
180
181 code.push_str(&format!(
182 " pub {}: {},\n",
183 field_name,
184 rust_type.to_type_string()
185 ));
186 }
187
188 code.push_str("}\n");
189
190 code.push('\n');
192 code.push_str(&generate_sort_by_enum(table));
193
194 fs::write(output_dir.join(&file_name), code)?;
195 Ok(())
196}
197
198fn generate_sort_by_enum(table: &TableMetadata) -> String {
200 let struct_name = to_struct_name(&table.name);
201 let enum_name = format!("{}SortBy", struct_name);
202
203 let mut code = String::new();
204
205 code.push_str(&format!("/// Sort columns for `{}`\n", table.name));
206 code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n");
207 code.push_str(&format!("pub enum {} {{\n", enum_name));
208
209 for col in &table.columns {
210 let variant = heck::AsPascalCase(&col.name).to_string();
211 code.push_str(&format!(" {},\n", variant));
212 }
213
214 code.push_str("}\n\n");
215
216 code.push_str(&format!("impl {} {{\n", enum_name));
218 code.push_str(" pub fn as_sql(&self) -> &'static str {\n");
219 code.push_str(" match self {\n");
220
221 for col in &table.columns {
222 let variant = heck::AsPascalCase(&col.name).to_string();
223 code.push_str(&format!(
224 " Self::{} => \"`{}`\",\n",
225 variant, col.name
226 ));
227 }
228
229 code.push_str(" }\n");
230 code.push_str(" }\n");
231 code.push_str("}\n");
232
233 code
234}
235
236fn generate_enum(table_name: &str, column: &ColumnMetadata, values: &[String]) -> String {
238 let enum_name = to_enum_name(table_name, &column.name);
239 let mut code = String::new();
240
241 code.push_str(&format!("/// Enum for `{}.{}`\n", table_name, column.name));
243
244 code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]\n");
246 code.push_str(&format!("pub enum {} {{\n", enum_name));
247
248 let mut used_variants: HashSet<String> = HashSet::new();
250 let mut variant_mappings: Vec<(String, String)> = Vec::new();
251
252 for value in values {
253 let variant = to_enum_variant(value);
254
255 let final_variant = if used_variants.contains(&variant) {
257 let mut counter = 2;
258 loop {
259 let new_variant = format!("{}{}", variant, counter);
260 if !used_variants.contains(&new_variant) {
261 break new_variant;
262 }
263 counter += 1;
264 }
265 } else {
266 variant
267 };
268
269 used_variants.insert(final_variant.clone());
270
271 let clean_value = value.trim_matches('\'').trim_matches('"');
273
274 if final_variant != clean_value {
276 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", clean_value));
277 }
278
279 code.push_str(&format!(" {},\n", final_variant));
280 variant_mappings.push((final_variant, clean_value.to_string()));
281 }
282
283 code.push_str("}\n\n");
284
285 code.push_str(&format!("impl rdbi::FromValue for {} {{\n", enum_name));
287 code.push_str(" fn from_value(value: rdbi::Value) -> rdbi::Result<Self> {\n");
288 code.push_str(" match value {\n");
289 code.push_str(" rdbi::Value::String(s) => match s.as_str() {\n");
290 for (variant, db_value) in &variant_mappings {
291 code.push_str(&format!(
292 " \"{}\" => Ok(Self::{}),\n",
293 db_value, variant
294 ));
295 }
296 code.push_str(&format!(
297 " _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: s }}),\n",
298 enum_name
299 ));
300 code.push_str(" },\n");
301 code.push_str(&format!(
302 " _ => Err(rdbi::Error::TypeConversion {{ expected: \"{}\", actual: value.type_name().to_string() }}),\n",
303 enum_name
304 ));
305 code.push_str(" }\n");
306 code.push_str(" }\n");
307 code.push_str("}\n\n");
308
309 code.push_str(&format!("impl rdbi::ToValue for {} {{\n", enum_name));
311 code.push_str(" fn to_value(&self) -> rdbi::Value {\n");
312 code.push_str(" rdbi::Value::String(match self {\n");
313 for (variant, db_value) in &variant_mappings {
314 code.push_str(&format!(
315 " Self::{} => \"{}\".to_string(),\n",
316 variant, db_value
317 ));
318 }
319 code.push_str(" })\n");
320 code.push_str(" }\n");
321 code.push_str("}\n");
322
323 code
324}
325
326fn get_index_info(table: &TableMetadata, column_name: &str) -> Vec<String> {
328 let mut info = Vec::new();
329
330 if let Some(pk) = &table.primary_key {
332 if pk.columns.contains(&column_name.to_string()) {
333 info.push("PRIMARY KEY".to_string());
334 }
335 }
336
337 for index in &table.indexes {
339 if index.columns.contains(&column_name.to_string()) {
340 let label = if index.unique {
341 format!("UNIQUE: {}", index.name)
342 } else {
343 format!("INDEX: {}", index.name)
344 };
345 info.push(label);
346 }
347 }
348
349 info
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::parser::{IndexMetadata, PrimaryKey};
356
357 fn make_table() -> TableMetadata {
358 TableMetadata {
359 name: "users".to_string(),
360 comment: None,
361 columns: vec![
362 ColumnMetadata {
363 name: "id".to_string(),
364 data_type: "BIGINT".to_string(),
365 nullable: false,
366 default_value: None,
367 is_auto_increment: true,
368 is_unsigned: false,
369 enum_values: None,
370 comment: None,
371 },
372 ColumnMetadata {
373 name: "username".to_string(),
374 data_type: "VARCHAR(255)".to_string(),
375 nullable: false,
376 default_value: None,
377 is_auto_increment: false,
378 is_unsigned: false,
379 enum_values: None,
380 comment: None,
381 },
382 ColumnMetadata {
383 name: "status".to_string(),
384 data_type: "ENUM".to_string(),
385 nullable: false,
386 default_value: None,
387 is_auto_increment: false,
388 is_unsigned: false,
389 enum_values: Some(vec![
390 "ACTIVE".to_string(),
391 "INACTIVE".to_string(),
392 "PENDING".to_string(),
393 ]),
394 comment: None,
395 },
396 ],
397 indexes: vec![IndexMetadata {
398 name: "idx_username".to_string(),
399 columns: vec!["username".to_string()],
400 unique: true,
401 }],
402 foreign_keys: vec![],
403 primary_key: Some(PrimaryKey {
404 columns: vec!["id".to_string()],
405 }),
406 }
407 }
408
409 #[test]
410 fn test_get_index_info() {
411 let table = make_table();
412 let info = get_index_info(&table, "id");
413 assert!(info.contains(&"PRIMARY KEY".to_string()));
414
415 let info = get_index_info(&table, "username");
416 assert!(info.iter().any(|i| i.contains("UNIQUE")));
417 }
418
419 #[test]
420 fn test_generate_enum() {
421 let col = ColumnMetadata {
422 name: "status".to_string(),
423 data_type: "ENUM".to_string(),
424 nullable: false,
425 default_value: None,
426 is_auto_increment: false,
427 is_unsigned: false,
428 enum_values: Some(vec!["ACTIVE".to_string(), "INACTIVE".to_string()]),
429 comment: None,
430 };
431
432 let code = generate_enum("users", &col, col.enum_values.as_ref().unwrap());
433 assert!(code.contains("pub enum UsersStatus"));
434 assert!(code.contains("Active"));
435 assert!(code.contains("Inactive"));
436 assert!(code.contains("impl rdbi::FromValue for UsersStatus"));
438 assert!(code.contains("impl rdbi::ToValue for UsersStatus"));
439 }
440}