use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
enum_type_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
};
use scythe_backend::types::resolve_type;
use std::fmt::Write;
use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;
use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::backends::typescript_common::{TsRowType, generate_zod_row_struct};
use crate::singularize;
fn neutral_to_sql_type(neutral_type: &str) -> &'static str {
match neutral_type {
"int16" => "sql.SmallInt",
"int32" => "sql.Int",
"int64" => "sql.BigInt",
"float32" => "sql.Real",
"float64" => "sql.Float",
"numeric" | "decimal" => "sql.VarChar",
"bool" => "sql.Bit",
"string" => "sql.NVarChar",
"text" => "sql.Text",
"date" => "sql.Date",
"datetime" => "sql.DateTime",
"datetime_tz" => "sql.DateTimeOffset",
"uuid" => "sql.UniqueIdentifier",
"binary" => "sql.Binary",
_ => "sql.VarChar", }
}
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-mssql.toml");
pub struct TypescriptMssqlBackend {
manifest: BackendManifest,
row_type: TsRowType,
}
impl TypescriptMssqlBackend {
pub fn new(engine: &str) -> Result<Self, ScytheError> {
match engine {
"mssql" => {}
_ => {
return Err(ScytheError::new(
ErrorCode::InternalError,
format!(
"typescript-mssql only supports MSSQL, got engine '{}'",
engine
),
));
}
}
let manifest = super::load_or_default_manifest(
"backends/typescript-mssql/manifest.toml",
DEFAULT_MANIFEST_TOML,
)?;
Ok(Self {
manifest,
row_type: TsRowType::default(),
})
}
}
impl CodegenBackend for TypescriptMssqlBackend {
fn name(&self) -> &str {
"typescript-mssql"
}
fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
&self.manifest
}
fn supported_engines(&self) -> &[&str] {
&["mssql"]
}
fn file_header(&self) -> String {
let mut header =
"/** Auto-generated by scythe. Do not edit. */\n\nimport sql from \"mssql\";\n"
.to_string();
if self.row_type == TsRowType::Zod {
header.push_str("import { z } from \"zod\";\n");
}
header
}
fn generate_row_struct(
&self,
query_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let struct_name = row_struct_name(query_name, &self.manifest.naming);
if self.row_type == TsRowType::Zod {
return Ok(generate_zod_row_struct(&struct_name, query_name, columns));
}
let mut out = String::new();
let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
let _ = writeln!(out, "export interface {} {{", struct_name);
for col in columns {
let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_model_struct(
&self,
table_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let singular = singularize(table_name);
let name = to_pascal_case(&singular);
self.generate_row_struct(&name, columns)
}
fn generate_query_fn(
&self,
analyzed: &AnalyzedQuery,
struct_name: &str,
_columns: &[ResolvedColumn],
params: &[ResolvedParam],
) -> Result<String, ScytheError> {
let func_name = fn_name(&analyzed.name, &self.manifest.naming);
let mut out = String::new();
let sql = super::rewrite_pg_placeholders(
&super::clean_sql_with_optional(
&analyzed.sql,
&analyzed.optional_params,
&analyzed.params,
),
|n| format!("@p{n}"),
);
let param_list = params
.iter()
.map(|p| format!("{}: {}", p.field_name, p.full_type))
.collect::<Vec<_>>()
.join(", ");
let inline_params = if params.is_empty() {
"pool: sql.ConnectionPool".to_string()
} else {
format!("pool: sql.ConnectionPool, {}", param_list)
};
let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
let oneliner = format!(
"export async function {}({}): Promise<{}> {{",
name, params_inline, ret
);
if oneliner.len() <= 80 {
let _ = writeln!(out, "{}", oneliner);
} else {
let mut parts = vec!["\tpool: sql.ConnectionPool".to_string()];
for p in params {
parts.push(format!("\t{}: {}", p.field_name, p.full_type));
}
let _ = writeln!(out, "export async function {}(", name);
for part in &parts {
let _ = writeln!(out, "{},", part);
}
let _ = writeln!(out, "): Promise<{}> {{", ret);
}
};
match &analyzed.command {
QueryCommand::One | QueryCommand::Opt => {
let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
let ret = format!("{} | null", struct_name);
write_fn_sig(&mut out, &func_name, &inline_params, &ret);
let _ = writeln!(out, "\tconst request = pool.request();");
for (i, p) in params.iter().enumerate() {
let sql_type = neutral_to_sql_type(&p.neutral_type);
let _ = writeln!(
out,
"\trequest.input(\"p{}\", {}, {});",
i + 1,
sql_type,
p.field_name
);
}
let _ = writeln!(
out,
"\tconst result = await request.query<{}>(`{}`);",
struct_name, sql
);
let _ = writeln!(out, "\treturn result.recordset[0] ?? null;");
let _ = write!(out, "}}");
}
QueryCommand::Batch => {
let batch_fn_name = format!("{}Batch", func_name);
if params.len() > 1 {
let params_type_name = format!("{}BatchParams", struct_name);
let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
let _ = writeln!(out, "export interface {} {{", params_type_name);
for p in params {
let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(
out,
"/** Execute {} for each item in the batch. */",
analyzed.name
);
let batch_params =
format!("pool: sql.ConnectionPool, items: {}[]", params_type_name);
write_fn_sig(&mut out, &batch_fn_name, &batch_params, "void");
let _ = writeln!(out, "\tconst transaction = pool.transaction();");
let _ = writeln!(out, "\tawait transaction.begin();");
let _ = writeln!(out, "\ttry {{");
let _ = writeln!(out, "\t\tfor (const item of items) {{");
let _ = writeln!(out, "\t\t\tconst request = transaction.request();");
for (i, p) in params.iter().enumerate() {
let sql_type = neutral_to_sql_type(&p.neutral_type);
let _ = writeln!(
out,
"\t\t\trequest.input(\"p{}\", {}, item.{});",
i + 1,
sql_type,
p.field_name
);
}
let _ = writeln!(out, "\t\t\tawait request.query(`{}`);", sql);
let _ = writeln!(out, "\t\t}}");
let _ = writeln!(out, "\t\tawait transaction.commit();");
let _ = writeln!(out, "\t}} catch (e) {{");
let _ = writeln!(out, "\t\tawait transaction.rollback();");
let _ = writeln!(out, "\t\tthrow e;");
let _ = writeln!(out, "\t}}");
let _ = write!(out, "}}");
} else if params.len() == 1 {
let _ = writeln!(
out,
"/** Execute {} for each item in the batch. */",
analyzed.name
);
let batch_params =
format!("pool: sql.ConnectionPool, items: {}[]", params[0].full_type);
write_fn_sig(&mut out, &batch_fn_name, &batch_params, "void");
let _ = writeln!(out, "\tconst transaction = pool.transaction();");
let _ = writeln!(out, "\tawait transaction.begin();");
let _ = writeln!(out, "\ttry {{");
let _ = writeln!(out, "\t\tfor (const item of items) {{");
let _ = writeln!(out, "\t\t\tconst request = transaction.request();");
let sql_type = neutral_to_sql_type(¶ms[0].neutral_type);
let _ = writeln!(out, "\t\t\trequest.input(\"p1\", {}, item);", sql_type);
let _ = writeln!(out, "\t\t\tawait request.query(`{}`);", sql);
let _ = writeln!(out, "\t\t}}");
let _ = writeln!(out, "\t\tawait transaction.commit();");
let _ = writeln!(out, "\t}} catch (e) {{");
let _ = writeln!(out, "\t\tawait transaction.rollback();");
let _ = writeln!(out, "\t\tthrow e;");
let _ = writeln!(out, "\t}}");
let _ = write!(out, "}}");
} else {
let _ = writeln!(
out,
"/** Execute {} for each item in the batch. */",
analyzed.name
);
write_fn_sig(
&mut out,
&batch_fn_name,
"pool: sql.ConnectionPool, count: number",
"void",
);
let _ = writeln!(out, "\tconst transaction = pool.transaction();");
let _ = writeln!(out, "\tawait transaction.begin();");
let _ = writeln!(out, "\ttry {{");
let _ = writeln!(out, "\t\tfor (let i = 0; i < count; i++) {{");
let _ = writeln!(out, "\t\t\tconst request = transaction.request();");
let _ = writeln!(out, "\t\t\tawait request.query(`{}`);", sql);
let _ = writeln!(out, "\t\t}}");
let _ = writeln!(out, "\t\tawait transaction.commit();");
let _ = writeln!(out, "\t}} catch (e) {{");
let _ = writeln!(out, "\t\tawait transaction.rollback();");
let _ = writeln!(out, "\t\tthrow e;");
let _ = writeln!(out, "\t}}");
let _ = write!(out, "}}");
}
}
QueryCommand::Many => {
let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
let ret = format!("{}[]", struct_name);
write_fn_sig(&mut out, &func_name, &inline_params, &ret);
let _ = writeln!(out, "\tconst request = pool.request();");
for (i, p) in params.iter().enumerate() {
let sql_type = neutral_to_sql_type(&p.neutral_type);
let _ = writeln!(
out,
"\trequest.input(\"p{}\", {}, {});",
i + 1,
sql_type,
p.field_name
);
}
let _ = writeln!(
out,
"\tconst result = await request.query<{}>(`{}`);",
struct_name, sql
);
let _ = writeln!(out, "\treturn result.recordset;");
let _ = write!(out, "}}");
}
QueryCommand::Exec => {
let _ = writeln!(out, "/** Execute a query returning no rows. */");
write_fn_sig(&mut out, &func_name, &inline_params, "void");
let _ = writeln!(out, "\tconst request = pool.request();");
for (i, p) in params.iter().enumerate() {
let sql_type = neutral_to_sql_type(&p.neutral_type);
let _ = writeln!(
out,
"\trequest.input(\"p{}\", {}, {});",
i + 1,
sql_type,
p.field_name
);
}
let _ = writeln!(out, "\tawait request.query(`{}`);", sql);
let _ = write!(out, "}}");
}
QueryCommand::Grouped => unreachable!("Grouped is rewritten to Many before codegen"),
QueryCommand::ExecResult | QueryCommand::ExecRows => {
let _ = writeln!(
out,
"/** Execute a query and return the number of affected rows. */"
);
write_fn_sig(&mut out, &func_name, &inline_params, "number");
let _ = writeln!(out, "\tconst request = pool.request();");
for (i, p) in params.iter().enumerate() {
let sql_type = neutral_to_sql_type(&p.neutral_type);
let _ = writeln!(
out,
"\trequest.input(\"p{}\", {}, {});",
i + 1,
sql_type,
p.field_name
);
}
let _ = writeln!(out, "\tconst result = await request.query(`{}`);", sql);
let _ = writeln!(out, "\treturn result.rowsAffected[0] ?? 0;");
let _ = write!(out, "}}");
}
}
Ok(out)
}
fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
if self.row_type == TsRowType::Zod {
return Ok(super::typescript_common::generate_zod_enum(
&type_name,
&enum_info.values,
));
}
let mut out = String::new();
let variants: Vec<String> = enum_info
.values
.iter()
.map(|v| format!("\"{}\"", v))
.collect();
let _ = write!(out, "export type {} = {};", type_name, variants.join(" | "));
Ok(out)
}
fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
let name = to_pascal_case(&composite.sql_name);
let mut out = String::new();
let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
let _ = writeln!(out, "export interface {} {{", name);
for field in &composite.fields {
let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
.map(|t| t.into_owned())
.map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("composite field type error: {}", e),
)
})?;
let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
}
let _ = write!(out, "}}");
Ok(out)
}
fn apply_options(
&mut self,
options: &std::collections::HashMap<String, String>,
) -> Result<(), ScytheError> {
if let Some(value) = options.get("row_type") {
self.row_type = TsRowType::from_option(value)?;
}
Ok(())
}
}