scythe-codegen 0.6.9

Polyglot code generation backends for scythe
Documentation
use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
};
use scythe_backend::types::resolve_type;
use std::fmt::Write;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::backends::typescript_common::{TsRowType, generate_zod_enum, generate_zod_row_struct};
use crate::singularize;

const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-pg.toml");
const DEFAULT_MANIFEST_REDSHIFT: &str = include_str!("../../manifests/typescript-pg.redshift.toml");

pub struct TypescriptPgBackend {
    manifest: BackendManifest,
    row_type: TsRowType,
}

impl TypescriptPgBackend {
    pub fn new(engine: &str) -> Result<Self, ScytheError> {
        let default_toml = match engine {
            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_TOML,
            "redshift" => DEFAULT_MANIFEST_REDSHIFT,
            _ => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    format!(
                        "typescript-pg only supports PostgreSQL/Redshift, got engine '{}'",
                        engine
                    ),
                ));
            }
        };
        let manifest =
            super::load_or_default_manifest("backends/typescript-pg/manifest.toml", default_toml)?;
        Ok(Self {
            manifest,
            row_type: TsRowType::default(),
        })
    }
}

impl CodegenBackend for TypescriptPgBackend {
    fn name(&self) -> &str {
        "typescript-pg"
    }

    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
        &self.manifest
    }

    fn supported_engines(&self) -> &[&str] {
        &["postgresql", "redshift"]
    }

    fn file_header(&self) -> String {
        let mut header =
            "/** Auto-generated by scythe. Do not edit. */\n\nimport type { PoolClient } from \"pg\";\n"
                .to_string();
        if self.row_type == TsRowType::Zod {
            header.push_str("import { z } from \"zod\";\n");
        }
        header
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        if self.row_type == TsRowType::Zod {
            return Ok(generate_zod_row_struct(&struct_name, query_name, columns));
        }
        let mut out = String::new();
        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
        let _ = writeln!(out, "export interface {} {{", struct_name);
        for col in columns {
            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
        }
        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let singular = singularize(table_name);
        let name = to_pascal_case(&singular);
        self.generate_row_struct(&name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        _columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let mut out = String::new();

        // Build parameter list
        let param_list = params
            .iter()
            .map(|p| format!("{}: {}", p.field_name, p.full_type))
            .collect::<Vec<_>>()
            .join(", ");
        let _sep = if param_list.is_empty() { "" } else { ", " };

        // Clean SQL — pg uses $1, $2 positional params natively
        let sql = super::clean_sql_with_optional(
            &analyzed.sql,
            &analyzed.optional_params,
            &analyzed.params,
        );

        // Build array of param values
        let _param_array: String = if params.is_empty() {
            String::new()
        } else {
            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
            format!(", [{}]", args.join(", "))
        };

        // Build inline params string for line-length checking
        let inline_params = if params.is_empty() {
            "client: PoolClient".to_string()
        } else {
            format!("client: PoolClient, {}", param_list)
        };

        // Helper: write a typed query call (with generic type annotation).
        // Biome always breaks `client.query<T>(...)` to multi-line.
        let write_typed_query = |out: &mut String,
                                 prefix: &str,
                                 type_name: &str,
                                 sql: &str,
                                 params: &[ResolvedParam]| {
            let _ = writeln!(out, "{}client.query<{}>(", prefix, type_name);
            let _ = writeln!(out, "\t\t`{}`,", sql);
            if !params.is_empty() {
                let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
                let _ = writeln!(out, "\t\t[{}],", args.join(", "));
            }
            let _ = writeln!(out, "\t);");
        };

        // Helper: write an untyped query call. Inline if short, multi-line if long.
        let write_untyped_query =
            |out: &mut String, prefix: &str, sql: &str, params: &[ResolvedParam]| {
                let param_str = if params.is_empty() {
                    String::new()
                } else {
                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
                    format!(", [{}]", args.join(", "))
                };
                let oneliner = format!("{}client.query(`{}`{});", prefix, sql, param_str);
                // Use tab width of 4 for line length estimation
                let estimated_len = oneliner.replace('\t', "    ").len();
                if estimated_len <= 80 {
                    let _ = writeln!(out, "{}", oneliner);
                } else {
                    let _ = writeln!(out, "{}client.query(", prefix);
                    let _ = writeln!(out, "\t\t`{}`,", sql);
                    if !params.is_empty() {
                        let args: Vec<String> =
                            params.iter().map(|p| p.field_name.clone()).collect();
                        let _ = writeln!(out, "\t\t[{}],", args.join(", "));
                    }
                    let _ = writeln!(out, "\t);");
                }
            };

        // Helper: write function signature, inline or multi-line based on length
        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
            let oneliner = format!(
                "export async function {}({}): {} {{",
                name, params_inline, ret
            );
            if oneliner.len() <= 80 {
                let _ = writeln!(out, "{}", oneliner);
            } else {
                let mut parts = vec!["\tclient: PoolClient".to_string()];
                for p in params {
                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
                }
                let _ = writeln!(out, "export async function {}(", name);
                for part in &parts {
                    let _ = writeln!(out, "{},", part);
                }
                let _ = writeln!(out, "): {} {{", ret);
            }
        };

        match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => {
                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
                let ret = format!("Promise<{} | null>", struct_name);
                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
                write_typed_query(
                    &mut out,
                    "\tconst { rows } = await ",
                    struct_name,
                    &sql,
                    params,
                );
                let _ = writeln!(out, "\treturn rows[0] ?? null;");
                let _ = write!(out, "}}");
            }
            QueryCommand::Many => {
                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
                let ret = format!("Promise<{}[]>", struct_name);
                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
                write_typed_query(
                    &mut out,
                    "\tconst { rows } = await ",
                    struct_name,
                    &sql,
                    params,
                );
                let _ = writeln!(out, "\treturn rows;");
                let _ = write!(out, "}}");
            }
            QueryCommand::Batch => {
                let batch_fn_name = format!("{}Batch", func_name);
                // Build params interface
                if params.len() > 1 {
                    let params_type_name = format!("{}BatchParams", struct_name);
                    let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
                    let _ = writeln!(out, "export interface {} {{", params_type_name);
                    for p in params {
                        let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
                    }
                    let _ = writeln!(out, "}}");
                    let _ = writeln!(out);
                    let _ = writeln!(
                        out,
                        "/** Execute {} for each item in the batch within a transaction. */",
                        analyzed.name
                    );
                    let batch_params = format!("client: PoolClient, items: {}[]", params_type_name);
                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
                    let _ = writeln!(out, "\ttry {{");
                    let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
                    let _ = writeln!(out, "\t\tfor (const item of items) {{");
                    let _ = writeln!(out, "\t\t\tawait client.query(");
                    let _ = writeln!(out, "\t\t\t\t`{}`,", sql);
                    let args: Vec<String> = params
                        .iter()
                        .map(|p| format!("item.{}", p.field_name))
                        .collect();
                    let _ = writeln!(out, "\t\t\t\t[{}],", args.join(", "));
                    let _ = writeln!(out, "\t\t\t);");
                    let _ = writeln!(out, "\t\t}}");
                    let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
                    let _ = writeln!(out, "\t}} catch (error) {{");
                    let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
                    let _ = writeln!(out, "\t\tthrow error;");
                    let _ = writeln!(out, "\t}}");
                    let _ = write!(out, "}}");
                } else if params.len() == 1 {
                    let _ = writeln!(
                        out,
                        "/** Execute {} for each item in the batch within a transaction. */",
                        analyzed.name
                    );
                    let batch_params =
                        format!("client: PoolClient, items: {}[]", params[0].full_type);
                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
                    let _ = writeln!(out, "\ttry {{");
                    let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
                    let _ = writeln!(out, "\t\tfor (const item of items) {{");
                    let _ = writeln!(out, "\t\t\tawait client.query(`{}`, [item]);", sql);
                    let _ = writeln!(out, "\t\t}}");
                    let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
                    let _ = writeln!(out, "\t}} catch (error) {{");
                    let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
                    let _ = writeln!(out, "\t\tthrow error;");
                    let _ = writeln!(out, "\t}}");
                    let _ = write!(out, "}}");
                } else {
                    let _ = writeln!(
                        out,
                        "/** Execute {} for each item in the batch within a transaction. */",
                        analyzed.name
                    );
                    write_fn_sig(
                        &mut out,
                        &batch_fn_name,
                        "client: PoolClient, count: number",
                        "Promise<void>",
                    );
                    let _ = writeln!(out, "\ttry {{");
                    let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
                    let _ = writeln!(out, "\t\tfor (let i = 0; i < count; i++) {{");
                    let _ = writeln!(out, "\t\t\tawait client.query(`{}`);", sql);
                    let _ = writeln!(out, "\t\t}}");
                    let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
                    let _ = writeln!(out, "\t}} catch (error) {{");
                    let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
                    let _ = writeln!(out, "\t\tthrow error;");
                    let _ = writeln!(out, "\t}}");
                    let _ = write!(out, "}}");
                }
            }
            QueryCommand::Exec => {
                let _ = writeln!(out, "/** Execute a query returning no rows. */");
                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
                write_untyped_query(&mut out, "\tawait ", &sql, params);
                let _ = write!(out, "}}");
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(
                    out,
                    "/** Execute a query and return the number of affected rows. */"
                );
                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
                write_untyped_query(&mut out, "\tconst result = await ", &sql, params);
                let _ = writeln!(out, "\treturn result.rowCount ?? 0;");
                let _ = write!(out, "}}");
            }
            QueryCommand::Grouped => {
                unreachable!("Grouped is rewritten to Many before codegen")
            }
        }

        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        if self.row_type == TsRowType::Zod {
            return Ok(generate_zod_enum(&type_name, &enum_info.values));
        }
        let mut out = String::new();
        let values_name = format!("{}Values", type_name);
        let _ = writeln!(out, "export const {} = {{", values_name);
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(out, "\t{}: \"{}\",", variant, value);
        }
        let _ = writeln!(out, "}} as const;");
        let _ = writeln!(out);
        let _ = write!(
            out,
            "export type {} = typeof {}[keyof typeof {}];",
            type_name, values_name, values_name
        );
        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let name = to_pascal_case(&composite.sql_name);
        let mut out = String::new();
        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
        let _ = writeln!(out, "export interface {} {{", name);
        if composite.fields.is_empty() {
            // empty interface
        } else {
            for field in &composite.fields {
                let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
                    .map(|t| t.into_owned())
                    .map_err(|e| {
                        ScytheError::new(
                            ErrorCode::InternalError,
                            format!("composite field type error: {}", e),
                        )
                    })?;
                let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
            }
        }
        let _ = write!(out, "}}");
        Ok(out)
    }

    fn apply_options(
        &mut self,
        options: &std::collections::HashMap<String, String>,
    ) -> Result<(), ScytheError> {
        if let Some(value) = options.get("row_type") {
            self.row_type = TsRowType::from_option(value)?;
        }
        Ok(())
    }
}