use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;
use std::fmt::Write;
use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;
use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sibyl.toml");
pub struct RustSibylBackend {
manifest: BackendManifest,
}
impl RustSibylBackend {
pub fn new(engine: &str) -> Result<Self, ScytheError> {
match engine {
"oracle" => {}
_ => {
return Err(ScytheError::new(
ErrorCode::InternalError,
format!("rust-sibyl only supports Oracle, got engine '{}'", engine),
));
}
}
let manifest = super::load_or_default_manifest(
"backends/rust-sibyl/manifest.toml",
DEFAULT_MANIFEST_TOML,
)?;
Ok(Self { manifest })
}
}
impl CodegenBackend for RustSibylBackend {
fn name(&self) -> &str {
"rust-sibyl"
}
fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
&self.manifest
}
fn supported_engines(&self) -> &[&str] {
&["oracle"]
}
fn file_header(&self) -> String {
"// Auto-generated by scythe. Do not edit.\n\
use sibyl::prelude::*;\n"
.to_string()
}
fn generate_row_struct(
&self,
query_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let struct_name = row_struct_name(query_name, &self.manifest.naming);
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, Clone)]");
let _ = writeln!(out, "pub struct {} {{", struct_name);
for col in columns {
let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_model_struct(
&self,
table_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let singular = singularize(table_name);
let name = to_pascal_case(&singular);
self.generate_row_struct(&name, columns)
}
fn generate_query_fn(
&self,
analyzed: &AnalyzedQuery,
struct_name: &str,
columns: &[ResolvedColumn],
params: &[ResolvedParam],
) -> Result<String, ScytheError> {
let func_name = fn_name(&analyzed.name, &self.manifest.naming);
let sql = super::rewrite_pg_placeholders(
&super::clean_sql_with_optional(
&analyzed.sql,
&analyzed.optional_params,
&analyzed.params,
),
|n| format!(":{n}"),
);
let param_list = params
.iter()
.map(|p| format!("{}: {}", p.field_name, p.borrowed_type))
.collect::<Vec<_>>()
.join(", ");
let sep = if param_list.is_empty() { "" } else { ", " };
let has_returning = sql.to_uppercase().contains("RETURNING");
let mut out = String::new();
match &analyzed.command {
QueryCommand::One | QueryCommand::Opt => {
let _ = writeln!(
out,
"pub async fn {}<'a>(session: &'a Session<'a>{}{}) -> sibyl::Result<Option<{}>> {{",
func_name, sep, param_list, struct_name
);
if has_returning {
let into_placeholders: Vec<String> = (0..columns.len())
.map(|i| format!(":{}", params.len() + i + 1))
.collect();
let full_sql = format!("{} INTO {}", sql, into_placeholders.join(", "));
let _ = writeln!(
out,
" let stmt = session.prepare(\"{}\").await?;",
full_sql
);
for (i, param) in params.iter().enumerate() {
let _ = writeln!(
out,
" stmt.bind({}, &{}).await?;",
i + 1,
param.field_name
);
}
for (i, col) in columns.iter().enumerate() {
let slot = params.len() + i + 1;
let sibyl_type = match col.neutral_type.as_str() {
"int16" | "int32" | "int64" | "float32" | "float64" | "decimal" => {
"sibyl::Number"
}
"date" | "datetime" | "datetime_tz" => "sibyl::Date",
_ => "sibyl::Varchar",
};
let _ = writeln!(
out,
" let out_{}: {} = stmt.returning_into({}).await?;",
col.field_name, sibyl_type, slot
);
}
let _ = writeln!(out, " stmt.execute(\"\").await?;");
for col in columns {
let extract = match col.neutral_type.as_str() {
"int16" => format!(
" let {} = out_{}.to_int::<i16>()? as {};",
col.field_name, col.field_name, col.lang_type
),
"int32" => format!(
" let {} = out_{}.to_int::<i32>()? as {};",
col.field_name, col.field_name, col.lang_type
),
"int64" => format!(
" let {} = out_{}.to_int::<i64>()? as {};",
col.field_name, col.field_name, col.lang_type
),
"float32" | "float64" | "decimal" => format!(
" let {} = out_{}.to_float::<f64>()? as {};",
col.field_name, col.field_name, col.lang_type
),
"date" | "datetime" | "datetime_tz" => format!(
" let {} = out_{}.timestamp()? as {};",
col.field_name, col.field_name, col.lang_type
),
_ => format!(
" let {} = out_{}.as_str()?.to_string();",
col.field_name, col.field_name
),
};
let _ = writeln!(out, "{}", extract);
}
let field_assigns: Vec<String> = columns
.iter()
.map(|c| format!("{}: {}", c.field_name, c.field_name))
.collect();
let _ = writeln!(
out,
" Ok(Some({} {{ {} }}))",
struct_name,
field_assigns.join(", ")
);
let _ = write!(out, "}}");
} else {
let _ = writeln!(out, " let stmt = session.prepare(\"{}\").await?;", sql);
for (i, param) in params.iter().enumerate() {
let _ = writeln!(
out,
" stmt.bind({}, &{}).await?;",
i + 1,
param.field_name
);
}
let _ = writeln!(out, " let rows = stmt.query(\"\").await?;");
let _ = writeln!(out, " if let Some(row) = rows.next().await? {{");
for (i, col) in columns.iter().enumerate() {
let _ = writeln!(
out,
" let {} = row.get::<{}>({})?;",
col.field_name, col.lang_type, i
);
}
let field_assigns: Vec<String> = columns
.iter()
.map(|c| format!("{}: {}", c.field_name, c.field_name))
.collect();
let _ = writeln!(
out,
" Ok(Some({} {{ {} }}))",
struct_name,
field_assigns.join(", ")
);
let _ = writeln!(out, " }} else {{");
let _ = writeln!(out, " Ok(None)");
let _ = writeln!(out, " }}");
let _ = write!(out, "}}");
}
}
QueryCommand::Many => {
let _ = writeln!(
out,
"pub async fn {}<'a>(session: &'a Session<'a>{}{}) -> sibyl::Result<Vec<{}>> {{",
func_name, sep, param_list, struct_name
);
let _ = writeln!(out, " let stmt = session.prepare(\"{}\").await?;", sql);
for (i, param) in params.iter().enumerate() {
let _ = writeln!(
out,
" stmt.bind({}, &{}).await?;",
i + 1,
param.field_name
);
}
let _ = writeln!(out, " let rows = stmt.query(\"\").await?;");
let _ = writeln!(out, " let mut results = Vec::new();");
let _ = writeln!(out, " while let Some(row) = rows.next().await? {{");
for (i, col) in columns.iter().enumerate() {
let _ = writeln!(
out,
" let {} = row.get::<{}>({})?;",
col.field_name, col.lang_type, i
);
}
let field_assigns: Vec<String> = columns
.iter()
.map(|c| format!("{}: {}", c.field_name, c.field_name))
.collect();
let _ = writeln!(
out,
" results.push({} {{ {} }});",
struct_name,
field_assigns.join(", ")
);
let _ = writeln!(out, " }}");
let _ = writeln!(out, " Ok(results)");
let _ = write!(out, "}}");
}
QueryCommand::Exec => {
let _ = writeln!(
out,
"pub async fn {}<'a>(session: &'a Session<'a>{}{}) -> sibyl::Result<()> {{",
func_name, sep, param_list
);
let _ = writeln!(out, " let stmt = session.prepare(\"{}\").await?;", sql);
for (i, param) in params.iter().enumerate() {
let _ = writeln!(
out,
" stmt.bind({}, &{}).await?;",
i + 1,
param.field_name
);
}
let _ = writeln!(out, " stmt.execute(\"\").await?;");
let _ = writeln!(out, " Ok(())");
let _ = write!(out, "}}");
}
QueryCommand::ExecResult | QueryCommand::ExecRows => {
let _ = writeln!(
out,
"pub async fn {}<'a>(session: &'a Session<'a>{}{}) -> sibyl::Result<usize> {{",
func_name, sep, param_list
);
let _ = writeln!(out, " let stmt = session.prepare(\"{}\").await?;", sql);
for (i, param) in params.iter().enumerate() {
let _ = writeln!(
out,
" stmt.bind({}, &{}).await?;",
i + 1,
param.field_name
);
}
let _ = writeln!(out, " let num_rows = stmt.execute(\"\").await?;");
let _ = writeln!(out, " Ok(num_rows)");
let _ = write!(out, "}}");
}
QueryCommand::Batch => {
let batch_fn_name = format!("{}_batch", func_name);
let _ = writeln!(
out,
"pub async fn {}<'a>(session: &'a Session<'a>, items: &[({})]) -> sibyl::Result<()> {{",
batch_fn_name,
params
.iter()
.map(|p| p.full_type.clone())
.collect::<Vec<_>>()
.join(", ")
);
let _ = writeln!(out, " let stmt = session.prepare(\"{}\").await?;", sql);
let _ = writeln!(out, " for item in items {{");
for (i, _param) in params.iter().enumerate() {
let _ = writeln!(out, " stmt.bind({}, &item.{}).await?;", i + 1, i);
}
let _ = writeln!(out, " stmt.execute(\"\").await?;");
let _ = writeln!(out, " }}");
let _ = writeln!(out, " Ok(())");
let _ = write!(out, "}}");
}
QueryCommand::Grouped => unreachable!("Grouped is rewritten to Many before codegen"),
}
Ok(out)
}
fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq)]");
let _ = writeln!(out, "pub enum {} {{", type_name);
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(out, " {},", variant);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl {} {{", type_name);
let _ = writeln!(out, " pub fn as_str(&self) -> &'static str {{");
let _ = writeln!(out, " match self {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(
out,
" {}::{} => \"{}\",",
type_name, variant, value
);
}
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = write!(out, "}}");
Ok(out)
}
fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
let name = to_pascal_case(&composite.sql_name);
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, Clone)]");
let _ = writeln!(out, "pub struct {} {{", name);
for field in &composite.fields {
let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
.map(|t| t.into_owned())
.map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("composite field type error: {}", e),
)
})?;
let _ = writeln!(
out,
" pub {}: {},",
to_snake_case(&field.name),
rust_type
);
}
let _ = write!(out, "}}");
Ok(out)
}
}