use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;
use std::collections::HashMap;
use std::fmt::Write;
use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;
use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;
use super::python_common::PythonRowType;
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-pyodbc.toml");
pub struct PythonPyodbcBackend {
manifest: BackendManifest,
row_type: PythonRowType,
}
impl PythonPyodbcBackend {
pub fn new(engine: &str) -> Result<Self, ScytheError> {
match engine {
"mssql" => {}
_ => {
return Err(ScytheError::new(
ErrorCode::InternalError,
format!("python-pyodbc only supports MSSQL, got engine '{}'", engine),
));
}
}
let manifest = super::load_or_default_manifest(
"backends/python-pyodbc/manifest.toml",
DEFAULT_MANIFEST_TOML,
)?;
Ok(Self {
manifest,
row_type: PythonRowType::default(),
})
}
}
impl CodegenBackend for PythonPyodbcBackend {
fn name(&self) -> &str {
"python-pyodbc"
}
fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
&self.manifest
}
fn supported_engines(&self) -> &[&str] {
&["mssql"]
}
fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
if let Some(rt) = options.get("row_type") {
self.row_type = PythonRowType::from_option(rt)?;
}
Ok(())
}
fn file_header(&self) -> String {
let import_line = self.row_type.import_line();
if self.row_type.is_stdlib_import() {
format!(
"\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
\n\
import datetime # noqa: F401\n\
import decimal # noqa: F401\n\
{import_line}\n\
from enum import Enum # noqa: F401\n\
\n\
import pyodbc # noqa: F401\n\
\n",
)
} else {
let third_party = self
.row_type
.sorted_third_party_imports("import pyodbc # noqa: F401");
format!(
"\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
\n\
import datetime # noqa: F401\n\
import decimal # noqa: F401\n\
from enum import Enum # noqa: F401\n\
\n\
{third_party}\n\
\n",
)
}
}
fn generate_row_struct(
&self,
query_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let struct_name = row_struct_name(query_name, &self.manifest.naming);
let mut out = String::new();
let _ = write!(out, "{}", self.row_type.decorator());
let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
if columns.is_empty() {
let _ = writeln!(out, " pass");
} else {
let _ = writeln!(out);
for col in columns {
let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
}
}
Ok(out)
}
fn generate_model_struct(
&self,
table_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let singular = singularize(table_name);
let name = to_pascal_case(&singular);
self.generate_row_struct(&name, columns)
}
fn generate_query_fn(
&self,
analyzed: &AnalyzedQuery,
struct_name: &str,
columns: &[ResolvedColumn],
params: &[ResolvedParam],
) -> Result<String, ScytheError> {
let func_name = fn_name(&analyzed.name, &self.manifest.naming);
let mut out = String::new();
let param_list = params
.iter()
.map(|p| format!("{}: {}", p.field_name, p.full_type))
.collect::<Vec<_>>()
.join(", ");
let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
let sql = super::rewrite_pg_placeholders(
&super::clean_sql_with_optional(
&analyzed.sql,
&analyzed.optional_params,
&analyzed.params,
),
|_| "?".to_string(),
);
let args_tuple = if params.is_empty() {
String::new()
} else {
let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
if args.len() == 1 {
format!("({},)", args[0])
} else {
format!("({})", args.join(", "))
}
};
match &analyzed.command {
QueryCommand::One | QueryCommand::Opt => {
let _ = writeln!(
out,
"def {}(conn: pyodbc.Connection{}{}) -> {} | None:",
func_name, kw_sep, param_list, struct_name
);
let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
let _ = writeln!(out, " cursor = conn.cursor()");
if params.is_empty() {
let _ = writeln!(out, " cursor.execute(\"\"\"{}\"\"\")", sql);
} else {
let _ = writeln!(
out,
" cursor.execute(\"\"\"{}\"\"\", {})",
sql, args_tuple
);
}
let _ = writeln!(out, " row = cursor.fetchone()");
let _ = writeln!(out, " if row is None:");
let _ = writeln!(out, " return None");
let field_assignments: Vec<String> = columns
.iter()
.enumerate()
.map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
.collect();
let oneliner = format!(
" return {}({})",
struct_name,
field_assignments.join(", ")
);
if oneliner.len() <= 88 {
let _ = writeln!(out, "{}", oneliner);
} else {
let _ = writeln!(out, " return {}(", struct_name);
for fa in &field_assignments {
let _ = writeln!(out, " {},", fa);
}
let _ = writeln!(out, " )");
}
}
QueryCommand::Batch => {
let batch_fn_name = format!("{}_batch", func_name);
let items_type = if params.len() > 1 {
let tuple_types: Vec<String> =
params.iter().map(|p| p.full_type.clone()).collect();
format!("list[tuple[{}]]", tuple_types.join(", "))
} else if params.len() == 1 {
format!("list[{}]", params[0].full_type)
} else {
"int".to_string()
};
let items_or_count = if params.is_empty() { "count" } else { "items" };
let _ = writeln!(
out,
"def {}(conn: pyodbc.Connection, *, {}: {}) -> None:",
batch_fn_name, items_or_count, items_type
);
let _ = writeln!(
out,
" \"\"\"Execute {} query for each item in the batch.\"\"\"",
analyzed.name
);
let _ = writeln!(out, " cursor = conn.cursor()");
if params.is_empty() {
let _ = writeln!(out, " for _ in range(count):");
let _ = writeln!(out, " cursor.execute(\"\"\"{}\"\"\")", sql);
} else if params.len() == 1 {
let _ = writeln!(out, " for item in items:");
let _ = writeln!(out, " cursor.execute(\"\"\"{}\"\"\", (item,))", sql);
} else {
let _ = writeln!(out, " for item in items:");
let _ = writeln!(out, " cursor.execute(\"\"\"{}\"\"\", item)", sql);
}
let _ = writeln!(out, " conn.commit()");
}
QueryCommand::Many => {
let _ = writeln!(
out,
"def {}(conn: pyodbc.Connection{}{}) -> list[{}]:",
func_name, kw_sep, param_list, struct_name
);
let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
let _ = writeln!(out, " cursor = conn.cursor()");
if params.is_empty() {
let _ = writeln!(out, " cursor.execute(\"\"\"{}\"\"\")", sql);
} else {
let _ = writeln!(
out,
" cursor.execute(\"\"\"{}\"\"\", {})",
sql, args_tuple
);
}
let _ = writeln!(out, " rows = cursor.fetchall()");
let field_assignments: Vec<String> = columns
.iter()
.enumerate()
.map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
.collect();
let oneliner = format!(
" return [{}({}) for r in rows]",
struct_name,
field_assignments.join(", ")
);
if oneliner.len() <= 88 {
let _ = writeln!(out, "{}", oneliner);
} else {
let _ = writeln!(out, " return [");
let _ = writeln!(out, " {}(", struct_name);
for fa in &field_assignments {
let _ = writeln!(out, " {},", fa);
}
let _ = writeln!(out, " )");
let _ = writeln!(out, " for r in rows");
let _ = writeln!(out, " ]");
}
}
QueryCommand::Exec => {
let _ = writeln!(
out,
"def {}(conn: pyodbc.Connection{}{}) -> None:",
func_name, kw_sep, param_list
);
let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
let _ = writeln!(out, " cursor = conn.cursor()");
if params.is_empty() {
let _ = writeln!(out, " cursor.execute(\"\"\"{}\"\"\")", sql);
} else {
let _ = writeln!(
out,
" cursor.execute(\"\"\"{}\"\"\", {})",
sql, args_tuple
);
}
let _ = writeln!(out, " conn.commit()");
}
QueryCommand::Grouped => unreachable!("Grouped is rewritten to Many before codegen"),
QueryCommand::ExecResult | QueryCommand::ExecRows => {
let _ = writeln!(
out,
"def {}(conn: pyodbc.Connection{}{}) -> int:",
func_name, kw_sep, param_list
);
let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
let _ = writeln!(out, " cursor = conn.cursor()");
if params.is_empty() {
let _ = writeln!(out, " cursor.execute(\"\"\"{}\"\"\")", sql);
} else {
let _ = writeln!(
out,
" cursor.execute(\"\"\"{}\"\"\", {})",
sql, args_tuple
);
}
let _ = writeln!(out, " conn.commit()");
let _ = writeln!(out, " return cursor.rowcount");
}
}
Ok(out)
}
fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
let mut out = String::new();
let _ = writeln!(out, "class {}(str, Enum):", type_name);
let _ = writeln!(
out,
" \"\"\"Database enum type {}.\"\"\"",
enum_info.sql_name
);
if enum_info.values.is_empty() {
let _ = writeln!(out, " pass");
} else {
let _ = writeln!(out);
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(out, " {} = \"{}\"", variant, value);
}
}
Ok(out)
}
fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
let name = to_pascal_case(&composite.sql_name);
let mut out = String::new();
let _ = write!(out, "{}", self.row_type.decorator());
let _ = writeln!(out, "{}", self.row_type.class_def(&name));
let _ = writeln!(
out,
" \"\"\"Composite type {}.\"\"\"",
composite.sql_name
);
if composite.fields.is_empty() {
let _ = writeln!(out, " pass");
} else {
let _ = writeln!(out);
for field in &composite.fields {
let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
.map(|t| t.into_owned())
.map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("composite field type error: {}", e),
)
})?;
let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
}
}
Ok(out)
}
}