use std::path::Path;
use std::sync::Arc;
use std::{collections::VecDeque, path::PathBuf};
use parking_lot::Mutex;
use rusqlite::OpenFlags;
use tokio::task::spawn_blocking;
use crate::frame::{Frame, FrameNo};
use crate::rpc::replication::Frame as RpcFrame;
use self::injector_wal::{
InjectorWal, InjectorWalManager, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN,
};
use super::error::Result;
use super::{Error, Injector};
mod headers;
mod injector_wal;
pub type FrameBuffer = Arc<Mutex<VecDeque<Frame>>>;
pub struct SqliteInjector {
pub(in super::super) inner: Arc<Mutex<SqliteInjectorInner>>,
}
impl Injector for SqliteInjector {
async fn inject_frame(&mut self, frame: RpcFrame) -> Result<Option<FrameNo>> {
let inner = self.inner.clone();
let frame =
Frame::try_from(&frame.data[..]).map_err(|e| Error::FatalInjectError(e.into()))?;
spawn_blocking(move || inner.lock().inject_frame(frame))
.await
.unwrap()
}
async fn rollback(&mut self) {
let inner = self.inner.clone();
spawn_blocking(move || inner.lock().rollback())
.await
.unwrap();
}
async fn flush(&mut self) -> Result<Option<FrameNo>> {
let inner = self.inner.clone();
spawn_blocking(move || inner.lock().flush()).await.unwrap()
}
#[inline]
fn durable_frame_no(&mut self, _frame_no: u64) {}
}
impl SqliteInjector {
pub async fn new(
path: PathBuf,
capacity: usize,
auto_checkpoint: u32,
encryption_config: Option<libsql_sys::EncryptionConfig>,
) -> super::Result<Self> {
let inner = spawn_blocking(move || {
SqliteInjectorInner::new(path, capacity, auto_checkpoint, encryption_config)
})
.await
.unwrap()?;
Ok(Self {
inner: Arc::new(Mutex::new(inner)),
})
}
}
pub(in super::super) struct SqliteInjectorInner {
is_txn: bool,
buffer: FrameBuffer,
capacity: usize,
connection: Arc<Mutex<libsql_sys::Connection<InjectorWal>>>,
biggest_uncommitted_seen: FrameNo,
path: PathBuf,
encryption_config: Option<libsql_sys::EncryptionConfig>,
auto_checkpoint: u32,
}
impl SqliteInjectorInner {
fn new(
path: impl AsRef<Path>,
capacity: usize,
auto_checkpoint: u32,
encryption_config: Option<libsql_sys::EncryptionConfig>,
) -> Result<Self, Error> {
let path = path.as_ref().to_path_buf();
let buffer = FrameBuffer::default();
let wal_manager = InjectorWalManager::new(buffer.clone());
let connection = libsql_sys::Connection::open(
&path,
OpenFlags::SQLITE_OPEN_READ_WRITE
| OpenFlags::SQLITE_OPEN_CREATE
| OpenFlags::SQLITE_OPEN_URI
| OpenFlags::SQLITE_OPEN_NO_MUTEX,
wal_manager,
auto_checkpoint,
encryption_config.clone(),
)?;
Ok(Self {
is_txn: false,
buffer,
capacity,
connection: Arc::new(Mutex::new(connection)),
biggest_uncommitted_seen: 0,
path,
encryption_config,
auto_checkpoint,
})
}
pub fn inject_frame(&mut self, frame: Frame) -> Result<Option<FrameNo>, Error> {
let frame_close_txn = frame.header().size_after.get() != 0;
self.buffer.lock().push_back(frame);
if frame_close_txn || self.buffer.lock().len() >= self.capacity {
return self.flush();
}
Ok(None)
}
pub fn rollback(&mut self) {
self.clear_buffer();
let conn = self.connection.lock();
let mut rollback = conn.prepare_cached("ROLLBACK").unwrap();
let _ = rollback.execute(());
self.is_txn = false;
}
pub fn flush(&mut self) -> Result<Option<FrameNo>, Error> {
match self.try_flush() {
Err(e) => {
self.biggest_uncommitted_seen = 0;
self.rollback();
Err(e)
}
Ok(ret) => Ok(ret),
}
}
fn try_flush(&mut self) -> Result<Option<FrameNo>, Error> {
if !self.is_txn {
self.begin_txn()?;
}
let lock = self.buffer.lock();
let last_frame_no = match lock.back().zip(lock.front()) {
Some((b, f)) => f.header().frame_no.get().max(b.header().frame_no.get()),
None => {
tracing::trace!("nothing to inject");
return Ok(None);
}
};
self.biggest_uncommitted_seen = self.biggest_uncommitted_seen.max(last_frame_no);
drop(lock);
let connection = self.connection.lock();
let mut stmt =
connection.prepare_cached("INSERT INTO libsql_temp_injection VALUES (42)")?;
match stmt.execute(()).and_then(|_| connection.cache_flush()) {
Ok(_) => panic!("replication hook was not called"),
Err(e) => {
if let Some(err) = e.sqlite_error() {
if err.extended_code == LIBSQL_INJECT_OK {
connection.pragma_update(None, "writable_schema", "reset")?;
let mut rollback = connection.prepare_cached("ROLLBACK")?;
let _ = rollback.execute(());
self.is_txn = false;
assert!(self.buffer.lock().is_empty());
let commit_frame_no = self.biggest_uncommitted_seen;
self.biggest_uncommitted_seen = 0;
return Ok(Some(commit_frame_no));
} else if err.extended_code == LIBSQL_INJECT_OK_TXN {
self.is_txn = true;
assert!(self.buffer.lock().is_empty());
return Ok(None);
} else if err.extended_code == LIBSQL_INJECT_FATAL {
return Err(Error::FatalInjectError(e.into()));
}
}
Err(Error::FatalInjectError(e.into()))
}
}
}
fn begin_txn(&mut self) -> Result<(), Error> {
let mut conn = self.connection.lock();
{
let wal_manager = InjectorWalManager::new(self.buffer.clone());
let new_conn = libsql_sys::Connection::open(
&self.path,
OpenFlags::SQLITE_OPEN_READ_WRITE
| OpenFlags::SQLITE_OPEN_CREATE
| OpenFlags::SQLITE_OPEN_URI
| OpenFlags::SQLITE_OPEN_NO_MUTEX,
wal_manager,
self.auto_checkpoint,
self.encryption_config.clone(),
)?;
let _ = std::mem::replace(&mut *conn, new_conn);
}
conn.pragma_update(None, "writable_schema", "true")?;
let mut stmt = conn.prepare_cached("BEGIN IMMEDIATE")?;
stmt.execute(())?;
let mut stmt =
conn.prepare_cached("CREATE TABLE IF NOT EXISTS libsql_temp_injection (x)")?;
stmt.execute(())?;
Ok(())
}
pub fn clear_buffer(&mut self) {
self.buffer.lock().clear()
}
#[cfg(test)]
pub fn is_txn(&self) -> bool {
self.is_txn
}
}
#[cfg(test)]
mod test {
use crate::frame::FrameBorrowed;
use std::mem::size_of;
use super::*;
const WAL: &[u8] = include_bytes!("../../../assets/test/test_wallog");
fn wal_log() -> impl Iterator<Item = Frame> {
WAL.chunks(size_of::<FrameBorrowed>())
.map(|b| Frame::try_from(b).unwrap())
}
#[test]
fn test_simple_inject_frames() {
let temp = tempfile::tempdir().unwrap();
let mut injector =
SqliteInjectorInner::new(temp.path().join("data"), 10, 10000, None).unwrap();
let log = wal_log();
for frame in log {
injector.inject_frame(frame).unwrap();
}
let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap();
conn.query_row("SELECT COUNT(*) FROM test", (), |row| {
assert_eq!(row.get::<_, usize>(0).unwrap(), 5);
Ok(())
})
.unwrap();
}
#[test]
fn test_inject_frames_split_txn() {
let temp = tempfile::tempdir().unwrap();
let mut injector =
SqliteInjectorInner::new(temp.path().join("data"), 1, 10000, None).unwrap();
let log = wal_log();
for frame in log {
injector.inject_frame(frame).unwrap();
}
let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap();
conn.query_row("SELECT COUNT(*) FROM test", (), |row| {
assert_eq!(row.get::<_, usize>(0).unwrap(), 5);
Ok(())
})
.unwrap();
}
#[test]
fn test_inject_partial_txn_isolated() {
let temp = tempfile::tempdir().unwrap();
let mut injector =
SqliteInjectorInner::new(temp.path().join("data"), 10, 1000, None).unwrap();
let mut frames = wal_log();
assert!(injector
.inject_frame(frames.next().unwrap())
.unwrap()
.is_none());
let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap();
assert!(conn
.query_row("SELECT COUNT(*) FROM test", (), |_| Ok(()))
.is_err());
while injector
.inject_frame(frames.next().unwrap())
.unwrap()
.is_none()
{}
conn.pragma_update(None, "writable_schema", "reset")
.unwrap();
conn.query_row("SELECT COUNT(*) FROM test", (), |_| Ok(()))
.unwrap();
}
}