Skip to main content

scythe_codegen/
lib.rs

1pub mod backend_trait;
2pub mod backends;
3pub mod overrides;
4pub mod resolve;
5pub mod validation;
6
7pub use backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
8pub use backends::get_backend;
9pub use overrides::TypeOverride;
10
11use scythe_backend::manifest::BackendManifest;
12use scythe_backend::naming::{row_struct_name, to_pascal_case};
13
14use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
15use scythe_core::catalog::Catalog;
16use scythe_core::errors::ScytheError;
17use scythe_core::parser::QueryCommand;
18
19// ---------------------------------------------------------------------------
20// Output types
21// ---------------------------------------------------------------------------
22
23#[derive(Debug, Default)]
24pub struct GeneratedCode {
25    pub query_fn: Option<String>,
26    pub row_struct: Option<String>,
27    pub model_struct: Option<String>,
28    pub enum_def: Option<String>,
29}
30
31// ---------------------------------------------------------------------------
32// Utility (shared across backends)
33// ---------------------------------------------------------------------------
34
35/// Simple singularization: remove trailing 's'.
36pub(crate) fn singularize(name: &str) -> String {
37    if let Some(stem) = name.strip_suffix("ies") {
38        format!("{stem}y")
39    } else if name.ends_with("sses")
40        || name.ends_with("shes")
41        || name.ends_with("ches")
42        || name.ends_with("xes")
43        || name.ends_with("zes")
44        || name.ends_with("ses")
45    {
46        name[..name.len() - 2].to_string()
47    } else if name.ends_with('s') && !name.ends_with("ss") {
48        name[..name.len() - 1].to_string()
49    } else {
50        name.to_string()
51    }
52}
53
54// ---------------------------------------------------------------------------
55// Manifest helpers
56// ---------------------------------------------------------------------------
57
58/// Get the manifest for a backend. Defaults to PostgreSQL engine.
59pub fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
60    let backend = get_backend(backend_name, "postgresql")?;
61    Ok(backend.manifest().clone())
62}
63
64/// Determine the struct name for a query (model struct or row struct).
65fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
66    if let Some(ref table_name) = analyzed.source_table {
67        let singular = singularize(table_name);
68        to_pascal_case(&singular).into_owned()
69    } else {
70        row_struct_name(&analyzed.name, &manifest.naming)
71    }
72}
73
74// ---------------------------------------------------------------------------
75// Public API
76// ---------------------------------------------------------------------------
77
78/// Generate code using a specific backend.
79pub fn generate_with_backend(
80    analyzed: &AnalyzedQuery,
81    backend: &dyn CodegenBackend,
82) -> Result<GeneratedCode, ScytheError> {
83    generate_with_backend_and_overrides(analyzed, backend, &[])
84}
85
86/// Generate code using a specific backend with type overrides.
87pub fn generate_with_backend_and_overrides(
88    analyzed: &AnalyzedQuery,
89    backend: &dyn CodegenBackend,
90    overrides: &[TypeOverride],
91) -> Result<GeneratedCode, ScytheError> {
92    let manifest = backend.manifest();
93    let source_table = analyzed.source_table.as_deref().unwrap_or("");
94    let columns = resolve::resolve_columns(&analyzed.columns, manifest, overrides, source_table)?;
95    let params = resolve::resolve_params(&analyzed.params, manifest, overrides, source_table)?;
96
97    let mut result = GeneratedCode::default();
98
99    // Generate enum definitions for any enum-typed columns
100    // Use the backend-specific enum generation for proper derives
101    let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
102    if !enum_def.is_empty() {
103        result.enum_def = Some(enum_def);
104    }
105
106    // Generate row/model struct for :one and :many commands (not :batch)
107    let needs_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
108    if needs_row_struct && !analyzed.columns.is_empty() {
109        if let Some(ref table_name) = analyzed.source_table {
110            result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
111        } else {
112            result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
113        }
114    }
115
116    // Generate composite type definitions
117    if !analyzed.composites.is_empty() {
118        let mut comp_defs = String::new();
119        for (i, comp) in analyzed.composites.iter().enumerate() {
120            if i > 0 {
121                comp_defs.push_str("\n\n");
122            }
123            comp_defs.push_str(&backend.generate_composite_def(comp)?);
124        }
125        if !comp_defs.is_empty() {
126            if let Some(ref mut existing) = result.model_struct {
127                existing.push_str("\n\n");
128                existing.push_str(&comp_defs);
129            } else {
130                result.model_struct = Some(comp_defs);
131            }
132        }
133    }
134
135    // Generate query function
136    let struct_name = determine_struct_name(analyzed, manifest);
137    result.query_fn = Some(backend.generate_query_fn(analyzed, &struct_name, &columns, &params)?);
138
139    Ok(result)
140}
141
142/// Generate enum definitions via the backend trait.
143fn generate_enum_defs_via_backend(
144    analyzed: &AnalyzedQuery,
145    backend: &dyn CodegenBackend,
146) -> Result<String, ScytheError> {
147    use ahash::AHashSet;
148    use std::fmt::Write;
149
150    let mut out = String::new();
151    let mut seen_enums: AHashSet<String> = AHashSet::new();
152
153    let enum_sources: Vec<&str> = analyzed
154        .columns
155        .iter()
156        .filter_map(|col| col.neutral_type.strip_prefix("enum::"))
157        .chain(
158            analyzed
159                .params
160                .iter()
161                .filter_map(|p| p.neutral_type.strip_prefix("enum::")),
162        )
163        .collect();
164
165    for sql_name in enum_sources {
166        if !seen_enums.insert(sql_name.to_string()) {
167            continue;
168        }
169
170        if !out.is_empty() {
171            let _ = writeln!(out);
172        }
173
174        if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
175            out.push_str(&backend.generate_enum_def(enum_info)?);
176        } else {
177            // Generate a stub enum with no variants (for enum types referenced but
178            // not fully defined in the query's EnumInfo list).
179            let stub_info = EnumInfo {
180                sql_name: sql_name.to_string(),
181                values: vec![],
182            };
183            out.push_str(&backend.generate_enum_def(&stub_info)?);
184        }
185    }
186
187    Ok(out)
188}
189
190/// Backward-compatible: generate code using the default sqlx backend.
191pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
192    let backend = get_backend("rust-sqlx", "postgresql")?;
193    generate_with_backend(analyzed, &*backend)
194}
195
196/// Stub for catalog-level codegen. Returns default for now.
197pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
198    Ok(GeneratedCode::default())
199}
200
201/// Generate a single enum definition using a specific backend.
202pub fn generate_single_enum_def_with_backend(
203    enum_info: &EnumInfo,
204    backend: &dyn CodegenBackend,
205) -> Result<String, ScytheError> {
206    backend.generate_enum_def(enum_info)
207}
208
209/// Backward-compatible: generate a single enum definition (sqlx backend).
210/// Uses the manifest directly for backward compatibility with existing callers.
211pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
212    // Reproduce the old behavior exactly using the sqlx backend's logic
213    use scythe_backend::naming::{enum_type_name, enum_variant_name};
214    use std::fmt::Write;
215
216    let mut out = String::with_capacity(256);
217    let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
218
219    let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
220    let _ = writeln!(
221        out,
222        "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
223        enum_info.sql_name
224    );
225    let _ = writeln!(out, "pub enum {type_name} {{");
226
227    for value in &enum_info.values {
228        let variant = enum_variant_name(value, &manifest.naming);
229        let _ = writeln!(out, "    {variant},");
230    }
231
232    let _ = write!(out, "}}");
233    out
234}
235
236/// Backward-compatible: load the default sqlx manifest.
237pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
238    let b = backends::sqlx::SqlxBackend::new("postgresql")?;
239    Ok(b.manifest().clone())
240}
241
242// ---------------------------------------------------------------------------
243// Tests
244// ---------------------------------------------------------------------------
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
250    use scythe_core::parser::QueryCommand;
251
252    fn make_query(
253        name: &str,
254        command: QueryCommand,
255        sql: &str,
256        columns: Vec<AnalyzedColumn>,
257        params: Vec<AnalyzedParam>,
258    ) -> AnalyzedQuery {
259        AnalyzedQuery {
260            name: name.to_string(),
261            command,
262            sql: sql.to_string(),
263            columns,
264            params,
265            deprecated: None,
266            source_table: None,
267            composites: Vec::new(),
268            enums: Vec::new(),
269            optional_params: Vec::new(),
270        }
271    }
272
273    #[test]
274    fn test_generate_select_many() {
275        let query = make_query(
276            "ListUsers",
277            QueryCommand::Many,
278            "SELECT id, name, email FROM users",
279            vec![
280                AnalyzedColumn {
281                    name: "id".to_string(),
282                    neutral_type: "int32".to_string(),
283                    nullable: false,
284                },
285                AnalyzedColumn {
286                    name: "name".to_string(),
287                    neutral_type: "string".to_string(),
288                    nullable: false,
289                },
290                AnalyzedColumn {
291                    name: "email".to_string(),
292                    neutral_type: "string".to_string(),
293                    nullable: true,
294                },
295            ],
296            vec![],
297        );
298
299        let result = generate(&query).unwrap();
300
301        let row_struct = result.row_struct.unwrap();
302        assert!(row_struct.contains("pub struct ListUsersRow"));
303        assert!(row_struct.contains("pub id: i32"));
304        assert!(row_struct.contains("pub name: String"));
305        assert!(row_struct.contains("pub email: Option<String>"));
306
307        let query_fn = result.query_fn.unwrap();
308        assert!(query_fn.contains("pub async fn list_users("));
309        assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
310        assert!(query_fn.contains(".fetch_all(pool)"));
311    }
312
313    #[test]
314    fn test_generate_select_one_with_param() {
315        let query = make_query(
316            "GetUser",
317            QueryCommand::One,
318            "SELECT id, name FROM users WHERE id = $1",
319            vec![
320                AnalyzedColumn {
321                    name: "id".to_string(),
322                    neutral_type: "int32".to_string(),
323                    nullable: false,
324                },
325                AnalyzedColumn {
326                    name: "name".to_string(),
327                    neutral_type: "string".to_string(),
328                    nullable: false,
329                },
330            ],
331            vec![AnalyzedParam {
332                name: "id".to_string(),
333                neutral_type: "int32".to_string(),
334                nullable: false,
335                position: 1,
336            }],
337        );
338
339        let result = generate(&query).unwrap();
340
341        let query_fn = result.query_fn.unwrap();
342        assert!(query_fn.contains("pub async fn get_user("));
343        assert!(query_fn.contains("id: i32"));
344        assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
345        assert!(query_fn.contains(".fetch_one(pool)"));
346    }
347
348    #[test]
349    fn test_generate_exec() {
350        let query = make_query(
351            "DeleteUser",
352            QueryCommand::Exec,
353            "DELETE FROM users WHERE id = $1",
354            vec![],
355            vec![AnalyzedParam {
356                name: "id".to_string(),
357                neutral_type: "int32".to_string(),
358                nullable: false,
359                position: 1,
360            }],
361        );
362
363        let result = generate(&query).unwrap();
364
365        assert!(result.row_struct.is_none());
366
367        let query_fn = result.query_fn.unwrap();
368        assert!(query_fn.contains("pub async fn delete_user("));
369        assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
370        assert!(query_fn.contains(".execute(pool)"));
371    }
372
373    #[test]
374    fn test_generate_with_enum_column() {
375        let query = make_query(
376            "GetUserStatus",
377            QueryCommand::One,
378            "SELECT id, status FROM users WHERE id = $1",
379            vec![
380                AnalyzedColumn {
381                    name: "id".to_string(),
382                    neutral_type: "int32".to_string(),
383                    nullable: false,
384                },
385                AnalyzedColumn {
386                    name: "status".to_string(),
387                    neutral_type: "enum::user_status".to_string(),
388                    nullable: false,
389                },
390            ],
391            vec![AnalyzedParam {
392                name: "id".to_string(),
393                neutral_type: "int32".to_string(),
394                nullable: false,
395                position: 1,
396            }],
397        );
398
399        let result = generate(&query).unwrap();
400
401        assert!(result.enum_def.is_some());
402        let enum_def = result.enum_def.unwrap();
403        assert!(enum_def.contains("pub enum UserStatus"));
404        assert!(enum_def.contains("type_name = \"user_status\""));
405
406        let row_struct = result.row_struct.unwrap();
407        assert!(row_struct.contains("pub status: UserStatus"));
408    }
409
410    #[test]
411    fn test_generate_from_catalog_returns_default() {
412        let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
413        let result = generate_from_catalog(&catalog).unwrap();
414        assert!(result.query_fn.is_none());
415        assert!(result.row_struct.is_none());
416    }
417
418    #[test]
419    fn test_singularize_basic() {
420        assert_eq!(singularize("users"), "user");
421        assert_eq!(singularize("orders"), "order");
422        assert_eq!(singularize("posts"), "post");
423    }
424
425    #[test]
426    fn test_singularize_ies() {
427        assert_eq!(singularize("categories"), "category");
428        assert_eq!(singularize("entries"), "entry");
429    }
430
431    #[test]
432    fn test_singularize_sses() {
433        assert_eq!(singularize("addresses"), "address");
434        assert_eq!(singularize("classes"), "class");
435    }
436
437    #[test]
438    fn test_singularize_no_change() {
439        assert_eq!(singularize("status"), "statu");
440        assert_eq!(singularize("boss"), "boss");
441        assert_eq!(singularize("address"), "address");
442    }
443
444    #[test]
445    fn test_singularize_shes_ches_xes() {
446        assert_eq!(singularize("batches"), "batch");
447        assert_eq!(singularize("boxes"), "box");
448        assert_eq!(singularize("wishes"), "wish");
449    }
450
451    #[test]
452    fn test_tokio_postgres_backend_basic() {
453        let backend = get_backend("tokio-postgres", "postgresql").unwrap();
454
455        let query = make_query(
456            "ListUsers",
457            QueryCommand::Many,
458            "SELECT id, name FROM users",
459            vec![
460                AnalyzedColumn {
461                    name: "id".to_string(),
462                    neutral_type: "int32".to_string(),
463                    nullable: false,
464                },
465                AnalyzedColumn {
466                    name: "name".to_string(),
467                    neutral_type: "string".to_string(),
468                    nullable: false,
469                },
470            ],
471            vec![],
472        );
473
474        let result = generate_with_backend(&query, &*backend).unwrap();
475
476        let row_struct = result.row_struct.unwrap();
477        assert!(row_struct.contains("pub struct ListUsersRow"));
478        assert!(row_struct.contains("pub id: i32"));
479        assert!(row_struct.contains("pub name: String"));
480        assert!(row_struct.contains("from_row"));
481        assert!(row_struct.contains("tokio_postgres::Row"));
482        // Should NOT contain sqlx
483        assert!(!row_struct.contains("sqlx"));
484
485        let query_fn = result.query_fn.unwrap();
486        assert!(query_fn.contains("pub async fn list_users("));
487        assert!(query_fn.contains("tokio_postgres::Client"));
488        assert!(query_fn.contains("tokio_postgres::Error"));
489        assert!(!query_fn.contains("sqlx"));
490    }
491
492    #[test]
493    fn test_tokio_postgres_enum() {
494        let backend = get_backend("tokio-postgres", "postgresql").unwrap();
495
496        let enum_info = scythe_core::analyzer::EnumInfo {
497            sql_name: "user_status".to_string(),
498            values: vec!["active".to_string(), "inactive".to_string()],
499        };
500
501        let def = backend.generate_enum_def(&enum_info).unwrap();
502        assert!(def.contains("pub enum UserStatus"));
503        assert!(def.contains("Active"));
504        assert!(def.contains("Inactive"));
505        assert!(def.contains("impl std::fmt::Display"));
506        assert!(def.contains("impl std::str::FromStr"));
507        // Should NOT contain sqlx
508        assert!(!def.contains("sqlx"));
509    }
510}