use crate::{DBResult, Error};
use libsqlite3_sys as sq;
use std::{
cell::Cell,
ffi::{c_void, CStr, CString},
pin::Pin,
};
mod conn;
#[cfg(feature = "regex")]
mod regex;
pub use conn::*;
fn check_rcode<'a>(sql_gen: impl FnOnce() -> Option<&'a str>, rcode: i32) -> Result<(), Error> {
let code = rcode & 0xff;
if code == sq::SQLITE_OK {
Ok(())
} else {
Err(Error::Sqlite {
code,
extended_code: rcode & !0xff,
msg: unsafe { CStr::from_ptr(sq::sqlite3_errstr(rcode)) }
.to_str()?
.to_string(),
sql: sql_gen().map(String::from),
})
}
}
pub(crate) trait PreparedKey {
fn into_u64(self) -> u64;
}
impl PreparedKey for u64 {
fn into_u64(self) -> u64 {
self
}
}
impl PreparedKey for std::any::TypeId {
fn into_u64(self) -> u64 {
use std::hash::Hash;
use std::hash::Hasher;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
pub struct Transaction {
db: ConnectionLease,
committed: bool,
}
impl Transaction {
fn begin_transaction(&mut self) -> DBResult<()> {
struct BeginTransaction;
self.db.with_prepared(
std::any::TypeId::of::<BeginTransaction>(),
|| Ok(String::from("BEGIN TRANSACTION")),
|ctx| {
ctx.run()?;
Ok(())
},
)
}
fn commit_transaction(&mut self) -> DBResult<()> {
struct CommitTransaction;
self.db.with_prepared(
std::any::TypeId::of::<CommitTransaction>(),
|| Ok(String::from("COMMIT TRANSACTION")),
|ctx| {
ctx.run()?;
Ok(())
},
)
}
fn rollback_transaction(&mut self) -> DBResult<()> {
struct RollbackTransaction;
self.db.with_prepared(
std::any::TypeId::of::<RollbackTransaction>(),
|| Ok(String::from("ROLLBACK TRANSACTION")),
|ctx| {
ctx.run()?;
Ok(())
},
)
}
pub(crate) fn new(db: ConnectionLease) -> DBResult<Self> {
let mut r = Self {
db,
committed: false,
};
r.begin_transaction()?;
Ok(r)
}
pub(crate) fn lease(&mut self) -> &mut ConnectionLease {
&mut self.db
}
pub fn commit(mut self) -> DBResult<()> {
self.committed = true;
match self.commit_transaction() {
Err(Error::Sqlite {
code: sq::SQLITE_BUSY,
..
}) => Err(Error::TransactionAbort),
v => v,
}
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if !self.committed {
let _ = self.rollback_transaction();
}
}
}
#[derive(Debug)]
pub(crate) struct RawSchemaRow {
pub type_: String,
pub name: String,
pub tbl_name: String,
pub _rootpage: String,
pub sql: String,
}
pub(crate) fn get_raw_schema(lease: &mut ConnectionLease) -> DBResult<Vec<RawSchemaRow>> {
let mut schema_entries: Vec<RawSchemaRow> = vec![];
unsafe extern "C" fn rowcb(
entries: *mut c_void,
ncols: i32,
rowdata: *mut *mut i8,
_cols: *mut *mut i8,
) -> i32 {
assert_eq!(ncols, 5);
let rows = std::slice::from_raw_parts(rowdata, ncols as usize);
let entries = (entries as *mut Vec<RawSchemaRow>).as_mut().unwrap();
let extract_value = |idx: usize| {
if !rows[idx].is_null() {
CStr::from_ptr(rows[idx]).to_str().map(ToString::to_string)
} else {
Ok(String::new())
}
};
let row = (|| {
DBResult::Ok(RawSchemaRow {
type_: extract_value(0)?,
name: extract_value(1)?,
tbl_name: extract_value(2)?,
_rootpage: extract_value(3)?,
sql: extract_value(4)?,
})
})();
match row {
Ok(row) => {
entries.push(row);
0
},
Err(err) => {
log::error!("error while getting raw schema: {err:?}");
1
},
}
}
unsafe {
let sql = "SELECT * FROM sqlite_schema";
let c_sql = CString::new(sql)?;
let entries_ptr = (&mut schema_entries) as *mut Vec<RawSchemaRow>;
check_rcode(
|| Some(sql),
sq::sqlite3_exec(
lease.conn.sqlite,
c_sql.as_ptr(),
Some(rowcb),
entries_ptr as *mut c_void,
std::ptr::null_mut(),
),
)?;
}
Ok(schema_entries)
}
struct Statement {
#[allow(unused)]
sqlite: *mut sq::sqlite3,
stmt: *mut sq::sqlite3_stmt,
}
impl Statement {
fn make_context(&mut self) -> DBResult<StatementContext<'_>> {
unsafe {
check_rcode(|| None, sq::sqlite3_reset(self.stmt))?;
}
Ok(StatementContext {
stmt: self,
owned_strings: Default::default(),
done: false.into(),
})
}
}
impl Drop for Statement {
fn drop(&mut self) {
unsafe {
sq::sqlite3_finalize(self.stmt);
}
}
}
#[cfg(test)]
mod test {
use super::ConnectionPool;
#[test]
fn simple_sql() {
let c = ConnectionPool::new(":memory:").expect("couldn't open test db");
let mut t = c.start().unwrap();
t.lease()
.execute_raw_sql("CREATE TABLE test_table (id integer primary key, value string)")
.expect("couldn't execute sql");
}
#[test]
fn prepare_stmt() {
let c = ConnectionPool::new(":memory:").expect("couldn't open test db");
let mut t = c.start().unwrap();
t.lease()
.execute_raw_sql("CREATE TABLE test_table (id integer primary key, value string)")
.expect("couldn't execute sql");
t.lease()
.with_prepared(
1,
|| Ok(format!("INSERT INTO test_table VALUES (?, ?)")),
|ctx| {
ctx.bind(1, 1usize)?;
ctx.bind(2, "value")?;
ctx.iter().last();
Ok(())
},
)
.expect("couldn't run prepared INSERT statement");
t.lease()
.with_prepared(
2,
|| Ok(format!("SELECT * FROM test_table")),
|ctx| {
let count = ctx
.iter()
.map(|row| {
let row = row.unwrap();
assert_eq!(row.read::<i64>(0).expect("couldn't read row ID"), 1);
assert_eq!(
row.read::<String>(1).expect("couldn't read row value"),
"value"
);
})
.count();
assert!(count > 0);
Ok(())
},
)
.expect("couldn't run prepared SELECT statement");
}
}
pub struct StatementRow<'a> {
stmt: &'a Statement,
_ctx: Option<StatementContext<'a>>,
}
impl StatementRow<'_> {
pub fn read<T: Readable>(&self, index: i32) -> DBResult<T> {
T::read_from(self, index)
}
pub fn borrow<T: Borrowable>(&self, index: i32) -> DBResult<T> {
T::borrow(self, index)
}
}
pub struct StatementContext<'a> {
stmt: &'a Statement,
owned_strings: Vec<Pin<String>>,
done: Cell<bool>,
}
impl<'a> StatementContext<'a> {
pub fn bind<B: Bindable>(&self, index: i32, bindable: B) -> DBResult<()> {
bindable.bind(self, index)
}
pub fn transfer(&mut self, s: Pin<String>) {
self.owned_strings.push(s);
}
fn step(&self) -> DBResult<bool> {
if self.done.get() {
return Ok(false);
}
let step_result = unsafe { sq::sqlite3_step(self.stmt.stmt) };
match step_result & 0xff {
sq::SQLITE_ROW => Ok(true),
sq::SQLITE_DONE => {
self.done.set(true);
Ok(false)
},
sq::SQLITE_BUSY => {
log::trace!("Concurrent database access!");
Err(Error::TransactionAbort)
},
sq::SQLITE_CONSTRAINT => {
let msg = unsafe { CStr::from_ptr(sq::sqlite3_errmsg(self.stmt.sqlite)) }
.to_str()
.unwrap()
.to_string();
log::trace!("SQLite constraint violation: {msg}");
Err(Error::ConstraintViolation(msg))
},
err => {
log::trace!("unexpected error during sqlite3_step: {:?}", err);
check_rcode(|| None, err)?;
unreachable!()
},
}
}
#[doc(hidden)]
pub fn run(self) -> DBResult<Option<StatementRow<'a>>> {
if self.step()? {
Ok(Some(StatementRow {
stmt: self.stmt,
_ctx: Some(self),
}))
} else {
Ok(None)
}
}
#[doc(hidden)]
pub fn iter(self) -> impl Iterator<Item = DBResult<StatementRow<'a>>> {
struct I<'a>(StatementContext<'a>);
impl<'a> Iterator for I<'a> {
type Item = DBResult<StatementRow<'a>>;
fn next(&mut self) -> Option<Self::Item> {
match self.0.step() {
Ok(true) => Some(Ok(StatementRow {
_ctx: None,
stmt: self.0.stmt,
})),
Ok(false) => None,
Err(e) => Some(Err(e)),
}
}
}
I(self)
}
}
impl Drop for StatementContext<'_> {
fn drop(&mut self) {
unsafe {
while self.step().is_ok_and(|v| v) {}
sq::sqlite3_clear_bindings(self.stmt.stmt);
}
}
}
pub trait Bindable {
fn bind<'ctx, 'data: 'ctx>(
&'data self,
ctx: &StatementContext<'ctx>,
index: i32,
) -> DBResult<()>;
}
impl Bindable for () {
fn bind<'ctx, 'data: 'ctx>(
&'data self,
ctx: &StatementContext<'ctx>,
index: i32,
) -> DBResult<()> {
unsafe { check_rcode(|| None, sq::sqlite3_bind_null(ctx.stmt.stmt, index)) }
}
}
impl Bindable for i64 {
fn bind<'ctx, 'data: 'ctx>(
&'data self,
ctx: &StatementContext<'ctx>,
index: i32,
) -> DBResult<()> {
unsafe { check_rcode(|| None, sq::sqlite3_bind_int64(ctx.stmt.stmt, index, *self)) }
}
}
impl Bindable for usize {
fn bind<'ctx, 'data: 'ctx>(
&'data self,
ctx: &StatementContext<'ctx>,
index: i32,
) -> DBResult<()> {
(*self as i64).bind(ctx, index)
}
}
impl Bindable for f32 {
fn bind<'ctx, 'data: 'ctx>(&self, ctx: &StatementContext<'ctx>, index: i32) -> DBResult<()> {
(*self as f64).bind(ctx, index)
}
}
impl Bindable for f64 {
fn bind<'ctx, 'data: 'ctx>(&self, ctx: &StatementContext<'ctx>, index: i32) -> DBResult<()> {
unsafe {
check_rcode(
|| None,
sq::sqlite3_bind_double(ctx.stmt.stmt, index, *self),
)
}
}
}
impl Bindable for &str {
fn bind<'ctx, 'data: 'ctx>(
&'data self,
ctx: &StatementContext<'ctx>,
index: i32,
) -> DBResult<()> {
unsafe {
check_rcode(
|| None,
sq::sqlite3_bind_text(
ctx.stmt.stmt,
index,
self.as_ptr().cast(),
self.len() as i32,
sq::SQLITE_STATIC(),
),
)
}
}
}
impl Bindable for str {
fn bind<'ctx, 'data: 'ctx>(
&'data self,
ctx: &StatementContext<'ctx>,
index: i32,
) -> DBResult<()> {
<&'_ str>::bind(&self, ctx, index)
}
}
impl Bindable for String {
fn bind<'ctx, 'data: 'ctx>(
&'data self,
ctx: &StatementContext<'ctx>,
index: i32,
) -> DBResult<()> {
self.as_str().bind(ctx, index)
}
}
impl Bindable for &[u8] {
fn bind<'ctx, 'data: 'ctx>(
&'data self,
ctx: &StatementContext<'ctx>,
index: i32,
) -> DBResult<()> {
unsafe {
check_rcode(
|| None,
sq::sqlite3_bind_blob64(
ctx.stmt.stmt,
index,
self.as_ptr().cast(),
self.len() as u64,
sq::SQLITE_STATIC(),
),
)
}
}
}
pub trait Readable: Sized {
fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self>;
}
pub trait Borrowable: Sized {
fn borrow(sr: &StatementRow<'_>, index: i32) -> DBResult<Self>;
}
pub struct IsNull(pub bool);
impl Readable for IsNull {
fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
let column_type = unsafe { sq::sqlite3_column_type(sr.stmt.stmt, index) };
if column_type == sq::SQLITE_NULL {
Ok(IsNull(true))
} else {
Ok(IsNull(false))
}
}
}
impl Borrowable for IsNull {
fn borrow(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
let column_type = unsafe { sq::sqlite3_column_type(sr.stmt.stmt, index) };
if column_type == sq::SQLITE_NULL {
Ok(IsNull(true))
} else {
Ok(IsNull(false))
}
}
}
impl Readable for i64 {
fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
unsafe { Ok(sq::sqlite3_column_int64(sr.stmt.stmt, index)) }
}
}
impl Borrowable for i64 {
fn borrow(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
i64::read_from(sr, index)
}
}
impl Readable for usize {
fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
i64::read_from(sr, index).map(|i| i as usize)
}
}
impl Borrowable for usize {
fn borrow(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
usize::read_from(sr, index)
}
}
impl Readable for f32 {
fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
f64::read_from(sr, index).map(|i| i as f32)
}
}
impl Borrowable for f32 {
fn borrow(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
f32::read_from(sr, index)
}
}
impl Readable for f64 {
fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
unsafe { Ok(sq::sqlite3_column_double(sr.stmt.stmt, index)) }
}
}
impl Borrowable for f64 {
fn borrow(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
f64::read_from(sr, index)
}
}
impl Readable for String {
fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
unsafe {
let text = sq::sqlite3_column_text(sr.stmt.stmt, index);
if text.is_null() {
Err(Error::InternalError(
"NULL pointer result from sqlite3_column_text",
))
} else {
Ok(CStr::from_ptr(text.cast()).to_str()?.to_string())
}
}
}
}
impl Borrowable for &str {
fn borrow(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
unsafe {
let text = sq::sqlite3_column_text(sr.stmt.stmt, index);
if text.is_null() {
Err(Error::InternalError(
"NULL pointer result from sqlite3_column_text",
))
} else {
Ok(CStr::from_ptr(text.cast()).to_str()?)
}
}
}
}
impl Readable for Vec<u8> {
fn read_from(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
unsafe {
let ptr = sq::sqlite3_column_blob(sr.stmt.stmt, index);
let len = sq::sqlite3_column_bytes(sr.stmt.stmt, index);
match len.cmp(&0) {
std::cmp::Ordering::Equal => Ok(vec![]),
std::cmp::Ordering::Less => Err(Error::InternalError(
"negative length returned from sqlite3_column_bytes",
)),
std::cmp::Ordering::Greater => {
Ok(std::slice::from_raw_parts(ptr.cast(), len as usize).to_vec())
},
}
}
}
}
impl Borrowable for &[u8] {
fn borrow(sr: &StatementRow<'_>, index: i32) -> DBResult<Self> {
unsafe {
let ptr = sq::sqlite3_column_blob(sr.stmt.stmt, index);
let len = sq::sqlite3_column_bytes(sr.stmt.stmt, index);
match len.cmp(&0) {
std::cmp::Ordering::Equal => Ok(&[]),
std::cmp::Ordering::Less => Err(Error::InternalError(
"negative length returned from sqlite3_column_bytes",
)),
std::cmp::Ordering::Greater => {
Ok(std::slice::from_raw_parts(ptr.cast(), len as usize))
},
}
}
}
}
#[cfg(test)]
mod sendsync_check {
struct CheckSend<T: Send>(std::marker::PhantomData<T>);
struct CheckSync<T: Sync>(std::marker::PhantomData<T>);
#[test]
fn check_send() {
let _ = CheckSend::<super::ConnectionPool>(Default::default());
}
#[test]
fn check_sync() {
let _ = CheckSync::<super::ConnectionPool>(Default::default());
}
}