scythe-codegen 0.6.9

Polyglot code generation backends for scythe
Documentation
use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
};
use scythe_backend::types::resolve_type;
use std::fmt::Write;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};

const DEFAULT_MANIFEST_PG: &str = include_str!("../../manifests/php-pdo.toml");
const DEFAULT_MANIFEST_MYSQL: &str = include_str!("../../manifests/php-pdo.mysql.toml");
const DEFAULT_MANIFEST_SQLITE: &str = include_str!("../../manifests/php-pdo.sqlite.toml");
const DEFAULT_MANIFEST_MSSQL: &str = include_str!("../../manifests/php-pdo.mssql.toml");
const DEFAULT_MANIFEST_REDSHIFT: &str = include_str!("../../manifests/php-pdo.redshift.toml");
const DEFAULT_MANIFEST_SNOWFLAKE: &str = include_str!("../../manifests/php-pdo.snowflake.toml");

pub struct PhpPdoBackend {
    manifest: BackendManifest,
}

impl PhpPdoBackend {
    pub fn new(engine: &str) -> Result<Self, ScytheError> {
        let default_toml = match engine {
            "postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_PG,
            "mysql" | "mariadb" => DEFAULT_MANIFEST_MYSQL,
            "sqlite" | "sqlite3" => DEFAULT_MANIFEST_SQLITE,
            "mssql" => DEFAULT_MANIFEST_MSSQL,
            "redshift" => DEFAULT_MANIFEST_REDSHIFT,
            "snowflake" => DEFAULT_MANIFEST_SNOWFLAKE,
            _ => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    format!("unsupported engine '{}' for php-pdo backend", engine),
                ));
            }
        };
        let manifest =
            super::load_or_default_manifest("backends/php-pdo/manifest.toml", default_toml)?;
        Ok(Self { manifest })
    }
}

/// Map a neutral type to a PHP cast expression.
fn php_cast(neutral_type: &str) -> &'static str {
    match neutral_type {
        "int16" | "int32" | "int64" => "(int) ",
        "float32" | "float64" => "(float) ",
        "bool" => "(bool) ",
        "string" | "json" | "inet" | "interval" | "uuid" | "decimal" | "bytes" => "(string) ",
        _ => "",
    }
}

