scythe-codegen 0.6.5

Polyglot code generation backends for scythe
Documentation
use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;
use std::fmt::Write;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;

const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tiberius.toml");

pub struct RustTiberiusBackend {
    manifest: BackendManifest,
}

impl RustTiberiusBackend {
    pub fn new(engine: &str) -> Result<Self, ScytheError> {
        match engine {
            "mssql" => {}
            _ => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    format!("rust-tiberius only supports MSSQL, got engine '{}'", engine),
                ));
            }
        }
        let manifest = super::load_or_default_manifest(
            "backends/rust-tiberius/manifest.toml",
            DEFAULT_MANIFEST_TOML,
        )?;
        Ok(Self { manifest })
    }
}

/// Rewrite $1, $2, ... positional params to @p1, @p2, ... for MSSQL.
impl CodegenBackend for RustTiberiusBackend {
    fn name(&self) -> &str {
        "rust-tiberius"
    }

    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
        &self.manifest
    }

    fn supported_engines(&self) -> &[&str] {
        &["mssql"]
    }

    fn file_header(&self) -> String {
        "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
            .to_string()
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        let mut out = String::new();

        let _ = writeln!(out, "#[derive(Debug, Clone)]");
        let _ = writeln!(out, "pub struct {} {{", struct_name);
        for col in columns {
            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
        }
        let _ = writeln!(out, "}}");
        let _ = writeln!(out);

        let _ = writeln!(out, "impl {} {{", struct_name);
        let _ = writeln!(
            out,
            "    pub fn from_row(row: &tiberius::Row) -> Result<Self, tiberius::error::Error> {{"
        );
        let _ = writeln!(out, "        Ok(Self {{");
        for col in columns {
            if col.nullable {
                let _ = writeln!(
                    out,
                    "            {}: row.try_get(\"{}\")?,",
                    col.field_name, col.name
                );
            } else {
                let _ = writeln!(
                    out,
                    "            {}: row.try_get(\"{}\")?.ok_or_else(|| tiberius::error::Error::Protocol(\"unexpected NULL for non-nullable column '{}'\".into()))?,",
                    col.field_name, col.name, col.name
                );
            }
        }
        let _ = writeln!(out, "        }})");
        let _ = writeln!(out, "    }}");
        let _ = write!(out, "}}");

        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let singular = singularize(table_name);
        let name = to_pascal_case(&singular).into_owned();
        self.generate_row_struct(&name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        _columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let mut out = String::new();

        if let Some(ref msg) = analyzed.deprecated {
            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
        }

        let sql = super::rewrite_pg_placeholders(
            &super::clean_sql_with_optional(
                &analyzed.sql,
                &analyzed.optional_params,
                &analyzed.params,
            ),
            |n| format!("@p{n}"),
        );

        let mut param_parts: Vec<String> =
            vec!["client: &mut tiberius::Client<tokio::net::TcpStream>".to_string()];
        for param in params {
            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
        }

        if matches!(analyzed.command, QueryCommand::Batch) {
            let batch_fn_name = format!("{}_batch", func_name);
            if params.len() > 1 {
                let params_struct_name = format!("{}BatchParams", struct_name);
                let _ = writeln!(out, "#[derive(Debug, Clone)]");
                let _ = writeln!(out, "pub struct {} {{", params_struct_name);
                for param in params {
                    let _ = writeln!(out, "    pub {}: {},", param.field_name, param.full_type);
                }
                let _ = writeln!(out, "}}");
                let _ = writeln!(out);
                let _ = writeln!(
                    out,
                    "pub async fn {}(client: &mut tiberius::Client<tokio::net::TcpStream>, items: &[{}]) -> Result<(), tiberius::error::Error> {{",
                    batch_fn_name, params_struct_name
                );
                let _ = writeln!(out, "    for item in items {{");
                let bind_args: Vec<String> = params
                    .iter()
                    .map(|p| format!("&item.{}", p.field_name))
                    .collect();
                let _ = writeln!(
                    out,
                    "        client.execute(r#\"{}\"#, &[{}]).await?;",
                    sql,
                    bind_args.join(", ")
                );
                let _ = writeln!(out, "    }}");
                let _ = writeln!(out, "    Ok(())");
            } else if params.len() == 1 {
                let _ = writeln!(
                    out,
                    "pub async fn {}(client: &mut tiberius::Client<tokio::net::TcpStream>, items: &[{}]) -> Result<(), tiberius::error::Error> {{",
                    batch_fn_name, params[0].full_type
                );
                let _ = writeln!(out, "    for item in items {{");
                let _ = writeln!(
                    out,
                    "        client.execute(r#\"{}\"#, &[item]).await?;",
                    sql
                );
                let _ = writeln!(out, "    }}");
                let _ = writeln!(out, "    Ok(())");
            } else {
                let _ = writeln!(
                    out,
                    "pub async fn {}(client: &mut tiberius::Client<tokio::net::TcpStream>, count: usize) -> Result<(), tiberius::error::Error> {{",
                    batch_fn_name
                );
                let _ = writeln!(out, "    for _ in 0..count {{");
                let _ = writeln!(out, "        client.execute(r#\"{}\"#, &[]).await?;", sql);
                let _ = writeln!(out, "    }}");
                let _ = writeln!(out, "    Ok(())");
            }
            let _ = write!(out, "}}");
            return Ok(out);
        }

        let return_type = match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => struct_name.to_string(),
            QueryCommand::Many => format!("Vec<{}>", struct_name),
            QueryCommand::Exec => "()".to_string(),
            QueryCommand::ExecResult | QueryCommand::ExecRows => "u64".to_string(),
            QueryCommand::Batch => unreachable!(),
            QueryCommand::Grouped => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    "grouped queries are not yet supported for rust-tiberius".to_string(),
                ));
            }
        };

        let _ = writeln!(
            out,
            "pub async fn {}({}) -> Result<{}, tiberius::error::Error> {{",
            func_name,
            param_parts.join(", "),
            return_type
        );

        let param_refs: String = if params.is_empty() {
            "&[]".to_string()
        } else {
            let refs: Vec<String> = params
                .iter()
                .map(|p| format!("&{}", p.field_name))
                .collect();
            format!("&[{}]", refs.join(", "))
        };

        match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => {
                let _ = writeln!(
                    out,
                    "    let stream = client.query(r#\"{}\"#, {}).await?;",
                    sql, param_refs
                );
                let _ = writeln!(
                    out,
                    "    let row = stream.into_row().await?.expect(\"expected one row\");"
                );
                let _ = writeln!(out, "    Ok({}::from_row(&row)?)", struct_name);
            }
            QueryCommand::Many => {
                let _ = writeln!(
                    out,
                    "    let stream = client.query(r#\"{}\"#, {}).await?;",
                    sql, param_refs
                );
                let _ = writeln!(out, "    let rows = stream.into_first_result().await?;");
                let _ = writeln!(
                    out,
                    "    rows.iter().map({}::from_row).collect()",
                    struct_name
                );
            }
            QueryCommand::Exec => {
                let _ = writeln!(
                    out,
                    "    client.execute(r#\"{}\"#, {}).await?;",
                    sql, param_refs
                );
                let _ = writeln!(out, "    Ok(())");
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(
                    out,
                    "    let result = client.execute(r#\"{}\"#, {}).await?;",
                    sql, param_refs
                );
                let _ = writeln!(out, "    Ok(result.total())");
            }
            QueryCommand::Batch => unreachable!(),
            QueryCommand::Grouped => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    "grouped queries are not yet supported for rust-tiberius".to_string(),
                ));
            }
        }

        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        let mut out = String::with_capacity(512);

        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
        let _ = writeln!(out, "pub enum {} {{", type_name);
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(out, "    {},", variant);
        }
        let _ = writeln!(out, "}}");
        let _ = writeln!(out);

        let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
        let _ = writeln!(
            out,
            "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
        );
        let _ = writeln!(out, "        match self {{");
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(
                out,
                "            {}::{} => write!(f, \"{}\"),",
                type_name, variant, value
            );
        }
        let _ = writeln!(out, "        }}");
        let _ = writeln!(out, "    }}");
        let _ = writeln!(out, "}}");
        let _ = writeln!(out);

        let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
        let _ = writeln!(out, "    type Err = String;");
        let _ = writeln!(
            out,
            "    fn from_str(s: &str) -> Result<Self, Self::Err> {{"
        );
        let _ = writeln!(out, "        match s {{");
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(
                out,
                "            \"{}\" => Ok({}::{}),",
                value, type_name, variant
            );
        }
        let _ = writeln!(
            out,
            "            _ => Err(format!(\"unknown variant: {{}}\", s)),"
        );
        let _ = writeln!(out, "        }}");
        let _ = writeln!(out, "    }}");
        let _ = write!(out, "}}");

        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
        let mut out = String::new();

        let _ = writeln!(out, "#[derive(Debug, Clone)]");
        let _ = writeln!(out, "pub struct {} {{", struct_name);
        for field in &composite.fields {
            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
                .map(|t| t.into_owned())
                .map_err(|e| {
                    ScytheError::new(
                        ErrorCode::InternalError,
                        format!("composite field type error: {}", e),
                    )
                })?;
            let _ = writeln!(
                out,
                "    pub {}: {},",
                to_snake_case(&field.name),
                rust_type
            );
        }
        let _ = write!(out, "}}");
        Ok(out)
    }
}