use once_cell::sync::OnceCell;
use std::sync::atomic::{AtomicBool, Ordering};
use tracing::info;
static SQL_LOG_OVERRIDE: OnceCell<AtomicBool> = OnceCell::new();
pub fn set_sql_logging(enabled: bool) {
SQL_LOG_OVERRIDE
.get_or_init(|| AtomicBool::new(enabled))
.store(enabled, Ordering::Relaxed);
}
#[inline]
pub fn is_sql_logging_enabled() -> bool {
if let Some(flag) = SQL_LOG_OVERRIDE.get() {
return flag.load(Ordering::Relaxed);
}
cfg!(debug_assertions)
}
#[track_caller]
pub fn log_query(sql: &str) {
if is_sql_logging_enabled() {
let loc = std::panic::Location::caller();
info!(
target: "sql",
"[{}:{}:{}]\n sql | {}",
loc.file(),
loc.line(),
loc.column(),
sql,
);
}
}
#[macro_export]
macro_rules! sqlx_fetch_all {
($exec:expr, $q:expr) => {{
let __q = $q;
$crate::sqlxhelper::logging::log_query(__q.sql());
__q.fetch_all($exec)
}};
}
#[macro_export]
macro_rules! sqlx_fetch_one {
($exec:expr, $q:expr) => {{
let __q = $q;
$crate::sqlxhelper::logging::log_query(__q.sql());
__q.fetch_one($exec)
}};
}
#[macro_export]
macro_rules! sqlx_fetch_optional {
($exec:expr, $q:expr) => {{
let __q = $q;
$crate::sqlxhelper::logging::log_query(__q.sql());
__q.fetch_optional($exec)
}};
}
#[macro_export]
macro_rules! sqlx_execute {
($exec:expr, $q:expr) => {{
let __q = $q;
$crate::sqlxhelper::logging::log_query(__q.sql());
__q.execute($exec)
}};
}
#[macro_export]
macro_rules! sqlx_fetch_scalar {
($exec:expr, $q:expr) => {{
let __q = $q;
$crate::sqlxhelper::logging::log_query(__q.sql());
__q.fetch_one($exec)
}};
}
#[cfg(test)]
mod tests {
use super::*;
use once_cell::sync::Lazy;
use std::io;
use std::sync::{Arc, Mutex, MutexGuard};
static LOG_STATE_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
fn lock_log_state() -> MutexGuard<'static, ()> {
LOG_STATE_LOCK.lock().unwrap_or_else(|e| e.into_inner())
}
struct BufWriter(Arc<Mutex<Vec<u8>>>);
impl io::Write for BufWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
struct MakeBufWriter(Arc<Mutex<Vec<u8>>>);
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for MakeBufWriter {
type Writer = BufWriter;
fn make_writer(&'a self) -> BufWriter {
BufWriter(Arc::clone(&self.0))
}
}
fn make_capture_subscriber() -> (impl tracing::Subscriber, Arc<Mutex<Vec<u8>>>) {
let buf = Arc::new(Mutex::new(Vec::<u8>::new()));
let subscriber = tracing_subscriber::fmt()
.with_writer(MakeBufWriter(Arc::clone(&buf)))
.with_ansi(false)
.finish();
(subscriber, buf)
}
#[test]
fn test_sql_logging_override() {
let _guard = lock_log_state();
set_sql_logging(true);
assert!(is_sql_logging_enabled(), "should be enabled");
set_sql_logging(false);
assert!(!is_sql_logging_enabled(), "should be disabled");
}
#[test]
fn test_log_query_output_contains_location() {
let _guard = lock_log_state();
set_sql_logging(true);
let (sub, buf) = make_capture_subscriber();
tracing::subscriber::with_default(sub, || {
log_query("SELECT 99"); });
let output = String::from_utf8_lossy(&buf.lock().unwrap()).into_owned();
assert!(
output.contains("logging.rs"),
"output should contain the source file; got: {output}"
);
assert!(
output.contains("SELECT 99"),
"output should contain the SQL; got: {output}"
);
set_sql_logging(false);
}
#[test]
fn test_log_query_silent_when_disabled() {
let _guard = lock_log_state();
set_sql_logging(false);
let (sub, buf) = make_capture_subscriber();
tracing::subscriber::with_default(sub, || {
log_query("SELECT secret");
});
let output = String::from_utf8_lossy(&buf.lock().unwrap()).into_owned();
assert!(
output.is_empty(),
"should produce no output when disabled; got: {output}"
);
}
}