#![forbid(unsafe_code)]
extern crate proc_macro;
mod codegen;
#[cfg(feature = "sqlite")]
mod codegen_sqlite;
mod connection;
mod dynamic;
#[cfg(feature = "explain")]
mod explain;
mod offline;
mod parse;
mod pg_enum;
mod sort_enum;
mod sql_norm;
mod stmt_name;
mod suggest;
mod test_macro;
pub(crate) mod types;
#[cfg(feature = "sqlite")]
mod types_sqlite;
mod validate;
#[cfg(feature = "sqlite")]
mod validate_sqlite;
use proc_macro::TokenStream;
#[proc_macro]
pub fn query(input: TokenStream) -> TokenStream {
let input2: proc_macro2::TokenStream = input.into();
match query_impl(input2) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn query_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
let sql = extract_sql(input)?;
let parsed = parse::parse_query(&sql)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
#[cfg(feature = "sqlite")]
{
let backend = connection::detect_backend()
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
if backend == Some(connection::Backend::Sqlite) {
return query_impl_sqlite(parsed);
}
}
query_impl_postgres(parsed)
}
fn query_impl_postgres(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
if parsed.sort_placeholder.is_some() {
return query_impl_sort(parsed);
}
if parsed.optional_clauses.is_empty() {
let validation = if offline::is_offline() {
offline::lookup_cached_validation(&parsed)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
} else {
let result = connection::with_connection(|conn| {
validate::validate_query_with_suggestions(&parsed, conn)
})?;
offline::write_cache(&parsed, &result);
result
};
validate::check_param_types(&parsed, &validation)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
Ok(codegen::generate_query_code(&parsed, &validation))
} else {
let validation = if offline::is_offline() {
offline::lookup_cached_validation(&parsed)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
} else {
let result = connection::with_connection(|conn| {
let variants = dynamic::expand_variants(&parsed)?;
validate::validate_variants(&variants, &parsed, conn)
})?;
offline::write_cache(&parsed, &result);
result
};
Ok(codegen::generate_dynamic_query_code(&parsed, &validation))
}
}
#[cfg(feature = "sqlite")]
fn query_impl_sqlite(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
if parsed.sort_placeholder.is_some() {
return query_impl_sqlite_sort(parsed);
}
if parsed.optional_clauses.is_empty() {
let validation = if offline::is_offline() {
offline::lookup_cached_validation(&parsed)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
} else {
let result = connection::with_sqlite_connection(|conn| {
validate_sqlite::validate_query_sqlite(&parsed, conn)
})?;
offline::write_cache(&parsed, &result);
result
};
Ok(codegen_sqlite::generate_sqlite_query_code(
&parsed,
&validation,
))
} else {
let validation = if offline::is_offline() {
offline::lookup_cached_validation(&parsed)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
} else {
let result = connection::with_sqlite_connection(|conn| {
let variants = dynamic::expand_variants(&parsed)?;
validate_sqlite::validate_variants_sqlite(&variants, &parsed, conn)
})?;
offline::write_cache(&parsed, &result);
result
};
Ok(codegen_sqlite::generate_dynamic_sqlite_query_code(
&parsed,
&validation,
))
}
}
#[cfg(feature = "sqlite")]
fn query_impl_sqlite_sort(
parsed: parse::ParsedQuery,
) -> Result<proc_macro2::TokenStream, syn::Error> {
let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
let sort_enum_name = &sort_placeholder.enum_name;
let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");
let dummy_parsed = parse::ParsedQuery {
normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
positional_sql: dummy_sql,
params: parsed.params.clone(),
kind: parsed.kind,
statement_name: parsed.statement_name.clone(),
optional_clauses: parsed.optional_clauses.clone(),
sort_placeholder: None,
};
let validation = if offline::is_offline() {
offline::lookup_cached_validation(&parsed)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
} else {
let result = connection::with_sqlite_connection(|conn| {
validate_sqlite::validate_query_sqlite(&dummy_parsed, conn)
})?;
offline::write_cache(&parsed, &result);
result
};
Ok(codegen_sqlite::generate_sort_sqlite_query_code(
&parsed,
&validation,
sort_enum_name,
))
}
fn query_impl_sort(parsed: parse::ParsedQuery) -> Result<proc_macro2::TokenStream, syn::Error> {
let sort_placeholder = parsed.sort_placeholder.as_ref().unwrap();
let sort_enum_name = &sort_placeholder.enum_name;
let dummy_sql = parsed.positional_sql.replace("{SORT}", "1");
let dummy_parsed = parse::ParsedQuery {
normalized_sql: parsed.normalized_sql.replace("{sort}", "1"),
positional_sql: dummy_sql,
params: parsed.params.clone(),
kind: parsed.kind,
statement_name: parsed.statement_name.clone(),
optional_clauses: parsed.optional_clauses.clone(),
sort_placeholder: None,
};
let validation = if offline::is_offline() {
offline::lookup_cached_validation(&parsed)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
} else {
let result = connection::with_connection(|conn| {
validate::validate_query_with_suggestions(&dummy_parsed, conn)
})?;
let sorts_dir = std::env::var("CARGO_MANIFEST_DIR")
.map(|d| std::path::PathBuf::from(d).join(".bsql").join("sorts"))
.ok();
if let Some(sorts_dir) = sorts_dir {
let path = sorts_dir.join(format!("{}.txt", sort_enum_name));
if let Ok(content) = std::fs::read_to_string(&path) {
connection::with_connection(|conn| {
for fragment in content.lines().filter(|l| !l.is_empty()) {
let test_sql = parsed.positional_sql.replace("{SORT}", fragment);
let prepare = format!("PREPARE __bsql_sort_check AS {}", test_sql);
if let Err(e) = conn.simple_query(&prepare) {
return Err(format!("sort fragment '{}' is invalid: {}", fragment, e));
}
let _ = conn.simple_query("DEALLOCATE __bsql_sort_check");
}
Ok(())
})?;
}
}
offline::write_cache(&parsed, &result);
result
};
validate::check_param_types(&parsed, &validation)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
Ok(codegen::generate_sort_query_code(
&parsed,
&validation,
sort_enum_name,
))
}
#[proc_macro]
pub fn query_as(input: TokenStream) -> TokenStream {
let input2: proc_macro2::TokenStream = input.into();
match query_as_impl(input2) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
}
struct QueryAsArgs {
target_type: syn::Path,
_comma: syn::Token![,],
sql: syn::LitStr,
}
impl syn::parse::Parse for QueryAsArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
Ok(QueryAsArgs {
target_type: input.parse()?,
_comma: input.parse()?,
sql: input.parse()?,
})
}
}
fn extract_type_and_sql(
input: proc_macro2::TokenStream,
) -> Result<(syn::Path, String), syn::Error> {
let args: QueryAsArgs = syn::parse2(input)?;
Ok((args.target_type, args.sql.value()))
}
fn query_as_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
let (target_type, sql) = extract_type_and_sql(input)?;
let parsed = parse::parse_query(&sql)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
if parsed.sort_placeholder.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"query_as! does not support $[sort: ...] placeholders; use query! instead",
));
}
if !parsed.optional_clauses.is_empty() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"query_as! does not support optional clauses; use query! instead",
));
}
#[cfg(feature = "sqlite")]
{
let backend = connection::detect_backend()
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
if backend == Some(connection::Backend::Sqlite) {
return query_as_impl_sqlite(parsed, target_type);
}
}
query_as_impl_postgres(parsed, target_type)
}
fn query_as_impl_postgres(
parsed: parse::ParsedQuery,
target_type: syn::Path,
) -> Result<proc_macro2::TokenStream, syn::Error> {
let validation = if offline::is_offline() {
offline::lookup_cached_validation(&parsed)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
} else {
let result = connection::with_connection(|conn| {
validate::validate_query_with_suggestions(&parsed, conn)
})?;
offline::write_cache(&parsed, &result);
result
};
validate::check_param_types(&parsed, &validation)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?;
Ok(codegen::generate_query_as_code(
&parsed,
&validation,
&target_type,
))
}
#[cfg(feature = "sqlite")]
fn query_as_impl_sqlite(
parsed: parse::ParsedQuery,
target_type: syn::Path,
) -> Result<proc_macro2::TokenStream, syn::Error> {
let validation = if offline::is_offline() {
offline::lookup_cached_validation(&parsed)
.map_err(|msg| syn::Error::new(proc_macro2::Span::call_site(), msg))?
} else {
let result = connection::with_sqlite_connection(|conn| {
validate_sqlite::validate_query_sqlite(&parsed, conn)
})?;
offline::write_cache(&parsed, &result);
result
};
Ok(codegen_sqlite::generate_sqlite_query_as_code(
&parsed,
&validation,
&target_type,
))
}
fn extract_sql(input: proc_macro2::TokenStream) -> Result<String, syn::Error> {
let lit: syn::LitStr = syn::parse2(input)?;
Ok(lit.value())
}
#[proc_macro_attribute]
pub fn pg_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr2: proc_macro2::TokenStream = attr.into();
let item2: proc_macro2::TokenStream = item.into();
match pg_enum::expand_pg_enum(attr2, item2) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn sort(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr2: proc_macro2::TokenStream = attr.into();
let item2: proc_macro2::TokenStream = item.into();
match sort_enum::expand_sort_enum(attr2, item2) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr2: proc_macro2::TokenStream = attr.into();
let item2: proc_macro2::TokenStream = item.into();
match test_macro::expand_test(attr2, item2) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[cfg(test)]
mod tests {
use super::{extract_type_and_sql, QueryAsArgs};
#[test]
fn parse_query_as_args() {
let tokens: proc_macro2::TokenStream = "User, \"SELECT id FROM users\"".parse().unwrap();
let args: QueryAsArgs = syn::parse2(tokens).unwrap();
assert_eq!(args.sql.value(), "SELECT id FROM users");
let last_segment = args.target_type.segments.last().unwrap().ident.to_string();
assert_eq!(last_segment, "User");
}
#[test]
fn parse_query_as_args_module_path() {
let tokens: proc_macro2::TokenStream = "crate::models::User, \"SELECT id FROM users\""
.parse()
.unwrap();
let args: QueryAsArgs = syn::parse2(tokens).unwrap();
assert_eq!(args.sql.value(), "SELECT id FROM users");
let segments: Vec<String> = args
.target_type
.segments
.iter()
.map(|s| s.ident.to_string())
.collect();
assert_eq!(segments, vec!["crate", "models", "User"]);
}
#[test]
fn extract_type_and_sql_basic() {
let tokens: proc_macro2::TokenStream = "Row, \"SELECT name FROM t WHERE id = $id: i32\""
.parse()
.unwrap();
let (path, sql) = extract_type_and_sql(tokens).unwrap();
assert_eq!(sql, "SELECT name FROM t WHERE id = $id: i32");
assert_eq!(path.segments.last().unwrap().ident.to_string(), "Row");
}
#[test]
fn extract_type_and_sql_missing_comma_fails() {
let tokens: proc_macro2::TokenStream = "User \"SELECT id FROM t\"".parse().unwrap();
assert!(extract_type_and_sql(tokens).is_err());
}
#[test]
fn extract_type_and_sql_missing_sql_fails() {
let tokens: proc_macro2::TokenStream = "User,".parse().unwrap();
assert!(extract_type_and_sql(tokens).is_err());
}
}