scythe-codegen 0.1.0

Polyglot code generation backends for scythe
Documentation
use std::fmt::Write;
use std::path::Path;

use scythe_backend::manifest::{BackendManifest, load_manifest};
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};

use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;

/// Default embedded manifest TOML for rust-sqlx, used as fallback.
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");

/// SqlxBackend generates Rust code targeting the sqlx crate.
pub struct SqlxBackend {
    manifest: BackendManifest,
}

impl SqlxBackend {
    pub fn new() -> Result<Self, ScytheError> {
        let manifest = load_sqlx_manifest()?;
        Ok(Self { manifest })
    }

    /// Access the internal manifest (for backward-compat callers).
    pub fn manifest(&self) -> &BackendManifest {
        &self.manifest
    }
}

fn load_sqlx_manifest() -> Result<BackendManifest, ScytheError> {
    let manifest_path = Path::new("backends/rust-sqlx/manifest.toml");
    if manifest_path.exists() {
        load_manifest(manifest_path).map_err(|e| {
            ScytheError::new(
                ErrorCode::InternalError,
                format!("failed to load manifest: {e}"),
            )
        })
    } else {
        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
            ScytheError::new(
                ErrorCode::InternalError,
                format!("failed to parse embedded manifest: {e}"),
            )
        })
    }
}

impl CodegenBackend for SqlxBackend {
    fn name(&self) -> &str {
        "rust-sqlx"
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        let mut out = String::new();

        let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
        let _ = writeln!(out, "pub struct {} {{", struct_name);

        for col in columns {
            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
        }

        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let singular = singularize(table_name);
        let struct_name = to_pascal_case(&singular).into_owned();
        let mut out = String::new();

        let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
        let _ = writeln!(out, "pub struct {} {{", struct_name);

        for col in columns {
            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
        }

        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        _columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let mut out = String::new();

        // Deprecated annotation
        if let Some(ref msg) = analyzed.deprecated {
            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
        }

        // Build parameter list
        let mut param_parts: Vec<String> = vec!["pool: &sqlx::PgPool".to_string()];
        for param in params {
            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
        }

        // Return type
        let return_type = match &analyzed.command {
            QueryCommand::One => struct_name.to_string(),
            QueryCommand::Many => format!("Vec<{}>", struct_name),
            QueryCommand::Exec => "()".to_string(),
            QueryCommand::ExecResult => "sqlx::postgres::PgQueryResult".to_string(),
            QueryCommand::ExecRows => "u64".to_string(),
            QueryCommand::Batch => format!("Vec<{}>", struct_name),
        };

        // Function signature
        let _ = writeln!(
            out,
            "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
            func_name,
            param_parts.join(", "),
            return_type
        );

        // Clean SQL
        let sql_raw = super::clean_sql(&analyzed.sql);
        let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);

        // Query body
        let has_row_struct = matches!(
            analyzed.command,
            QueryCommand::One | QueryCommand::Many | QueryCommand::Batch
        );

        // Build bind params string
        let bind_params: String = analyzed
            .params
            .iter()
            .map(|p| {
                let param_name = to_snake_case(&p.name);
                if p.neutral_type.starts_with("enum::") {
                    let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
                    let rust_type = enum_type_name(enum_name, &self.manifest.naming);
                    format!(", {} as &{}", param_name, rust_type)
                } else {
                    format!(", {}", param_name)
                }
            })
            .collect();

        let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);

        if is_exec_rows {
            if has_row_struct && !analyzed.columns.is_empty() {
                let _ = write!(
                    out,
                    "    let result = sqlx::query_as!({}, \"{}\"{})",
                    struct_name, sql, bind_params
                );
            } else {
                let _ = write!(
                    out,
                    "    let result = sqlx::query!(\"{}\"{})",
                    sql, bind_params
                );
            }
        } else if has_row_struct && !analyzed.columns.is_empty() {
            let _ = write!(
                out,
                "    sqlx::query_as!({}, \"{}\"{})",
                struct_name, sql, bind_params
            );
        } else {
            let _ = write!(out, "    sqlx::query!(\"{}\"{})", sql, bind_params);
        }

        let _ = writeln!(out);

        // Fetch method
        let fetch_method = match &analyzed.command {
            QueryCommand::One => ".fetch_one(pool)",
            QueryCommand::Many => ".fetch_all(pool)",
            QueryCommand::Exec => ".execute(pool)",
            QueryCommand::ExecResult => ".execute(pool)",
            QueryCommand::ExecRows => ".execute(pool)",
            QueryCommand::Batch => ".fetch_all(pool)",
        };

        let _ = write!(out, "        {}", fetch_method);
        let _ = writeln!(out);

        // Post-processing for exec variants
        match &analyzed.command {
            QueryCommand::Exec => {
                let _ = writeln!(out, "        .await?;");
                let _ = writeln!(out, "    Ok(())");
            }
            QueryCommand::ExecRows => {
                let _ = writeln!(out, "        .await?;");
                let _ = writeln!(out, "    Ok(result.rows_affected())");
            }
            _ => {
                let _ = writeln!(out, "        .await");
            }
        }

        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let mut out = String::with_capacity(256);
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);

        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
        let _ = writeln!(
            out,
            "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
            enum_info.sql_name
        );
        let _ = writeln!(out, "pub enum {type_name} {{");

        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(out, "    {variant},");
        }

        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        use scythe_backend::types::resolve_type;

        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
        let mut out = String::new();

        let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
        let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
        let _ = writeln!(out, "pub struct {} {{", struct_name);
        for field in &composite.fields {
            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
                .map(|t| t.into_owned())
                .map_err(|e| {
                    ScytheError::new(
                        ErrorCode::InternalError,
                        format!("composite field type error: {}", e),
                    )
                })?;
            let _ = writeln!(
                out,
                "    pub {}: {},",
                to_snake_case(&field.name),
                rust_type
            );
        }
        let _ = write!(out, "}}");
        Ok(out)
    }
}

// ---------------------------------------------------------------------------
// Internal helpers (moved from old modules)
// ---------------------------------------------------------------------------

/// Rewrite SQL to add enum type annotations for sqlx.
fn rewrite_sql_for_enums(
    sql: &str,
    columns: &[AnalyzedColumn],
    manifest: &BackendManifest,
) -> String {
    let enum_cols: Vec<(&str, String)> = columns
        .iter()
        .filter_map(|col| {
            if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
                let rust_type = enum_type_name(enum_name, &manifest.naming);
                let annotation = if col.nullable {
                    format!("Option<{}>", rust_type)
                } else {
                    rust_type
                };
                Some((col.name.as_str(), annotation))
            } else {
                None
            }
        })
        .collect();

    if enum_cols.is_empty() {
        return sql.to_string();
    }

    let mut result = sql.to_string();
    for (col_name, annotation) in &enum_cols {
        let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
        if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
            let select_part = &result[..from_pos];
            let rest = &result[from_pos..];
            let new_select = replace_column_in_select(select_part, col_name, &alias);
            result = format!("{}{}", new_select, rest);
        }
    }
    result
}

fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
    let mut result = select.to_string();
    let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
    for pattern in &patterns {
        if let Some(pos) = result.rfind(pattern.as_str()) {
            let after = pos + pattern.len();
            let next_char = result[after..].chars().next();
            if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
                let prefix = &result[..pos + pattern.len() - col_name.len()];
                let suffix = &result[after..];
                result = format!("{}{}{}", prefix, replacement, suffix);
                break;
            }
        }
    }
    result
}