scythe-codegen 0.6.9

Polyglot code generation backends for scythe
Documentation
use std::fmt::Write;

use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
};
use scythe_backend::types::resolve_type;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};

pub struct GoDatabaseSqlBackend {
    manifest: BackendManifest,
    engine: String,
}

impl GoDatabaseSqlBackend {
    pub fn new(engine: &str) -> Result<Self, ScytheError> {
        let manifest_toml = match engine {
            "mysql" => include_str!("../../manifests/go-database-sql.mysql.toml"),
            "mariadb" => include_str!("../../manifests/go-database-sql.mariadb.toml"),
            "mssql" => include_str!("../../manifests/go-database-sql.mssql.toml"),
            "sqlite" | "sqlite3" => include_str!("../../manifests/go-database-sql.sqlite.toml"),
            "duckdb" => include_str!("../../manifests/go-database-sql.duckdb.toml"),
            _ => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    format!(
                        "go-database-sql supports MySQL, MSSQL, SQLite, and DuckDB, got engine '{}'",
                        engine
                    ),
                ));
            }
        };
        let manifest = super::load_or_default_manifest(
            "backends/go-database-sql/manifest.toml",
            manifest_toml,
        )?;
        Ok(Self {
            manifest,
            engine: engine.to_string(),
        })
    }
}

