scythe-codegen 0.6.9

Polyglot code generation backends for scythe
Documentation
use std::fmt::Write;

use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case,
};
use scythe_backend::types::resolve_type;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};

const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/csharp-mysqlconnector.toml");

pub struct CsharpMysqlConnectorBackend {
    manifest: BackendManifest,
}

impl CsharpMysqlConnectorBackend {
    pub fn new(engine: &str) -> Result<Self, ScytheError> {
        match engine {
            "mysql" | "mariadb" => {}
            _ => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    format!(
                        "csharp-mysqlconnector only supports MySQL, got engine '{}'",
                        engine
                    ),
                ));
            }
        }
        let manifest = super::load_or_default_manifest(
            "backends/csharp-mysqlconnector/manifest.toml",
            DEFAULT_MANIFEST_TOML,
        )?;
        Ok(Self { manifest })
    }
}

/// Map a neutral type to a MySqlDataReader method.
fn reader_method(neutral_type: &str) -> &'static str {
    match neutral_type {
        "bool" => "GetBoolean",
        "int16" => "GetInt16",
        "int32" => "GetInt32",
        "int64" => "GetInt64",
        "float32" => "GetFloat",
        "float64" => "GetDouble",
        "string" | "json" | "inet" | "interval" => "GetString",
        "decimal" => "GetDecimal",
        "date" => "GetFieldValue<DateOnly>",
        "time" | "time_tz" => "GetFieldValue<TimeOnly>",
        "datetime" => "GetDateTime",
        "datetime_tz" => "GetFieldValue<DateTimeOffset>",
        _ => "GetValue",
    }
}

/// Build the expression to read a column from MySqlDataReader.
fn column_read_expr(col: &ResolvedColumn, ordinal: usize) -> String {
    if col.neutral_type.starts_with("enum::") {
        format!(
            "(Enum.TryParse<{typ}>(reader.GetString({ord}), true, out var enumVal{ord}) ? enumVal{ord} : throw new InvalidOperationException($\"Invalid enum value '{{reader.GetString({ord})}}' for {typ}\"))",
            typ = col.lang_type,
            ord = ordinal
        )
    } else if col.neutral_type == "uuid" {
        // MySqlConnector returns Guid for UUID columns; use GetValue().ToString() to get a string.
        format!("reader.GetValue({}).ToString()!", ordinal)
    } else {
        let method = reader_method(&col.neutral_type);
        format!("reader.{}({})", method, ordinal)
    }
}

