scythe-codegen 0.1.0

Polyglot code generation backends for scythe
Documentation
use std::fmt::Write;
use std::path::Path;

use scythe_backend::manifest::{BackendManifest, load_manifest};
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;

const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-asyncpg.toml");

pub struct PythonAsyncpgBackend {
    manifest: BackendManifest,
}

impl PythonAsyncpgBackend {
    pub fn new() -> Result<Self, ScytheError> {
        let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
        let manifest = if manifest_path.exists() {
            load_manifest(manifest_path)
                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
        } else {
            toml::from_str(DEFAULT_MANIFEST_TOML)
                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
        };
        Ok(Self { manifest })
    }

    pub fn manifest(&self) -> &BackendManifest {
        &self.manifest
    }
}

impl CodegenBackend for PythonAsyncpgBackend {
    fn name(&self) -> &str {
        "python-asyncpg"
    }

    fn file_header(&self) -> String {
        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
         \n\
         import datetime  # noqa: F401\n\
         from dataclasses import dataclass\n\
         from enum import Enum  # noqa: F401\n\
         \n\
         from asyncpg import Connection  # noqa: F401\n\
         \n"
        .to_string()
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "@dataclass");
        let _ = writeln!(out, "class {}:", struct_name);
        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
        if columns.is_empty() {
            let _ = writeln!(out, "    pass");
        } else {
            let _ = writeln!(out);
            for col in columns {
                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
            }
        }
        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let singular = singularize(table_name);
        let name = to_pascal_case(&singular);
        self.generate_row_struct(&name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let mut out = String::new();

        // Build parameter list (keyword-only after conn)
        let param_list = params
            .iter()
            .map(|p| format!("{}: {}", p.field_name, p.full_type))
            .collect::<Vec<_>>()
            .join(", ");
        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };

        // Clean SQL — asyncpg uses $1, $2 positional params natively
        let sql = super::clean_sql(&analyzed.sql);

        match &analyzed.command {
            QueryCommand::One => {
                let _ = writeln!(
                    out,
                    "async def {}(conn: Connection{}{}) -> {} | None:",
                    func_name, kw_sep, param_list, struct_name
                );
                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
                let _ = writeln!(out, "    row = await conn.fetchrow(");
                let _ = writeln!(out, "        \"{}\",", sql);
                if !params.is_empty() {
                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
                    let _ = writeln!(out, "        {},", args.join(", "));
                }
                let _ = writeln!(out, "    )");
                let _ = writeln!(out, "    if row is None:");
                let _ = writeln!(out, "        return None");
                let field_assignments: Vec<String> = columns
                    .iter()
                    .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
                    .collect();
                let oneliner = format!(
                    "    return {}({})",
                    struct_name,
                    field_assignments.join(", ")
                );
                if oneliner.len() <= 88 {
                    let _ = writeln!(out, "{}", oneliner);
                } else {
                    let _ = writeln!(out, "    return {}(", struct_name);
                    for fa in &field_assignments {
                        let _ = writeln!(out, "        {},", fa);
                    }
                    let _ = writeln!(out, "    )");
                }
            }
            QueryCommand::Many | QueryCommand::Batch => {
                let _ = writeln!(
                    out,
                    "async def {}(conn: Connection{}{}) -> list[{}]:",
                    func_name, kw_sep, param_list, struct_name
                );
                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
                let _ = writeln!(out, "    rows = await conn.fetch(");
                let _ = writeln!(out, "        \"{}\",", sql);
                if !params.is_empty() {
                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
                    let _ = writeln!(out, "        {},", args.join(", "));
                }
                let _ = writeln!(out, "    )");
                let field_assignments: Vec<String> = columns
                    .iter()
                    .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
                    .collect();
                let oneliner = format!(
                    "    return [{}({}) for r in rows]",
                    struct_name,
                    field_assignments.join(", ")
                );
                if oneliner.len() <= 88 {
                    let _ = writeln!(out, "{}", oneliner);
                } else {
                    let _ = writeln!(out, "    return [");
                    let _ = writeln!(out, "        {}(", struct_name);
                    for fa in &field_assignments {
                        let _ = writeln!(out, "            {},", fa);
                    }
                    let _ = writeln!(out, "        )");
                    let _ = writeln!(out, "        for r in rows");
                    let _ = writeln!(out, "    ]");
                }
            }
            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(
                    out,
                    "async def {}(conn: Connection{}{}) -> None:",
                    func_name, kw_sep, param_list
                );
                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
                let _ = writeln!(out, "    await conn.execute(");
                let _ = writeln!(out, "        \"{}\",", sql);
                if !params.is_empty() {
                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
                    let _ = writeln!(out, "        {},", args.join(", "));
                }
                let _ = writeln!(out, "    )");
            }
        }

        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "class {}(str, Enum):", type_name);
        let _ = writeln!(
            out,
            "    \"\"\"Database enum type {}.\"\"\"",
            enum_info.sql_name
        );
        if enum_info.values.is_empty() {
            let _ = writeln!(out, "    pass");
        } else {
            let _ = writeln!(out);
            for value in &enum_info.values {
                let variant = enum_variant_name(value, &self.manifest.naming);
                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
            }
        }
        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let name = to_pascal_case(&composite.sql_name);
        let mut out = String::new();
        let _ = writeln!(out, "@dataclass");
        let _ = writeln!(out, "class {}:", name);
        let _ = writeln!(
            out,
            "    \"\"\"Composite type {}.\"\"\"",
            composite.sql_name
        );
        if composite.fields.is_empty() {
            let _ = writeln!(out, "    pass");
        } else {
            let _ = writeln!(out);
            for field in &composite.fields {
                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
                    .map(|t| t.into_owned())
                    .map_err(|e| {
                        ScytheError::new(
                            ErrorCode::InternalError,
                            format!("composite field type error: {}", e),
                        )
                    })?;
                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
            }
        }
        Ok(out)
    }
}