impl CodegenBackend for GoDatabaseSqlBackend {
    fn name(&self) -> &str {
        "go-database-sql"
    }

    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
        &self.manifest
    }

    fn supported_engines(&self) -> &[&str] {
        &["mysql", "mariadb", "mssql", "sqlite", "duckdb"]
    }

    fn file_header(&self) -> String {
        // TODO: determine uses_time from actual column types instead of engine
        let uses_time = matches!(
            self.engine.as_str(),
            "mysql" | "mariadb" | "duckdb" | "mssql" | "snowflake"
        );
        let mut header =
            String::from("package queries\n\nimport (\n\t\"context\"\n\t\"database/sql\"");
        if uses_time {
            header.push_str("\n\t\"time\"");
        }
        header.push_str("\n)\n");
        header
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "type {} struct {{", struct_name);
        for col in columns {
            let field = to_pascal_case(&col.field_name);
            let json_tag = &col.field_name;
            let _ = writeln!(out, "\t{} {} `json:\"{}\"`", field, col.full_type, json_tag);
        }
        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let name = to_pascal_case(table_name);
        self.generate_row_struct(&name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let mut sql = super::clean_sql_oneline_with_optional(
            &analyzed.sql,
            &analyzed.optional_params,
            &analyzed.params,
        );
        // MSSQL requires @pN placeholders instead of ?
        if self.engine == "mssql" {
            sql = super::rewrite_pg_placeholders(&sql, |n| format!("@p{n}"));
        }

        let param_list = params
            .iter()
            .map(|p| {
                let field = to_pascal_case(&p.field_name);
                format!("{} {}", field, p.full_type)
            })
            .collect::<Vec<_>>()
            .join(", ");
        let sep = if param_list.is_empty() { "" } else { ", " };

        let args = params
            .iter()
            .map(|p| to_pascal_case(&p.field_name).into_owned())
            .collect::<Vec<_>>();

        let mut out = String::new();

        match &analyzed.command {
            QueryCommand::Exec => {
                let _ = writeln!(
                    out,
                    "func {}(ctx context.Context, db *sql.DB{}{}) error {{",
                    func_name, sep, param_list
                );
                let args_str = if args.is_empty() {
                    String::new()
                } else {
                    format!(", {}", args.join(", "))
                };
                let _ = writeln!(
                    out,
                    "\t_, err := db.ExecContext(ctx, \"{}\"{})",
                    sql, args_str
                );
                let _ = writeln!(out, "\treturn err");
                let _ = write!(out, "}}");
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(
                    out,
                    "func {}(ctx context.Context, db *sql.DB{}{}) (int64, error) {{",
                    func_name, sep, param_list
                );
                let args_str = if args.is_empty() {
                    String::new()
                } else {
                    format!(", {}", args.join(", "))
                };
                let _ = writeln!(
                    out,
                    "\tresult, err := db.ExecContext(ctx, \"{}\"{})",
                    sql, args_str
                );
                let _ = writeln!(out, "\tif err != nil {{");
                let _ = writeln!(out, "\t\treturn 0, err");
                let _ = writeln!(out, "\t}}");
                let _ = writeln!(out, "\treturn result.RowsAffected()");
                let _ = write!(out, "}}");
            }
            QueryCommand::One | QueryCommand::Opt => {
                let _ = writeln!(
                    out,
                    "func {}(ctx context.Context, db *sql.DB{}{}) ({}, error) {{",
                    func_name, sep, param_list, struct_name
                );
                let args_str = if args.is_empty() {
                    String::new()
                } else {
                    format!(", {}", args.join(", "))
                };
                let _ = writeln!(
                    out,
                    "\trow := db.QueryRowContext(ctx, \"{}\"{})",
                    sql, args_str
                );
                let _ = writeln!(out, "\tvar r {}", struct_name);
                let scan_fields: Vec<String> = columns
                    .iter()
                    .map(|c| format!("&r.{}", to_pascal_case(&c.field_name)))
                    .collect();
                let _ = writeln!(out, "\terr := row.Scan({})", scan_fields.join(", "));
                let _ = writeln!(out, "\treturn r, err");
                let _ = write!(out, "}}");
            }
            QueryCommand::Batch => {
                let batch_fn_name = format!("{}Batch", func_name);
                if params.len() > 1 {
                    let params_struct_name = format!("{}BatchParams", func_name);
                    let _ = writeln!(out, "type {} struct {{", params_struct_name);
                    for p in params {
                        let field = to_pascal_case(&p.field_name);
                        let _ = writeln!(out, "\t{} {}", field, p.full_type);
                    }
                    let _ = writeln!(out, "}}");
                    let _ = writeln!(out);
                    let _ = writeln!(
                        out,
                        "func {}(ctx context.Context, db *sql.DB, items []{}) error {{",
                        batch_fn_name, params_struct_name
                    );
                } else if params.len() == 1 {
                    let _ = writeln!(
                        out,
                        "func {}(ctx context.Context, db *sql.DB, items []{}) error {{",
                        batch_fn_name, params[0].full_type
                    );
                } else {
                    let _ = writeln!(
                        out,
                        "func {}(ctx context.Context, db *sql.DB, count int) error {{",
                        batch_fn_name
                    );
                }
                let _ = writeln!(out, "\ttx, err := db.BeginTx(ctx, nil)");
                let _ = writeln!(out, "\tif err != nil {{");
                let _ = writeln!(out, "\t\treturn err");
                let _ = writeln!(out, "\t}}");
                let _ = writeln!(out, "\tdefer tx.Rollback()");
                if params.is_empty() {
                    let _ = writeln!(out, "\tfor i := 0; i < count; i++ {{");
                    let _ = writeln!(out, "\t\t_, err := tx.ExecContext(ctx, \"{}\")", sql);
                } else {
                    let _ = writeln!(out, "\tfor _, item := range items {{");
                    if params.len() > 1 {
                        let item_args: Vec<String> = params
                            .iter()
                            .map(|p| format!("item.{}", to_pascal_case(&p.field_name)))
                            .collect();
                        let _ = writeln!(
                            out,
                            "\t\t_, err := tx.ExecContext(ctx, \"{}\", {})",
                            sql,
                            item_args.join(", ")
                        );
                    } else {
                        let _ =
                            writeln!(out, "\t\t_, err := tx.ExecContext(ctx, \"{}\", item)", sql);
                    }
                }
                let _ = writeln!(out, "\t\tif err != nil {{");
                let _ = writeln!(out, "\t\t\treturn err");
                let _ = writeln!(out, "\t\t}}");
                let _ = writeln!(out, "\t}}");
                let _ = writeln!(out, "\treturn tx.Commit()");
                let _ = write!(out, "}}");
            }
            QueryCommand::Many => {
                let _ = writeln!(
                    out,
                    "func {}(ctx context.Context, db *sql.DB{}{}) ([]{}, error) {{",
                    func_name, sep, param_list, struct_name
                );
                let args_str = if args.is_empty() {
                    String::new()
                } else {
                    format!(", {}", args.join(", "))
                };
                let _ = writeln!(
                    out,
                    "\trows, err := db.QueryContext(ctx, \"{}\"{})",
                    sql, args_str
                );
                let _ = writeln!(out, "\tif err != nil {{");
                let _ = writeln!(out, "\t\treturn nil, err");
                let _ = writeln!(out, "\t}}");
                let _ = writeln!(out, "\tdefer rows.Close()");
                let _ = writeln!(out, "\tvar result []{}", struct_name);
                let _ = writeln!(out, "\tfor rows.Next() {{");
                let _ = writeln!(out, "\t\tvar r {}", struct_name);
                let scan_fields: Vec<String> = columns
                    .iter()
                    .map(|c| format!("&r.{}", to_pascal_case(&c.field_name)))
                    .collect();
                let _ = writeln!(
                    out,
                    "\t\tif err := rows.Scan({}); err != nil {{",
                    scan_fields.join(", ")
                );
                let _ = writeln!(out, "\t\t\treturn nil, err");
                let _ = writeln!(out, "\t\t}}");
                let _ = writeln!(out, "\t\tresult = append(result, r)");
                let _ = writeln!(out, "\t}}");
                let _ = writeln!(out, "\treturn result, rows.Err()");
                let _ = write!(out, "}}");
            }
            QueryCommand::Grouped => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    "Grouped queries should be rewritten before codegen".to_string(),
                ));
            }
        }

        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "type {} string", type_name);
        let _ = writeln!(out);
        let _ = writeln!(out, "const (");
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(
                out,
                "\t{}{} {} = \"{}\"",
                type_name, variant, type_name, value
            );
        }
        let _ = write!(out, ")");
        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let name = to_pascal_case(&composite.sql_name);
        let mut out = String::new();
        let _ = writeln!(out, "type {} struct {{", name);
        if !composite.fields.is_empty() {
            for field in &composite.fields {
                let field_name = to_pascal_case(&field.name);
                let go_type = resolve_type(&field.neutral_type, &self.manifest, false)
                    .map(|t| t.into_owned())
                    .unwrap_or_else(|_| "any".to_string());
                let json_tag = &field.name;
                let _ = writeln!(out, "\t{} {} `json:\"{}\"`", field_name, go_type, json_tag);
            }
        }
        let _ = write!(out, "}}");
        Ok(out)
    }
}