impl CodegenBackend for CsharpMysqlConnectorBackend {
    fn name(&self) -> &str {
        "csharp-mysqlconnector"
    }

    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
        &self.manifest
    }

    fn supported_engines(&self) -> &[&str] {
        &["mysql", "mariadb"]
    }

    fn file_header(&self) -> String {
        "// Auto-generated by scythe. Do not edit.\n#nullable enable\n\nusing MySqlConnector;\n\npublic static class Queries {"
            .to_string()
    }

    fn file_footer(&self) -> String {
        "}".to_string()
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "public record {}(", struct_name);
        for (i, c) in columns.iter().enumerate() {
            let field = to_pascal_case(&c.field_name);
            let sep = if i + 1 < columns.len() { "," } else { "" };
            let _ = writeln!(out, "    {} {}{}", c.full_type, field, sep);
        }
        let _ = write!(out, ");");
        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let name = to_pascal_case(table_name);
        self.generate_row_struct(&name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let sql = super::rewrite_pg_placeholders(
            &super::clean_sql_oneline_with_optional(
                &analyzed.sql,
                &analyzed.optional_params,
                &analyzed.params,
            ),
            |n| format!("@p{n}"),
        );
        let mut out = String::new();

        let param_list = params
            .iter()
            .map(|p| format!("{} {}", p.full_type, p.field_name))
            .collect::<Vec<_>>()
            .join(", ");
        let sep = if param_list.is_empty() { "" } else { ", " };

        // Handle :batch separately
        if matches!(analyzed.command, QueryCommand::Batch) {
            let batch_fn_name = format!("{}Batch", func_name);
            if params.len() > 1 {
                let params_record_name = format!("{}BatchParams", to_pascal_case(&analyzed.name));
                let _ = writeln!(out, "public record {}(", params_record_name);
                for (i, p) in params.iter().enumerate() {
                    let field = to_pascal_case(&p.field_name);
                    let sep = if i + 1 < params.len() { "," } else { "" };
                    let _ = writeln!(out, "    {} {}{}", p.full_type, field, sep);
                }
                let _ = writeln!(out, ");");
                let _ = writeln!(out);
                let _ = writeln!(
                    out,
                    "public static async Task {}(MySqlConnection conn, List<{}> items) {{",
                    batch_fn_name, params_record_name
                );
            } else if params.len() == 1 {
                let _ = writeln!(
                    out,
                    "public static async Task {}(MySqlConnection conn, List<{}> items) {{",
                    batch_fn_name, params[0].full_type
                );
            } else {
                let _ = writeln!(
                    out,
                    "public static async Task {}(MySqlConnection conn, int count) {{",
                    batch_fn_name
                );
            }
            let _ = writeln!(
                out,
                "    await using var tx = await conn.BeginTransactionAsync();"
            );
            let _ = writeln!(out, "    try {{");
            if params.is_empty() {
                let _ = writeln!(out, "        for (int i = 0; i < count; i++) {{");
            } else {
                let _ = writeln!(out, "        foreach (var item in items) {{");
            }
            let _ = writeln!(
                out,
                "            await using var cmd = new MySqlCommand(\"{}\", conn, (MySqlTransaction)tx);",
                sql
            );
            for (i, p) in params.iter().enumerate() {
                let value_expr = if params.len() > 1 {
                    let field = to_pascal_case(&p.field_name);
                    format!("item.{}", field)
                } else {
                    "item".to_string()
                };
                let _ = writeln!(
                    out,
                    "            cmd.Parameters.AddWithValue(\"@p{}\", {});",
                    i + 1,
                    value_expr
                );
            }
            let _ = writeln!(out, "            await cmd.ExecuteNonQueryAsync();");
            let _ = writeln!(out, "        }}");
            let _ = writeln!(out, "        await tx.CommitAsync();");
            let _ = writeln!(out, "    }} catch {{");
            let _ = writeln!(out, "        await tx.RollbackAsync();");
            let _ = writeln!(out, "        throw;");
            let _ = writeln!(out, "    }}");
            let _ = write!(out, "}}");
            return Ok(out);
        }

        let return_type = match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => format!("{}?", struct_name),
            QueryCommand::Many => {
                format!("List<{}>", struct_name)
            }
            QueryCommand::Exec => "void".to_string(),
            QueryCommand::ExecResult | QueryCommand::ExecRows => "int".to_string(),
            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
        };

        let is_async_void = return_type == "void";
        let task_type = if is_async_void {
            "Task".to_string()
        } else {
            format!("Task<{}>", return_type)
        };

        let _ = writeln!(
            out,
            "public static async {} {}(MySqlConnection conn{}{}) {{",
            task_type, func_name, sep, param_list
        );

        let _ = writeln!(
            out,
            "    await using var cmd = new MySqlCommand(\"{}\", conn);",
            sql
        );
        for (i, p) in params.iter().enumerate() {
            let value_expr = if p.neutral_type.starts_with("enum::") {
                format!("{}.ToString().ToLower()", p.field_name)
            } else {
                p.field_name.clone()
            };
            let _ = writeln!(
                out,
                "    cmd.Parameters.AddWithValue(\"@p{}\", {});",
                i + 1,
                value_expr
            );
        }

        match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => {
                let _ = writeln!(
                    out,
                    "    await using var reader = await cmd.ExecuteReaderAsync();"
                );
                let _ = writeln!(out, "    if (!await reader.ReadAsync()) return null;");
                let _ = writeln!(out, "    return new {}(", struct_name);
                for (i, col) in columns.iter().enumerate() {
                    let expr = column_read_expr(col, i);
                    let sep = if i + 1 < columns.len() { "," } else { "" };
                    if col.nullable {
                        let _ = writeln!(out, "        reader.IsDBNull({i}) ? null : {expr}{sep}");
                    } else {
                        let _ = writeln!(out, "        {expr}{sep}");
                    }
                }
                let _ = writeln!(out, "    );");
            }
            QueryCommand::Many => {
                let _ = writeln!(
                    out,
                    "    await using var reader = await cmd.ExecuteReaderAsync();"
                );
                let _ = writeln!(out, "    var results = new List<{}>();", struct_name);
                let _ = writeln!(out, "    while (await reader.ReadAsync()) {{");
                let _ = writeln!(out, "        results.Add(new {}(", struct_name);
                for (i, col) in columns.iter().enumerate() {
                    let expr = column_read_expr(col, i);
                    let sep = if i + 1 < columns.len() { "," } else { "" };
                    if col.nullable {
                        let _ =
                            writeln!(out, "            reader.IsDBNull({i}) ? null : {expr}{sep}");
                    } else {
                        let _ = writeln!(out, "            {expr}{sep}");
                    }
                }
                let _ = writeln!(out, "        ));");
                let _ = writeln!(out, "    }}");
                let _ = writeln!(out, "    return results;");
            }
            QueryCommand::Exec => {
                let _ = writeln!(out, "    await cmd.ExecuteNonQueryAsync();");
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(out, "    return await cmd.ExecuteNonQueryAsync();");
            }
            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
        }

        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "public enum {} {{", type_name);
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(out, "    {},", variant);
        }
        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let name = to_pascal_case(&composite.sql_name);
        let mut out = String::new();
        if composite.fields.is_empty() {
            let _ = writeln!(out, "public record {}();", name);
        } else {
            let _ = writeln!(out, "public record {}(", name);
            for (i, field) in composite.fields.iter().enumerate() {
                let cs_type = resolve_type(&field.neutral_type, &self.manifest, false)
                    .map(|t| t.into_owned())
                    .unwrap_or_else(|_| "object".to_string());
                let field_name = to_pascal_case(&field.name);
                let sep = if i + 1 < composite.fields.len() {
                    ","
                } else {
                    ""
                };
                let _ = writeln!(out, "    {} {}{}", cs_type, field_name, sep);
            }
            let _ = write!(out, ");");
        }
        Ok(out)
    }
}