use parking_lot::{
MappedRwLockReadGuard, MappedRwLockWriteGuard, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard,
};
use std::collections::HashMap;
use std::fmt;
use std::fs::OpenOptions;
use std::io::{copy, Cursor as IOCursor, Read, Seek, SeekFrom, Write};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{mpsc, Arc, Weak};
use std::thread;
use std::time::{Duration, SystemTime};
use crate::bucket::{Bucket, Cursor};
use crate::consts::{Flags, IGNORE_NOSYNC, PGID, TXID};
use crate::db::{CheckMode, WeakDB, DB};
use crate::errors::Error;
use crate::meta::Meta;
use crate::page::{OwnedPage, Page, PageInfo};
use super::stats::TxStats;
pub(crate) struct TxInner {
pub(crate) writable: bool,
pub(crate) managed: AtomicBool,
pub(crate) check: AtomicBool,
pub(crate) db: RwLock<WeakDB>,
pub(crate) meta: RwLock<Meta>,
pub(crate) root: RwLock<Bucket>,
pub(crate) pages: RwLock<HashMap<PGID, OwnedPage>>,
pub(crate) stats: Mutex<TxStats>,
pub(crate) commit_handlers: Mutex<Vec<Box<dyn Fn()>>>,
pub(super) write_flag: usize,
}
impl fmt::Debug for TxInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let db = self
.db
.try_read()
.unwrap()
.upgrade()
.map(|db| &db as *const DB);
f.debug_struct("TxInner")
.field("writable", &self.writable)
.field("managed", &self.managed)
.field("db", &db)
.field("meta", &*self.meta.try_read().unwrap())
.field("root", &*self.root.try_read().unwrap())
.field("pages", &*self.pages.try_read().unwrap())
.field("stats", &*self.stats.try_lock().unwrap())
.field(
"commit handlers len",
&self.commit_handlers.try_lock().unwrap().len(),
)
.field("write_flag", &self.write_flag)
.finish()
}
}
#[derive(Debug)]
pub struct Tx(pub(crate) Arc<TxInner>);
unsafe impl Sync for Tx {}
unsafe impl Send for Tx {}
impl Tx {
pub(crate) fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
pub fn writable(&self) -> bool {
self.0.writable
}
pub(crate) fn opened(&self) -> bool {
match self.0.db.try_read().unwrap().upgrade() {
None => false,
Some(db) => db.opened(),
}
}
pub(crate) fn db(&self) -> Result<DB, Error> {
self.0
.db
.try_read()
.unwrap()
.upgrade()
.ok_or(Error::DatabaseGone)
}
pub(crate) fn id(&self) -> TXID {
self.0.meta.try_read().unwrap().txid
}
pub(crate) fn pgid(&self) -> PGID {
self.0.meta.try_read().unwrap().pgid
}
pub(crate) fn set_pgid(&mut self, id: PGID) -> Result<(), Error> {
self.0.meta.try_write().ok_or("pgid locked")?.pgid = id;
Ok(())
}
pub fn on_commit(&mut self, handler: Box<dyn Fn()>) {
self.0.commit_handlers.lock().push(handler);
}
pub(super) fn size(&self) -> i64 {
self.pgid() as i64 * self.db().unwrap().page_size() as i64
}
pub fn cursor(&self) -> Cursor<RwLockWriteGuard<Bucket>> {
self.0.stats.lock().cursor_count += 1;
Cursor::new(self.0.root.write())
}
pub fn stats(&self) -> TxStats {
self.0.stats.lock().clone()
}
pub fn bucket(&self, key: &[u8]) -> Result<MappedRwLockReadGuard<Bucket>, Error> {
let bucket = self.0.root.try_read().ok_or("Can't acquire bucket")?;
RwLockReadGuard::try_map(bucket, |b| b.bucket(key)).map_err(|_| "Can't get bucket".into())
}
pub fn bucket_mut(&mut self, key: &[u8]) -> Result<MappedRwLockWriteGuard<Bucket>, Error> {
if !self.0.writable {
return Err(Error::TxReadonly);
};
let bucket = self.0.root.try_write().ok_or("Can't acquire bucket")?;
RwLockWriteGuard::try_map(bucket, |b| b.bucket_mut(key))
.map_err(|_| "Can't get bucket".into())
}
pub fn buckets(&self) -> Vec<Vec<u8>> {
self.0.root.read().buckets()
}
pub fn create_bucket(&mut self, key: &[u8]) -> Result<MappedRwLockWriteGuard<Bucket>, Error> {
if !self.0.writable {
return Err(Error::TxReadonly);
};
let bucket = self.0.root.try_write().ok_or("Can't acquire bucket")?;
RwLockWriteGuard::try_map(bucket, |b| b.create_bucket(key).ok())
.map_err(|_| "Can't get bucket".into())
}
pub fn create_bucket_if_not_exists(
&mut self,
key: &[u8],
) -> Result<MappedRwLockWriteGuard<Bucket>, Error> {
if !self.0.writable {
return Err(Error::TxReadonly);
};
let bucket = self.0.root.try_write().ok_or("Can't acquire bucket")?;
RwLockWriteGuard::try_map(bucket, |b| b.create_bucket_if_not_exists(key).ok())
.map_err(|_| "Can't get bucket".into())
}
pub fn delete_bucket(&mut self, key: &[u8]) -> Result<(), Error> {
if !self.0.writable {
return Err(Error::TxReadonly);
};
self.0.root.try_write().unwrap().delete_bucket(key)
}
#[allow(clippy::type_complexity)]
pub fn for_each<'a, E: Into<Error>>(
&self,
mut handler: Box<dyn FnMut(&[u8], Option<&Bucket>) -> Result<(), E> + 'a>,
) -> Result<(), Error> {
let root = self.0.root.try_write().unwrap();
root.for_each(Box::new(|k: &[u8], _v: Option<&[u8]>| -> Result<(), E> {
handler(k, root.bucket(k))
}))
}
pub fn write_to<W: Write>(&self, mut w: W) -> Result<i64, Error> {
let db = self.db()?;
let page_size = db.page_size();
let mut file =
db.0.file
.try_write()
.ok_or("can't obtain file write access")?;
let mut written = 0;
let mut page = OwnedPage::new(page_size);
page.flags = Flags::META;
{
*page.meta_mut() = self.0.meta.try_read().unwrap().clone();
page.meta_mut().checksum = page.meta().sum64();
w.write_all(page.buf()).map_err(|e| format!("{e}"))?;
written += page.size();
}
{
page.id = 1;
page.meta_mut().txid -= 1;
page.meta_mut().checksum = page.meta().sum64();
w.write_all(page.buf()).map_err(|e| format!("{e}"))?;
written += page.size();
}
file.seek(SeekFrom::Start(page_size as u64 * 2))
.map_err(|e| format!("{e}"))?;
let size = self.size() as u64 - (page_size as u64 * 2);
written += copy(&mut Read::by_ref(&mut file.get_mut()).take(size), &mut w)
.map_err(|e| format!("{e}"))? as usize;
Ok(written as i64)
}
pub fn copy_to(&self, path: &str, mode: OpenOptions) -> Result<(), Error> {
let file = mode.open(path).map_err(|e| e.to_string())?;
self.write_to(file)?;
Ok(())
}
pub(crate) fn close(&self) -> Result<(), Error> {
let mut db = self.db()?;
let tx = db.remove_tx(self)?;
*tx.0.db.write() = WeakDB::new();
tx.0.root.try_write().unwrap().clear();
tx.0.pages.try_write().unwrap().clear();
Ok(())
}
pub fn commit(&mut self) -> Result<(), Error> {
if self.0.managed.load(Ordering::Acquire) {
return Err(Error::TxManaged);
} else if !self.writable() {
return Err(Error::TxReadonly);
};
let mut db = self.db()?;
{
let start_time = SystemTime::now();
self.0.root.try_write().unwrap().rebalance();
let mut stats = self.0.stats.lock();
if stats.rebalance > 0 {
stats.rebalance_time += SystemTime::now().duration_since(start_time)?;
};
}
{
let start_time = SystemTime::now();
{
let mut root = self.0.root.try_write().unwrap();
root.spill()?;
}
self.0.stats.try_lock().unwrap().spill_time =
SystemTime::now().duration_since(start_time)?;
}
self.0.meta.try_write().unwrap().root.root = self.0.root.try_read().unwrap().bucket.root;
let (txid, tx_pgid, freelist_pgid) = {
let meta = self.0.meta.try_read().unwrap();
(meta.txid as usize, meta.pgid, meta.freelist)
};
let (freelist_size, page_size) = {
let page = db.page(freelist_pgid);
let mut freelist = db.0.freelist.try_write().unwrap();
freelist.free(txid as u64, &page)?;
let freelist_size = freelist.size();
let page_size = db.page_size();
(freelist_size, page_size)
};
{
let page = self.allocate((freelist_size / page_size) as u64 + 1);
if let Err(e) = page {
self.rollback()?;
return Err(e);
}
let page = page?;
let page = unsafe { &mut *page };
db.0.freelist.try_write().unwrap().write(page);
self.0.meta.try_write().unwrap().freelist = page.id;
if self.pgid() > tx_pgid {
if let Err(e) = db.grow((tx_pgid + 1) * page_size as u64) {
self.rollback()?;
return Err(e);
}
}
let write_start_time = SystemTime::now();
if let Err(e) = self.write() {
self.rollback()?;
return Err(e);
}
if self.0.check.swap(false, Ordering::AcqRel) {
let strict = db.0.check_mode.contains(CheckMode::STRICT);
if let Err(e) = self.check_sync() {
if strict {
self.rollback()?;
return Err(e);
} else {
println!("{e}");
}
}
};
if let Err(e) = self.write_meta() {
self.0.check.store(false, Ordering::Release);
self.rollback()?;
return Err(e);
};
self.0.stats.try_lock().unwrap().write_time +=
SystemTime::now().duration_since(write_start_time)?;
};
self.close()?;
{
for h in &*self.0.commit_handlers.try_lock().unwrap() {
h();
}
}
Ok(())
}
pub fn rollback(&self) -> Result<(), Error> {
if self.0.managed.load(Ordering::Acquire) {
return Err(Error::TxManaged);
};
self.__rollback()?;
Ok(())
}
pub(crate) fn __rollback(&self) -> Result<(), Error> {
let db = self.db()?;
if self.0.check.swap(false, Ordering::AcqRel) {
let strict = db.0.check_mode.contains(CheckMode::STRICT);
if let Err(e) = self.check_sync() {
if strict {
return Err(e);
} else {
println!("{e}");
}
}
};
if self.writable() {
let txid = self.id();
let mut freelist = db.0.freelist.write();
freelist.rollback(txid);
let freelist_id = db.meta()?.freelist;
let freelist_page = db.page(freelist_id);
freelist.reload(&freelist_page);
};
self.close()?;
Ok(())
}
pub fn check_sync(&self) -> Result<(), Error> {
let (sender, ch) = mpsc::channel::<String>();
let tx = self.clone();
let handle = thread::spawn(move || tx.__check(sender));
let mut errs = vec![];
for err in ch {
errs.push(err);
}
if let Err(e) = handle.join() {
let estr = e.downcast_ref::<String>();
if let Some(estr) = estr {
errs.push(estr.clone());
} else {
errs.push(format!("check thread panicked: {e:?}"));
}
}
if !errs.is_empty() {
return Err(Error::CheckFail(errs));
};
Ok(())
}
pub fn check(&self) -> mpsc::Receiver<String> {
let (sender, receiver) = mpsc::channel::<String>();
let tx = self.clone();
thread::spawn(move || tx.__check(sender));
receiver
}
pub fn freed(&self) -> Result<HashMap<PGID, bool>, Error> {
let mut freed = HashMap::<PGID, bool>::new();
let all_pgids = self
.db()
.unwrap()
.0
.freelist
.try_read()
.unwrap()
.get_pgids();
for id in &all_pgids {
if freed.contains_key(id) {
return Err(format!("page {id}: already freed").into());
}
freed.insert(*id, true);
}
Ok(freed)
}
pub(super) fn __check(&self, ch: mpsc::Sender<String>) {
let freed = self.freed();
if let Err(e) = freed {
ch.send(e.to_string()).unwrap();
return;
}
let freed = freed.unwrap();
let mut reachable = HashMap::new();
reachable.insert(0, true);
reachable.insert(1, true);
let freelist_pgid = self
.0
.meta
.try_read_for(Duration::from_secs(10))
.unwrap()
.freelist;
let freelist_overflow = unsafe { &*self.page(freelist_pgid).unwrap() }.overflow;
for i in 0..=freelist_overflow {
reachable.insert(freelist_pgid + u64::from(i), true);
}
self.check_bucket(
&self.0.root.try_read().unwrap(),
&mut reachable,
&freed,
&ch,
);
for i in 0..self.0.meta.try_read().unwrap().pgid {
let is_reachable = reachable.contains_key(&i);
let is_freed = freed.contains_key(&i);
if !is_reachable && !is_freed {
ch.send(format!("page {i}: unreachable unfreed")).unwrap();
};
}
}
fn check_bucket(
&self,
b: &Bucket,
reachable: &mut HashMap<PGID, bool>,
freed: &HashMap<PGID, bool>,
ch: &mpsc::Sender<String>,
) {
if b.bucket.root == 0 {
return;
}
let meta_pgid = self.pgid();
let handler = Box::new(|p: &Page, _pgid: usize| {
if p.id > meta_pgid {
ch.send(format!("page {}: out of bounds: {}", p.id, meta_pgid))
.unwrap();
}
for i in 0..=p.overflow {
let id = p.id + u64::from(i);
if reachable.contains_key(&id) {
ch.send(format!("page {id}: multiple references"))
.unwrap();
}
reachable.insert(id, true);
}
let page_type_is_valid = matches!(p.flags, Flags::BRANCHES | Flags::LEAVES);
if freed.contains_key(&p.id) {
ch.send(format!("page {}: reachable freed", p.id)).unwrap();
} else if !page_type_is_valid {
ch.send(format!("page {}: invalid type: {}", p.id, p.flags))
.unwrap();
}
});
b.tx().unwrap().for_each_page(b.bucket.root, 0, handler);
b.for_each(Box::new(|k, _v| -> Result<(), String> {
let child = b.bucket(k);
if let Some(child) = child {
self.check_bucket(child, reachable, freed, ch);
};
Ok(())
}))
.unwrap();
}
pub(crate) fn allocate(&mut self, count: u64) -> Result<*mut Page, Error> {
let mut db = match self.db() {
Err(_) => return Err(Error::TxClosed),
Ok(db) => db,
};
let mut page = db.allocate(count, self)?;
let page_id = page.id;
let page_ptr = &mut *page as *mut Page;
self.0.pages.try_write().unwrap().insert(page_id, page);
{
let mut stats = self.0.stats.lock();
stats.page_count += 1;
stats.page_alloc += count as usize * self.db()?.page_size();
}
Ok(page_ptr)
}
pub(crate) fn write(&mut self) -> Result<(), Error> {
let mut pages: Vec<_> = self
.0
.pages
.write()
.drain()
.map(|x| {
let pgid = x.1.id;
(pgid, x.1)
})
.collect();
pages.sort_by(|a, b| a.0.cmp(&b.0));
let mut db = self.db()?;
let page_size = db.page_size();
for (id, p) in &pages {
let size = (p.overflow + 1) as usize * page_size;
let offset = *id * page_size as u64;
let buf = unsafe { std::slice::from_raw_parts(p.as_ptr(), size) };
let cursor = IOCursor::new(buf);
db.write_at(offset, cursor)?;
}
if !db.0.no_sync || IGNORE_NOSYNC {
db.sync()?;
}
{
let mut page_pool = db.0.page_pool.lock();
let mut i = 0;
while i != pages.len() {
if pages[i].1.size() == page_size {
let mut page = pages.remove(i).1;
for i in page.buf_mut() {
*i = 0;
}
page_pool.push(page);
} else {
i += 1;
}
}
}
Ok(())
}
pub(crate) fn write_meta(&mut self) -> Result<(), Error> {
let mut db = self.db()?;
let mut buf = vec![0u8; db.page_size()];
let page = Page::from_buf_mut(&mut buf);
self.0.meta.try_write().unwrap().write(page)?;
db.write_at(0, IOCursor::new(buf))?;
if !db.0.no_sync || IGNORE_NOSYNC {
db.sync()?;
}
self.0.stats.lock().write += 1;
Ok(())
}
pub(crate) fn page(&self, id: PGID) -> Result<*const Page, Error> {
{
let pages = self.0.pages.try_read().unwrap();
if let Some(p) = pages.get(&id) {
return Ok(&**p);
}
}
Ok(&*self.db()?.page(id))
}
pub(crate) fn for_each_page<'a>(
&self,
pgid: PGID,
depth: usize,
mut func: Box<dyn FnMut(&Page, usize) + 'a>,
) {
let p = unsafe { &*self.page(pgid).unwrap() };
func(p, depth);
let flags = p.flags;
if flags != Flags::BRANCHES {
return;
}
let count = p.count as usize;
for i in 0..count {
let el = p.branch_page_element(i);
self.for_each_page(el.pgid, depth + 1, Box::new(|p, d| func(p, d)));
}
}
pub fn page_info(&self, id: usize) -> Result<Option<PageInfo>, Error> {
if !self.opened() {
return Err(Error::TxClosed);
};
if id >= self.pgid() as usize {
return Ok(None);
};
let db = self.db()?;
let p = db.page(id as u64);
let mut info = PageInfo {
id: id as isize,
ptype: Flags::FREELIST,
count: p.count as usize,
overflow_count: p.overflow as usize,
};
if !db.0.freelist.try_read().unwrap().freed(id as u64) {
info.ptype = p.flags
}
Ok(Some(info))
}
}
impl Drop for Tx {
fn drop(&mut self) {
let count = Arc::strong_count(&self.0);
if count > 2 {
return;
};
if let Ok(_db) = self.db() {
if self.0.writable {
self.commit().unwrap();
} else {
self.rollback().unwrap();
}
}
}
}
#[derive(Clone)]
pub(crate) struct WeakTx(Weak<TxInner>);
impl WeakTx {
pub(crate) fn new() -> Self {
Self(Weak::new())
}
pub(crate) fn upgrade(&self) -> Option<Tx> {
self.0.upgrade().map(Tx)
}
pub(crate) fn from(tx: &Tx) -> Self {
Self(Arc::downgrade(&tx.0))
}
}