use std::fmt::Write;
use scythe_core::parser::QueryCommand;
use crate::backend_trait::{RbsGenerationContext, RbsQueryInfo, ResolvedColumn};
fn neutral_to_rbs(neutral_type: &str, nullable: bool) -> String {
let base = match neutral_type {
"int16" | "int32" | "int64" => "Integer",
"float32" | "float64" => "Float",
"decimal" => "BigDecimal",
"string" => "String",
"bool" => "bool",
"bytes" => "String",
"uuid" => "String",
"date" => "Date",
"time" | "time_tz" | "datetime" | "datetime_tz" => "Time",
"interval" => "String",
"json" => "Hash[String, untyped]",
"inet" => "String",
t if t.starts_with("enum::") => "String",
t if t.starts_with("array::") => {
let inner = &t["array::".len()..];
let inner_rbs = neutral_to_rbs(inner, false);
return if nullable {
format!("Array[{}]?", inner_rbs)
} else {
format!("Array[{}]", inner_rbs)
};
}
_ => "untyped",
};
if nullable {
format!("{}?", base)
} else {
base.to_string()
}
}
fn param_neutral_to_rbs(neutral_type: &str, nullable: bool) -> String {
neutral_to_rbs(neutral_type, nullable)
}
pub fn generate_rbs_content(context: &RbsGenerationContext, connection_type: &str) -> String {
let mut out = String::new();
let _ = writeln!(out, "# Auto-generated by scythe. Do not edit.");
let _ = writeln!(out);
let _ = writeln!(out, "module Queries");
for enum_info in &context.enums {
let _ = writeln!(out, " module {}", enum_info.type_name);
for value in &enum_info.values {
let _ = writeln!(out, " {}: String", value);
}
let _ = writeln!(out, " ALL: Array[String]");
let _ = writeln!(out, " end");
let _ = writeln!(out);
}
for query in &context.queries {
if let Some(ref struct_name) = query.struct_name
&& !query.columns.is_empty()
{
write_rbs_data_class(&mut out, struct_name, &query.columns);
let _ = writeln!(out);
}
write_rbs_method(&mut out, query, connection_type);
let _ = writeln!(out);
}
let _ = write!(out, "end");
out.push('\n');
out
}
fn write_rbs_data_class(out: &mut String, struct_name: &str, columns: &[ResolvedColumn]) {
let _ = writeln!(out, " class {}", struct_name);
for col in columns {
let rbs_type = neutral_to_rbs(&col.neutral_type, col.nullable);
let _ = writeln!(out, " attr_reader {}: {}", col.field_name, rbs_type);
}
let ctor_params: Vec<String> = columns
.iter()
.map(|col| {
let rbs_type = neutral_to_rbs(&col.neutral_type, col.nullable);
format!("{}: {}", col.field_name, rbs_type)
})
.collect();
let _ = writeln!(
out,
" def self.new: ({}) -> {}",
ctor_params.join(", "),
struct_name
);
let _ = writeln!(out, " end");
}
fn write_rbs_method(out: &mut String, query: &RbsQueryInfo, connection_type: &str) {
let param_types: Vec<String> = query
.params
.iter()
.map(|p| param_neutral_to_rbs(&p.neutral_type, p.nullable))
.collect();
let mut all_param_types = vec![connection_type.to_string()];
all_param_types.extend(param_types);
let params_str = all_param_types.join(", ");
let return_type = match query.command {
QueryCommand::One | QueryCommand::Opt => {
if let Some(ref sn) = query.struct_name {
format!("{}?", sn)
} else {
"void".to_string()
}
}
QueryCommand::Many | QueryCommand::Grouped => {
if let Some(ref sn) = query.struct_name {
format!("Array[{}]", sn)
} else {
"Array[untyped]".to_string()
}
}
QueryCommand::Exec => "void".to_string(),
QueryCommand::ExecResult | QueryCommand::ExecRows => "Integer".to_string(),
QueryCommand::Batch => {
let item_type = if query.params.len() > 1 {
let inner: Vec<String> = query
.params
.iter()
.map(|p| param_neutral_to_rbs(&p.neutral_type, p.nullable))
.collect();
format!("Array[[{}]]", inner.join(", "))
} else if query.params.len() == 1 {
let p = &query.params[0];
format!(
"Array[{}]",
param_neutral_to_rbs(&p.neutral_type, p.nullable)
)
} else {
"Array[untyped]".to_string()
};
let _ = writeln!(
out,
" def self.{}_batch: ({}, {}) -> void",
query.func_name, connection_type, item_type
);
return;
}
};
let _ = writeln!(
out,
" def self.{}: ({}) -> {}",
query.func_name, params_str, return_type
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend_trait::{
RbsEnumInfo, RbsGenerationContext, RbsQueryInfo, ResolvedColumn, ResolvedParam,
};
use scythe_core::parser::QueryCommand;
fn col(name: &str, neutral_type: &str, nullable: bool) -> ResolvedColumn {
ResolvedColumn {
name: name.to_string(),
field_name: name.to_string(),
lang_type: String::new(),
full_type: String::new(),
neutral_type: neutral_type.to_string(),
nullable,
}
}
fn param(name: &str, neutral_type: &str, nullable: bool) -> ResolvedParam {
ResolvedParam {
name: name.to_string(),
field_name: name.to_string(),
lang_type: String::new(),
full_type: String::new(),
borrowed_type: String::new(),
neutral_type: neutral_type.to_string(),
nullable,
}
}
#[test]
fn test_neutral_to_rbs_scalars() {
assert_eq!(neutral_to_rbs("int32", false), "Integer");
assert_eq!(neutral_to_rbs("int64", false), "Integer");
assert_eq!(neutral_to_rbs("string", false), "String");
assert_eq!(neutral_to_rbs("bool", false), "bool");
assert_eq!(neutral_to_rbs("float64", false), "Float");
assert_eq!(neutral_to_rbs("decimal", false), "BigDecimal");
assert_eq!(neutral_to_rbs("datetime_tz", false), "Time");
assert_eq!(neutral_to_rbs("date", false), "Date");
assert_eq!(neutral_to_rbs("uuid", false), "String");
assert_eq!(neutral_to_rbs("json", false), "Hash[String, untyped]");
assert_eq!(neutral_to_rbs("bytes", false), "String");
}
#[test]
fn test_neutral_to_rbs_nullable() {
assert_eq!(neutral_to_rbs("string", true), "String?");
assert_eq!(neutral_to_rbs("int32", true), "Integer?");
assert_eq!(neutral_to_rbs("bool", true), "bool?");
}
#[test]
fn test_neutral_to_rbs_array() {
assert_eq!(neutral_to_rbs("array::int32", false), "Array[Integer]");
assert_eq!(neutral_to_rbs("array::string", true), "Array[String]?");
}
#[test]
fn test_neutral_to_rbs_enum() {
assert_eq!(neutral_to_rbs("enum::user_status", false), "String");
}
#[test]
fn test_generate_rbs_one_query() {
let context = RbsGenerationContext {
queries: vec![RbsQueryInfo {
func_name: "get_user".to_string(),
struct_name: Some("GetUserRow".to_string()),
columns: vec![
col("id", "int32", false),
col("name", "string", false),
col("email", "string", true),
],
params: vec![param("id", "int32", false)],
command: QueryCommand::One,
}],
enums: vec![],
};
let rbs = generate_rbs_content(&context, "PG::Connection");
assert!(rbs.contains("module Queries"));
assert!(rbs.contains("class GetUserRow"));
assert!(rbs.contains("attr_reader id: Integer"));
assert!(rbs.contains("attr_reader name: String"));
assert!(rbs.contains("attr_reader email: String?"));
assert!(
rbs.contains("def self.new: (id: Integer, name: String, email: String?) -> GetUserRow")
);
assert!(rbs.contains("def self.get_user: (PG::Connection, Integer) -> GetUserRow?"));
assert!(rbs.contains("end\n"));
}
#[test]
fn test_generate_rbs_many_query() {
let context = RbsGenerationContext {
queries: vec![RbsQueryInfo {
func_name: "list_users".to_string(),
struct_name: Some("ListUsersRow".to_string()),
columns: vec![col("id", "int32", false), col("name", "string", false)],
params: vec![],
command: QueryCommand::Many,
}],
enums: vec![],
};
let rbs = generate_rbs_content(&context, "PG::Connection");
assert!(rbs.contains("def self.list_users: (PG::Connection) -> Array[ListUsersRow]"));
}
#[test]
fn test_generate_rbs_exec_query() {
let context = RbsGenerationContext {
queries: vec![RbsQueryInfo {
func_name: "delete_user".to_string(),
struct_name: None,
columns: vec![],
params: vec![param("id", "int32", false)],
command: QueryCommand::Exec,
}],
enums: vec![],
};
let rbs = generate_rbs_content(&context, "PG::Connection");
assert!(rbs.contains("def self.delete_user: (PG::Connection, Integer) -> void"));
}
#[test]
fn test_generate_rbs_exec_rows_query() {
let context = RbsGenerationContext {
queries: vec![RbsQueryInfo {
func_name: "delete_user".to_string(),
struct_name: None,
columns: vec![],
params: vec![param("id", "int32", false)],
command: QueryCommand::ExecRows,
}],
enums: vec![],
};
let rbs = generate_rbs_content(&context, "PG::Connection");
assert!(rbs.contains("def self.delete_user: (PG::Connection, Integer) -> Integer"));
}
#[test]
fn test_generate_rbs_batch_query() {
let context = RbsGenerationContext {
queries: vec![RbsQueryInfo {
func_name: "insert_user".to_string(),
struct_name: None,
columns: vec![],
params: vec![
param("name", "string", false),
param("email", "string", true),
],
command: QueryCommand::Batch,
}],
enums: vec![],
};
let rbs = generate_rbs_content(&context, "PG::Connection");
assert!(rbs.contains(
"def self.insert_user_batch: (PG::Connection, Array[[String, String?]]) -> void"
));
}
#[test]
fn test_generate_rbs_with_enums() {
let context = RbsGenerationContext {
queries: vec![],
enums: vec![RbsEnumInfo {
type_name: "UserStatus".to_string(),
values: vec!["ACTIVE".to_string(), "INACTIVE".to_string()],
}],
};
let rbs = generate_rbs_content(&context, "PG::Connection");
assert!(rbs.contains("module UserStatus"));
assert!(rbs.contains("ACTIVE: String"));
assert!(rbs.contains("INACTIVE: String"));
assert!(rbs.contains("ALL: Array[String]"));
}
#[test]
fn test_generate_rbs_mysql2_connection_type() {
let context = RbsGenerationContext {
queries: vec![RbsQueryInfo {
func_name: "get_user".to_string(),
struct_name: Some("GetUserRow".to_string()),
columns: vec![col("id", "int32", false)],
params: vec![param("id", "int32", false)],
command: QueryCommand::One,
}],
enums: vec![],
};
let rbs = generate_rbs_content(&context, "Mysql2::Client");
assert!(rbs.contains("def self.get_user: (Mysql2::Client, Integer) -> GetUserRow?"));
}
#[test]
fn test_generate_rbs_sqlite3_connection_type() {
let context = RbsGenerationContext {
queries: vec![RbsQueryInfo {
func_name: "get_user".to_string(),
struct_name: Some("GetUserRow".to_string()),
columns: vec![col("id", "int32", false)],
params: vec![param("id", "int32", false)],
command: QueryCommand::One,
}],
enums: vec![],
};
let rbs = generate_rbs_content(&context, "SQLite3::Database");
assert!(rbs.contains("def self.get_user: (SQLite3::Database, Integer) -> GetUserRow?"));
}
#[test]
fn test_generate_rbs_trilogy_connection_type() {
let context = RbsGenerationContext {
queries: vec![RbsQueryInfo {
func_name: "get_user".to_string(),
struct_name: Some("GetUserRow".to_string()),
columns: vec![col("id", "int32", false)],
params: vec![param("id", "int32", false)],
command: QueryCommand::One,
}],
enums: vec![],
};
let rbs = generate_rbs_content(&context, "Trilogy");
assert!(rbs.contains("def self.get_user: (Trilogy, Integer) -> GetUserRow?"));
}
}