1pub mod backend_trait;
2pub mod backends;
3pub mod overrides;
4pub mod resolve;
5pub mod validation;
6
7pub use backend_trait::{
8 CodegenBackend, RbsEnumInfo, RbsGenerationContext, RbsQueryInfo, ResolvedColumn, ResolvedParam,
9};
10pub use backends::get_backend;
11pub use overrides::TypeOverride;
12
13use scythe_backend::manifest::BackendManifest;
14use scythe_backend::naming::{row_struct_name, to_pascal_case};
15
16use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
17use scythe_core::catalog::Catalog;
18use scythe_core::errors::ScytheError;
19use scythe_core::parser::QueryCommand;
20
21#[derive(Debug, Default)]
26pub struct GeneratedCode {
27 pub query_fn: Option<String>,
28 pub row_struct: Option<String>,
29 pub model_struct: Option<String>,
30 pub enum_def: Option<String>,
31}
32
33pub fn singularize(name: &str) -> String {
39 if let Some(stem) = name.strip_suffix("ies") {
40 format!("{stem}y")
41 } else if name.ends_with("sses")
42 || name.ends_with("shes")
43 || name.ends_with("ches")
44 || name.ends_with("xes")
45 || name.ends_with("zes")
46 || name.ends_with("ses")
47 {
48 name[..name.len() - 2].to_string()
49 } else if name.ends_with('s') && !name.ends_with("ss") {
50 name[..name.len() - 1].to_string()
51 } else {
52 name.to_string()
53 }
54}
55
56pub fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
62 let backend = get_backend(backend_name, "postgresql")?;
63 Ok(backend.manifest().clone())
64}
65
66fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
68 if let Some(ref table_name) = analyzed.source_table {
69 let singular = singularize(table_name);
70 to_pascal_case(&singular).into_owned()
71 } else {
72 row_struct_name(&analyzed.name, &manifest.naming)
73 }
74}
75
76pub fn generate_with_backend(
82 analyzed: &AnalyzedQuery,
83 backend: &dyn CodegenBackend,
84) -> Result<GeneratedCode, ScytheError> {
85 generate_with_backend_and_overrides(analyzed, backend, &[])
86}
87
88pub fn generate_with_backend_and_overrides(
90 analyzed: &AnalyzedQuery,
91 backend: &dyn CodegenBackend,
92 overrides: &[TypeOverride],
93) -> Result<GeneratedCode, ScytheError> {
94 let manifest = backend.manifest();
95 let source_table = analyzed.source_table.as_deref().unwrap_or("");
96 let columns = resolve::resolve_columns(&analyzed.columns, manifest, overrides, source_table)?;
97 let params = resolve::resolve_params(&analyzed.params, manifest, overrides, source_table)?;
98
99 let mut result = GeneratedCode::default();
100
101 let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
104 if !enum_def.is_empty() {
105 result.enum_def = Some(enum_def);
106 }
107
108 let needs_row_struct = matches!(
110 analyzed.command,
111 QueryCommand::One | QueryCommand::Many | QueryCommand::Grouped
112 );
113 if needs_row_struct && !analyzed.columns.is_empty() {
114 if let Some(ref table_name) = analyzed.source_table {
115 result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
116 } else {
117 result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
118 }
119 }
120
121 if !analyzed.composites.is_empty() {
123 let mut comp_defs = String::new();
124 for (i, comp) in analyzed.composites.iter().enumerate() {
125 if i > 0 {
126 comp_defs.push_str("\n\n");
127 }
128 comp_defs.push_str(&backend.generate_composite_def(comp)?);
129 }
130 if !comp_defs.is_empty() {
131 if let Some(ref mut existing) = result.model_struct {
132 existing.push_str("\n\n");
133 existing.push_str(&comp_defs);
134 } else {
135 result.model_struct = Some(comp_defs);
136 }
137 }
138 }
139
140 let struct_name = determine_struct_name(analyzed, manifest);
142
143 if analyzed.command == QueryCommand::Grouped {
146 let many_proxy = AnalyzedQuery {
147 name: analyzed.name.clone(),
148 command: QueryCommand::Many,
149 sql: analyzed.sql.clone(),
150 columns: analyzed.columns.clone(),
151 params: analyzed.params.clone(),
152 deprecated: analyzed.deprecated.clone(),
153 source_table: analyzed.source_table.clone(),
154 composites: analyzed.composites.clone(),
155 enums: analyzed.enums.clone(),
156 optional_params: analyzed.optional_params.clone(),
157 group_by: analyzed.group_by.clone(),
158 };
159 result.query_fn =
160 Some(backend.generate_query_fn(&many_proxy, &struct_name, &columns, ¶ms)?);
161 } else {
162 result.query_fn =
163 Some(backend.generate_query_fn(analyzed, &struct_name, &columns, ¶ms)?);
164 }
165
166 Ok(result)
167}
168
169fn generate_enum_defs_via_backend(
171 analyzed: &AnalyzedQuery,
172 backend: &dyn CodegenBackend,
173) -> Result<String, ScytheError> {
174 use ahash::AHashSet;
175 use std::fmt::Write;
176
177 let mut out = String::new();
178 let mut seen_enums: AHashSet<String> = AHashSet::new();
179
180 let enum_sources: Vec<&str> = analyzed
181 .columns
182 .iter()
183 .filter_map(|col| col.neutral_type.strip_prefix("enum::"))
184 .chain(
185 analyzed
186 .params
187 .iter()
188 .filter_map(|p| p.neutral_type.strip_prefix("enum::")),
189 )
190 .collect();
191
192 for sql_name in enum_sources {
193 if !seen_enums.insert(sql_name.to_string()) {
194 continue;
195 }
196
197 if !out.is_empty() {
198 let _ = writeln!(out);
199 }
200
201 if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
202 out.push_str(&backend.generate_enum_def(enum_info)?);
203 } else {
204 let stub_info = EnumInfo {
207 sql_name: sql_name.to_string(),
208 values: vec![],
209 };
210 out.push_str(&backend.generate_enum_def(&stub_info)?);
211 }
212 }
213
214 Ok(out)
215}
216
217pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
219 let backend = get_backend("rust-sqlx", "postgresql")?;
220 generate_with_backend(analyzed, &*backend)
221}
222
223pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
225 Ok(GeneratedCode::default())
226}
227
228pub fn generate_single_enum_def_with_backend(
230 enum_info: &EnumInfo,
231 backend: &dyn CodegenBackend,
232) -> Result<String, ScytheError> {
233 backend.generate_enum_def(enum_info)
234}
235
236pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
239 use scythe_backend::naming::{enum_type_name, enum_variant_name};
241 use std::fmt::Write;
242
243 let mut out = String::with_capacity(256);
244 let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
245
246 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
247 let _ = writeln!(
248 out,
249 "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
250 enum_info.sql_name
251 );
252 let _ = writeln!(out, "pub enum {type_name} {{");
253
254 for value in &enum_info.values {
255 let variant = enum_variant_name(value, &manifest.naming);
256 let _ = writeln!(out, " {variant},");
257 }
258
259 let _ = write!(out, "}}");
260 out
261}
262
263pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
265 let b = backends::sqlx::SqlxBackend::new("postgresql")?;
266 Ok(b.manifest().clone())
267}
268
269#[cfg(test)]
274mod tests {
275 use super::*;
276 use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
277 use scythe_core::parser::QueryCommand;
278
279 fn make_query(
280 name: &str,
281 command: QueryCommand,
282 sql: &str,
283 columns: Vec<AnalyzedColumn>,
284 params: Vec<AnalyzedParam>,
285 ) -> AnalyzedQuery {
286 AnalyzedQuery {
287 name: name.to_string(),
288 command,
289 sql: sql.to_string(),
290 columns,
291 params,
292 deprecated: None,
293 source_table: None,
294 composites: Vec::new(),
295 enums: Vec::new(),
296 optional_params: Vec::new(),
297 group_by: None,
298 }
299 }
300
301 #[test]
302 fn test_generate_select_many() {
303 let query = make_query(
304 "ListUsers",
305 QueryCommand::Many,
306 "SELECT id, name, email FROM users",
307 vec![
308 AnalyzedColumn {
309 name: "id".to_string(),
310 neutral_type: "int32".to_string(),
311 nullable: false,
312 },
313 AnalyzedColumn {
314 name: "name".to_string(),
315 neutral_type: "string".to_string(),
316 nullable: false,
317 },
318 AnalyzedColumn {
319 name: "email".to_string(),
320 neutral_type: "string".to_string(),
321 nullable: true,
322 },
323 ],
324 vec![],
325 );
326
327 let result = generate(&query).unwrap();
328
329 let row_struct = result.row_struct.unwrap();
330 assert!(row_struct.contains("pub struct ListUsersRow"));
331 assert!(row_struct.contains("pub id: i32"));
332 assert!(row_struct.contains("pub name: String"));
333 assert!(row_struct.contains("pub email: Option<String>"));
334
335 let query_fn = result.query_fn.unwrap();
336 assert!(query_fn.contains("pub async fn list_users("));
337 assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
338 assert!(query_fn.contains(".fetch_all(pool)"));
339 }
340
341 #[test]
342 fn test_generate_select_one_with_param() {
343 let query = make_query(
344 "GetUser",
345 QueryCommand::One,
346 "SELECT id, name FROM users WHERE id = $1",
347 vec![
348 AnalyzedColumn {
349 name: "id".to_string(),
350 neutral_type: "int32".to_string(),
351 nullable: false,
352 },
353 AnalyzedColumn {
354 name: "name".to_string(),
355 neutral_type: "string".to_string(),
356 nullable: false,
357 },
358 ],
359 vec![AnalyzedParam {
360 name: "id".to_string(),
361 neutral_type: "int32".to_string(),
362 nullable: false,
363 position: 1,
364 }],
365 );
366
367 let result = generate(&query).unwrap();
368
369 let query_fn = result.query_fn.unwrap();
370 assert!(query_fn.contains("pub async fn get_user("));
371 assert!(query_fn.contains("id: i32"));
372 assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
373 assert!(query_fn.contains(".fetch_one(pool)"));
374 }
375
376 #[test]
377 fn test_generate_exec() {
378 let query = make_query(
379 "DeleteUser",
380 QueryCommand::Exec,
381 "DELETE FROM users WHERE id = $1",
382 vec![],
383 vec![AnalyzedParam {
384 name: "id".to_string(),
385 neutral_type: "int32".to_string(),
386 nullable: false,
387 position: 1,
388 }],
389 );
390
391 let result = generate(&query).unwrap();
392
393 assert!(result.row_struct.is_none());
394
395 let query_fn = result.query_fn.unwrap();
396 assert!(query_fn.contains("pub async fn delete_user("));
397 assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
398 assert!(query_fn.contains(".execute(pool)"));
399 }
400
401 #[test]
402 fn test_generate_with_enum_column() {
403 let query = make_query(
404 "GetUserStatus",
405 QueryCommand::One,
406 "SELECT id, status FROM users WHERE id = $1",
407 vec![
408 AnalyzedColumn {
409 name: "id".to_string(),
410 neutral_type: "int32".to_string(),
411 nullable: false,
412 },
413 AnalyzedColumn {
414 name: "status".to_string(),
415 neutral_type: "enum::user_status".to_string(),
416 nullable: false,
417 },
418 ],
419 vec![AnalyzedParam {
420 name: "id".to_string(),
421 neutral_type: "int32".to_string(),
422 nullable: false,
423 position: 1,
424 }],
425 );
426
427 let result = generate(&query).unwrap();
428
429 assert!(result.enum_def.is_some());
430 let enum_def = result.enum_def.unwrap();
431 assert!(enum_def.contains("pub enum UserStatus"));
432 assert!(enum_def.contains("type_name = \"user_status\""));
433
434 let row_struct = result.row_struct.unwrap();
435 assert!(row_struct.contains("pub status: UserStatus"));
436 }
437
438 #[test]
439 fn test_generate_from_catalog_returns_default() {
440 let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
441 let result = generate_from_catalog(&catalog).unwrap();
442 assert!(result.query_fn.is_none());
443 assert!(result.row_struct.is_none());
444 }
445
446 #[test]
447 fn test_singularize_basic() {
448 assert_eq!(singularize("users"), "user");
449 assert_eq!(singularize("orders"), "order");
450 assert_eq!(singularize("posts"), "post");
451 }
452
453 #[test]
454 fn test_singularize_ies() {
455 assert_eq!(singularize("categories"), "category");
456 assert_eq!(singularize("entries"), "entry");
457 }
458
459 #[test]
460 fn test_singularize_sses() {
461 assert_eq!(singularize("addresses"), "address");
462 assert_eq!(singularize("classes"), "class");
463 }
464
465 #[test]
466 fn test_singularize_no_change() {
467 assert_eq!(singularize("status"), "statu");
468 assert_eq!(singularize("boss"), "boss");
469 assert_eq!(singularize("address"), "address");
470 }
471
472 #[test]
473 fn test_singularize_shes_ches_xes() {
474 assert_eq!(singularize("batches"), "batch");
475 assert_eq!(singularize("boxes"), "box");
476 assert_eq!(singularize("wishes"), "wish");
477 }
478
479 #[test]
480 fn test_tokio_postgres_backend_basic() {
481 let backend = get_backend("tokio-postgres", "postgresql").unwrap();
482
483 let query = make_query(
484 "ListUsers",
485 QueryCommand::Many,
486 "SELECT id, name FROM users",
487 vec![
488 AnalyzedColumn {
489 name: "id".to_string(),
490 neutral_type: "int32".to_string(),
491 nullable: false,
492 },
493 AnalyzedColumn {
494 name: "name".to_string(),
495 neutral_type: "string".to_string(),
496 nullable: false,
497 },
498 ],
499 vec![],
500 );
501
502 let result = generate_with_backend(&query, &*backend).unwrap();
503
504 let row_struct = result.row_struct.unwrap();
505 assert!(row_struct.contains("pub struct ListUsersRow"));
506 assert!(row_struct.contains("pub id: i32"));
507 assert!(row_struct.contains("pub name: String"));
508 assert!(row_struct.contains("from_row"));
509 assert!(row_struct.contains("tokio_postgres::Row"));
510 assert!(!row_struct.contains("sqlx"));
512
513 let query_fn = result.query_fn.unwrap();
514 assert!(query_fn.contains("pub async fn list_users("));
515 assert!(query_fn.contains("tokio_postgres::Client"));
516 assert!(query_fn.contains("Box<dyn std::error::Error>"));
517 assert!(!query_fn.contains("sqlx"));
518 }
519
520 #[test]
521 fn test_tokio_postgres_enum() {
522 let backend = get_backend("tokio-postgres", "postgresql").unwrap();
523
524 let enum_info = scythe_core::analyzer::EnumInfo {
525 sql_name: "user_status".to_string(),
526 values: vec!["active".to_string(), "inactive".to_string()],
527 };
528
529 let def = backend.generate_enum_def(&enum_info).unwrap();
530 assert!(def.contains("pub enum UserStatus"));
531 assert!(def.contains("Active"));
532 assert!(def.contains("Inactive"));
533 assert!(def.contains("impl std::fmt::Display"));
534 assert!(def.contains("impl std::str::FromStr"));
535 assert!(!def.contains("sqlx"));
537 }
538}