use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use crate::parse::ParsedQuery;
use crate::validate::ValidationResult;
fn effective_sql<'a>(parsed: &'a ParsedQuery, validation: &'a ValidationResult) -> &'a str {
validation
.rewritten_sql
.as_deref()
.unwrap_or(&parsed.positional_sql)
}
pub fn generate_query_code(parsed: &ParsedQuery, validation: &ValidationResult) -> TokenStream {
if parsed.optional_clauses.is_empty() {
let result_struct = gen_result_struct(parsed, validation);
let executor_struct = gen_executor_struct(parsed);
let executor_impls = gen_executor_impls(parsed, validation);
let constructor = gen_constructor(parsed);
return quote! {
{
#result_struct
#executor_struct
#executor_impls
#constructor
}
};
}
let msg = "internal error: generate_query_code called for dynamic query — use generate_dynamic_query_code";
quote! { compile_error!(#msg) }
}
pub fn generate_dynamic_query_code(
parsed: &ParsedQuery,
validation: &ValidationResult,
) -> TokenStream {
let result_struct = gen_result_struct(parsed, validation);
let for_each_row_struct = gen_pg_for_each_row_struct(parsed, validation);
let executor_struct = gen_dynamic_executor_struct(parsed);
let executor_impls = gen_dynamic_executor_impls(parsed, validation);
let constructor = gen_dynamic_constructor(parsed);
quote! {
{
#result_struct
#for_each_row_struct
#executor_struct
#executor_impls
#constructor
}
}
}
pub fn generate_sort_query_code(
parsed: &ParsedQuery,
validation: &ValidationResult,
sort_enum_name: &str,
) -> TokenStream {
let result_struct = gen_result_struct(parsed, validation);
let sort_enum_ident = format_ident!("{}", sort_enum_name);
let executor_name = executor_struct_name(parsed);
let param_fields: Vec<TokenStream> = parsed
.params
.iter()
.map(|p| {
let name = param_ident(&p.name);
let ty = inject_lifetime(&p.rust_type);
quote! { #name: #ty }
})
.collect();
let executor_struct = quote! {
#[must_use = "query is not executed until .fetch_all(), .execute(), or another execution method is called"]
#[allow(non_camel_case_types)]
struct #executor_name<'_bsql> {
#(#param_fields,)*
sort: #sort_enum_ident,
_marker: ::std::marker::PhantomData<&'_bsql ()>,
}
};
let param_refs: Vec<TokenStream> = parsed
.params
.iter()
.map(|p| {
let name = param_ident(&p.name);
quote! { &self.#name as &(dyn ::bsql_core::driver::Encode + Sync) }
})
.collect();
let params_slice = if param_refs.is_empty() {
quote! { &[] }
} else {
quote! { &[#(#param_refs),*] }
};
let is_select = parsed.kind == crate::parse::QueryKind::Select;
let query_method = if is_select {
quote! { query_raw_readonly }
} else {
quote! { query_raw }
};
let eff_sql_sort = effective_sql(parsed, validation);
let sql_template = eff_sql_sort;
let has_columns = !validation.columns.is_empty();
let sort_parts: Vec<&str> = sql_template.split("{SORT}").collect();
let sql_prefix = sort_parts[0];
let sql_suffix = if sort_parts.len() > 1 {
sort_parts[1]
} else {
""
};
let needs_limit = has_columns
&& is_select
&& !parsed.normalized_sql.contains(" limit ")
&& !parsed.normalized_sql.contains(" for ");
let limited_suffix = if needs_limit {
format!("{sql_suffix} LIMIT 2")
} else {
sql_suffix.to_owned()
};
let limited_suffix_lit = &limited_suffix;
let build_sql = quote! {
static SORT_SQL_CACHE: ::std::sync::OnceLock<::std::sync::Mutex<Vec<(usize, ::std::sync::Arc<str>, u64)>>> = ::std::sync::OnceLock::new();
let sort_fragment: &'static str = self.sort.sql();
let cache = SORT_SQL_CACHE.get_or_init(|| ::std::sync::Mutex::new(Vec::new()));
let key = sort_fragment.as_ptr() as usize;
let (sql_arc, sql_hash) = {
let guard = cache.lock().unwrap_or_else(|e| e.into_inner());
if let Some(entry) = guard.iter().find(|e| e.0 == key) {
(entry.1.clone(), entry.2)
} else {
drop(guard);
let built = format!("{}{}{}", #sql_prefix, sort_fragment, #sql_suffix);
let hash = ::bsql_core::driver::hash_sql(&built);
let arc: ::std::sync::Arc<str> = ::std::sync::Arc::from(built);
let mut guard = cache.lock().unwrap_or_else(|e| e.into_inner());
if let Some(entry) = guard.iter().find(|e| e.0 == key) {
(entry.1.clone(), entry.2)
} else {
guard.push((key, arc.clone(), hash));
(arc, hash)
}
}
};
let sql: &str = &sql_arc;
};
let build_limited_sql = if needs_limit {
quote! {
static SORT_LIMITED_SQL_CACHE: ::std::sync::OnceLock<::std::sync::Mutex<Vec<(usize, ::std::sync::Arc<str>, u64)>>> = ::std::sync::OnceLock::new();
let sort_fragment: &'static str = self.sort.sql();
let cache = SORT_LIMITED_SQL_CACHE.get_or_init(|| ::std::sync::Mutex::new(Vec::new()));
let key = sort_fragment.as_ptr() as usize;
let (sql_arc, sql_hash) = {
let guard = cache.lock().unwrap_or_else(|e| e.into_inner());
if let Some(entry) = guard.iter().find(|e| e.0 == key) {
(entry.1.clone(), entry.2)
} else {
drop(guard);
let built = format!("{}{}{}", #sql_prefix, sort_fragment, #limited_suffix_lit);
let hash = ::bsql_core::driver::hash_sql(&built);
let arc: ::std::sync::Arc<str> = ::std::sync::Arc::from(built);
let mut guard = cache.lock().unwrap_or_else(|e| e.into_inner());
if let Some(entry) = guard.iter().find(|e| e.0 == key) {
(entry.1.clone(), entry.2)
} else {
guard.push((key, arc.clone(), hash));
(arc, hash)
}
}
};
let sql: &str = &sql_arc;
}
} else {
build_sql.clone()
};
let fetch_methods = if has_columns {
let result_name = result_struct_name(parsed);
let stream_name = stream_struct_name(parsed);
let row_decode = gen_row_decode(validation);
let column_check = gen_column_count_check(validation);
let qm = &query_method;
quote! {
#[allow(non_camel_case_types)]
pub struct #stream_name {
inner: ::bsql_core::QueryStream,
}
#[allow(non_camel_case_types)]
impl #stream_name {
::bsql_core::__bsql_fn! {
pub fn next(&mut self) -> ::bsql_core::BsqlResult<Option<#result_name>> {
if let Some(row) = self.inner.next_row() {
#column_check
return Ok(Some(#result_name { #row_decode }));
}
if !::bsql_core::__bsql_call!(self.inner.fetch_next_chunk())? {
return Ok(None);
}
match self.inner.next_row() {
Some(row) => {
#column_check
Ok(Some(#result_name { #row_decode }))
},
None => Ok(None),
}
}
}
pub fn remaining(&self) -> usize {
self.inner.remaining()
}
}
#[allow(non_camel_case_types)]
impl<'_bsql> #executor_name<'_bsql> {
::bsql_core::__bsql_fn! {
pub fn fetch_one(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<#result_name> {
#build_limited_sql
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(sql, sql_hash, #params_slice))?;
if owned.len() != 1 {
return Err(::bsql_core::error::QueryError::row_count(
"exactly 1 row",
owned.len() as u64,
));
}
let row = owned.row(0);
#column_check
Ok(#result_name { #row_decode })
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_optional(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<Option<#result_name>> {
#build_limited_sql
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(sql, sql_hash, #params_slice))?;
match owned.len() {
0 => Ok(None),
1 => {
let row = owned.row(0);
#column_check
Ok(Some(#result_name { #row_decode }))
}
n => Err(::bsql_core::error::QueryError::row_count(
"0 or 1 rows",
n as u64,
)),
}
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_all(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<Vec<#result_name>> {
#build_sql
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(sql, sql_hash, #params_slice))?;
owned.iter().map(|row| {
#column_check
Ok(#result_name { #row_decode })
}).collect::<::bsql_core::BsqlResult<Vec<_>>>()
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_stream(
self,
pool: &::bsql_core::Pool,
) -> ::bsql_core::BsqlResult<#stream_name> {
#build_sql
let inner = ::bsql_core::__bsql_call!(pool.query_stream(sql, sql_hash, #params_slice))?;
Ok(#stream_name { inner })
}
}
::bsql_core::__bsql_fn! {
pub fn execute(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<u64> {
#build_sql
let executor = executor.into(); ::bsql_core::__bsql_call!(executor.execute_raw(sql, sql_hash, #params_slice))
}
}
::bsql_core::__bsql_fn! {
pub fn defer(self, tx: &mut ::bsql_core::Transaction) -> ::bsql_core::BsqlResult<()> {
#build_sql
::bsql_core::__bsql_call!(tx.defer_execute(sql, sql_hash, #params_slice))
}
}
}
}
} else {
quote! {
#[allow(non_camel_case_types)]
impl<'_bsql> #executor_name<'_bsql> {
::bsql_core::__bsql_fn! {
pub fn execute(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<u64> {
#build_sql
let executor = executor.into(); ::bsql_core::__bsql_call!(executor.execute_raw(sql, sql_hash, #params_slice))
}
}
::bsql_core::__bsql_fn! {
pub fn defer(self, tx: &mut ::bsql_core::Transaction) -> ::bsql_core::BsqlResult<()> {
#build_sql
::bsql_core::__bsql_call!(tx.defer_execute(sql, sql_hash, #params_slice))
}
}
}
}
};
let for_each_row_struct = gen_pg_for_each_row_struct(parsed, validation);
let field_inits: Vec<proc_macro2::Ident> =
parsed.params.iter().map(|p| param_ident(&p.name)).collect();
let constructor = quote! {
#executor_name {
#(#field_inits,)*
sort,
_marker: ::std::marker::PhantomData,
}
};
quote! {
{
#result_struct
#for_each_row_struct
#executor_struct
#fetch_methods
#constructor
}
}
}
pub fn generate_query_as_code(
parsed: &ParsedQuery,
validation: &ValidationResult,
target_type: &syn::Path,
) -> TokenStream {
let executor_struct = gen_executor_struct(parsed);
let executor_impls = gen_query_as_executor_impls(parsed, validation, target_type);
let constructor = gen_constructor(parsed);
quote! {
{
#executor_struct
#executor_impls
#constructor
}
}
}
fn gen_query_as_executor_impls(
parsed: &ParsedQuery,
validation: &ValidationResult,
target_type: &syn::Path,
) -> TokenStream {
let executor_name = executor_struct_name(parsed);
let eff_sql = effective_sql(parsed, validation);
let sql_lit = eff_sql;
let is_select = parsed.kind == crate::parse::QueryKind::Select;
let query_method = if is_select {
quote! { query_raw_readonly }
} else {
quote! { query_raw }
};
let param_refs: Vec<TokenStream> = parsed
.params
.iter()
.map(|p| {
let name = param_ident(&p.name);
quote! { &self.#name as &(dyn ::bsql_core::driver::Encode + Sync) }
})
.collect();
let params_slice = if param_refs.is_empty() {
quote! { &[] }
} else {
quote! { &[#(#param_refs),*] }
};
let sql_hash_val = bsql_core::rapid_hash_str(eff_sql);
let has_columns = !validation.columns.is_empty();
let needs_limit = has_columns
&& is_select
&& !parsed.normalized_sql.contains(" limit ")
&& !parsed.normalized_sql.contains(" for ");
let limited_sql = if needs_limit {
format!("{} LIMIT 2", eff_sql)
} else {
eff_sql.to_owned()
};
let limited_sql_lit = &limited_sql;
let limited_sql_hash_val = bsql_core::rapid_hash_str(&limited_sql);
let row_decode = if has_columns {
gen_row_decode(validation)
} else {
TokenStream::new()
};
let column_check = gen_column_count_check(validation);
let fetch_methods = if has_columns {
let qm = &query_method;
quote! {
::bsql_core::__bsql_fn! {
pub fn fetch_one(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<#target_type> {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(#limited_sql_lit, #limited_sql_hash_val, #params_slice))?;
if owned.len() != 1 {
return Err(::bsql_core::error::QueryError::row_count(
"exactly 1 row",
owned.len() as u64,
));
}
let row = owned.row(0);
#column_check
Ok(#target_type { #row_decode })
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_all(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<Vec<#target_type>> {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(#sql_lit, #sql_hash_val, #params_slice))?;
owned.iter().map(|row| {
#column_check
Ok(#target_type { #row_decode })
}).collect::<::bsql_core::BsqlResult<Vec<_>>>()
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_optional(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<Option<#target_type>> {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(#limited_sql_lit, #limited_sql_hash_val, #params_slice))?;
match owned.len() {
0 => Ok(None),
1 => {
let row = owned.row(0);
#column_check
Ok(Some(#target_type { #row_decode }))
}
n => Err(::bsql_core::error::QueryError::row_count(
"0 or 1 rows",
n as u64,
)),
}
}
}
}
} else {
TokenStream::new()
};
let execute_method = quote! {
::bsql_core::__bsql_fn! {
pub fn execute(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<u64> {
let executor = executor.into(); ::bsql_core::__bsql_call!(executor.execute_raw(#sql_lit, #sql_hash_val, #params_slice))
}
}
};
let defer_method = quote! {
::bsql_core::__bsql_fn! {
pub fn defer(self, tx: &mut ::bsql_core::Transaction) -> ::bsql_core::BsqlResult<()> {
::bsql_core::__bsql_call!(tx.defer_execute(#sql_lit, #sql_hash_val, #params_slice))
}
}
};
quote! {
#[allow(non_camel_case_types)]
impl<'_bsql> #executor_name<'_bsql> {
#fetch_methods
#execute_method
#defer_method
}
}
}
fn gen_result_struct(parsed: &ParsedQuery, validation: &ValidationResult) -> TokenStream {
if validation.columns.is_empty() {
return TokenStream::new(); }
let struct_name = result_struct_name(parsed);
let deduped_names = deduplicate_column_names(&validation.columns);
let fields = validation.columns.iter().enumerate().map(|(i, col)| {
let field_name = format_ident!("{}", deduped_names[i]);
let field_type = parse_result_type(&col.rust_type);
quote! { pub #field_name: #field_type }
});
#[cfg(feature = "explain")]
let explain_doc = if let Some(ref plan) = validation.explain_plan {
let doc_lines: Vec<TokenStream> = std::iter::once(quote! { #[doc = ""] })
.chain(std::iter::once(quote! { #[doc = "**Query plan:**"] }))
.chain(std::iter::once(quote! { #[doc = "```text"] }))
.chain(plan.lines().map(|line| {
let line_str = line.to_string();
quote! { #[doc = #line_str] }
}))
.chain(std::iter::once(quote! { #[doc = "```"] }))
.collect();
quote! { #(#doc_lines)* }
} else {
TokenStream::new()
};
#[cfg(not(feature = "explain"))]
let explain_doc = TokenStream::new();
quote! {
#explain_doc
#[derive(Debug)]
#[allow(non_camel_case_types)]
pub struct #struct_name {
#(#fields,)*
}
}
}
fn gen_executor_struct(parsed: &ParsedQuery) -> TokenStream {
let struct_name = executor_struct_name(parsed);
let fields: Vec<TokenStream> = parsed
.params
.iter()
.map(|p| {
let name = param_ident(&p.name);
let ty = inject_lifetime(&p.rust_type);
quote! { #name: #ty }
})
.collect();
quote! {
#[must_use = "query is not executed until .fetch_all(), .execute(), or another execution method is called"]
#[allow(non_camel_case_types)]
struct #struct_name<'_bsql> {
#(#fields,)*
_marker: ::std::marker::PhantomData<&'_bsql ()>,
}
}
}
fn gen_executor_impls(parsed: &ParsedQuery, validation: &ValidationResult) -> TokenStream {
let executor_name = executor_struct_name(parsed);
let eff_sql = effective_sql(parsed, validation);
let sql_lit = eff_sql;
let is_select = parsed.kind == crate::parse::QueryKind::Select;
let query_method = if is_select {
quote! { query_raw_readonly }
} else {
quote! { query_raw }
};
let param_refs: Vec<TokenStream> = parsed
.params
.iter()
.map(|p| {
let name = param_ident(&p.name);
quote! { &self.#name as &(dyn ::bsql_core::driver::Encode + Sync) }
})
.collect();
let params_slice = if param_refs.is_empty() {
quote! { &[] }
} else {
quote! { &[#(#param_refs),*] }
};
let sql_hash_val = bsql_core::rapid_hash_str(eff_sql);
let has_columns = !validation.columns.is_empty();
let needs_limit = has_columns
&& is_select
&& !parsed.normalized_sql.contains(" limit ")
&& !parsed.normalized_sql.contains(" for ");
let limited_sql = if needs_limit {
format!("{} LIMIT 2", eff_sql)
} else {
eff_sql.to_owned()
};
let limited_sql_lit = &limited_sql;
let limited_sql_hash_val = bsql_core::rapid_hash_str(&limited_sql);
let row_decode = if has_columns {
gen_row_decode(validation)
} else {
TokenStream::new()
};
let column_check = gen_column_count_check(validation);
let fetch_methods = if has_columns {
let result_name = result_struct_name(parsed);
let stream_name = stream_struct_name(parsed);
quote! {
::bsql_core::__bsql_fn! {
pub fn fetch_one(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<#result_name> {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#query_method(#limited_sql_lit, #limited_sql_hash_val, #params_slice))?;
if owned.len() != 1 {
return Err(::bsql_core::error::QueryError::row_count(
"exactly 1 row",
owned.len() as u64,
));
}
let row = owned.row(0);
#column_check
Ok(#result_name { #row_decode })
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_optional(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<Option<#result_name>> {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#query_method(#limited_sql_lit, #limited_sql_hash_val, #params_slice))?;
match owned.len() {
0 => Ok(None),
1 => {
let row = owned.row(0);
#column_check
Ok(Some(#result_name { #row_decode }))
}
n => Err(::bsql_core::error::QueryError::row_count(
"0 or 1 rows",
n as u64,
)),
}
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_all(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<Vec<#result_name>> {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#query_method(#sql_lit, #sql_hash_val, #params_slice))?;
owned.iter().map(|row| {
#column_check
Ok(#result_name { #row_decode })
}).collect::<::bsql_core::BsqlResult<Vec<_>>>()
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_stream(
self,
pool: &::bsql_core::Pool,
) -> ::bsql_core::BsqlResult<#stream_name> {
let inner = ::bsql_core::__bsql_call!(pool.query_stream(#sql_lit, #sql_hash_val, #params_slice))?;
Ok(#stream_name { inner })
}
}
}
} else {
TokenStream::new()
};
let stream_struct = if has_columns {
let result_name = result_struct_name(parsed);
let stream_name = stream_struct_name(parsed);
gen_stream_struct(&result_name, &stream_name, &row_decode, &column_check)
} else {
TokenStream::new()
};
let execute_method = quote! {
::bsql_core::__bsql_fn! {
pub fn execute(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<u64> {
let executor = executor.into(); ::bsql_core::__bsql_call!(executor.execute_raw(#sql_lit, #sql_hash_val, #params_slice))
}
}
};
let defer_method = quote! {
::bsql_core::__bsql_fn! {
pub fn defer(self, tx: &mut ::bsql_core::Transaction) -> ::bsql_core::BsqlResult<()> {
::bsql_core::__bsql_call!(tx.defer_execute(#sql_lit, #sql_hash_val, #params_slice))
}
}
};
let for_each_row_struct = if has_columns {
gen_pg_for_each_row_struct(parsed, validation)
} else {
TokenStream::new()
};
let for_each_methods = if has_columns && is_select {
let fe_row_name = pg_for_each_row_struct_name(parsed);
let (fe_raw_stmts, fe_raw_inits) = gen_pg_for_each_raw_decode(validation);
let (fe_raw_stmts2, fe_raw_inits2) = gen_pg_for_each_raw_decode(validation);
let raw_column_check = gen_raw_column_count_check(validation);
quote! {
::bsql_core::__bsql_fn! {
pub fn for_each<_BsqlForEachF>(
self,
pool: &::bsql_core::Pool,
mut f: _BsqlForEachF,
) -> ::bsql_core::BsqlResult<()>
where
_BsqlForEachF: FnMut(#fe_row_name<'_>) -> Result<(), ::bsql_core::BsqlError>,
{
::bsql_core::__bsql_call!(pool.__for_each_raw_bytes(
#sql_lit,
#sql_hash_val,
#params_slice,
true,
|_bsql_data: &[u8]| -> ::bsql_core::BsqlResult<()> {
#raw_column_check
#fe_raw_stmts
let _bsql_typed = #fe_row_name { #fe_raw_inits };
f(_bsql_typed)
},
))
}
}
::bsql_core::__bsql_fn! {
pub fn for_each_map<_BsqlForEachF, _BsqlForEachT>(
self,
pool: &::bsql_core::Pool,
mut f: _BsqlForEachF,
) -> ::bsql_core::BsqlResult<Vec<_BsqlForEachT>>
where
_BsqlForEachF: FnMut(#fe_row_name<'_>) -> _BsqlForEachT,
{
let mut _bsql_results: Vec<_BsqlForEachT> = Vec::new();
::bsql_core::__bsql_call!(pool.__for_each_raw_bytes(
#sql_lit,
#sql_hash_val,
#params_slice,
true,
|_bsql_data: &[u8]| -> ::bsql_core::BsqlResult<()> {
#raw_column_check
#fe_raw_stmts2
let _bsql_typed = #fe_row_name { #fe_raw_inits2 };
_bsql_results.push(f(_bsql_typed));
Ok(())
},
))?;
Ok(_bsql_results)
}
}
}
} else {
TokenStream::new()
};
quote! {
#stream_struct
#for_each_row_struct
#[allow(non_camel_case_types)]
impl<'_bsql> #executor_name<'_bsql> {
#fetch_methods
#for_each_methods
#execute_method
#defer_method
}
}
}
fn gen_dynamic_executor_struct(parsed: &ParsedQuery) -> TokenStream {
let struct_name = executor_struct_name(parsed);
let mut fields: Vec<TokenStream> = Vec::new();
let mut seen_names: std::collections::HashSet<String> = std::collections::HashSet::new();
for p in &parsed.params {
let name = param_ident(&p.name);
let ty = inject_lifetime(&p.rust_type);
fields.push(quote! { #name: #ty });
seen_names.insert(p.name.clone());
}
for clause in &parsed.optional_clauses {
for p in &clause.params {
if seen_names.insert(p.name.clone()) {
let name = param_ident(&p.name);
let ty = inject_lifetime(&p.rust_type);
fields.push(quote! { #name: #ty });
}
}
}
quote! {
#[must_use = "query is not executed until .fetch_all(), .execute(), or another execution method is called"]
#[allow(non_camel_case_types)]
struct #struct_name<'_bsql> {
#(#fields,)*
_marker: ::std::marker::PhantomData<&'_bsql ()>,
}
}
}
fn gen_dynamic_executor_impls(parsed: &ParsedQuery, validation: &ValidationResult) -> TokenStream {
let executor_name = executor_struct_name(parsed);
let has_columns = !validation.columns.is_empty();
let is_select = parsed.kind == crate::parse::QueryKind::Select;
let query_method = if is_select {
quote! { query_raw_readonly }
} else {
quote! { query_raw }
};
let row_decode = if has_columns {
gen_row_decode(validation)
} else {
TokenStream::new()
};
let column_check = gen_column_count_check(validation);
let fetch_methods = if has_columns {
let result_name = result_struct_name(parsed);
let stream_name = stream_struct_name(parsed);
let needs_limit = has_columns
&& is_select
&& !parsed.normalized_sql.contains(" limit ")
&& !parsed.normalized_sql.contains(" for ");
let qm = &query_method;
let owned_fetch_one_optional = {
let fetch_one_dispatcher = gen_runtime_dispatcher(parsed, needs_limit, |_| {
quote! {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(&_bsql_sql, _bsql_hash, &_bsql_params[..]))?;
if owned.len() != 1 {
return Err(::bsql_core::error::QueryError::row_count(
"exactly 1 row",
owned.len() as u64,
));
}
let row = owned.row(0);
#column_check
Ok(#result_name { #row_decode })
}
});
let fetch_optional_dispatcher = gen_runtime_dispatcher(parsed, needs_limit, |_| {
quote! {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(&_bsql_sql, _bsql_hash, &_bsql_params[..]))?;
match owned.len() {
0 => Ok(None),
1 => {
let row = owned.row(0);
#column_check
Ok(Some(#result_name { #row_decode }))
}
n => Err(::bsql_core::error::QueryError::row_count(
"0 or 1 rows",
n as u64,
)),
}
}
});
quote! {
::bsql_core::__bsql_fn! {
pub fn fetch_one(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<#result_name> {
#fetch_one_dispatcher
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_optional(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<Option<#result_name>> {
#fetch_optional_dispatcher
}
}
}
};
let fetch_dispatcher = gen_runtime_dispatcher(parsed, false, |_| {
quote! {
let executor = executor.into(); let owned = ::bsql_core::__bsql_call!(executor.#qm(&_bsql_sql, _bsql_hash, &_bsql_params[..]))?;
owned.iter().map(|row| {
#column_check
Ok(#result_name { #row_decode })
}).collect::<::bsql_core::BsqlResult<Vec<_>>>()
}
});
let fetch_stream_dispatcher = gen_runtime_dispatcher(parsed, false, |_| {
quote! {
let inner = ::bsql_core::__bsql_call!(pool.query_stream(&_bsql_sql, _bsql_hash, &_bsql_params[..]))?;
Ok(#stream_name { inner })
}
});
quote! {
#owned_fetch_one_optional
::bsql_core::__bsql_fn! {
pub fn fetch_all(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<Vec<#result_name>> {
#fetch_dispatcher
}
}
::bsql_core::__bsql_fn! {
pub fn fetch_stream(
self,
pool: &::bsql_core::Pool,
) -> ::bsql_core::BsqlResult<#stream_name> {
#fetch_stream_dispatcher
}
}
}
} else {
TokenStream::new()
};
let stream_struct = if has_columns {
let result_name = result_struct_name(parsed);
let stream_name = stream_struct_name(parsed);
gen_stream_struct(&result_name, &stream_name, &row_decode, &column_check)
} else {
TokenStream::new()
};
let execute_dispatcher = gen_runtime_dispatcher(parsed, false, |_| {
quote! {
let executor = executor.into(); ::bsql_core::__bsql_call!(executor.execute_raw(&_bsql_sql, _bsql_hash, &_bsql_params[..]))
}
});
let execute_method = quote! {
::bsql_core::__bsql_fn! {
pub fn execute(
self,
executor: impl Into<::bsql_core::QueryTarget<'_>>,
) -> ::bsql_core::BsqlResult<u64> {
#execute_dispatcher
}
}
};
let defer_dispatcher = gen_runtime_dispatcher(parsed, false, |_| {
quote! {
::bsql_core::__bsql_call!(tx.defer_execute(&_bsql_sql, _bsql_hash, &_bsql_params[..]))
}
});
let defer_method = quote! {
::bsql_core::__bsql_fn! {
pub fn defer(self, tx: &mut ::bsql_core::Transaction) -> ::bsql_core::BsqlResult<()> {
#defer_dispatcher
}
}
};
let for_each_methods = if has_columns && is_select {
let fe_row_name = pg_for_each_row_struct_name(parsed);
let (fe_raw_stmts, fe_raw_inits) = gen_pg_for_each_raw_decode(validation);
let (fe_raw_stmts2, fe_raw_inits2) = gen_pg_for_each_raw_decode(validation);
let raw_column_check = gen_raw_column_count_check(validation);
let for_each_dispatcher = gen_runtime_dispatcher(parsed, false, |_| {
quote! {
::bsql_core::__bsql_call!(pool.__for_each_raw_bytes(
&_bsql_sql,
_bsql_hash,
&_bsql_params[..],
true,
|_bsql_data: &[u8]| -> ::bsql_core::BsqlResult<()> {
#raw_column_check
#fe_raw_stmts
let _bsql_typed = #fe_row_name { #fe_raw_inits };
f(_bsql_typed)
},
))
}
});
let for_each_map_dispatcher = gen_runtime_dispatcher(parsed, false, |_| {
quote! {
::bsql_core::__bsql_call!(pool.__for_each_raw_bytes(
&_bsql_sql,
_bsql_hash,
&_bsql_params[..],
true,
|_bsql_data: &[u8]| -> ::bsql_core::BsqlResult<()> {
#raw_column_check
#fe_raw_stmts2
let _bsql_typed = #fe_row_name { #fe_raw_inits2 };
_bsql_results.push(f(_bsql_typed));
Ok(())
},
))
}
});
quote! {
::bsql_core::__bsql_fn! {
pub fn for_each<_BsqlForEachF>(
self,
pool: &::bsql_core::Pool,
mut f: _BsqlForEachF,
) -> ::bsql_core::BsqlResult<()>
where
_BsqlForEachF: FnMut(#fe_row_name<'_>) -> Result<(), ::bsql_core::BsqlError>,
{
#for_each_dispatcher
}
}
::bsql_core::__bsql_fn! {
pub fn for_each_map<_BsqlForEachF, _BsqlForEachT>(
self,
pool: &::bsql_core::Pool,
mut f: _BsqlForEachF,
) -> ::bsql_core::BsqlResult<Vec<_BsqlForEachT>>
where
_BsqlForEachF: FnMut(#fe_row_name<'_>) -> _BsqlForEachT,
{
let mut _bsql_results: Vec<_BsqlForEachT> = Vec::new();
#for_each_map_dispatcher?;
Ok(_bsql_results)
}
}
}
} else {
TokenStream::new()
};
quote! {
#stream_struct
#[allow(non_camel_case_types)]
impl<'_bsql> #executor_name<'_bsql> {
#fetch_methods
#for_each_methods
#execute_method
#defer_method
}
}
}
fn gen_runtime_dispatcher<F>(parsed: &ParsedQuery, inject_limit: bool, body_fn: F) -> TokenStream
where
F: Fn(()) -> TokenStream,
{
let mut segments: Vec<String> = Vec::new();
let mut remaining = parsed.positional_sql.as_str();
for i in 0..parsed.optional_clauses.len() {
let marker = format!("{{OPT_{i}}}");
if let Some(pos) = remaining.find(&marker) {
segments.push(remaining[..pos].to_owned());
remaining = &remaining[pos + marker.len()..];
} else {
segments.push(remaining.to_owned());
remaining = "";
}
}
segments.push(remaining.to_owned());
let base_len: usize = segments.iter().map(|s| s.len()).sum();
let clause_len: usize = parsed
.optional_clauses
.iter()
.map(|c| c.sql_fragment.len() + 8) .sum();
let capacity = base_len + clause_len + if inject_limit { 10 } else { 0 };
let first_segment = &segments[0];
let base_param_pushes: Vec<TokenStream> = parsed
.params
.iter()
.map(|p| {
let name = param_ident(&p.name);
quote! {
_bsql_params.push(&self.#name as &(dyn ::bsql_core::driver::Encode + Sync));
}
})
.collect();
let clause_appends: Vec<TokenStream> = parsed
.optional_clauses
.iter()
.enumerate()
.map(|(i, clause)| {
let param_name = param_ident(&clause.params[0].name);
let trailing = &segments[i + 1];
let frag = &clause.sql_fragment;
let mut frag_parts: Vec<String> = Vec::new();
let mut frag_remaining = frag.as_str();
loop {
if let Some(pos) = frag_remaining.find("${P_") {
frag_parts.push(frag_remaining[..pos].to_owned());
let after = &frag_remaining[pos + 4..];
if let Some(end) = after.find('}') {
frag_remaining = &after[end + 1..];
} else {
frag_remaining = "";
break;
}
} else {
frag_parts.push(frag_remaining.to_owned());
break;
}
}
let num_refs = frag_parts.len() - 1;
let frag_append: Vec<TokenStream> = frag_parts
.iter()
.enumerate()
.map(|(j, part)| {
if j == 0 {
quote! {
_bsql_sql.push(' ');
_bsql_sql.push_str(#part);
}
} else {
quote! {
_bsql_sql.push('$');
{
let _n = _bsql_params.len() + 1;
if _n < 10 {
_bsql_sql.push((b'0' + _n as u8) as char);
} else if _n < 100 {
_bsql_sql.push((b'0' + (_n / 10) as u8) as char);
_bsql_sql.push((b'0' + (_n % 10) as u8) as char);
} else {
let _s = _n.to_string();
_bsql_sql.push_str(&_s);
}
}
_bsql_sql.push_str(#part);
}
}
})
.collect();
let param_push = quote! {
_bsql_params.push(self.#param_name.as_ref().unwrap() as &(dyn ::bsql_core::driver::Encode + Sync));
};
let trailing_push = if trailing.is_empty() {
TokenStream::new()
} else {
quote! { _bsql_sql.push_str(#trailing); }
};
if num_refs <= 1 {
quote! {
if self.#param_name.is_some() {
#(#frag_append)*
#param_push
}
#trailing_push
}
} else {
let frag_parts_multi: Vec<TokenStream> = frag_parts
.iter()
.enumerate()
.map(|(j, part)| {
if j == 0 {
quote! {
_bsql_sql.push(' ');
_bsql_sql.push_str(#part);
}
} else {
quote! {
_bsql_sql.push_str(&_bsql_pn);
_bsql_sql.push_str(#part);
}
}
})
.collect();
quote! {
if self.#param_name.is_some() {
#param_push
let _bsql_pn = format!("${}", _bsql_params.len());
#(#frag_parts_multi)*
}
#trailing_push
}
}
})
.collect();
let limit_push = if inject_limit {
quote! { _bsql_sql.push_str(" LIMIT 2"); }
} else {
TokenStream::new()
};
let body = body_fn(());
let n_clauses = parsed.optional_clauses.len();
let max_params = parsed.params.len() + n_clauses;
quote! {
{
let mut _bsql_sql = ::std::string::String::with_capacity(#capacity);
_bsql_sql.push_str(#first_segment);
let mut _bsql_params: ::std::vec::Vec<&(dyn ::bsql_core::driver::Encode + Sync)> =
::std::vec::Vec::with_capacity(#max_params);
#(#base_param_pushes)*
#(#clause_appends)*
#limit_push
let _bsql_hash = ::bsql_core::rapid_hash_str(&_bsql_sql);
#body
}
}
}
fn gen_dynamic_constructor(parsed: &ParsedQuery) -> TokenStream {
let executor_name = executor_struct_name(parsed);
let mut field_names: Vec<proc_macro2::Ident> = Vec::new();
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
for p in &parsed.params {
field_names.push(param_ident(&p.name));
seen.insert(p.name.clone());
}
for clause in &parsed.optional_clauses {
for p in &clause.params {
if seen.insert(p.name.clone()) {
field_names.push(param_ident(&p.name));
}
}
}
quote! { #executor_name { #(#field_names,)* _marker: ::std::marker::PhantomData } }
}
fn gen_stream_struct(
result_name: &proc_macro2::Ident,
stream_name: &proc_macro2::Ident,
row_decode: &TokenStream,
column_check: &TokenStream,
) -> TokenStream {
quote! {
#[allow(non_camel_case_types)]
pub struct #stream_name {
inner: ::bsql_core::QueryStream,
}
#[allow(non_camel_case_types)]
impl #stream_name {
::bsql_core::__bsql_fn! {
pub fn next(&mut self) -> ::bsql_core::BsqlResult<Option<#result_name>> {
if let Some(row) = self.inner.next_row() {
#column_check
return Ok(Some(#result_name { #row_decode }));
}
if !::bsql_core::__bsql_call!(self.inner.fetch_next_chunk())? {
return Ok(None);
}
match self.inner.next_row() {
Some(row) => {
#column_check
Ok(Some(#result_name { #row_decode }))
},
None => Ok(None),
}
}
}
pub fn remaining(&self) -> usize {
self.inner.remaining()
}
}
}
}
fn pg_for_each_row_struct_name(parsed: &ParsedQuery) -> proc_macro2::Ident {
format_ident!("BsqlForEachRow_{}", &parsed.statement_name)
}
fn pg_for_each_result_type(type_str: &str) -> TokenStream {
match type_str {
"String" => quote! { &'a str },
"Vec<u8>" => quote! { &'a [u8] },
_ => {
if let Some(inner) = type_str
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
{
match inner {
"String" => quote! { Option<&'a str> },
"Vec<u8>" => quote! { Option<&'a [u8]> },
_ => parse_result_type(type_str),
}
} else {
parse_result_type(type_str)
}
}
}
}
fn gen_pg_for_each_row_struct(parsed: &ParsedQuery, validation: &ValidationResult) -> TokenStream {
if validation.columns.is_empty() {
return TokenStream::new();
}
let struct_name = pg_for_each_row_struct_name(parsed);
let deduped_names = deduplicate_column_names(&validation.columns);
let fields = validation.columns.iter().enumerate().map(|(i, col)| {
let field_name = format_ident!("{}", deduped_names[i]);
let field_type = pg_for_each_result_type(&col.rust_type);
quote! { pub #field_name: #field_type }
});
let needs_lifetime = validation.columns.iter().any(|col| {
let rt = &col.rust_type;
matches!(rt.as_str(), "String" | "Vec<u8>")
|| rt.starts_with("Option<String>")
|| rt.starts_with("Option<Vec<u8>>")
});
let phantom_field = if needs_lifetime {
TokenStream::new()
} else {
quote! { pub _marker: ::std::marker::PhantomData<&'a ()>, }
};
quote! {
#[derive(Debug)]
#[allow(non_camel_case_types)]
pub struct #struct_name<'a> {
#(#fields,)*
#phantom_field
}
}
}
fn gen_pg_for_each_raw_decode(validation: &ValidationResult) -> (TokenStream, TokenStream) {
let deduped_names = deduplicate_column_names(&validation.columns);
let decode_stmts: Vec<TokenStream> = deduped_names
.iter()
.enumerate()
.map(|(i, name)| {
let field_name = format_ident!("{}", name);
let col = &validation.columns[i];
gen_pg_raw_column_decode(&field_name, &col.rust_type)
})
.collect();
let field_inits: Vec<TokenStream> = deduped_names
.iter()
.enumerate()
.map(|(i, name)| {
let field_name = format_ident!("{}", name);
let _ = i;
quote! { #field_name }
})
.collect();
let needs_lifetime = validation.columns.iter().any(|col| {
let rt = &col.rust_type;
matches!(rt.as_str(), "String" | "Vec<u8>")
|| rt.starts_with("Option<String>")
|| rt.starts_with("Option<Vec<u8>>")
});
let phantom_init = if needs_lifetime {
TokenStream::new()
} else {
quote! { , _marker: ::std::marker::PhantomData }
};
let stmts = quote! {
let mut _bsql_pos: usize = 2; #(#decode_stmts)*
};
let inits = quote! {
#(#field_inits),* #phantom_init
};
(stmts, inits)
}
fn gen_pg_raw_column_decode(field_name: &proc_macro2::Ident, rust_type: &str) -> TokenStream {
if let Some(inner) = rust_type
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
{
gen_pg_raw_nullable_decode(field_name, inner)
} else {
gen_pg_raw_not_null_decode(field_name, rust_type)
}
}
fn gen_pg_raw_not_null_decode(field_name: &proc_macro2::Ident, rust_type: &str) -> TokenStream {
let field_str = field_name.to_string();
match rust_type {
"bool" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "bool", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL bool")),
));
} else {
let _v = _bsql_data[_bsql_pos] != 0;
_bsql_pos += _bsql_len as usize;
_v
};
},
"i16" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "i16", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL i16")),
));
} else {
let _v = i16::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
]);
_bsql_pos += _bsql_len as usize;
_v
};
},
"i32" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "i32", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL i32")),
));
} else {
let _v = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += _bsql_len as usize;
_v
};
},
"i64" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "i64", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL i64")),
));
} else {
let _v = i64::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
_bsql_data[_bsql_pos + 4], _bsql_data[_bsql_pos + 5],
_bsql_data[_bsql_pos + 6], _bsql_data[_bsql_pos + 7],
]);
_bsql_pos += _bsql_len as usize;
_v
};
},
"f32" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "f32", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL f32")),
));
} else {
let _v = f32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += _bsql_len as usize;
_v
};
},
"f64" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "f64", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL f64")),
));
} else {
let _v = f64::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
_bsql_data[_bsql_pos + 4], _bsql_data[_bsql_pos + 5],
_bsql_data[_bsql_pos + 6], _bsql_data[_bsql_pos + 7],
]);
_bsql_pos += _bsql_len as usize;
_v
};
},
"String" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "&str", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL &str")),
));
} else {
let _end = _bsql_pos + _bsql_len as usize;
let _v = ::bsql_core::driver::decode_str(&_bsql_data[_bsql_pos.._end])
.map_err(|e| ::bsql_core::error::DecodeError::with_source(
#field_str, "&str", "invalid UTF-8", e,
))?;
_bsql_pos = _end;
_v
};
},
"Vec<u8>" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "&[u8]", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL &[u8]")),
));
} else {
let _end = _bsql_pos + _bsql_len as usize;
let _v = &_bsql_data[_bsql_pos.._end];
_bsql_pos = _end;
_v
};
},
"u32" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name = if _bsql_len < 0 {
return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "u32", "NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL u32")),
));
} else {
let _v = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]) as u32;
_bsql_pos += _bsql_len as usize;
_v
};
},
"()" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
if _bsql_len > 0 { _bsql_pos += _bsql_len as usize; }
let #field_name = ();
},
_ => gen_pg_raw_feature_decode(field_name, rust_type),
}
}
fn gen_pg_raw_nullable_decode(field_name: &proc_macro2::Ident, inner_type: &str) -> TokenStream {
let field_str = field_name.to_string();
match inner_type {
"bool" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<bool> = if _bsql_len < 0 { None } else {
let _v = _bsql_data[_bsql_pos] != 0;
_bsql_pos += _bsql_len as usize;
Some(_v)
};
},
"i16" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<i16> = if _bsql_len < 0 { None } else {
let _v = i16::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
]);
_bsql_pos += _bsql_len as usize;
Some(_v)
};
},
"i32" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<i32> = if _bsql_len < 0 { None } else {
let _v = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += _bsql_len as usize;
Some(_v)
};
},
"i64" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<i64> = if _bsql_len < 0 { None } else {
let _v = i64::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
_bsql_data[_bsql_pos + 4], _bsql_data[_bsql_pos + 5],
_bsql_data[_bsql_pos + 6], _bsql_data[_bsql_pos + 7],
]);
_bsql_pos += _bsql_len as usize;
Some(_v)
};
},
"f32" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<f32> = if _bsql_len < 0 { None } else {
let _v = f32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += _bsql_len as usize;
Some(_v)
};
},
"f64" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<f64> = if _bsql_len < 0 { None } else {
let _v = f64::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
_bsql_data[_bsql_pos + 4], _bsql_data[_bsql_pos + 5],
_bsql_data[_bsql_pos + 6], _bsql_data[_bsql_pos + 7],
]);
_bsql_pos += _bsql_len as usize;
Some(_v)
};
},
"String" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<&str> = if _bsql_len < 0 { None } else {
let _end = _bsql_pos + _bsql_len as usize;
let _v = ::bsql_core::driver::decode_str(&_bsql_data[_bsql_pos.._end])
.map_err(|e| ::bsql_core::error::DecodeError::with_source(
#field_str, "&str", "invalid UTF-8", e,
))?;
_bsql_pos = _end;
Some(_v)
};
},
"Vec<u8>" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<&[u8]> = if _bsql_len < 0 { None } else {
let _end = _bsql_pos + _bsql_len as usize;
let _v = &_bsql_data[_bsql_pos.._end];
_bsql_pos = _end;
Some(_v)
};
},
"u32" => quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let #field_name: Option<u32> = if _bsql_len < 0 { None } else {
let _v = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]) as u32;
_bsql_pos += _bsql_len as usize;
Some(_v)
};
},
_ => gen_pg_raw_nullable_feature_decode(field_name, inner_type),
}
}
fn gen_pg_raw_feature_decode(field_name: &proc_macro2::Ident, rust_type: &str) -> TokenStream {
let field_str = field_name.to_string();
let read_raw = quote! {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
let _bsql_raw: &[u8] = if _bsql_len < 0 {
&[]
} else {
let _end = _bsql_pos + _bsql_len as usize;
let _v = &_bsql_data[_bsql_pos.._end];
_bsql_pos = _end;
_v
};
};
let decode_expr = match rust_type {
"::uuid::Uuid" | "uuid::Uuid" => quote! {
let #field_name = match ::bsql_core::driver::decode_uuid_type(_bsql_raw) {
Ok(v) => v,
Err(e) => return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "uuid", "invalid data", e,
)),
};
},
_ => {
let col_idx_lit = 0usize;
let decode = gen_not_null_decode(col_idx_lit, rust_type);
quote! {
let #field_name = {
let mut _bsql_tmp = Vec::with_capacity(6 + _bsql_raw.len());
_bsql_tmp.extend_from_slice(&1i16.to_be_bytes());
_bsql_tmp.extend_from_slice(&(_bsql_raw.len() as i32).to_be_bytes());
_bsql_tmp.extend_from_slice(_bsql_raw);
let _bsql_row = ::bsql_core::driver::PgDataRow::new(&_bsql_tmp)
.map_err(|e| ::bsql_core::error::DecodeError::with_source(
#field_str, "decode", "invalid data", e,
))?;
let row = &_bsql_row;
#decode
};
}
}
};
quote! {
#read_raw
#decode_expr
}
}
fn gen_pg_raw_nullable_feature_decode(
field_name: &proc_macro2::Ident,
inner_type: &str,
) -> TokenStream {
let field_str = field_name.to_string();
match inner_type {
"::uuid::Uuid" | "uuid::Uuid" => quote! {
let #field_name = {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
if _bsql_len < 0 {
None
} else {
let _end = _bsql_pos + _bsql_len as usize;
let _raw = &_bsql_data[_bsql_pos.._end];
_bsql_pos = _end;
Some(match ::bsql_core::driver::decode_uuid_type(_raw) {
Ok(v) => v,
Err(e) => return Err(::bsql_core::error::DecodeError::with_source(
#field_str, "uuid", "invalid data", e,
)),
})
}
};
},
_ => {
let col_idx_lit = 0usize;
let decode = gen_nullable_decode(col_idx_lit, inner_type);
quote! {
let #field_name = {
let _bsql_len = i32::from_be_bytes([
_bsql_data[_bsql_pos], _bsql_data[_bsql_pos + 1],
_bsql_data[_bsql_pos + 2], _bsql_data[_bsql_pos + 3],
]);
_bsql_pos += 4;
if _bsql_len < 0 {
None
} else {
let _end = _bsql_pos + _bsql_len as usize;
let _raw = &_bsql_data[_bsql_pos.._end];
_bsql_pos = _end;
let mut _bsql_tmp = Vec::with_capacity(6 + _raw.len());
_bsql_tmp.extend_from_slice(&1i16.to_be_bytes());
_bsql_tmp.extend_from_slice(&(_raw.len() as i32).to_be_bytes());
_bsql_tmp.extend_from_slice(_raw);
let _bsql_row = ::bsql_core::driver::PgDataRow::new(&_bsql_tmp)
.map_err(|e| ::bsql_core::error::DecodeError::with_source(
#field_str, "decode", "invalid data", e,
))?;
let row = &_bsql_row;
#decode
}
};
}
}
}
}
fn gen_column_count_check(validation: &ValidationResult) -> TokenStream {
let expected = validation.columns.len();
if expected == 0 {
return TokenStream::new();
}
quote! {
if row.column_count() < #expected {
return Err(::bsql_core::error::DecodeError::column_count(
#expected,
row.column_count(),
));
}
}
}
fn gen_raw_column_count_check(validation: &ValidationResult) -> TokenStream {
let expected = validation.columns.len() as i16;
if expected == 0 {
return TokenStream::new();
}
quote! {
{
let _bsql_num_cols = i16::from_be_bytes([_bsql_data[0], _bsql_data[1]]);
if _bsql_num_cols < #expected {
return Err(::bsql_core::error::DecodeError::column_count(
#expected as usize,
_bsql_num_cols as usize,
));
}
}
}
}
fn gen_row_decode(validation: &ValidationResult) -> TokenStream {
let deduped_names = deduplicate_column_names(&validation.columns);
let fields = deduped_names.iter().enumerate().map(|(i, name)| {
let field_name = format_ident!("{}", name);
let idx = i;
let col = &validation.columns[i];
let decode_expr = gen_column_decode(idx, &col.rust_type);
quote! { #field_name: #decode_expr }
});
quote! { #(#fields),* }
}
fn gen_column_decode(idx: usize, rust_type: &str) -> TokenStream {
if let Some(inner) = rust_type
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
{
gen_nullable_decode(idx, inner)
} else {
gen_not_null_decode(idx, rust_type)
}
}
fn gen_not_null_decode_error(col_idx: &str, type_name: &str) -> TokenStream {
quote! {
::bsql_core::error::DecodeError::with_source(
#col_idx,
#type_name,
"NULL or invalid data",
::std::io::Error::new(::std::io::ErrorKind::InvalidData, concat!("expected NOT NULL ", #type_name)),
)
}
}
fn gen_not_null_decode(idx: usize, rust_type: &str) -> TokenStream {
let col_idx = idx.to_string();
match rust_type {
"bool" => {
let err = gen_not_null_decode_error(&col_idx, "bool");
quote! { row.get_bool(#idx).ok_or_else(|| #err)? }
}
"i16" => {
let err = gen_not_null_decode_error(&col_idx, "i16");
quote! { row.get_i16(#idx).ok_or_else(|| #err)? }
}
"i32" => {
let err = gen_not_null_decode_error(&col_idx, "i32");
quote! { row.get_i32(#idx).ok_or_else(|| #err)? }
}
"i64" => {
let err = gen_not_null_decode_error(&col_idx, "i64");
quote! { row.get_i64(#idx).ok_or_else(|| #err)? }
}
"f32" => {
let err = gen_not_null_decode_error(&col_idx, "f32");
quote! { row.get_f32(#idx).ok_or_else(|| #err)? }
}
"f64" => {
let err = gen_not_null_decode_error(&col_idx, "f64");
quote! { row.get_f64(#idx).ok_or_else(|| #err)? }
}
"String" => {
let err = gen_not_null_decode_error(&col_idx, "String");
quote! { row.get_str(#idx).ok_or_else(|| #err)?.to_owned() }
}
"Vec<u8>" => {
let err = gen_not_null_decode_error(&col_idx, "Vec<u8>");
quote! { row.get_bytes(#idx).ok_or_else(|| #err)?.to_vec() }
}
"u32" => {
let err = gen_not_null_decode_error(&col_idx, "u32");
quote! { row.get_i32(#idx).ok_or_else(|| #err)? as u32 }
}
"()" => quote! { () },
_ => gen_feature_gated_decode(idx, rust_type),
}
}
fn gen_decode_match(idx: usize, type_name: &str, decode_expr: TokenStream) -> TokenStream {
let col_idx = idx.to_string();
quote! {
match #decode_expr {
Ok(v) => v,
Err(e) => return Err(::bsql_core::error::DecodeError::with_source(
#col_idx,
#type_name,
"invalid data",
e,
)),
}
}
}
fn gen_feature_gated_decode(idx: usize, rust_type: &str) -> TokenStream {
match rust_type {
"::uuid::Uuid" | "uuid::Uuid" => gen_decode_match(
idx,
"uuid",
quote! {
::bsql_core::driver::decode_uuid_type(
row.get_raw(#idx).unwrap_or_default()
)
},
),
"::time::OffsetDateTime" | "time::OffsetDateTime" => gen_decode_match(
idx,
"timestamptz",
quote! {
::bsql_core::driver::decode_timestamptz_time(
row.get_raw(#idx).unwrap_or_default()
)
},
),
"::time::PrimitiveDateTime" | "time::PrimitiveDateTime" => gen_decode_match(
idx,
"timestamp",
quote! {
::bsql_core::driver::decode_timestamptz_time(
row.get_raw(#idx).unwrap_or_default()
).map(|odt| ::time::PrimitiveDateTime::new(odt.date(), odt.time()))
},
),
"::time::Date" | "time::Date" => gen_decode_match(
idx,
"date",
quote! {
::bsql_core::driver::decode_date_time(
row.get_raw(#idx).unwrap_or_default()
)
},
),
"::time::Time" | "time::Time" => gen_decode_match(
idx,
"time",
quote! {
::bsql_core::driver::decode_time_time(
row.get_raw(#idx).unwrap_or_default()
)
},
),
"::chrono::DateTime<::chrono::Utc>"
| "::chrono::DateTime<chrono::Utc>"
| "chrono::DateTime<chrono::Utc>"
| "chrono::DateTime<Utc>" => gen_decode_match(
idx,
"timestamptz",
quote! {
::bsql_core::driver::decode_timestamptz_chrono(
row.get_raw(#idx).unwrap_or_default()
)
},
),
"::chrono::NaiveDateTime" | "chrono::NaiveDateTime" => gen_decode_match(
idx,
"timestamp",
quote! {
::bsql_core::driver::decode_timestamptz_chrono(
row.get_raw(#idx).unwrap_or_default()
).map(|dt| dt.naive_utc())
},
),
"::chrono::NaiveDate" | "chrono::NaiveDate" => gen_decode_match(
idx,
"date",
quote! {
::bsql_core::driver::decode_date_chrono(
row.get_raw(#idx).unwrap_or_default()
)
},
),
"::chrono::NaiveTime" | "chrono::NaiveTime" => gen_decode_match(
idx,
"time",
quote! {
::bsql_core::driver::decode_time_chrono(
row.get_raw(#idx).unwrap_or_default()
)
},
),
"::rust_decimal::Decimal" | "rust_decimal::Decimal" => gen_decode_match(
idx,
"numeric",
quote! {
::bsql_core::driver::decode_numeric_decimal(
row.get_raw(#idx).unwrap_or_default()
)
},
),
"Vec<bool>" => quote! {
::bsql_core::driver::decode_array_bool(
row.get_raw(#idx).unwrap_or_default()
).unwrap_or_default()
},
"Vec<i16>" => quote! {
::bsql_core::driver::decode_array_i16(
row.get_raw(#idx).unwrap_or_default()
).unwrap_or_default()
},
"Vec<i32>" => quote! {
::bsql_core::driver::decode_array_i32(
row.get_raw(#idx).unwrap_or_default()
).unwrap_or_default()
},
"Vec<i64>" => quote! {
::bsql_core::driver::decode_array_i64(
row.get_raw(#idx).unwrap_or_default()
).unwrap_or_default()
},
"Vec<f32>" => quote! {
::bsql_core::driver::decode_array_f32(
row.get_raw(#idx).unwrap_or_default()
).unwrap_or_default()
},
"Vec<f64>" => quote! {
::bsql_core::driver::decode_array_f64(
row.get_raw(#idx).unwrap_or_default()
).unwrap_or_default()
},
"Vec<String>" => quote! {
::bsql_core::driver::decode_array_str(
row.get_raw(#idx).unwrap_or_default()
).unwrap_or_default()
},
"Vec<Vec<u8>>" => quote! {
::bsql_core::driver::decode_array_bytea(
row.get_raw(#idx).unwrap_or_default()
).unwrap_or_default()
},
"Vec<::time::OffsetDateTime>" | "Vec<time::OffsetDateTime>" => gen_decode_match(
idx,
"timestamptz[]",
quote! { {
let raw = row.get_raw(#idx).unwrap_or_default();
::bsql_core::driver::decode_array_i64(raw).and_then(|micros_vec| {
let mut out = Vec::with_capacity(micros_vec.len());
for micros in micros_vec {
let buf = micros.to_be_bytes();
out.push(::bsql_core::driver::decode_timestamptz_time(&buf)?);
}
Ok(out)
})
} },
),
"Vec<::time::PrimitiveDateTime>" | "Vec<time::PrimitiveDateTime>" => gen_decode_match(
idx,
"timestamp[]",
quote! { {
let raw = row.get_raw(#idx).unwrap_or_default();
::bsql_core::driver::decode_array_i64(raw).and_then(|micros_vec| {
let mut out = Vec::with_capacity(micros_vec.len());
for micros in micros_vec {
let buf = micros.to_be_bytes();
let odt = ::bsql_core::driver::decode_timestamptz_time(&buf)?;
out.push(::time::PrimitiveDateTime::new(odt.date(), odt.time()));
}
Ok(out)
})
} },
),
"Vec<::time::Date>" | "Vec<time::Date>" => gen_decode_match(
idx,
"date[]",
quote! { {
let raw = row.get_raw(#idx).unwrap_or_default();
::bsql_core::driver::decode_array_i32(raw).and_then(|days_vec| {
let mut out = Vec::with_capacity(days_vec.len());
for days in days_vec {
let buf = days.to_be_bytes();
out.push(::bsql_core::driver::decode_date_time(&buf)?);
}
Ok(out)
})
} },
),
"Vec<::time::Time>" | "Vec<time::Time>" => gen_decode_match(
idx,
"time[]",
quote! { {
let raw = row.get_raw(#idx).unwrap_or_default();
::bsql_core::driver::decode_array_i64(raw).and_then(|micros_vec| {
let mut out = Vec::with_capacity(micros_vec.len());
for micros in micros_vec {
let buf = micros.to_be_bytes();
out.push(::bsql_core::driver::decode_time_time(&buf)?);
}
Ok(out)
})
} },
),
"Vec<::uuid::Uuid>" | "Vec<uuid::Uuid>" => gen_decode_match(
idx,
"uuid[]",
quote! { {
let raw = row.get_raw(#idx).unwrap_or_default();
::bsql_core::driver::decode_array_bytea(raw).and_then(|elements| {
let mut out = Vec::with_capacity(elements.len());
for bytes in &elements {
out.push(::bsql_core::driver::decode_uuid_type(bytes)?);
}
Ok(out)
})
} },
),
"Vec<::rust_decimal::Decimal>" | "Vec<rust_decimal::Decimal>" => gen_decode_match(
idx,
"numeric[]",
quote! { {
let raw = row.get_raw(#idx).unwrap_or_default();
::bsql_core::driver::decode_array_bytea(raw).and_then(|elements| {
let mut out = Vec::with_capacity(elements.len());
for bytes in &elements {
out.push(::bsql_core::driver::decode_numeric_decimal(bytes)?);
}
Ok(out)
})
} },
),
"Vec<::chrono::DateTime<::chrono::Utc>>" | "Vec<chrono::DateTime<chrono::Utc>>" => {
gen_decode_match(
idx,
"timestamptz[]",
quote! { {
let raw = row.get_raw(#idx).unwrap_or_default();
::bsql_core::driver::decode_array_i64(raw).and_then(|micros_vec| {
let mut out = Vec::with_capacity(micros_vec.len());
for micros in micros_vec {
let buf = micros.to_be_bytes();
out.push(::bsql_core::driver::decode_timestamptz_chrono(&buf)?);
}
Ok(out)
})
} },
)
}
"Vec<::chrono::NaiveDateTime>" | "Vec<chrono::NaiveDateTime>" => gen_decode_match(
idx,
"timestamp[]",
quote! { {
let raw = row.get_raw(#idx).unwrap_or_default();
::bsql_core::driver::decode_array_i64(raw).and_then(|micros_vec| {
let mut out = Vec::with_capacity(micros_vec.len());
for micros in micros_vec {
let buf = micros.to_be_bytes();
let dt = ::bsql_core::driver::decode_timestamptz_chrono(&buf)?;
out.push(dt.naive_utc());
}
Ok(out)
})
} },
),
_ => {
quote! { {
let _raw = row.get_raw(#idx).unwrap_or_default();
compile_error!(concat!("bsql: unsupported type for decode: ", #rust_type))
} }
}
}
}
fn gen_nullable_decode(idx: usize, inner_type: &str) -> TokenStream {
match inner_type {
"bool" => quote! { row.get_bool(#idx) },
"i16" => quote! { row.get_i16(#idx) },
"i32" => quote! { row.get_i32(#idx) },
"i64" => quote! { row.get_i64(#idx) },
"f32" => quote! { row.get_f32(#idx) },
"f64" => quote! { row.get_f64(#idx) },
"String" => quote! { row.get_str(#idx).map(|s| s.to_owned()) },
"Vec<u8>" => quote! { row.get_bytes(#idx).map(|b| b.to_vec()) },
"u32" => quote! { row.get_i32(#idx).map(|v| v as u32) },
_ => {
let decode = gen_feature_gated_decode(idx, inner_type);
quote! { {
if row.is_null(#idx) {
None
} else {
Some(#decode)
}
} }
}
}
}
fn gen_constructor(parsed: &ParsedQuery) -> TokenStream {
let executor_name = executor_struct_name(parsed);
let field_inits = parsed.params.iter().map(|p| {
let name = param_ident(&p.name);
quote! { #name }
});
quote! { #executor_name { #(#field_inits,)* _marker: ::std::marker::PhantomData } }
}
fn inject_lifetime(type_str: &str) -> TokenStream {
match syn::parse_str::<syn::Type>(type_str) {
Ok(ty) => {
let rewritten = add_lifetime_to_refs(ty);
quote! { #rewritten }
}
Err(_) => {
let msg = format!("internal error: cannot parse type `{type_str}`");
quote! { compile_error!(#msg) }
}
}
}
fn add_lifetime_to_refs(ty: syn::Type) -> syn::Type {
match ty {
syn::Type::Reference(mut r) => {
if r.lifetime.is_none() {
r.lifetime = Some(syn::Lifetime::new("'_bsql", proc_macro2::Span::call_site()));
}
r.elem = Box::new(add_lifetime_to_refs(*r.elem));
syn::Type::Reference(r)
}
syn::Type::Slice(mut s) => {
s.elem = Box::new(add_lifetime_to_refs(*s.elem));
syn::Type::Slice(s)
}
syn::Type::Path(mut p) => {
for seg in &mut p.path.segments {
if let syn::PathArguments::AngleBracketed(args) = &mut seg.arguments {
for arg in &mut args.args {
if let syn::GenericArgument::Type(inner) = arg {
*inner = add_lifetime_to_refs(inner.clone());
}
}
}
}
syn::Type::Path(p)
}
other => other,
}
}
fn parse_result_type(type_str: &str) -> TokenStream {
match syn::parse_str::<syn::Type>(type_str) {
Ok(ty) => quote! { #ty },
Err(_) => {
let msg = format!("internal error: cannot parse type `{type_str}`");
quote! { compile_error!(#msg) }
}
}
}
fn deduplicate_column_names(columns: &[crate::validate::ColumnInfo]) -> Vec<String> {
let names: Vec<String> = columns
.iter()
.enumerate()
.map(|(i, col)| sanitize_column_name(&col.name, i))
.collect();
let mut final_names: Vec<String> = Vec::with_capacity(names.len());
for name in &names {
let mut candidate = name.clone();
let mut suffix = 1u32;
while final_names.contains(&candidate) {
candidate = format!("{name}_{suffix}");
suffix += 1;
}
final_names.push(candidate);
}
final_names
}
fn result_struct_name(parsed: &ParsedQuery) -> proc_macro2::Ident {
format_ident!("BsqlResult_{}", &parsed.statement_name)
}
fn executor_struct_name(parsed: &ParsedQuery) -> proc_macro2::Ident {
format_ident!("BsqlExecutor_{}", &parsed.statement_name)
}
fn stream_struct_name(parsed: &ParsedQuery) -> proc_macro2::Ident {
format_ident!("BsqlStream_{}", &parsed.statement_name)
}
const RUST_KEYWORDS: &[&str] = &[
"as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern",
"false", "fn", "for", "gen", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut",
"pub", "raw", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
"type", "unsafe", "use", "where", "while", "yield",
];
fn sanitize_param_name(name: &str) -> String {
if RUST_KEYWORDS.contains(&name) {
format!("{name}_")
} else {
name.to_owned()
}
}
fn param_ident(name: &str) -> proc_macro2::Ident {
format_ident!("{}", sanitize_param_name(name))
}
fn sanitize_column_name(name: &str, index: usize) -> String {
if name == "?column?" || name.is_empty() {
return format!("col_{index}");
}
let sanitized: String = name
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '_' {
c
} else {
'_'
}
})
.collect();
let sanitized = if sanitized.starts_with(|c: char| c.is_ascii_digit()) {
format!("col_{sanitized}")
} else {
sanitized
};
if RUST_KEYWORDS.contains(&sanitized.as_str()) {
format!("{sanitized}_")
} else {
sanitized
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parse::parse_query;
use crate::validate::ColumnInfo;
fn make_validation(columns: Vec<ColumnInfo>) -> ValidationResult {
ValidationResult {
columns,
param_pg_oids: smallvec::smallvec![],
param_is_pg_enum: smallvec::smallvec![],
rewritten_sql: None,
#[cfg(feature = "explain")]
explain_plan: None,
}
}
fn col(name: &str, rust_type: &str) -> ColumnInfo {
ColumnInfo {
name: name.into(),
pg_oid: 0,
pg_type_name: String::new(),
is_nullable: false,
rust_type: rust_type.into(),
}
}
#[test]
fn generates_result_struct_with_fields() {
let parsed = parse_query("SELECT id, name FROM users WHERE 1 = $a: i32").unwrap();
let validation = make_validation(vec![col("id", "i32"), col("name", "String")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("pub id : i32"),
"missing id field: {code_str}"
);
assert!(
code_str.contains("pub name : String"),
"missing name field: {code_str}"
);
}
#[test]
fn generates_nullable_field_as_option() {
let parsed = parse_query("SELECT bio FROM users WHERE 1 = $a: i32").unwrap();
let validation = make_validation(vec![col("bio", "Option<String>")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("Option < String >") || code_str.contains("Option<String>"),
"missing Option<String>: {code_str}"
);
}
#[test]
fn generates_fetch_one_method() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("fetch_one"),
"missing fetch_one: {code_str}"
);
assert!(
code_str.contains("fn fetch_all"),
"missing fetch_all method: {code_str}"
);
assert!(
code_str.contains("fn fetch_all"),
"missing fetch_all (renamed to fetch): {code_str}"
);
assert!(
code_str.contains("fetch_optional"),
"missing fetch_optional: {code_str}"
);
assert!(code_str.contains("execute"), "missing execute: {code_str}");
assert!(
!code_str.contains("fn run"),
"run alias should be removed: {code_str}"
);
assert!(
!code_str.contains("fn get (self , executor"),
"get alias should be removed: {code_str}"
);
assert!(
!code_str.contains("fn stream"),
"stream alias should be removed: {code_str}"
);
}
#[test]
fn no_params_generates_unit_struct() {
let parsed = parse_query("SELECT 1").unwrap();
let validation = make_validation(vec![col("col_0", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("struct BsqlExecutor_"),
"missing executor: {code_str}"
);
}
#[test]
fn execute_only_query_has_no_result_struct() {
let parsed = parse_query("UPDATE t SET a = $a: i32 WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("BsqlResult_"),
"should not have result struct: {code_str}"
);
assert!(code_str.contains("execute"), "missing execute: {code_str}");
}
#[test]
fn positional_sql_in_generated_code() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("$1"),
"should contain positional $1: {code_str}"
);
assert!(
!code_str.contains("$id"),
"should not contain named $id: {code_str}"
);
}
#[test]
fn uses_driver_encode_not_tosql() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("bsql_core :: driver :: Encode"),
"should use bsql_core::driver::Encode: {code_str}"
);
assert!(
!code_str.contains("ToSql"),
"should not use ToSql: {code_str}"
);
}
#[test]
fn uses_typed_getters_not_row_get() {
let parsed = parse_query("SELECT id, name FROM t WHERE 1 = $a: i32").unwrap();
let validation = make_validation(vec![col("id", "i32"), col("name", "String")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("get_i32"),
"should use get_i32 for i32 column: {code_str}"
);
assert!(
code_str.contains("get_str"),
"should use get_str for String column: {code_str}"
);
}
#[test]
fn select_uses_query_raw_readonly() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("query_raw_readonly"),
"SELECT should use query_raw_readonly: {code_str}"
);
}
#[test]
fn insert_uses_query_raw_not_readonly() {
let parsed = parse_query("INSERT INTO t (a) VALUES ($a: i32) RETURNING id").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("query_raw_readonly"),
"INSERT should NOT use query_raw_readonly: {code_str}"
);
assert!(
code_str.contains("query_raw"),
"INSERT RETURNING should use query_raw: {code_str}"
);
}
#[test]
fn fetch_one_injects_limit_2() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("LIMIT 2"),
"missing LIMIT 2 in fetch_one: {code_str}"
);
}
#[test]
fn existing_limit_not_doubled() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32 LIMIT 10").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("LIMIT 2"),
"should not add LIMIT 2 when LIMIT exists: {code_str}"
);
}
#[test]
fn for_update_no_limit_injected() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32 FOR UPDATE").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("LIMIT 2"),
"FOR UPDATE query should NOT get LIMIT 2 injected: {code_str}"
);
}
#[test]
fn inject_lifetime_bare_ref_str() {
let ts = inject_lifetime("&str");
let s = ts.to_string();
assert!(s.contains("'_bsql"), "missing lifetime: {s}");
}
#[test]
fn inject_lifetime_bare_ref_slice() {
let ts = inject_lifetime("&[u8]");
let s = ts.to_string();
assert!(s.contains("'_bsql"), "missing lifetime: {s}");
}
#[test]
fn inject_lifetime_no_ref_passes_through() {
let ts = inject_lifetime("i32");
let s = ts.to_string();
assert!(!s.contains("'_bsql"), "i32 should have no lifetime: {s}");
}
#[test]
fn duplicate_column_names_deduplicated() {
let columns = vec![col("id", "i32"), col("id", "i32"), col("name", "String")];
let names = deduplicate_column_names(&columns);
assert_eq!(names, vec!["id", "id_1", "name"]);
}
#[test]
fn three_duplicate_columns() {
let columns = vec![col("id", "i32"), col("id", "i32"), col("id", "i32")];
let names = deduplicate_column_names(&columns);
assert_eq!(names, vec!["id", "id_1", "id_2"]);
}
#[test]
fn sanitize_unnamed_column() {
assert_eq!(sanitize_column_name("?column?", 0), "col_0");
}
#[test]
fn sanitize_column_keyword_type() {
assert_eq!(sanitize_column_name("type", 0), "type_");
}
#[test]
fn sanitize_param_keyword() {
assert_eq!(sanitize_param_name("type"), "type_");
assert_eq!(sanitize_param_name("fn"), "fn_");
}
#[test]
fn sanitize_param_non_keyword() {
assert_eq!(sanitize_param_name("id"), "id");
}
#[test]
fn sanitize_raw_keyword() {
assert_eq!(sanitize_param_name("raw"), "raw_");
assert_eq!(sanitize_column_name("raw", 0), "raw_");
}
#[test]
fn not_null_decode_uses_ok_or_else() {
let parsed = parse_query("SELECT id FROM t WHERE 1 = $a: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("unwrap_or_default"),
"should not use unwrap_or_default for NOT NULL decode: {code_str}"
);
assert!(
code_str.contains("ok_or_else"),
"should use ok_or_else for NOT NULL decode: {code_str}"
);
}
#[test]
fn timestamp_decode_has_primitive_date_time() {
let parsed = parse_query("SELECT ts FROM t WHERE 1 = $a: i32").unwrap();
let validation = make_validation(vec![col("ts", "::time::PrimitiveDateTime")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("decode_timestamptz_time"),
"PrimitiveDateTime should use timestamptz decode: {code_str}"
);
assert!(
code_str.contains("PrimitiveDateTime"),
"should reference PrimitiveDateTime: {code_str}"
);
}
#[test]
fn timestamp_decode_has_naive_date_time() {
let parsed = parse_query("SELECT ts FROM t WHERE 1 = $a: i32").unwrap();
let validation = make_validation(vec![col("ts", "::chrono::NaiveDateTime")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("decode_timestamptz_chrono"),
"NaiveDateTime should use timestamptz decode: {code_str}"
);
assert!(
code_str.contains("naive_utc"),
"should convert to naive_utc: {code_str}"
);
}
#[test]
fn select_returns_owned_structs_no_zero_copy() {
let parsed = parse_query("SELECT id, name FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32"), col("name", "String")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("BsqlRows_"),
"BsqlRows should be removed: {code_str}"
);
assert!(
!code_str.contains("BsqlSingleRef_"),
"BsqlSingleRef should be removed: {code_str}"
);
assert!(
!code_str.contains("fetch_ref"),
"fetch_ref should not exist: {code_str}"
);
assert!(
code_str.contains("fn fetch_all"),
"missing fetch_all method: {code_str}"
);
assert!(
code_str.contains("fn fetch_one"),
"missing fetch_one method: {code_str}"
);
assert!(
code_str.contains("fn fetch_optional"),
"missing fetch_optional method: {code_str}"
);
assert!(
code_str.contains("Vec <"),
"fetch should return Vec: {code_str}"
);
}
#[test]
fn fetch_ref_not_generated_for_insert() {
let parsed = parse_query("INSERT INTO t (a) VALUES ($a: i32) RETURNING id").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("fetch_ref"),
"INSERT should NOT have fetch_ref: {code_str}"
);
}
#[test]
fn fetch_ref_not_generated_for_execute_only() {
let parsed = parse_query("DELETE FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("fetch_ref"),
"execute-only should NOT have fetch_ref: {code_str}"
);
assert!(
!code_str.contains("BsqlRows_"),
"execute-only should NOT have BsqlRows: {code_str}"
);
}
#[test]
fn generated_code_includes_column_count_check() {
let parsed = parse_query("SELECT id, name FROM t WHERE 1 = $a: i32").unwrap();
let validation = make_validation(vec![col("id", "i32"), col("name", "String")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("column_count"),
"generated code should include column_count bounds check: {code_str}"
);
assert!(
code_str.contains("DecodeError"),
"generated code should reference DecodeError for column mismatch: {code_str}"
);
}
#[test]
fn column_count_check_uses_correct_count() {
let parsed = parse_query("SELECT a, b, c FROM t WHERE 1 = $x: i32").unwrap();
let validation = make_validation(vec![col("a", "i32"), col("b", "i32"), col("c", "i32")]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
code_str.contains("3usize") || code_str.contains("3 usize"),
"column_count check should use expected=3: {code_str}"
);
}
#[test]
fn no_column_count_check_for_execute_only() {
let parsed = parse_query("UPDATE t SET a = $a: i32 WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![]);
let code = generate_query_code(&parsed, &validation);
let code_str = code.to_string();
assert!(
!code_str.contains("column_count"),
"execute-only query should NOT have column_count check: {code_str}"
);
}
#[test]
fn gen_column_count_check_zero_columns_returns_empty() {
let validation = make_validation(vec![]);
let check = gen_column_count_check(&validation);
assert!(
check.is_empty(),
"0 columns should produce empty check: {}",
check
);
}
#[test]
fn gen_column_count_check_one_column() {
let validation = make_validation(vec![col("id", "i32")]);
let check = gen_column_count_check(&validation);
let code = check.to_string();
assert!(
code.contains("column_count"),
"1 column should produce a check: {code}"
);
assert!(
code.contains("1usize") || code.contains("1 usize"),
"should check for 1 column: {code}"
);
}
#[test]
fn gen_column_count_check_ten_columns() {
let cols: Vec<ColumnInfo> = (0..10).map(|i| col(&format!("c{i}"), "i32")).collect();
let validation = make_validation(cols);
let check = gen_column_count_check(&validation);
let code = check.to_string();
assert!(
code.contains("10usize") || code.contains("10 usize"),
"should check for 10 columns: {code}"
);
}
#[test]
fn gen_column_count_check_references_decode_error() {
let validation = make_validation(vec![col("a", "i32"), col("b", "String")]);
let check = gen_column_count_check(&validation);
let code = check.to_string();
assert!(
code.contains("DecodeError :: column_count"),
"should reference DecodeError::column_count: {code}"
);
}
#[test]
fn query_as_uses_target_type_not_anonymous_struct() {
let parsed = parse_query("SELECT id, name FROM users WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32"), col("name", "String")]);
let target_type: syn::Path = syn::parse_str("User").unwrap();
let code = generate_query_as_code(&parsed, &validation, &target_type);
let code_str = code.to_string();
assert!(
code_str.contains("User"),
"should reference target type User: {code_str}"
);
assert!(
!code_str.contains("BsqlResult_"),
"should not generate anonymous result struct: {code_str}"
);
assert!(
!code_str.contains("BsqlRows_"),
"should not generate rows struct: {code_str}"
);
assert!(
!code_str.contains("BsqlSingleRef_"),
"should not generate single ref struct: {code_str}"
);
}
#[test]
fn query_as_generates_fetch_methods() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let target_type: syn::Path = syn::parse_str("MyRow").unwrap();
let code = generate_query_as_code(&parsed, &validation, &target_type);
let code_str = code.to_string();
assert!(
code_str.contains("fetch_one"),
"missing fetch_one: {code_str}"
);
assert!(
code_str.contains("fn fetch_all"),
"missing fetch_all: {code_str}"
);
assert!(
code_str.contains("fetch_optional"),
"missing fetch_optional: {code_str}"
);
assert!(code_str.contains("execute"), "missing execute: {code_str}");
assert!(
!code_str.contains("fn run"),
"run alias should be removed: {code_str}"
);
assert!(code_str.contains("fn defer"), "missing defer: {code_str}");
}
#[test]
fn query_as_with_module_path() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let target_type: syn::Path = syn::parse_str("crate::models::User").unwrap();
let code = generate_query_as_code(&parsed, &validation, &target_type);
let code_str = code.to_string();
assert!(
code_str.contains("crate :: models :: User"),
"should use fully qualified path: {code_str}"
);
}
#[test]
fn query_as_no_columns_has_no_fetch() {
let parsed = parse_query("UPDATE t SET a = $a: i32 WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![]);
let target_type: syn::Path = syn::parse_str("User").unwrap();
let code = generate_query_as_code(&parsed, &validation, &target_type);
let code_str = code.to_string();
assert!(
!code_str.contains("fetch_one"),
"execute-only should not have fetch_one: {code_str}"
);
assert!(code_str.contains("execute"), "missing execute: {code_str}");
}
#[test]
fn query_as_has_column_count_check() {
let parsed = parse_query("SELECT id, name FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32"), col("name", "String")]);
let target_type: syn::Path = syn::parse_str("User").unwrap();
let code = generate_query_as_code(&parsed, &validation, &target_type);
let code_str = code.to_string();
assert!(
code_str.contains("column_count"),
"should have column count check: {code_str}"
);
}
#[test]
fn query_as_nullable_column() {
let parsed = parse_query("SELECT email FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("email", "Option<String>")]);
let target_type: syn::Path = syn::parse_str("UserEmail").unwrap();
let code = generate_query_as_code(&parsed, &validation, &target_type);
let code_str = code.to_string();
assert!(
code_str.contains("UserEmail"),
"should use target type: {code_str}"
);
assert!(
code_str.contains("get_str"),
"should decode String column: {code_str}"
);
}
#[test]
fn query_as_injects_limit_2() {
let parsed = parse_query("SELECT id FROM t WHERE id = $id: i32").unwrap();
let validation = make_validation(vec![col("id", "i32")]);
let target_type: syn::Path = syn::parse_str("Row").unwrap();
let code = generate_query_as_code(&parsed, &validation, &target_type);
let code_str = code.to_string();
assert!(
code_str.contains("LIMIT 2"),
"missing LIMIT 2 in query_as fetch_one: {code_str}"
);
}
}