scythe-codegen 0.6.6

Polyglot code generation backends for scythe
Documentation
use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;
use std::collections::HashMap;
use std::fmt::Write;

use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;

use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;

use super::python_common::PythonRowType;

const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiosqlite.toml");

pub struct PythonAiosqliteBackend {
    manifest: BackendManifest,
    row_type: PythonRowType,
}

impl PythonAiosqliteBackend {
    pub fn new(engine: &str) -> Result<Self, ScytheError> {
        match engine {
            "sqlite" | "sqlite3" => {}
            _ => {
                return Err(ScytheError::new(
                    ErrorCode::InternalError,
                    format!(
                        "python-aiosqlite only supports SQLite, got engine '{}'",
                        engine
                    ),
                ));
            }
        }
        let manifest = super::load_or_default_manifest(
            "backends/python-aiosqlite/manifest.toml",
            DEFAULT_MANIFEST_TOML,
        )?;
        Ok(Self {
            manifest,
            row_type: PythonRowType::default(),
        })
    }
}

/// Rewrite $1, $2, ... positional params to ? for SQLite.
impl CodegenBackend for PythonAiosqliteBackend {
    fn name(&self) -> &str {
        "python-aiosqlite"
    }

    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
        &self.manifest
    }

    fn supported_engines(&self) -> &[&str] {
        &["sqlite"]
    }

    fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
        if let Some(rt) = options.get("row_type") {
            self.row_type = PythonRowType::from_option(rt)?;
        }
        Ok(())
    }

    fn file_header(&self) -> String {
        let import_line = self.row_type.import_line();
        if self.row_type.is_stdlib_import() {
            format!(
                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
                 \n\
                 {import_line}\n\
                 from enum import Enum  # noqa: F401\n\
                 \n\
                 import aiosqlite  # noqa: F401\n\
                 \n",
            )
        } else {
            let third_party = self
                .row_type
                .sorted_third_party_imports("import aiosqlite  # noqa: F401");
            format!(
                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
                 \n\
                 from enum import Enum  # noqa: F401\n\
                 \n\
                 {third_party}\n\
                 \n",
            )
        }
    }

    fn generate_row_struct(
        &self,
        query_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let struct_name = row_struct_name(query_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = write!(out, "{}", self.row_type.decorator());
        let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
        if columns.is_empty() {
            let _ = writeln!(out, "    pass");
        } else {
            let _ = writeln!(out);
            for col in columns {
                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
            }
        }
        Ok(out)
    }

    fn generate_model_struct(
        &self,
        table_name: &str,
        columns: &[ResolvedColumn],
    ) -> Result<String, ScytheError> {
        let singular = singularize(table_name);
        let name = to_pascal_case(&singular);
        self.generate_row_struct(&name, columns)
    }

    fn generate_query_fn(
        &self,
        analyzed: &AnalyzedQuery,
        struct_name: &str,
        columns: &[ResolvedColumn],
        params: &[ResolvedParam],
    ) -> Result<String, ScytheError> {
        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
        let mut out = String::new();

        let param_list = params
            .iter()
            .map(|p| format!("{}: {}", p.field_name, p.full_type))
            .collect::<Vec<_>>()
            .join(", ");
        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };

        let sql = super::rewrite_pg_placeholders(
            &super::clean_sql_with_optional(
                &analyzed.sql,
                &analyzed.optional_params,
                &analyzed.params,
            ),
            |_| "?".to_string(),
        );

        let args_list = if params.is_empty() {
            String::new()
        } else {
            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
            if args.len() == 1 {
                format!("({},)", args[0])
            } else {
                format!("({})", args.join(", "))
            }
        };

        match &analyzed.command {
            QueryCommand::One | QueryCommand::Opt => {
                let _ = writeln!(
                    out,
                    "async def {}(conn: aiosqlite.Connection{}{}) -> {} | None:",
                    func_name, kw_sep, param_list, struct_name
                );
                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
                if params.is_empty() {
                    let _ = writeln!(
                        out,
                        "    async with conn.execute(\"\"\"{}\"\"\") as cursor:",
                        sql
                    );
                } else {
                    let _ = writeln!(
                        out,
                        "    async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
                        sql, args_list
                    );
                }
                let _ = writeln!(out, "        row = await cursor.fetchone()");
                let _ = writeln!(out, "    if row is None:");
                let _ = writeln!(out, "        return None");
                let field_assignments: Vec<String> = columns
                    .iter()
                    .enumerate()
                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
                    .collect();
                let oneliner = format!(
                    "    return {}({})",
                    struct_name,
                    field_assignments.join(", ")
                );
                if oneliner.len() <= 88 {
                    let _ = writeln!(out, "{}", oneliner);
                } else {
                    let _ = writeln!(out, "    return {}(", struct_name);
                    for fa in &field_assignments {
                        let _ = writeln!(out, "        {},", fa);
                    }
                    let _ = writeln!(out, "    )");
                }
            }
            QueryCommand::Batch => {
                let batch_fn_name = format!("{}_batch", func_name);
                let items_type = if params.len() > 1 {
                    let tuple_types: Vec<String> =
                        params.iter().map(|p| p.full_type.clone()).collect();
                    format!("list[tuple[{}]]", tuple_types.join(", "))
                } else if params.len() == 1 {
                    format!("list[{}]", params[0].full_type)
                } else {
                    "int".to_string()
                };
                let param_name = if params.is_empty() { "count" } else { "items" };
                let _ = writeln!(
                    out,
                    "async def {}(conn: aiosqlite.Connection, *, {}: {}) -> None:",
                    batch_fn_name, param_name, items_type
                );
                let _ = writeln!(
                    out,
                    "    \"\"\"Execute {} query for each item in the batch.\"\"\"",
                    analyzed.name
                );
                if params.is_empty() {
                    let _ = writeln!(out, "    for _ in range(count):");
                    let _ = writeln!(out, "        await conn.execute(\"\"\"{}\"\"\") ", sql);
                } else if params.len() == 1 {
                    let _ = writeln!(
                        out,
                        "    await conn.executemany(\"\"\"{}\"\"\", [(item,) for item in items])",
                        sql
                    );
                } else {
                    let _ = writeln!(
                        out,
                        "    await conn.executemany(\"\"\"{}\"\"\", items)",
                        sql
                    );
                }
                let _ = writeln!(out, "    await conn.commit()");
            }
            QueryCommand::Many => {
                let _ = writeln!(
                    out,
                    "async def {}(conn: aiosqlite.Connection{}{}) -> list[{}]:",
                    func_name, kw_sep, param_list, struct_name
                );
                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
                if params.is_empty() {
                    let _ = writeln!(
                        out,
                        "    async with conn.execute(\"\"\"{}\"\"\") as cursor:",
                        sql
                    );
                } else {
                    let _ = writeln!(
                        out,
                        "    async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
                        sql, args_list
                    );
                }
                let _ = writeln!(out, "        rows = await cursor.fetchall()");
                let field_assignments: Vec<String> = columns
                    .iter()
                    .enumerate()
                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
                    .collect();
                let oneliner = format!(
                    "    return [{}({}) for r in rows]",
                    struct_name,
                    field_assignments.join(", ")
                );
                if oneliner.len() <= 88 {
                    let _ = writeln!(out, "{}", oneliner);
                } else {
                    let _ = writeln!(out, "    return [");
                    let _ = writeln!(out, "        {}(", struct_name);
                    for fa in &field_assignments {
                        let _ = writeln!(out, "            {},", fa);
                    }
                    let _ = writeln!(out, "        )");
                    let _ = writeln!(out, "        for r in rows");
                    let _ = writeln!(out, "    ]");
                }
            }
            QueryCommand::Exec => {
                let _ = writeln!(
                    out,
                    "async def {}(conn: aiosqlite.Connection{}{}) -> None:",
                    func_name, kw_sep, param_list
                );
                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
                if params.is_empty() {
                    let _ = writeln!(out, "    await conn.execute(\"\"\"{}\"\"\") ", sql);
                } else {
                    let _ = writeln!(
                        out,
                        "    await conn.execute(\"\"\"{}\"\"\", {})",
                        sql, args_list
                    );
                }
            }
            QueryCommand::ExecResult | QueryCommand::ExecRows => {
                let _ = writeln!(
                    out,
                    "async def {}(conn: aiosqlite.Connection{}{}) -> int:",
                    func_name, kw_sep, param_list
                );
                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
                if params.is_empty() {
                    let _ = writeln!(out, "    cursor = await conn.execute(\"\"\"{}\"\"\") ", sql);
                } else {
                    let _ = writeln!(
                        out,
                        "    cursor = await conn.execute(\"\"\"{}\"\"\", {})",
                        sql, args_list
                    );
                }
                let _ = writeln!(out, "    return cursor.rowcount");
            }
            QueryCommand::Grouped => {
                unreachable!("Grouped is rewritten to Many before codegen")
            }
        }

        Ok(out)
    }

    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
        let mut out = String::new();
        let _ = writeln!(out, "class {}(str, Enum):", type_name);
        let _ = writeln!(
            out,
            "    \"\"\"Database enum type {}.\"\"\"",
            enum_info.sql_name
        );
        if enum_info.values.is_empty() {
            let _ = writeln!(out, "    pass");
        } else {
            let _ = writeln!(out);
            for value in &enum_info.values {
                let variant = enum_variant_name(value, &self.manifest.naming);
                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
            }
        }
        Ok(out)
    }

    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
        let name = to_pascal_case(&composite.sql_name);
        let mut out = String::new();
        let _ = write!(out, "{}", self.row_type.decorator());
        let _ = writeln!(out, "{}", self.row_type.class_def(&name));
        let _ = writeln!(
            out,
            "    \"\"\"Composite type {}.\"\"\"",
            composite.sql_name
        );
        if composite.fields.is_empty() {
            let _ = writeln!(out, "    pass");
        } else {
            let _ = writeln!(out);
            for field in &composite.fields {
                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
                    .map(|t| t.into_owned())
                    .map_err(|e| {
                        ScytheError::new(
                            ErrorCode::InternalError,
                            format!("composite field type error: {}", e),
                        )
                    })?;
                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
            }
        }
        Ok(out)
    }
}