use crate::connection::handle::ConnectionHandle;
use crate::connection::LogSettings;
use crate::connection::{ConnectionState, Statements};
use crate::error::Error;
use crate::SqliteConnectOptions;
use libsqlite3_sys::{
sqlite3_busy_timeout, SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY,
SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE,
SQLITE_OPEN_SHAREDCACHE, SQLITE_OPEN_URI,
};
use percent_encoding::NON_ALPHANUMERIC;
use std::collections::BTreeMap;
use std::ffi::CString;
use std::io;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[cfg(feature = "load-extension")]
use sqlx_core::IndexMap;
static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
pub struct EstablishParams {
filename: CString,
open_flags: i32,
busy_timeout: Duration,
statement_cache_capacity: usize,
log_settings: LogSettings,
#[cfg(feature = "load-extension")]
extensions: IndexMap<CString, Option<CString>>,
pub(crate) thread_name: String,
pub(crate) command_channel_size: usize,
#[cfg(feature = "regexp")]
register_regexp_function: bool,
}
impl EstablishParams {
pub fn from_options(options: &SqliteConnectOptions) -> Result<Self, Error> {
let mut filename = options
.filename
.to_str()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must be valid UTF-8",
)
})?
.to_owned();
let mut flags = SQLITE_OPEN_URI;
flags |= if options.serialized {
SQLITE_OPEN_FULLMUTEX
} else {
SQLITE_OPEN_NOMUTEX
};
flags |= if options.read_only {
SQLITE_OPEN_READONLY
} else if options.create_if_missing {
SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE
} else {
SQLITE_OPEN_READWRITE
};
if options.in_memory {
flags |= SQLITE_OPEN_MEMORY;
}
flags |= if options.shared_cache {
SQLITE_OPEN_SHAREDCACHE
} else {
SQLITE_OPEN_PRIVATECACHE
};
let mut query_params = BTreeMap::new();
if options.immutable {
query_params.insert("immutable", "true");
}
if let Some(vfs) = options.vfs.as_deref() {
query_params.insert("vfs", vfs);
}
if !query_params.is_empty() {
filename = format!(
"file:{}?{}",
percent_encoding::percent_encode(filename.as_bytes(), NON_ALPHANUMERIC),
serde_urlencoded::to_string(&query_params).unwrap()
);
}
let filename = CString::new(filename).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"filename passed to SQLite must not contain nul bytes",
)
})?;
#[cfg(feature = "load-extension")]
let extensions = options
.extensions
.iter()
.map(|(name, entry)| {
let entry = entry
.as_ref()
.map(|e| {
CString::new(e.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension entrypoint names passed to SQLite must not contain nul bytes"
)
})
})
.transpose()?;
Ok((
CString::new(name.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension names passed to SQLite must not contain nul bytes",
)
})?,
entry,
))
})
.collect::<Result<IndexMap<CString, Option<CString>>, io::Error>>()?;
let thread_id = THREAD_ID.fetch_add(1, Ordering::AcqRel);
Ok(Self {
filename,
open_flags: flags,
busy_timeout: options.busy_timeout,
statement_cache_capacity: options.statement_cache_capacity,
log_settings: options.log_settings.clone(),
#[cfg(feature = "load-extension")]
extensions,
thread_name: (options.thread_name)(thread_id as u64),
command_channel_size: options.command_channel_size,
#[cfg(feature = "regexp")]
register_regexp_function: options.register_regexp_function,
})
}
pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
let mut handle = ConnectionHandle::open(&self.filename, self.open_flags)?;
#[cfg(feature = "load-extension")]
unsafe {
self.apply_extensions(&mut handle)?;
}
#[cfg(feature = "regexp")]
if self.register_regexp_function {
let status = crate::regexp::register(handle.as_ptr());
if status != libsqlite3_sys::SQLITE_OK {
return Err(Error::Database(Box::new(handle.expect_error())));
}
}
let ms = i32::try_from(self.busy_timeout.as_millis())
.expect("Given busy timeout value is too big.");
handle.call_with_result(|db| unsafe { sqlite3_busy_timeout(db, ms) })?;
Ok(ConnectionState {
handle,
statements: Statements::new(self.statement_cache_capacity),
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: None,
commit_hook_callback: None,
rollback_hook_callback: None,
})
}
#[cfg(feature = "load-extension")]
unsafe fn apply_extensions(&self, handle: &mut ConnectionHandle) -> Result<(), Error> {
use libsqlite3_sys::{sqlite3_free, sqlite3_load_extension};
use std::ffi::{c_int, CStr};
use std::ptr;
fn enable_load_extension(
handle: &mut ConnectionHandle,
enabled: bool,
) -> Result<(), Error> {
use libsqlite3_sys::{sqlite3_db_config, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION};
handle.call_with_result(|db| unsafe {
sqlite3_db_config(
db,
SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
enabled as c_int,
ptr::null_mut::<c_int>(),
)
})?;
Ok(())
}
if self.extensions.is_empty() {
return Ok(());
}
enable_load_extension(handle, true)?;
for (name, entrypoint) in &self.extensions {
let name_ptr = name.as_ptr();
let entrypoint_ptr = entrypoint.as_ref().map_or_else(ptr::null, |s| s.as_ptr());
let mut err_msg_ptr = ptr::null_mut();
handle
.call_with_result(|db| unsafe {
sqlite3_load_extension(db, name_ptr, entrypoint_ptr, &mut err_msg_ptr)
})
.map_err(|e| {
if !err_msg_ptr.is_null() {
let err_msg = unsafe { CStr::from_ptr(err_msg_ptr) }
.to_string_lossy()
.into_owned();
unsafe {
sqlite3_free(err_msg_ptr.cast());
}
e.with_message(err_msg)
} else {
e
}
})?;
}
enable_load_extension(handle, false)
}
}