scythe-codegen 0.6.8

Polyglot code generation backends for scythe
Documentation
use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;
use std::fmt::Write;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;

const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sibyl.toml");

pub struct RustSibylBackend {
    manifest: BackendManifest,
}

impl RustSibylBackend {
    pub fn new(engine: &str) -> Result<Self, ScytheError> {
        match engine {
            "oracle" => {}
            _ => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    format!("rust-sibyl only supports Oracle, got engine '{}'", engine),
                ));
            }
        }
        let manifest = super::load_or_default_manifest(
            "backends/rust-sibyl/manifest.toml",
            DEFAULT_MANIFEST_TOML,
        )?;
        Ok(Self { manifest })
    }
}

/// Rewrite $1, $2, ... positional params to :1, :2, ... for Oracle.
impl CodegenBackend for RustSibylBackend {
    fn name(&self) -> &str {
        "rust-sibyl"
    }

    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
        &self.manifest
    }

    fn supported_engines(&self) -> &[&str] {
        &["oracle"]
    }

    fn file_header(&self) -> String {
        "// Auto-generated by scythe. Do not edit.\n\
         use sibyl::prelude::*;\n"
            .to_string()
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "#[derive(Debug, Clone)]");
        let _ = writeln!(out, "pub struct {} {{", struct_name);
        for col in columns {
            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
        }
        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let singular = singularize(table_name);
        let name = to_pascal_case(&singular);
        self.generate_row_struct(&name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let sql = super::rewrite_pg_placeholders(
            &super::clean_sql_with_optional(
                &analyzed.sql,
                &analyzed.optional_params,
                &analyzed.params,
            ),
            |n| format!(":{n}"),
        );

        let param_list = params
            .iter()
            .map(|p| format!("{}: {}", p.field_name, p.borrowed_type))
            .collect::<Vec<_>>()
            .join(", ");
        let sep = if param_list.is_empty() { "" } else { ", " };

        // Check if this is a DML with RETURNING (INSERT/UPDATE/DELETE RETURNING)
        let has_returning = sql.to_uppercase().contains("RETURNING");

        let mut out = String::new();

        match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => {
                let _ = writeln!(
                    out,
                    "pub async fn {}<'a>(session: &'a Session<'a>{}{}) -> sibyl::Result<Option<{}>> {{",
                    func_name, sep, param_list, struct_name
                );

                if has_returning {
                    // Oracle RETURNING INTO uses sibyl's returning_into() output binding.
                    // Append INTO :N+1, :N+2, ... placeholders to SQL.
                    let into_placeholders: Vec<String> = (0..columns.len())
                        .map(|i| format!(":{}", params.len() + i + 1))
                        .collect();
                    let full_sql = format!("{} INTO {}", sql, into_placeholders.join(", "));
                    let _ = writeln!(
                        out,
                        "    let stmt = session.prepare(\"{}\").await?;",
                        full_sql
                    );
                    for (i, param) in params.iter().enumerate() {
                        let _ = writeln!(
                            out,
                            "    stmt.bind({}, &{}).await?;",
                            i + 1,
                            param.field_name
                        );
                    }
                    for (i, col) in columns.iter().enumerate() {
                        let slot = params.len() + i + 1;
                        let sibyl_type = match col.neutral_type.as_str() {
                            "int16" | "int32" | "int64" | "float32" | "float64" | "decimal" => {
                                "sibyl::Number"
                            }
                            "date" | "datetime" | "datetime_tz" => "sibyl::Date",
                            _ => "sibyl::Varchar",
                        };
                        let _ = writeln!(
                            out,
                            "    let out_{}: {} = stmt.returning_into({}).await?;",
                            col.field_name, sibyl_type, slot
                        );
                    }
                    let _ = writeln!(out, "    stmt.execute(\"\").await?;");
                    for col in columns {
                        let extract = match col.neutral_type.as_str() {
                            "int16" => format!(
                                "    let {} = out_{}.to_int::<i16>()? as {};",
                                col.field_name, col.field_name, col.lang_type
                            ),
                            "int32" => format!(
                                "    let {} = out_{}.to_int::<i32>()? as {};",
                                col.field_name, col.field_name, col.lang_type
                            ),
                            "int64" => format!(
                                "    let {} = out_{}.to_int::<i64>()? as {};",
                                col.field_name, col.field_name, col.lang_type
                            ),
                            "float32" | "float64" | "decimal" => format!(
                                "    let {} = out_{}.to_float::<f64>()? as {};",
                                col.field_name, col.field_name, col.lang_type
                            ),
                            "date" | "datetime" | "datetime_tz" => format!(
                                "    let {} = out_{}.timestamp()? as {};",
                                col.field_name, col.field_name, col.lang_type
                            ),
                            _ => format!(
                                "    let {} = out_{}.as_str()?.to_string();",
                                col.field_name, col.field_name
                            ),
                        };
                        let _ = writeln!(out, "{}", extract);
                    }
                    let field_assigns: Vec<String> = columns
                        .iter()
                        .map(|c| format!("{}: {}", c.field_name, c.field_name))
                        .collect();
                    let _ = writeln!(
                        out,
                        "    Ok(Some({} {{ {} }}))",
                        struct_name,
                        field_assigns.join(", ")
                    );
                    let _ = write!(out, "}}");
                } else {
                    let _ = writeln!(out, "    let stmt = session.prepare(\"{}\").await?;", sql);
                    for (i, param) in params.iter().enumerate() {
                        let _ = writeln!(
                            out,
                            "    stmt.bind({}, &{}).await?;",
                            i + 1,
                            param.field_name
                        );
                    }
                    let _ = writeln!(out, "    let rows = stmt.query(\"\").await?;");
                    let _ = writeln!(out, "    if let Some(row) = rows.next().await? {{");
                    for (i, col) in columns.iter().enumerate() {
                        let _ = writeln!(
                            out,
                            "        let {} = row.get::<{}>({})?;",
                            col.field_name, col.lang_type, i
                        );
                    }
                    let field_assigns: Vec<String> = columns
                        .iter()
                        .map(|c| format!("{}: {}", c.field_name, c.field_name))
                        .collect();
                    let _ = writeln!(
                        out,
                        "        Ok(Some({} {{ {} }}))",
                        struct_name,
                        field_assigns.join(", ")
                    );
                    let _ = writeln!(out, "    }} else {{");
                    let _ = writeln!(out, "        Ok(None)");
                    let _ = writeln!(out, "    }}");
                    let _ = write!(out, "}}");
                }
            }
            QueryCommand::Many => {
                let _ = writeln!(
                    out,
                    "pub async fn {}<'a>(session: &'a Session<'a>{}{}) -> sibyl::Result<Vec<{}>> {{",
                    func_name, sep, param_list, struct_name
                );
                let _ = writeln!(out, "    let stmt = session.prepare(\"{}\").await?;", sql);
                for (i, param) in params.iter().enumerate() {
                    let _ = writeln!(
                        out,
                        "    stmt.bind({}, &{}).await?;",
                        i + 1,
                        param.field_name
                    );
                }
                let _ = writeln!(out, "    let rows = stmt.query(\"\").await?;");
                let _ = writeln!(out, "    let mut results = Vec::new();");
                let _ = writeln!(out, "    while let Some(row) = rows.next().await? {{");
                for (i, col) in columns.iter().enumerate() {
                    let _ = writeln!(
                        out,
                        "        let {} = row.get::<{}>({})?;",
                        col.field_name, col.lang_type, i
                    );
                }
                let field_assigns: Vec<String> = columns
                    .iter()
                    .map(|c| format!("{}: {}", c.field_name, c.field_name))
                    .collect();
                let _ = writeln!(
                    out,
                    "        results.push({} {{ {} }});",
                    struct_name,
                    field_assigns.join(", ")
                );
                let _ = writeln!(out, "    }}");
                let _ = writeln!(out, "    Ok(results)");
                let _ = write!(out, "}}");
            }
            QueryCommand::Exec => {
                let _ = writeln!(
                    out,
                    "pub async fn {}<'a>(session: &'a Session<'a>{}{}) -> sibyl::Result<()> {{",
                    func_name, sep, param_list
                );
                let _ = writeln!(out, "    let stmt = session.prepare(\"{}\").await?;", sql);
                for (i, param) in params.iter().enumerate() {
                    let _ = writeln!(
                        out,
                        "    stmt.bind({}, &{}).await?;",
                        i + 1,
                        param.field_name
                    );
                }
                let _ = writeln!(out, "    stmt.execute(\"\").await?;");
                let _ = writeln!(out, "    Ok(())");
                let _ = write!(out, "}}");
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(
                    out,
                    "pub async fn {}<'a>(session: &'a Session<'a>{}{}) -> sibyl::Result<usize> {{",
                    func_name, sep, param_list
                );
                let _ = writeln!(out, "    let stmt = session.prepare(\"{}\").await?;", sql);
                for (i, param) in params.iter().enumerate() {
                    let _ = writeln!(
                        out,
                        "    stmt.bind({}, &{}).await?;",
                        i + 1,
                        param.field_name
                    );
                }
                let _ = writeln!(out, "    let num_rows = stmt.execute(\"\").await?;");
                let _ = writeln!(out, "    Ok(num_rows)");
                let _ = write!(out, "}}");
            }
            QueryCommand::Batch => {
                let batch_fn_name = format!("{}_batch", func_name);
                let _ = writeln!(
                    out,
                    "pub async fn {}<'a>(session: &'a Session<'a>, items: &[({})]) -> sibyl::Result<()> {{",
                    batch_fn_name,
                    params
                        .iter()
                        .map(|p| p.full_type.clone())
                        .collect::<Vec<_>>()
                        .join(", ")
                );
                let _ = writeln!(out, "    let stmt = session.prepare(\"{}\").await?;", sql);
                let _ = writeln!(out, "    for item in items {{");
                for (i, _param) in params.iter().enumerate() {
                    let _ = writeln!(out, "        stmt.bind({}, &item.{}).await?;", i + 1, i);
                }
                let _ = writeln!(out, "        stmt.execute(\"\").await?;");
                let _ = writeln!(out, "    }}");
                let _ = writeln!(out, "    Ok(())");
                let _ = write!(out, "}}");
            }
            QueryCommand::Grouped => unreachable!("Grouped is rewritten to Many before codegen"),
        }

        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq)]");
        let _ = writeln!(out, "pub enum {} {{", type_name);
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(out, "    {},", variant);
        }
        let _ = writeln!(out, "}}");
        let _ = writeln!(out);
        let _ = writeln!(out, "impl {} {{", type_name);
        let _ = writeln!(out, "    pub fn as_str(&self) -> &'static str {{");
        let _ = writeln!(out, "        match self {{");
        for value in &enum_info.values {
            let variant = enum_variant_name(value, &self.manifest.naming);
            let _ = writeln!(
                out,
                "            {}::{} => \"{}\",",
                type_name, variant, value
            );
        }
        let _ = writeln!(out, "        }}");
        let _ = writeln!(out, "    }}");
        let _ = write!(out, "}}");
        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let name = to_pascal_case(&composite.sql_name);
        let mut out = String::new();
        let _ = writeln!(out, "#[derive(Debug, Clone)]");
        let _ = writeln!(out, "pub struct {} {{", name);
        for field in &composite.fields {
            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
                .map(|t| t.into_owned())
                .map_err(|e| {
                    ScytheError::new(
                        ErrorCode::InternalError,
                        format!("composite field type error: {}", e),
                    )
                })?;
            let _ = writeln!(
                out,
                "    pub {}: {},",
                to_snake_case(&field.name),
                rust_type
            );
        }
        let _ = write!(out, "}}");
        Ok(out)
    }
}