1pub mod backend_trait;
2pub mod backends;
3pub mod overrides;
4pub mod resolve;
5pub mod validation;
6
7pub use backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
8pub use backends::get_backend;
9pub use overrides::TypeOverride;
10
11use scythe_backend::manifest::BackendManifest;
12use scythe_backend::naming::{row_struct_name, to_pascal_case};
13
14use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
15use scythe_core::catalog::Catalog;
16use scythe_core::errors::ScytheError;
17use scythe_core::parser::QueryCommand;
18
19#[derive(Debug, Default)]
24pub struct GeneratedCode {
25 pub query_fn: Option<String>,
26 pub row_struct: Option<String>,
27 pub model_struct: Option<String>,
28 pub enum_def: Option<String>,
29}
30
31pub(crate) fn singularize(name: &str) -> String {
37 if let Some(stem) = name.strip_suffix("ies") {
38 format!("{stem}y")
39 } else if name.ends_with("sses")
40 || name.ends_with("shes")
41 || name.ends_with("ches")
42 || name.ends_with("xes")
43 || name.ends_with("zes")
44 || name.ends_with("ses")
45 {
46 name[..name.len() - 2].to_string()
47 } else if name.ends_with('s') && !name.ends_with("ss") {
48 name[..name.len() - 1].to_string()
49 } else {
50 name.to_string()
51 }
52}
53
54pub fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
60 let backend = get_backend(backend_name, "postgresql")?;
61 Ok(backend.manifest().clone())
62}
63
64fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
66 if let Some(ref table_name) = analyzed.source_table {
67 let singular = singularize(table_name);
68 to_pascal_case(&singular).into_owned()
69 } else {
70 row_struct_name(&analyzed.name, &manifest.naming)
71 }
72}
73
74pub fn generate_with_backend(
80 analyzed: &AnalyzedQuery,
81 backend: &dyn CodegenBackend,
82) -> Result<GeneratedCode, ScytheError> {
83 generate_with_backend_and_overrides(analyzed, backend, &[])
84}
85
86pub fn generate_with_backend_and_overrides(
88 analyzed: &AnalyzedQuery,
89 backend: &dyn CodegenBackend,
90 overrides: &[TypeOverride],
91) -> Result<GeneratedCode, ScytheError> {
92 let manifest = backend.manifest();
93 let source_table = analyzed.source_table.as_deref().unwrap_or("");
94 let columns = resolve::resolve_columns(&analyzed.columns, manifest, overrides, source_table)?;
95 let params = resolve::resolve_params(&analyzed.params, manifest, overrides, source_table)?;
96
97 let mut result = GeneratedCode::default();
98
99 let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
102 if !enum_def.is_empty() {
103 result.enum_def = Some(enum_def);
104 }
105
106 let needs_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
108 if needs_row_struct && !analyzed.columns.is_empty() {
109 if let Some(ref table_name) = analyzed.source_table {
110 result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
111 } else {
112 result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
113 }
114 }
115
116 if !analyzed.composites.is_empty() {
118 let mut comp_defs = String::new();
119 for (i, comp) in analyzed.composites.iter().enumerate() {
120 if i > 0 {
121 comp_defs.push_str("\n\n");
122 }
123 comp_defs.push_str(&backend.generate_composite_def(comp)?);
124 }
125 if !comp_defs.is_empty() {
126 if let Some(ref mut existing) = result.model_struct {
127 existing.push_str("\n\n");
128 existing.push_str(&comp_defs);
129 } else {
130 result.model_struct = Some(comp_defs);
131 }
132 }
133 }
134
135 let struct_name = determine_struct_name(analyzed, manifest);
137 result.query_fn = Some(backend.generate_query_fn(analyzed, &struct_name, &columns, ¶ms)?);
138
139 Ok(result)
140}
141
142fn generate_enum_defs_via_backend(
144 analyzed: &AnalyzedQuery,
145 backend: &dyn CodegenBackend,
146) -> Result<String, ScytheError> {
147 use ahash::AHashSet;
148 use std::fmt::Write;
149
150 let mut out = String::new();
151 let mut seen_enums: AHashSet<String> = AHashSet::new();
152
153 let enum_sources: Vec<&str> = analyzed
154 .columns
155 .iter()
156 .filter_map(|col| col.neutral_type.strip_prefix("enum::"))
157 .chain(
158 analyzed
159 .params
160 .iter()
161 .filter_map(|p| p.neutral_type.strip_prefix("enum::")),
162 )
163 .collect();
164
165 for sql_name in enum_sources {
166 if !seen_enums.insert(sql_name.to_string()) {
167 continue;
168 }
169
170 if !out.is_empty() {
171 let _ = writeln!(out);
172 }
173
174 if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
175 out.push_str(&backend.generate_enum_def(enum_info)?);
176 } else {
177 let stub_info = EnumInfo {
180 sql_name: sql_name.to_string(),
181 values: vec![],
182 };
183 out.push_str(&backend.generate_enum_def(&stub_info)?);
184 }
185 }
186
187 Ok(out)
188}
189
190pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
192 let backend = get_backend("rust-sqlx", "postgresql")?;
193 generate_with_backend(analyzed, &*backend)
194}
195
196pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
198 Ok(GeneratedCode::default())
199}
200
201pub fn generate_single_enum_def_with_backend(
203 enum_info: &EnumInfo,
204 backend: &dyn CodegenBackend,
205) -> Result<String, ScytheError> {
206 backend.generate_enum_def(enum_info)
207}
208
209pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
212 use scythe_backend::naming::{enum_type_name, enum_variant_name};
214 use std::fmt::Write;
215
216 let mut out = String::with_capacity(256);
217 let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
218
219 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
220 let _ = writeln!(
221 out,
222 "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
223 enum_info.sql_name
224 );
225 let _ = writeln!(out, "pub enum {type_name} {{");
226
227 for value in &enum_info.values {
228 let variant = enum_variant_name(value, &manifest.naming);
229 let _ = writeln!(out, " {variant},");
230 }
231
232 let _ = write!(out, "}}");
233 out
234}
235
236pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
238 let b = backends::sqlx::SqlxBackend::new("postgresql")?;
239 Ok(b.manifest().clone())
240}
241
242#[cfg(test)]
247mod tests {
248 use super::*;
249 use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
250 use scythe_core::parser::QueryCommand;
251
252 fn make_query(
253 name: &str,
254 command: QueryCommand,
255 sql: &str,
256 columns: Vec<AnalyzedColumn>,
257 params: Vec<AnalyzedParam>,
258 ) -> AnalyzedQuery {
259 AnalyzedQuery {
260 name: name.to_string(),
261 command,
262 sql: sql.to_string(),
263 columns,
264 params,
265 deprecated: None,
266 source_table: None,
267 composites: Vec::new(),
268 enums: Vec::new(),
269 optional_params: Vec::new(),
270 }
271 }
272
273 #[test]
274 fn test_generate_select_many() {
275 let query = make_query(
276 "ListUsers",
277 QueryCommand::Many,
278 "SELECT id, name, email FROM users",
279 vec![
280 AnalyzedColumn {
281 name: "id".to_string(),
282 neutral_type: "int32".to_string(),
283 nullable: false,
284 },
285 AnalyzedColumn {
286 name: "name".to_string(),
287 neutral_type: "string".to_string(),
288 nullable: false,
289 },
290 AnalyzedColumn {
291 name: "email".to_string(),
292 neutral_type: "string".to_string(),
293 nullable: true,
294 },
295 ],
296 vec![],
297 );
298
299 let result = generate(&query).unwrap();
300
301 let row_struct = result.row_struct.unwrap();
302 assert!(row_struct.contains("pub struct ListUsersRow"));
303 assert!(row_struct.contains("pub id: i32"));
304 assert!(row_struct.contains("pub name: String"));
305 assert!(row_struct.contains("pub email: Option<String>"));
306
307 let query_fn = result.query_fn.unwrap();
308 assert!(query_fn.contains("pub async fn list_users("));
309 assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
310 assert!(query_fn.contains(".fetch_all(pool)"));
311 }
312
313 #[test]
314 fn test_generate_select_one_with_param() {
315 let query = make_query(
316 "GetUser",
317 QueryCommand::One,
318 "SELECT id, name FROM users WHERE id = $1",
319 vec![
320 AnalyzedColumn {
321 name: "id".to_string(),
322 neutral_type: "int32".to_string(),
323 nullable: false,
324 },
325 AnalyzedColumn {
326 name: "name".to_string(),
327 neutral_type: "string".to_string(),
328 nullable: false,
329 },
330 ],
331 vec![AnalyzedParam {
332 name: "id".to_string(),
333 neutral_type: "int32".to_string(),
334 nullable: false,
335 position: 1,
336 }],
337 );
338
339 let result = generate(&query).unwrap();
340
341 let query_fn = result.query_fn.unwrap();
342 assert!(query_fn.contains("pub async fn get_user("));
343 assert!(query_fn.contains("id: i32"));
344 assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
345 assert!(query_fn.contains(".fetch_one(pool)"));
346 }
347
348 #[test]
349 fn test_generate_exec() {
350 let query = make_query(
351 "DeleteUser",
352 QueryCommand::Exec,
353 "DELETE FROM users WHERE id = $1",
354 vec![],
355 vec![AnalyzedParam {
356 name: "id".to_string(),
357 neutral_type: "int32".to_string(),
358 nullable: false,
359 position: 1,
360 }],
361 );
362
363 let result = generate(&query).unwrap();
364
365 assert!(result.row_struct.is_none());
366
367 let query_fn = result.query_fn.unwrap();
368 assert!(query_fn.contains("pub async fn delete_user("));
369 assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
370 assert!(query_fn.contains(".execute(pool)"));
371 }
372
373 #[test]
374 fn test_generate_with_enum_column() {
375 let query = make_query(
376 "GetUserStatus",
377 QueryCommand::One,
378 "SELECT id, status FROM users WHERE id = $1",
379 vec![
380 AnalyzedColumn {
381 name: "id".to_string(),
382 neutral_type: "int32".to_string(),
383 nullable: false,
384 },
385 AnalyzedColumn {
386 name: "status".to_string(),
387 neutral_type: "enum::user_status".to_string(),
388 nullable: false,
389 },
390 ],
391 vec![AnalyzedParam {
392 name: "id".to_string(),
393 neutral_type: "int32".to_string(),
394 nullable: false,
395 position: 1,
396 }],
397 );
398
399 let result = generate(&query).unwrap();
400
401 assert!(result.enum_def.is_some());
402 let enum_def = result.enum_def.unwrap();
403 assert!(enum_def.contains("pub enum UserStatus"));
404 assert!(enum_def.contains("type_name = \"user_status\""));
405
406 let row_struct = result.row_struct.unwrap();
407 assert!(row_struct.contains("pub status: UserStatus"));
408 }
409
410 #[test]
411 fn test_generate_from_catalog_returns_default() {
412 let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
413 let result = generate_from_catalog(&catalog).unwrap();
414 assert!(result.query_fn.is_none());
415 assert!(result.row_struct.is_none());
416 }
417
418 #[test]
419 fn test_singularize_basic() {
420 assert_eq!(singularize("users"), "user");
421 assert_eq!(singularize("orders"), "order");
422 assert_eq!(singularize("posts"), "post");
423 }
424
425 #[test]
426 fn test_singularize_ies() {
427 assert_eq!(singularize("categories"), "category");
428 assert_eq!(singularize("entries"), "entry");
429 }
430
431 #[test]
432 fn test_singularize_sses() {
433 assert_eq!(singularize("addresses"), "address");
434 assert_eq!(singularize("classes"), "class");
435 }
436
437 #[test]
438 fn test_singularize_no_change() {
439 assert_eq!(singularize("status"), "statu");
440 assert_eq!(singularize("boss"), "boss");
441 assert_eq!(singularize("address"), "address");
442 }
443
444 #[test]
445 fn test_singularize_shes_ches_xes() {
446 assert_eq!(singularize("batches"), "batch");
447 assert_eq!(singularize("boxes"), "box");
448 assert_eq!(singularize("wishes"), "wish");
449 }
450
451 #[test]
452 fn test_tokio_postgres_backend_basic() {
453 let backend = get_backend("tokio-postgres", "postgresql").unwrap();
454
455 let query = make_query(
456 "ListUsers",
457 QueryCommand::Many,
458 "SELECT id, name FROM users",
459 vec![
460 AnalyzedColumn {
461 name: "id".to_string(),
462 neutral_type: "int32".to_string(),
463 nullable: false,
464 },
465 AnalyzedColumn {
466 name: "name".to_string(),
467 neutral_type: "string".to_string(),
468 nullable: false,
469 },
470 ],
471 vec![],
472 );
473
474 let result = generate_with_backend(&query, &*backend).unwrap();
475
476 let row_struct = result.row_struct.unwrap();
477 assert!(row_struct.contains("pub struct ListUsersRow"));
478 assert!(row_struct.contains("pub id: i32"));
479 assert!(row_struct.contains("pub name: String"));
480 assert!(row_struct.contains("from_row"));
481 assert!(row_struct.contains("tokio_postgres::Row"));
482 assert!(!row_struct.contains("sqlx"));
484
485 let query_fn = result.query_fn.unwrap();
486 assert!(query_fn.contains("pub async fn list_users("));
487 assert!(query_fn.contains("tokio_postgres::Client"));
488 assert!(query_fn.contains("tokio_postgres::Error"));
489 assert!(!query_fn.contains("sqlx"));
490 }
491
492 #[test]
493 fn test_tokio_postgres_enum() {
494 let backend = get_backend("tokio-postgres", "postgresql").unwrap();
495
496 let enum_info = scythe_core::analyzer::EnumInfo {
497 sql_name: "user_status".to_string(),
498 values: vec!["active".to_string(), "inactive".to_string()],
499 };
500
501 let def = backend.generate_enum_def(&enum_info).unwrap();
502 assert!(def.contains("pub enum UserStatus"));
503 assert!(def.contains("Active"));
504 assert!(def.contains("Inactive"));
505 assert!(def.contains("impl std::fmt::Display"));
506 assert!(def.contains("impl std::str::FromStr"));
507 assert!(!def.contains("sqlx"));
509 }
510}