Skip to main content

scythe_codegen/
lib.rs

1pub mod backend_trait;
2pub mod backends;
3pub mod overrides;
4pub mod resolve;
5pub mod validation;
6
7pub use backend_trait::{
8    CodegenBackend, RbsEnumInfo, RbsGenerationContext, RbsQueryInfo, ResolvedColumn, ResolvedParam,
9};
10pub use backends::get_backend;
11pub use overrides::TypeOverride;
12
13use scythe_backend::manifest::BackendManifest;
14use scythe_backend::naming::{row_struct_name, to_pascal_case};
15
16use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
17use scythe_core::catalog::Catalog;
18use scythe_core::errors::ScytheError;
19use scythe_core::parser::QueryCommand;
20
21// ---------------------------------------------------------------------------
22// Output types
23// ---------------------------------------------------------------------------
24
25#[derive(Debug, Default)]
26pub struct GeneratedCode {
27    pub query_fn: Option<String>,
28    pub row_struct: Option<String>,
29    pub model_struct: Option<String>,
30    pub enum_def: Option<String>,
31}
32
33// ---------------------------------------------------------------------------
34// Utility (shared across backends)
35// ---------------------------------------------------------------------------
36
37/// Simple singularization: remove trailing 's'.
38pub fn singularize(name: &str) -> String {
39    if let Some(stem) = name.strip_suffix("ies") {
40        format!("{stem}y")
41    } else if name.ends_with("sses")
42        || name.ends_with("shes")
43        || name.ends_with("ches")
44        || name.ends_with("xes")
45        || name.ends_with("zes")
46        || name.ends_with("ses")
47    {
48        name[..name.len() - 2].to_string()
49    } else if name.ends_with('s') && !name.ends_with("ss") {
50        name[..name.len() - 1].to_string()
51    } else {
52        name.to_string()
53    }
54}
55
56// ---------------------------------------------------------------------------
57// Manifest helpers
58// ---------------------------------------------------------------------------
59
60/// Get the manifest for a backend. Defaults to PostgreSQL engine.
61pub fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
62    let backend = get_backend(backend_name, "postgresql")?;
63    Ok(backend.manifest().clone())
64}
65
66/// Determine the struct name for a query (model struct or row struct).
67fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
68    if let Some(ref table_name) = analyzed.source_table {
69        let singular = singularize(table_name);
70        to_pascal_case(&singular).into_owned()
71    } else {
72        row_struct_name(&analyzed.name, &manifest.naming)
73    }
74}
75
76// ---------------------------------------------------------------------------
77// Public API
78// ---------------------------------------------------------------------------
79
80/// Generate code using a specific backend.
81pub fn generate_with_backend(
82    analyzed: &AnalyzedQuery,
83    backend: &dyn CodegenBackend,
84) -> Result<GeneratedCode, ScytheError> {
85    generate_with_backend_and_overrides(analyzed, backend, &[])
86}
87
88/// Generate code using a specific backend with type overrides.
89pub fn generate_with_backend_and_overrides(
90    analyzed: &AnalyzedQuery,
91    backend: &dyn CodegenBackend,
92    overrides: &[TypeOverride],
93) -> Result<GeneratedCode, ScytheError> {
94    let manifest = backend.manifest();
95    let source_table = analyzed.source_table.as_deref().unwrap_or("");
96    let columns = resolve::resolve_columns(&analyzed.columns, manifest, overrides, source_table)?;
97    let params = resolve::resolve_params(&analyzed.params, manifest, overrides, source_table)?;
98
99    let mut result = GeneratedCode::default();
100
101    // Generate enum definitions for any enum-typed columns
102    // Use the backend-specific enum generation for proper derives
103    let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
104    if !enum_def.is_empty() {
105        result.enum_def = Some(enum_def);
106    }
107
108    // Generate row/model struct for :one, :opt, and :many commands (not :batch)
109    let needs_row_struct = matches!(
110        analyzed.command,
111        QueryCommand::One | QueryCommand::Opt | QueryCommand::Many | QueryCommand::Grouped
112    );
113    if needs_row_struct && !analyzed.columns.is_empty() {
114        if let Some(ref table_name) = analyzed.source_table {
115            result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
116        } else {
117            result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
118        }
119    }
120
121    // Generate composite type definitions
122    if !analyzed.composites.is_empty() {
123        let mut comp_defs = String::new();
124        for (i, comp) in analyzed.composites.iter().enumerate() {
125            if i > 0 {
126                comp_defs.push_str("\n\n");
127            }
128            comp_defs.push_str(&backend.generate_composite_def(comp)?);
129        }
130        if !comp_defs.is_empty() {
131            if let Some(ref mut existing) = result.model_struct {
132                existing.push_str("\n\n");
133                existing.push_str(&comp_defs);
134            } else {
135                result.model_struct = Some(comp_defs);
136            }
137        }
138    }
139
140    // Generate query function
141    let struct_name = determine_struct_name(analyzed, manifest);
142
143    // For :grouped, delegate to the backend as :many for now.
144    // Full grouped codegen (parent + child structs, grouping logic) will come in a later phase.
145    if analyzed.command == QueryCommand::Grouped {
146        let many_proxy = AnalyzedQuery {
147            name: analyzed.name.clone(),
148            command: QueryCommand::Many,
149            sql: analyzed.sql.clone(),
150            columns: analyzed.columns.clone(),
151            params: analyzed.params.clone(),
152            deprecated: analyzed.deprecated.clone(),
153            source_table: analyzed.source_table.clone(),
154            composites: analyzed.composites.clone(),
155            enums: analyzed.enums.clone(),
156            optional_params: analyzed.optional_params.clone(),
157            group_by: analyzed.group_by.clone(),
158        };
159        result.query_fn =
160            Some(backend.generate_query_fn(&many_proxy, &struct_name, &columns, &params)?);
161    } else {
162        result.query_fn =
163            Some(backend.generate_query_fn(analyzed, &struct_name, &columns, &params)?);
164    }
165
166    Ok(result)
167}
168
169/// Generate enum definitions via the backend trait.
170fn generate_enum_defs_via_backend(
171    analyzed: &AnalyzedQuery,
172    backend: &dyn CodegenBackend,
173) -> Result<String, ScytheError> {
174    use ahash::AHashSet;
175    use std::fmt::Write;
176
177    let mut out = String::new();
178    let mut seen_enums: AHashSet<String> = AHashSet::new();
179
180    let enum_sources: Vec<&str> = analyzed
181        .columns
182        .iter()
183        .filter_map(|col| col.neutral_type.strip_prefix("enum::"))
184        .chain(
185            analyzed
186                .params
187                .iter()
188                .filter_map(|p| p.neutral_type.strip_prefix("enum::")),
189        )
190        .collect();
191
192    for sql_name in enum_sources {
193        if !seen_enums.insert(sql_name.to_string()) {
194            continue;
195        }
196
197        if !out.is_empty() {
198            let _ = writeln!(out);
199        }
200
201        if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
202            out.push_str(&backend.generate_enum_def(enum_info)?);
203        } else {
204            // Generate a stub enum with no variants (for enum types referenced but
205            // not fully defined in the query's EnumInfo list).
206            let stub_info = EnumInfo {
207                sql_name: sql_name.to_string(),
208                values: vec![],
209            };
210            out.push_str(&backend.generate_enum_def(&stub_info)?);
211        }
212    }
213
214    Ok(out)
215}
216
217/// Backward-compatible: generate code using the default sqlx backend.
218pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
219    let backend = get_backend("rust-sqlx", "postgresql")?;
220    generate_with_backend(analyzed, &*backend)
221}
222
223/// Stub for catalog-level codegen. Returns default for now.
224pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
225    Ok(GeneratedCode::default())
226}
227
228/// Generate a single enum definition using a specific backend.
229pub fn generate_single_enum_def_with_backend(
230    enum_info: &EnumInfo,
231    backend: &dyn CodegenBackend,
232) -> Result<String, ScytheError> {
233    backend.generate_enum_def(enum_info)
234}
235
236/// Backward-compatible: generate a single enum definition (sqlx backend).
237/// Uses the manifest directly for backward compatibility with existing callers.
238pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
239    // Reproduce the old behavior exactly using the sqlx backend's logic
240    use scythe_backend::naming::{enum_type_name, enum_variant_name};
241    use std::fmt::Write;
242
243    let mut out = String::with_capacity(256);
244    let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
245
246    let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
247    let _ = writeln!(
248        out,
249        "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
250        enum_info.sql_name
251    );
252    let _ = writeln!(out, "pub enum {type_name} {{");
253
254    for value in &enum_info.values {
255        let variant = enum_variant_name(value, &manifest.naming);
256        let _ = writeln!(out, "    {variant},");
257    }
258
259    let _ = write!(out, "}}");
260    out
261}
262
263/// Backward-compatible: load the default sqlx manifest.
264pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
265    let b = backends::sqlx::SqlxBackend::new("postgresql")?;
266    Ok(b.manifest().clone())
267}
268
269// ---------------------------------------------------------------------------
270// Tests
271// ---------------------------------------------------------------------------
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
277    use scythe_core::parser::QueryCommand;
278
279    fn make_query(
280        name: &str,
281        command: QueryCommand,
282        sql: &str,
283        columns: Vec<AnalyzedColumn>,
284        params: Vec<AnalyzedParam>,
285    ) -> AnalyzedQuery {
286        AnalyzedQuery {
287            name: name.to_string(),
288            command,
289            sql: sql.to_string(),
290            columns,
291            params,
292            deprecated: None,
293            source_table: None,
294            composites: Vec::new(),
295            enums: Vec::new(),
296            optional_params: Vec::new(),
297            group_by: None,
298        }
299    }
300
301    #[test]
302    fn test_generate_select_many() {
303        let query = make_query(
304            "ListUsers",
305            QueryCommand::Many,
306            "SELECT id, name, email FROM users",
307            vec![
308                AnalyzedColumn {
309                    name: "id".to_string(),
310                    neutral_type: "int32".to_string(),
311                    nullable: false,
312                },
313                AnalyzedColumn {
314                    name: "name".to_string(),
315                    neutral_type: "string".to_string(),
316                    nullable: false,
317                },
318                AnalyzedColumn {
319                    name: "email".to_string(),
320                    neutral_type: "string".to_string(),
321                    nullable: true,
322                },
323            ],
324            vec![],
325        );
326
327        let result = generate(&query).unwrap();
328
329        let row_struct = result.row_struct.unwrap();
330        assert!(row_struct.contains("pub struct ListUsersRow"));
331        assert!(row_struct.contains("pub id: i32"));
332        assert!(row_struct.contains("pub name: String"));
333        assert!(row_struct.contains("pub email: Option<String>"));
334
335        let query_fn = result.query_fn.unwrap();
336        assert!(query_fn.contains("pub async fn list_users("));
337        assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
338        assert!(query_fn.contains(".fetch_all(pool)"));
339    }
340
341    #[test]
342    fn test_generate_select_one_with_param() {
343        let query = make_query(
344            "GetUser",
345            QueryCommand::One,
346            "SELECT id, name FROM users WHERE id = $1",
347            vec![
348                AnalyzedColumn {
349                    name: "id".to_string(),
350                    neutral_type: "int32".to_string(),
351                    nullable: false,
352                },
353                AnalyzedColumn {
354                    name: "name".to_string(),
355                    neutral_type: "string".to_string(),
356                    nullable: false,
357                },
358            ],
359            vec![AnalyzedParam {
360                name: "id".to_string(),
361                neutral_type: "int32".to_string(),
362                nullable: false,
363                position: 1,
364            }],
365        );
366
367        let result = generate(&query).unwrap();
368
369        let query_fn = result.query_fn.unwrap();
370        assert!(query_fn.contains("pub async fn get_user("));
371        assert!(query_fn.contains("id: i32"));
372        assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
373        assert!(query_fn.contains(".fetch_one(pool)"));
374    }
375
376    #[test]
377    fn test_generate_exec() {
378        let query = make_query(
379            "DeleteUser",
380            QueryCommand::Exec,
381            "DELETE FROM users WHERE id = $1",
382            vec![],
383            vec![AnalyzedParam {
384                name: "id".to_string(),
385                neutral_type: "int32".to_string(),
386                nullable: false,
387                position: 1,
388            }],
389        );
390
391        let result = generate(&query).unwrap();
392
393        assert!(result.row_struct.is_none());
394
395        let query_fn = result.query_fn.unwrap();
396        assert!(query_fn.contains("pub async fn delete_user("));
397        assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
398        assert!(query_fn.contains(".execute(pool)"));
399    }
400
401    #[test]
402    fn test_generate_with_enum_column() {
403        let query = make_query(
404            "GetUserStatus",
405            QueryCommand::One,
406            "SELECT id, status FROM users WHERE id = $1",
407            vec![
408                AnalyzedColumn {
409                    name: "id".to_string(),
410                    neutral_type: "int32".to_string(),
411                    nullable: false,
412                },
413                AnalyzedColumn {
414                    name: "status".to_string(),
415                    neutral_type: "enum::user_status".to_string(),
416                    nullable: false,
417                },
418            ],
419            vec![AnalyzedParam {
420                name: "id".to_string(),
421                neutral_type: "int32".to_string(),
422                nullable: false,
423                position: 1,
424            }],
425        );
426
427        let result = generate(&query).unwrap();
428
429        assert!(result.enum_def.is_some());
430        let enum_def = result.enum_def.unwrap();
431        assert!(enum_def.contains("pub enum UserStatus"));
432        assert!(enum_def.contains("type_name = \"user_status\""));
433
434        let row_struct = result.row_struct.unwrap();
435        assert!(row_struct.contains("pub status: UserStatus"));
436    }
437
438    #[test]
439    fn test_generate_from_catalog_returns_default() {
440        let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
441        let result = generate_from_catalog(&catalog).unwrap();
442        assert!(result.query_fn.is_none());
443        assert!(result.row_struct.is_none());
444    }
445
446    #[test]
447    fn test_singularize_basic() {
448        assert_eq!(singularize("users"), "user");
449        assert_eq!(singularize("orders"), "order");
450        assert_eq!(singularize("posts"), "post");
451    }
452
453    #[test]
454    fn test_singularize_ies() {
455        assert_eq!(singularize("categories"), "category");
456        assert_eq!(singularize("entries"), "entry");
457    }
458
459    #[test]
460    fn test_singularize_sses() {
461        assert_eq!(singularize("addresses"), "address");
462        assert_eq!(singularize("classes"), "class");
463    }
464
465    #[test]
466    fn test_singularize_no_change() {
467        assert_eq!(singularize("status"), "statu");
468        assert_eq!(singularize("boss"), "boss");
469        assert_eq!(singularize("address"), "address");
470    }
471
472    #[test]
473    fn test_singularize_shes_ches_xes() {
474        assert_eq!(singularize("batches"), "batch");
475        assert_eq!(singularize("boxes"), "box");
476        assert_eq!(singularize("wishes"), "wish");
477    }
478
479    #[test]
480    fn test_tokio_postgres_backend_basic() {
481        let backend = get_backend("tokio-postgres", "postgresql").unwrap();
482
483        let query = make_query(
484            "ListUsers",
485            QueryCommand::Many,
486            "SELECT id, name FROM users",
487            vec![
488                AnalyzedColumn {
489                    name: "id".to_string(),
490                    neutral_type: "int32".to_string(),
491                    nullable: false,
492                },
493                AnalyzedColumn {
494                    name: "name".to_string(),
495                    neutral_type: "string".to_string(),
496                    nullable: false,
497                },
498            ],
499            vec![],
500        );
501
502        let result = generate_with_backend(&query, &*backend).unwrap();
503
504        let row_struct = result.row_struct.unwrap();
505        assert!(row_struct.contains("pub struct ListUsersRow"));
506        assert!(row_struct.contains("pub id: i32"));
507        assert!(row_struct.contains("pub name: String"));
508        assert!(row_struct.contains("from_row"));
509        assert!(row_struct.contains("tokio_postgres::Row"));
510        // Should NOT contain sqlx
511        assert!(!row_struct.contains("sqlx"));
512
513        let query_fn = result.query_fn.unwrap();
514        assert!(query_fn.contains("pub async fn list_users("));
515        assert!(query_fn.contains("tokio_postgres::GenericClient"));
516        assert!(query_fn.contains("tokio_postgres::Error"));
517        assert!(!query_fn.contains("sqlx"));
518    }
519
520    #[test]
521    fn test_tokio_postgres_enum() {
522        let backend = get_backend("tokio-postgres", "postgresql").unwrap();
523
524        let enum_info = scythe_core::analyzer::EnumInfo {
525            sql_name: "user_status".to_string(),
526            values: vec!["active".to_string(), "inactive".to_string()],
527        };
528
529        let def = backend.generate_enum_def(&enum_info).unwrap();
530        assert!(def.contains("pub enum UserStatus"));
531        assert!(def.contains("Active"));
532        assert!(def.contains("Inactive"));
533        assert!(def.contains("impl std::fmt::Display"));
534        assert!(def.contains("impl std::str::FromStr"));
535        // Should NOT contain sqlx
536        assert!(!def.contains("sqlx"));
537    }
538}