use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
mpsc::{Receiver, Sender, channel},
},
thread::{self, JoinHandle},
time::Duration,
};
use crate::database::{CompareAndSwapTransaction, Database, DatabaseError, Db, TransactionError};
type TransactionFn<D> =
Box<dyn Fn(&mut CompareAndSwapTransaction<D>) -> Result<(), TransactionError> + Send + 'static>;
pub struct DatabaseWriter<D: Database> {
db: Db<D>,
sender: Option<Sender<TransactionFn<D>>>,
closed: Arc<AtomicBool>,
thread: Option<JoinHandle<Result<(), DatabaseError>>>,
}
impl<D: Database + 'static> DatabaseWriter<D> {
#[must_use]
pub fn spawn(db: Db<D>) -> Self {
let (sender, rx) = channel();
let closed = Arc::new(AtomicBool::new(false));
let run_closed = closed.clone();
let run_db = db.clone();
let thread = thread::spawn(move || {
if let Err(err) = Self::run(run_db, rx) {
run_closed.store(true, Ordering::Relaxed);
return Err(err);
}
Ok(())
});
Self {
db,
sender: Some(sender),
closed,
thread: thread.into(),
}
}
pub fn run(db: Db<D>, rx: Receiver<TransactionFn<D>>) -> Result<(), DatabaseError> {
'run_loop: loop {
let mut queue = Vec::new();
match rx.recv() {
Ok(func) => queue.push(func),
Err(_) => return Ok(()),
}
thread::sleep(Duration::from_millis(250));
while let Ok(func) = rx.try_recv() {
queue.push(func);
}
for _ in 0..10 {
let mut cas_tx = CompareAndSwapTransaction::with_db(db.clone());
for func in &mut queue {
func(&mut cas_tx)?;
}
match cas_tx.apply(false) {
Ok(()) => continue 'run_loop,
Err(TransactionError::CompareAndSwapError) => {
#[cfg(feature = "log")]
log::trace!("Transaction ran into a CAS error and is retrying.");
}
Err(err) => return Err(DatabaseError::Transaction(err)),
}
}
return Err(DatabaseError::Transaction(TransactionError::TooManyRetries));
}
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.closed.load(Ordering::Relaxed)
}
pub fn transaction<F>(&self, func: F)
where
F: Fn(&mut CompareAndSwapTransaction<D>) -> Result<(), TransactionError> + Send + 'static,
{
if self.is_closed() {
return;
}
let boxed: TransactionFn<D> = Box::new(move |cas_tx| func(cas_tx));
if let Some(sender) = &self.sender {
let _ = sender.send(boxed);
}
}
pub fn finish(mut self, flush: bool) -> Result<(), DatabaseError> {
if flush {
self.db.flush()?;
}
drop(self.sender.take());
self.thread
.take()
.unwrap()
.join()
.expect("Writer thread panicked")?;
Ok(())
}
}
impl<T: Database> Drop for DatabaseWriter<T> {
fn drop(&mut self) {
drop(self.sender.take());
}
}