use crate::Error;
use crate::error::StorageError;
use crate::transactions::WriteTransaction;
use std::fmt::{Display, Formatter};
use std::sync::mpsc;
use std::sync::{Condvar, Mutex};
#[derive(Debug)]
#[non_exhaustive]
pub enum GroupCommitError {
BatchFailed(Error),
PeerFailed,
TransactionFailed(StorageError),
CommitFailed(StorageError),
Shutdown,
}
impl Display for GroupCommitError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::BatchFailed(e) => write!(f, "Batch operation failed: {e}"),
Self::PeerFailed => write!(f, "Rolled back: another batch in the group failed"),
Self::TransactionFailed(e) => write!(f, "Transaction acquisition failed: {e}"),
Self::CommitFailed(e) => write!(f, "Commit failed: {e}"),
Self::Shutdown => write!(f, "Database is shutting down"),
}
}
}
impl std::error::Error for GroupCommitError {}
type BatchFn =
Box<dyn FnOnce(&WriteTransaction) -> std::result::Result<(), Error> + Send + 'static>;
pub struct WriteBatch {
operations: BatchFn,
}
impl WriteBatch {
pub fn new<F>(f: F) -> Self
where
F: FnOnce(&WriteTransaction) -> std::result::Result<(), Error> + Send + 'static,
{
Self {
operations: Box::new(f),
}
}
pub(crate) fn apply(self, txn: &WriteTransaction) -> std::result::Result<(), Error> {
(self.operations)(txn)
}
}
pub(crate) struct PendingBatch {
pub batch: WriteBatch,
pub result_tx: mpsc::SyncSender<Result<(), GroupCommitError>>,
}
struct GroupCommitState {
pending: Vec<PendingBatch>,
active_leader: bool,
shutdown: bool,
}
pub(crate) struct GroupCommitter {
state: Mutex<GroupCommitState>,
leader_done: Condvar,
}
impl GroupCommitter {
pub fn new() -> Self {
Self {
state: Mutex::new(GroupCommitState {
pending: Vec::new(),
active_leader: false,
shutdown: false,
}),
leader_done: Condvar::new(),
}
}
pub fn enqueue(
&self,
batch: WriteBatch,
) -> Result<(bool, mpsc::Receiver<Result<(), GroupCommitError>>), GroupCommitError> {
let (result_tx, result_rx) = mpsc::sync_channel(1);
let mut state = self.state.lock().unwrap();
if state.shutdown {
return Err(GroupCommitError::Shutdown);
}
let should_lead = !state.active_leader;
if should_lead {
state.active_leader = true;
}
state.pending.push(PendingBatch { batch, result_tx });
Ok((should_lead, result_rx))
}
pub fn drain_pending(&self) -> Vec<PendingBatch> {
let mut state = self.state.lock().unwrap();
std::mem::take(&mut state.pending)
}
pub fn finish_leader(&self) {
let mut state = self.state.lock().unwrap();
state.active_leader = false;
self.leader_done.notify_all();
}
pub fn shutdown(&self) {
let mut state = self.state.lock().unwrap();
state.shutdown = true;
let pending = std::mem::take(&mut state.pending);
drop(state);
for p in pending {
let _ = p.result_tx.send(Err(GroupCommitError::Shutdown));
}
self.leader_done.notify_all();
}
}