impl CodegenBackend for PhpPdoBackend {
    fn name(&self) -> &str {
        "php-pdo"
    }

    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
        &self.manifest
    }

    fn supported_engines(&self) -> &[&str] {
        &[
            "postgresql",
            "mysql",
            "mariadb",
            "sqlite",
            "mssql",
            "redshift",
            "snowflake",
        ]
    }

    fn file_header(&self) -> String {
        "<?php\n\ndeclare(strict_types=1);\n\nnamespace App\\Generated;\n\n// Auto-generated by scythe. Do not edit.\n"
            .to_string()
    }

    fn query_class_header(&self) -> String {
        "final class Queries {".to_string()
    }

    fn file_footer(&self) -> String {
        "}".to_string()
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        let mut out = String::new();

        // Readonly class with constructor
        let _ = writeln!(out, "readonly class {} {{", struct_name);
        let _ = writeln!(out, "    public function __construct(");
        for c in columns.iter() {
            let sep = ",";
            let _ = writeln!(
                out,
                "        public {} ${}{}",
                c.full_type, c.field_name, sep
            );
        }
        let _ = writeln!(out, "    ) {{}}");
        let _ = writeln!(out);

        // fromRow factory method
        let _ = writeln!(
            out,
            "    public static function fromRow(array $row): self {{"
        );
        let _ = writeln!(out, "        return new self(");
        for c in columns.iter() {
            let sep = ",";
            let is_enum = c.neutral_type.starts_with("enum::");
            let is_datetime = matches!(
                c.neutral_type.as_str(),
                "date" | "time" | "time_tz" | "datetime" | "datetime_tz"
            );
            if is_enum {
                // Enum columns: convert DB string to PHP backed enum via ::from()
                let enum_type = &c.lang_type;
                if c.nullable {
                    let _ = writeln!(
                        out,
                        "            {}: $row['{}'] !== null ? {}::from($row['{}']) : null{}",
                        c.field_name, c.name, enum_type, c.name, sep
                    );
                } else {
                    let _ = writeln!(
                        out,
                        "            {}: {}::from($row['{}']){}",
                        c.field_name, enum_type, c.name, sep
                    );
                }
            } else if is_datetime {
                // DateTime columns: PDO returns strings, wrap in DateTimeImmutable
                if c.nullable {
                    let _ = writeln!(
                        out,
                        "            {}: $row['{}'] !== null ? new \\DateTimeImmutable($row['{}']) : null{}",
                        c.field_name, c.name, c.name, sep
                    );
                } else {
                    let _ = writeln!(
                        out,
                        "            {}: new \\DateTimeImmutable($row['{}']){}",
                        c.field_name, c.name, sep
                    );
                }
            } else {
                let cast = php_cast(&c.neutral_type);
                if c.nullable {
                    let _ = writeln!(
                        out,
                        "            {}: $row['{}'] !== null ? {}{} : null{}",
                        c.field_name,
                        c.name,
                        cast,
                        format_args!("$row['{}']", c.name),
                        sep
                    );
                } else {
                    let _ = writeln!(
                        out,
                        "            {}: {}$row['{}']{}",
                        c.field_name, cast, c.name, sep
                    );
                }
            }
        }
        let _ = writeln!(out, "        );");
        let _ = writeln!(out, "    }}");
        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let name = to_pascal_case(table_name);
        self.generate_row_struct(&name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        _columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let sql = super::rewrite_pg_placeholders(
            &super::clean_sql_oneline_with_optional(
                &analyzed.sql,
                &analyzed.optional_params,
                &analyzed.params,
            ),
            |n| format!(":p{n}"),
        );
        let mut out = String::new();

        // Handle :batch separately
        if matches!(analyzed.command, QueryCommand::Batch) {
            let batch_fn_name = format!("{}Batch", func_name);
            // PHPDoc for batch function
            let _ = writeln!(out, "    /**");
            let _ = writeln!(out, "     * @param \\PDO $pdo");
            let _ = writeln!(out, "     * @param array<int, array<int, mixed>> $items");
            let _ = writeln!(out, "     * @return void");
            let _ = writeln!(out, "     */");
            let _ = writeln!(
                out,
                "    public static function {}(\\PDO $pdo, array $items): void {{",
                batch_fn_name
            );
            let _ = writeln!(out, "        $stmt = $pdo->prepare(\"{}\");", sql);
            let _ = writeln!(out, "        $pdo->beginTransaction();");
            let _ = writeln!(out, "        try {{");
            let _ = writeln!(out, "            foreach ($items as $item) {{");
            if params.is_empty() {
                let _ = writeln!(out, "                $stmt->execute();");
            } else {
                let use_positional = sql.contains('?');
                if use_positional {
                    let _ = writeln!(out, "                $stmt->execute($item);");
                } else {
                    // Named params — build mapping from item array
                    let bindings = params
                        .iter()
                        .enumerate()
                        .map(|(i, _p)| format!("\"p{}\" => $item[{}]", i + 1, i))
                        .collect::<Vec<_>>()
                        .join(", ");
                    let _ = writeln!(out, "                $stmt->execute([{}]);", bindings);
                }
            }
            let _ = writeln!(out, "            }}");
            let _ = writeln!(out, "            $pdo->commit();");
            let _ = writeln!(out, "        }} catch (\\Throwable $e) {{");
            let _ = writeln!(out, "            $pdo->rollBack();");
            let _ = writeln!(out, "            throw $e;");
            let _ = writeln!(out, "        }}");
            let _ = write!(out, "    }}");
            return Ok(out);
        }

        // Build PHP parameter list
        let param_list = params
            .iter()
            .map(|p| format!("{} ${}", p.full_type, p.field_name))
            .collect::<Vec<_>>()
            .join(", ");
        let sep = if param_list.is_empty() { "" } else { ", " };

        // Return type depends on command
        let return_type = match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => format!("?{}", struct_name),
            QueryCommand::Many => "\\Generator".to_string(),
            QueryCommand::Exec => "void".to_string(),
            QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
        };

        // PHPDoc block
        let _ = writeln!(out, "    /**");
        let _ = writeln!(out, "     * @param \\PDO $pdo");
        for p in params {
            let _ = writeln!(out, "     * @param {} ${}", p.full_type, p.field_name);
        }
        match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => {
                let _ = writeln!(out, "     * @return {}|null", struct_name);
            }
            QueryCommand::Many => {
                let _ = writeln!(
                    out,
                    "     * @return \\Generator<int, {}, mixed, void>",
                    struct_name
                );
            }
            QueryCommand::Exec => {
                let _ = writeln!(out, "     * @return void");
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(out, "     * @return int");
            }
            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
        }
        let _ = writeln!(out, "     */");

        let _ = writeln!(
            out,
            "    public static function {}(\\PDO $pdo{}{}): {} {{",
            func_name, sep, param_list, return_type
        );

        // Prepare statement
        let _ = writeln!(out, "        $stmt = $pdo->prepare(\"{}\");", sql);

        // Build execute params
        // If the SQL contains `?` placeholders (MySQL/SQLite), use positional array.
        // If it contains `:pN` placeholders (PostgreSQL), use named array.
        if params.is_empty() {
            let _ = writeln!(out, "        $stmt->execute();");
        } else {
            let use_positional = sql.contains('?');
            let bindings = params
                .iter()
                .enumerate()
                .map(|(i, p)| {
                    let value = if p.neutral_type.starts_with("enum::") {
                        format!("${}->value", p.field_name)
                    } else {
                        format!("${}", p.field_name)
                    };
                    if use_positional {
                        value
                    } else {
                        format!("\"p{}\" => {}", i + 1, value)
                    }
                })
                .collect::<Vec<_>>()
                .join(", ");
            let _ = writeln!(out, "        $stmt->execute([{}]);", bindings);
        }

        match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => {
                let _ = writeln!(out, "        $row = $stmt->fetch(\\PDO::FETCH_ASSOC);");
                let _ = writeln!(
                    out,
                    "        return $row ? {}::fromRow($row) : null;",
                    struct_name
                );
            }
            QueryCommand::Many => {
                let _ = writeln!(
                    out,
                    "        while ($row = $stmt->fetch(\\PDO::FETCH_ASSOC)) {{"
                );
                let _ = writeln!(out, "            yield {}::fromRow($row);", struct_name);
                let _ = writeln!(out, "        }}");
            }
            QueryCommand::Exec => {
                // nothing else needed
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(out, "        return $stmt->rowCount();");
            }
            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
        }

        let _ = write!(out, "    }}");
        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "enum {}: string {{", type_name);
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(out, "    case {} = \"{}\";", variant, value);
        }
        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let name = to_pascal_case(&composite.sql_name);
        let mut out = String::new();
        let _ = writeln!(out, "readonly class {} {{", name);
        let _ = writeln!(out, "    public function __construct(");
        if composite.fields.is_empty() {
            // empty constructor
        } else {
            for field in &composite.fields {
                let field_type = resolve_type(&field.neutral_type, &self.manifest, false)
                    .map(|t| t.into_owned())
                    .unwrap_or_else(|_| "mixed".to_string());
                let _ = writeln!(out, "        public {} ${},", field_type, field.name);
            }
        }
        let _ = writeln!(out, "    ) {{}}");
        let _ = write!(out, "}}");
        Ok(out)
    }
}