scythe-codegen 0.2.0

Polyglot code generation backends for scythe
Documentation
use std::fmt::Write;
use std::path::Path;

use scythe_backend::manifest::{BackendManifest, load_manifest};
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;

/// Default embedded manifest TOML for rust-tokio-postgres, used as fallback.
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");

/// TokioPostgresBackend generates Rust code targeting the tokio-postgres crate.
pub struct TokioPostgresBackend {
    manifest: BackendManifest,
}

impl TokioPostgresBackend {
    pub fn new(engine: &str) -> Result<Self, ScytheError> {
        match engine {
            "postgresql" | "postgres" | "pg" => {}
            _ => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    format!(
                        "rust-tokio-postgres only supports PostgreSQL, got engine '{}'",
                        engine
                    ),
                ));
            }
        }
        let manifest = load_tokio_postgres_manifest()?;
        Ok(Self { manifest })
    }
}

fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
    let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
    if manifest_path.exists() {
        load_manifest(manifest_path).map_err(|e| {
            ScytheError::new(
                ErrorCode::InternalError,
                format!("failed to load manifest: {e}"),
            )
        })
    } else {
        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
            ScytheError::new(
                ErrorCode::InternalError,
                format!("failed to parse embedded manifest: {e}"),
            )
        })
    }
}

impl CodegenBackend for TokioPostgresBackend {
    fn name(&self) -> &str {
        "rust-tokio-postgres"
    }

    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
        &self.manifest
    }

    fn file_header(&self) -> String {
        "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
            .to_string()
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        generate_struct_with_from_row(&struct_name, columns)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let singular = singularize(table_name);
        let struct_name = to_pascal_case(&singular).into_owned();
        generate_struct_with_from_row(&struct_name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        _columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let mut out = String::new();

        // Deprecated annotation
        if let Some(ref msg) = analyzed.deprecated {
            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
        }

        // Build parameter list
        let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
        for param in params {
            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
        }

        // Return type
        let return_type = match &analyzed.command {
            QueryCommand::One => struct_name.to_string(),
            QueryCommand::Many => format!("Vec<{}>", struct_name),
            QueryCommand::Exec => "()".to_string(),
            QueryCommand::ExecResult => "u64".to_string(),
            QueryCommand::ExecRows => "u64".to_string(),
            QueryCommand::Batch => format!("Vec<{}>", struct_name),
        };

        // Function signature
        let _ = writeln!(
            out,
            "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
            func_name,
            param_parts.join(", "),
            return_type
        );

        // Clean SQL
        let sql = super::clean_sql(&analyzed.sql);

        // Build param references for the query call
        let param_refs: String = if params.is_empty() {
            "&[]".to_string()
        } else {
            let refs: Vec<String> = params
                .iter()
                .map(|p| {
                    if p.neutral_type.starts_with("enum::") {
                        format!("&{}.to_string()", p.field_name)
                    } else {
                        format!("&{}", p.field_name)
                    }
                })
                .collect();
            format!("&[{}]", refs.join(", "))
        };

        match &analyzed.command {
            QueryCommand::One => {
                let _ = writeln!(
                    out,
                    "    let row = client.query_one(r#\"{}\"#, {}).await?;",
                    sql, param_refs
                );
                let _ = writeln!(out, "    Ok({}::from_row(&row))", struct_name);
            }
            QueryCommand::Many | QueryCommand::Batch => {
                let _ = writeln!(
                    out,
                    "    let rows = client.query(r#\"{}\"#, {}).await?;",
                    sql, param_refs
                );
                let _ = writeln!(
                    out,
                    "    Ok(rows.iter().map({}::from_row).collect())",
                    struct_name
                );
            }
            QueryCommand::Exec => {
                let _ = writeln!(
                    out,
                    "    client.execute(r#\"{}\"#, {}).await?;",
                    sql, param_refs
                );
                let _ = writeln!(out, "    Ok(())");
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(
                    out,
                    "    let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
                    sql, param_refs
                );
                let _ = writeln!(out, "    Ok(rows_affected)");
            }
        }

        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        let mut out = String::with_capacity(512);

        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
        let _ = writeln!(out, "pub enum {} {{", type_name);
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(out, "    {},", variant);
        }
        let _ = writeln!(out, "}}");
        let _ = writeln!(out);

        // impl Display for serialization
        let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
        let _ = writeln!(
            out,
            "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
        );
        let _ = writeln!(out, "        match self {{");
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(
                out,
                "            {}::{} => write!(f, \"{}\"),",
                type_name, variant, value
            );
        }
        let _ = writeln!(out, "        }}");
        let _ = writeln!(out, "    }}");
        let _ = writeln!(out, "}}");
        let _ = writeln!(out);

        // impl FromStr for deserialization
        let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
        let _ = writeln!(out, "    type Err = String;");
        let _ = writeln!(
            out,
            "    fn from_str(s: &str) -> Result<Self, Self::Err> {{"
        );
        let _ = writeln!(out, "        match s {{");
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(
                out,
                "            \"{}\" => Ok({}::{}),",
                value, type_name, variant
            );
        }
        let _ = writeln!(
            out,
            "            _ => Err(format!(\"unknown variant: {{}}\", s)),"
        );
        let _ = writeln!(out, "        }}");
        let _ = writeln!(out, "    }}");
        let _ = write!(out, "}}");

        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
        let mut out = String::new();

        let _ = writeln!(out, "#[derive(Debug, Clone)]");
        let _ = writeln!(out, "pub struct {} {{", struct_name);
        for field in &composite.fields {
            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
                .map(|t| t.into_owned())
                .map_err(|e| {
                    ScytheError::new(
                        ErrorCode::InternalError,
                        format!("composite field type error: {}", e),
                    )
                })?;
            let _ = writeln!(
                out,
                "    pub {}: {},",
                to_snake_case(&field.name),
                rust_type
            );
        }
        let _ = write!(out, "}}");
        Ok(out)
    }
}

// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------

/// Generate a struct with a `from_row` method for tokio-postgres.
fn generate_struct_with_from_row(
    struct_name: &str,
    columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
    let mut out = String::new();

    let _ = writeln!(out, "#[derive(Debug, Clone)]");
    let _ = writeln!(out, "pub struct {} {{", struct_name);
    for col in columns {
        let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
    }
    let _ = writeln!(out, "}}");
    let _ = writeln!(out);

    let _ = writeln!(out, "impl {} {{", struct_name);
    let _ = writeln!(
        out,
        "    pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
    );
    let _ = writeln!(out, "        Self {{");
    for col in columns {
        if col.neutral_type.starts_with("enum::") {
            // Enum columns need string conversion
            if col.nullable {
                let _ = writeln!(
                    out,
                    "            {}: row.get::<_, Option<String>>(\"{}\").map(|s| s.parse().unwrap()),",
                    col.field_name, col.name
                );
            } else {
                let _ = writeln!(
                    out,
                    "            {}: row.get::<_, String>(\"{}\").parse().unwrap(),",
                    col.field_name, col.name
                );
            }
        } else {
            let _ = writeln!(
                out,
                "            {}: row.get(\"{}\"),",
                col.field_name, col.name
            );
        }
    }
    let _ = writeln!(out, "        }}");
    let _ = writeln!(out, "    }}");
    let _ = write!(out, "}}");

    Ok(out)
}