#![allow(dead_code)]
pub mod flush;
pub mod memtable;
pub mod wal;
use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use parking_lot::{Condvar, Mutex, RwLock};
use flush::{FlushCoordinator, FlushError, FlushOutcome};
use memtable::{MemTable, MemTableEntry, MemTableError};
use wal::{Wal, WalTicket};
const WAL_ENTRY_HEADER: usize = 17;
use crate::observe::{OperationKind, OperationObserver};
use crate::storage::Segment;
#[derive(Clone, Debug)]
pub enum Mutation {
Put { key: Vec<u8>, value: Vec<u8> },
Delete { key: Vec<u8> },
}
#[derive(Clone, Debug, Default)]
pub struct Transaction {
pub operations: Vec<Mutation>,
}
impl Transaction {
pub fn new(operations: Vec<Mutation>) -> Self {
Self { operations }
}
pub fn is_empty(&self) -> bool {
self.operations.is_empty()
}
}
#[derive(Clone, Debug, Default)]
pub struct CommitReceipt {
pub wal_offset: u64,
pub wal_len: u32,
pub sequence_start: u64,
pub sequence_end: u64,
pub operations: usize,
}
#[derive(Clone, Debug)]
pub struct CommitOutcome {
pub receipt: CommitReceipt,
pub entries: Vec<MemTableEntry>,
}
#[derive(Debug, Clone)]
pub struct GroupCommitError(Arc<GroupCommitErrorInner>);
#[derive(Debug, thiserror::Error)]
enum GroupCommitErrorInner {
#[error("wal error: {0}")]
Wal(#[from] std::io::Error),
#[error("memtable error: {0}")]
MemTable(#[from] MemTableError),
#[error("flush error: {0}")]
Flush(#[from] FlushError),
}
impl std::fmt::Display for GroupCommitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for GroupCommitError {}
impl From<std::io::Error> for GroupCommitError {
fn from(value: std::io::Error) -> Self {
Self(Arc::new(GroupCommitErrorInner::Wal(value)))
}
}
impl From<MemTableError> for GroupCommitError {
fn from(value: MemTableError) -> Self {
Self(Arc::new(GroupCommitErrorInner::MemTable(value)))
}
}
impl From<FlushError> for GroupCommitError {
fn from(value: FlushError) -> Self {
Self(Arc::new(GroupCommitErrorInner::Flush(value)))
}
}
#[derive(Clone, Debug)]
pub struct GroupCommitConfig {
pub max_batch_ops: usize,
pub max_batch_bytes: u64,
pub min_interval: Duration,
pub max_interval: Duration,
pub fsync_on_commit: bool,
}
impl Default for GroupCommitConfig {
fn default() -> Self {
Self {
max_batch_ops: 1024,
max_batch_bytes: 0,
min_interval: Duration::from_micros(500),
max_interval: Duration::from_micros(1000),
fsync_on_commit: true,
}
}
}
struct PendingTxn {
id: u64,
txn: Transaction,
}
enum PendingResult {
Receipt(CommitOutcome),
Error(GroupCommitError),
}
struct BatchState {
pending: VecDeque<PendingTxn>,
results: HashMap<u64, PendingResult>,
last_flush: Instant,
shutdown: bool,
}
impl BatchState {
fn new() -> Self {
Self {
pending: VecDeque::new(),
results: HashMap::new(),
last_flush: Instant::now(),
shutdown: false,
}
}
}
pub struct GroupCommitController {
wal: Arc<Wal>,
flush: Arc<FlushCoordinator>,
config: GroupCommitConfig,
next_ticket: AtomicU64,
state: Arc<(Mutex<BatchState>, Condvar)>,
memtable: Arc<RwLock<Arc<MemTable>>>,
rotation: Arc<Mutex<()>>,
completed_flushes: Arc<Mutex<Vec<FlushOutcome>>>,
worker: Option<std::thread::JoinHandle<()>>,
metrics: OperationObserver,
}
impl GroupCommitController {
pub fn new(
wal: Wal,
flush: FlushCoordinator,
config: GroupCommitConfig,
metrics: OperationObserver,
) -> Self {
let controller = Self {
wal: Arc::new(wal),
flush: Arc::new(flush),
config,
next_ticket: AtomicU64::new(0),
state: Arc::new((Mutex::new(BatchState::new()), Condvar::new())),
memtable: Arc::new(RwLock::new(Arc::new(MemTable::new()))),
rotation: Arc::new(Mutex::new(())),
completed_flushes: Arc::new(Mutex::new(Vec::new())),
worker: None,
metrics,
};
controller.spawn_worker()
}
fn spawn_worker(mut self) -> Self {
let state = Arc::clone(&self.state);
let wal = Arc::clone(&self.wal);
let flush = Arc::clone(&self.flush);
let config = self.config.clone();
let memtable = Arc::clone(&self.memtable);
let rotation = Arc::clone(&self.rotation);
let completed = Arc::clone(&self.completed_flushes);
let metrics = self.metrics.clone();
let handle = std::thread::Builder::new()
.name("group-commit".to_string())
.spawn(move || {
let (lock, cv) = &*state;
loop {
let mut guard = lock.lock();
while guard.pending.is_empty() && !guard.shutdown {
cv.wait(&mut guard);
}
if guard.shutdown {
break;
}
let now = Instant::now();
let elapsed = now.saturating_duration_since(guard.last_flush);
let pending_len = guard.pending.len();
if pending_len > 1
&& pending_len < config.max_batch_ops
&& elapsed < config.min_interval
{
let timeout = config.min_interval - elapsed;
let wait_result = cv.wait_for(&mut guard, timeout);
if guard.shutdown {
break;
}
if guard.pending.is_empty() {
continue;
}
if !wait_result.timed_out()
&& guard.pending.len() < config.max_batch_ops
&& Instant::now().saturating_duration_since(guard.last_flush)
< config.max_interval
{
continue;
}
}
let mut batch = Vec::new();
let mut approx_bytes: usize = 0;
while let Some(txn) = guard.pending.pop_front() {
let mut txn_bytes = 0usize;
for op in &txn.txn.operations {
match op {
Mutation::Put { key, value } => {
txn_bytes += 17 + key.len() + value.len();
}
Mutation::Delete { key } => {
txn_bytes += 17 + key.len();
}
}
}
if config.max_batch_bytes > 0
&& approx_bytes > 0
&& (approx_bytes + txn_bytes) as u64 > config.max_batch_bytes
{
guard.pending.push_front(txn);
break;
}
approx_bytes += txn_bytes;
batch.push(txn);
if batch.len() >= config.max_batch_ops {
break;
}
}
guard.last_flush = Instant::now();
let batch_ids: Vec<u64> = batch.iter().map(|txn| txn.id).collect();
drop(guard);
let results = process_batch(
&wal,
&flush,
&memtable,
&rotation,
&completed,
&metrics,
batch,
config.fsync_on_commit,
);
let mut guard = lock.lock();
match results {
Ok(entries) => {
for (id, outcome) in entries {
guard.results.insert(id, PendingResult::Receipt(outcome));
}
}
Err(err) => {
for id in &batch_ids {
guard.results.insert(*id, PendingResult::Error(err.clone()));
}
let remaining: Vec<u64> =
guard.pending.drain(..).map(|txn| txn.id).collect();
for id in remaining {
guard.results.insert(id, PendingResult::Error(err.clone()));
}
}
}
cv.notify_all();
}
})
.expect("failed to spawn group-commit worker");
self.worker = Some(handle);
self
}
pub fn submit(&self, txn: Transaction) -> Result<CommitOutcome, GroupCommitError> {
let id = self.next_ticket.fetch_add(1, Ordering::SeqCst);
let pending = PendingTxn { id, txn };
let (lock, cv) = &*self.state;
let mut guard = lock.lock();
guard.pending.push_back(pending);
cv.notify_one();
loop {
if let Some(result) = guard.results.remove(&id) {
match result {
PendingResult::Receipt(outcome) => return Ok(outcome),
PendingResult::Error(err) => return Err(err),
}
}
cv.wait(&mut guard);
}
}
pub fn current_memtable(&self) -> Arc<MemTable> {
Arc::clone(&self.memtable.read())
}
pub fn drain_completed_flushes(&self) -> Vec<FlushOutcome> {
let mut guard = self.completed_flushes.lock();
guard.drain(..).collect()
}
pub fn sync_wal(&self) -> std::io::Result<()> {
self.wal.commit()
}
fn rotate_memtable(
wal: &Arc<Wal>,
flush: &Arc<FlushCoordinator>,
memtable: &Arc<RwLock<Arc<MemTable>>>,
rotation: &Arc<Mutex<()>>,
completed: &Arc<Mutex<Vec<FlushOutcome>>>,
metrics: &OperationObserver,
current: &Arc<MemTable>,
fsync_on_commit: bool,
) -> Result<Arc<MemTable>, GroupCommitError> {
let _guard = rotation.lock();
let mut slot = memtable.write();
if !Arc::ptr_eq(&slot, current) {
return Ok(slot.clone());
}
let frozen = current.freeze();
let timer = metrics.timer(OperationKind::Flush);
let outcome = flush.flush(frozen)?;
drop(timer);
metrics.record_value(OperationKind::Flush, outcome.segment.metadata().size_bytes);
completed.lock().push(outcome);
let new_table = Arc::new(MemTable::new());
*slot = new_table.clone();
if fsync_on_commit {
wal.commit()?; } else {
wal.flush()?; }
Ok(new_table)
}
}
impl Drop for GroupCommitController {
fn drop(&mut self) {
if let Some(handle) = self.worker.take() {
let (lock, cv) = &*self.state;
let mut guard = lock.lock();
guard.shutdown = true;
cv.notify_all();
drop(guard);
handle.join().ok();
}
}
}
fn process_batch(
wal: &Arc<Wal>,
flush: &Arc<FlushCoordinator>,
memtable: &Arc<RwLock<Arc<MemTable>>>,
rotation: &Arc<Mutex<()>>,
completed: &Arc<Mutex<Vec<FlushOutcome>>>,
metrics: &OperationObserver,
batch: Vec<PendingTxn>,
fsync_on_commit: bool,
) -> Result<Vec<(u64, CommitOutcome)>, GroupCommitError> {
if batch.is_empty() {
return Ok(Vec::new());
}
let mut active = memtable.read().clone();
let mut per_txn_entries: Vec<(u64, Vec<MemTableEntry>, Option<u64>, Option<u64>)> = Vec::new();
let mut total_wal_bytes: usize = 0;
let mut applied_ops = 0u64;
let memtable_timer = metrics.timer(OperationKind::MemtableApply);
for pending in batch {
let mut txn_entries = Vec::new();
let mut seq_start = None;
let mut seq_end = None;
for op in pending.txn.operations.into_iter() {
let entry = match op {
Mutation::Put { key, value } => {
let key_arc: Arc<[u8]> = Arc::from(key);
let val_arc: Arc<[u8]> = Arc::from(value);
loop {
match active.put_arc(Arc::clone(&key_arc), Arc::clone(&val_arc)) {
Ok(entry) => break entry,
Err(MemTableError::Backpressure) | Err(MemTableError::Frozen) => {
active = GroupCommitController::rotate_memtable(
wal,
flush,
memtable,
rotation,
completed,
metrics,
&active,
fsync_on_commit,
)?;
}
}
}
}
Mutation::Delete { key } => {
loop {
match active.delete(&key) {
Ok(entry) => break entry,
Err(MemTableError::Backpressure) | Err(MemTableError::Frozen) => {
active = GroupCommitController::rotate_memtable(
wal,
flush,
memtable,
rotation,
completed,
metrics,
&active,
fsync_on_commit,
)?;
}
}
}
}
};
seq_start = Some(seq_start.unwrap_or(entry.sequence));
seq_end = Some(entry.sequence);
total_wal_bytes += entry.key.len() + entry.value.len() + WAL_ENTRY_HEADER;
txn_entries.push(entry);
applied_ops += 1;
}
per_txn_entries.push((pending.id, txn_entries, seq_start, seq_end));
}
drop(memtable_timer);
if applied_ops > 0 {
metrics.record_value(OperationKind::MemtableApply, applied_ops);
}
thread_local! {
static WAL_FRAME_BUF: RefCell<Vec<u8>> = RefCell::new(Vec::new());
}
let max_payload = wal.max_payload();
let mut logical_pos: u64 = 0;
let mut txn_results = Vec::with_capacity(per_txn_entries.len());
let mut first_offset: Option<u64> = None;
let mut total_len: usize = 0;
WAL_FRAME_BUF.with(|cell| -> Result<(), GroupCommitError> {
let mut buf = cell.borrow_mut();
buf.clear();
buf.reserve(total_wal_bytes.min(max_payload));
for (id, entries, seq_start, seq_end) in per_txn_entries {
let txn_start = logical_pos;
for entry in &entries {
let entry_len = WAL_ENTRY_HEADER + entry.key.len() + entry.value.len();
if entry_len > max_payload {
return Err(GroupCommitError::from(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"single wal entry larger than frame",
)));
}
if buf.len() + entry_len > max_payload {
let t = wal.append(&buf)?;
if first_offset.is_none() {
first_offset = Some(t.offset);
}
total_len += buf.len();
buf.clear();
}
encode_wal_entry(entry, &mut buf);
logical_pos += entry_len as u64;
}
let txn_len = (logical_pos - txn_start) as u32;
txn_results.push((
id,
txn_start,
txn_len,
seq_start.unwrap_or(0),
seq_end.unwrap_or(0),
entries,
));
}
if !buf.is_empty() {
let t = wal.append(&buf)?;
if first_offset.is_none() {
first_offset = Some(t.offset);
}
total_len += buf.len();
buf.clear();
}
Ok(())
})?;
let ticket = if total_len > 0 {
let wal_timer = metrics.timer(OperationKind::WalAppend);
if fsync_on_commit {
wal.commit()?;
} else {
wal.flush()?;
}
drop(wal_timer);
metrics.record_value(OperationKind::WalAppend, total_len as u64);
WalTicket { offset: first_offset.unwrap_or_else(|| wal.committed_bytes()), len: total_len as u32 }
} else {
WalTicket { offset: wal.committed_bytes(), len: 0 }
};
let mut results = Vec::with_capacity(txn_results.len());
for (id, start, len, seq_start, seq_end, entries) in txn_results {
let receipt = CommitReceipt {
wal_offset: ticket.offset + start,
wal_len: len,
sequence_start: seq_start,
sequence_end: seq_end,
operations: entries.len(),
};
results.push((id, CommitOutcome { receipt, entries }));
}
Ok(results)
}
fn encode_wal_entry(entry: &MemTableEntry, buffer: &mut Vec<u8>) {
buffer.reserve(entry.key.len() + entry.value.len() + WAL_ENTRY_HEADER);
buffer.push(if entry.tombstone { 1 } else { 0 });
buffer.extend_from_slice(&(entry.key.len() as u32).to_le_bytes());
buffer.extend_from_slice(&(entry.value.len() as u32).to_le_bytes());
buffer.extend_from_slice(&entry.sequence.to_le_bytes());
buffer.extend_from_slice(entry.key.as_ref());
buffer.extend_from_slice(entry.value.as_ref());
}