use std::path::{Path, PathBuf};
use std::time::Duration;
use rusqlite::{Connection, OpenFlags};
use crate::EngineError;
const SHARED_SQLITE_POLICY: &str = include_str!("../sqlite.env");
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SharedSqlitePolicy {
pub minimum_supported_version: String,
pub repo_dev_version: String,
pub repo_local_binary_relpath: PathBuf,
}
#[cfg(feature = "tracing")]
static SQLITE_LOG_INIT: std::sync::Once = std::sync::Once::new();
#[cfg(feature = "tracing")]
fn sqlite_log_callback(code: std::os::raw::c_int, msg: &str) {
let primary = code & 0xFF;
if primary == rusqlite::ffi::SQLITE_NOTICE as std::os::raw::c_int {
tracing::info!(target: "fathomdb_engine::sqlite", sqlite_error_code = code, "{msg}");
} else if primary == rusqlite::ffi::SQLITE_WARNING as std::os::raw::c_int {
tracing::warn!(target: "fathomdb_engine::sqlite", sqlite_error_code = code, "{msg}");
} else {
tracing::error!(target: "fathomdb_engine::sqlite", sqlite_error_code = code, "{msg}");
}
}
#[cfg(all(feature = "tracing", debug_assertions))]
fn install_trace_v2(conn: &Connection) {
use std::os::raw::{c_int, c_uint, c_void};
unsafe extern "C" fn trace_v2_callback(
event_type: c_uint,
_ctx: *mut c_void,
p: *mut c_void,
x: *mut c_void,
) -> c_int {
if event_type == rusqlite::ffi::SQLITE_TRACE_PROFILE as c_uint {
let stmt = p.cast::<rusqlite::ffi::sqlite3_stmt>();
let nanos = unsafe { *(x.cast::<i64>()) };
let sql_ptr = unsafe { rusqlite::ffi::sqlite3_sql(stmt) };
if !sql_ptr.is_null() {
let sql = unsafe { std::ffi::CStr::from_ptr(sql_ptr) }.to_string_lossy();
tracing::trace!(
target: "fathomdb_engine::sqlite",
sql = %sql,
duration_us = nanos / 1000,
"sqlite statement profile"
);
}
}
0
}
unsafe {
rusqlite::ffi::sqlite3_trace_v2(
conn.handle(),
rusqlite::ffi::SQLITE_TRACE_PROFILE as c_uint,
Some(trace_v2_callback),
std::ptr::null_mut(),
);
}
}
pub fn open_connection(path: &Path) -> Result<Connection, EngineError> {
#[cfg(feature = "tracing")]
SQLITE_LOG_INIT.call_once(|| {
unsafe {
let _ = rusqlite::trace::config_log(Some(sqlite_log_callback));
}
});
let conn = Connection::open_with_flags(
path,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
)?;
conn.busy_timeout(Duration::from_secs(5))?;
#[cfg(all(feature = "tracing", debug_assertions))]
install_trace_v2(&conn);
Ok(conn)
}
pub fn open_readonly_connection(path: &Path) -> Result<Connection, EngineError> {
#[cfg(feature = "tracing")]
SQLITE_LOG_INIT.call_once(|| {
unsafe {
let _ = rusqlite::trace::config_log(Some(sqlite_log_callback));
}
});
let conn = Connection::open_with_flags(path, OpenFlags::SQLITE_OPEN_READ_ONLY)?;
conn.busy_timeout(Duration::from_secs(5))?;
#[cfg(all(feature = "tracing", debug_assertions))]
install_trace_v2(&conn);
Ok(conn)
}
#[cfg(feature = "sqlite-vec")]
pub fn open_readonly_connection_with_vec(path: &Path) -> Result<Connection, EngineError> {
unsafe {
rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute::<
*const (),
unsafe extern "C" fn(
*mut rusqlite::ffi::sqlite3,
*mut *mut std::os::raw::c_char,
*const rusqlite::ffi::sqlite3_api_routines,
) -> i32,
>(
sqlite_vec::sqlite3_vec_init as *const ()
)));
}
open_readonly_connection(path)
}
#[cfg(feature = "sqlite-vec")]
pub fn open_connection_with_vec(path: &Path) -> Result<Connection, EngineError> {
unsafe {
rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute::<
*const (),
unsafe extern "C" fn(
*mut rusqlite::ffi::sqlite3,
*mut *mut std::os::raw::c_char,
*const rusqlite::ffi::sqlite3_api_routines,
) -> i32,
>(
sqlite_vec::sqlite3_vec_init as *const ()
)));
}
open_connection(path)
}
pub fn shared_sqlite_policy() -> Result<SharedSqlitePolicy, String> {
let mut minimum_supported_version = None;
let mut repo_dev_version = None;
for raw_line in SHARED_SQLITE_POLICY.lines() {
let line = raw_line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let Some((key, value)) = line.split_once('=') else {
return Err(format!("invalid sqlite policy line: {line}"));
};
match key.trim() {
"SQLITE_MIN_VERSION" => minimum_supported_version = Some(value.trim().to_owned()),
"SQLITE_VERSION" => repo_dev_version = Some(value.trim().to_owned()),
other => return Err(format!("unknown sqlite policy key: {other}")),
}
}
let minimum_supported_version =
minimum_supported_version.ok_or_else(|| "missing SQLITE_MIN_VERSION".to_owned())?;
let repo_dev_version = repo_dev_version.ok_or_else(|| "missing SQLITE_VERSION".to_owned())?;
let repo_local_binary_relpath =
PathBuf::from(format!(".local/sqlite-{repo_dev_version}/bin/sqlite3"));
Ok(SharedSqlitePolicy {
minimum_supported_version,
repo_dev_version,
repo_local_binary_relpath,
})
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::shared_sqlite_policy;
#[test]
fn shared_sqlite_policy_matches_repo_defaults() {
let policy = shared_sqlite_policy().expect("shared sqlite policy");
assert_eq!(policy.minimum_supported_version, "3.41.0");
assert_eq!(policy.repo_dev_version, "3.46.0");
assert!(
policy
.repo_local_binary_relpath
.ends_with("sqlite-3.46.0/bin/sqlite3")
);
}
}