use crate::wal::{Record, Result, WAL};
use crossbeam_channel::{bounded, Receiver, Sender, TryRecvError};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, Thread};
use std::time::{Duration, Instant};
struct Writer {
record: Record,
thread: Thread,
result: parking_lot::Mutex<Option<Result<u64>>>,
done: AtomicBool,
}
impl Writer {
fn new(record: Record) -> Self {
Self {
record,
thread: thread::current(),
result: parking_lot::Mutex::new(None),
done: AtomicBool::new(false),
}
}
#[inline]
fn signal_done(&self, res: Result<u64>) {
*self.result.lock() = Some(res);
self.done.store(true, Ordering::Release);
self.thread.unpark();
}
#[inline]
fn is_done(&self) -> bool {
self.done.load(Ordering::Acquire)
}
fn take_result(&self) -> Result<u64> {
self.result
.lock()
.take()
.expect("result must be set before waking writer")
}
}
#[derive(Debug, Clone, Copy)]
pub struct PipelineConfig {
pub min_delay: Duration,
pub max_delay: Duration,
pub adaptive_threshold: usize,
pub max_batch_size: usize,
pub enable_pipelining: bool,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
min_delay: Duration::from_micros(50), max_delay: Duration::from_micros(500), adaptive_threshold: 16, max_batch_size: 256,
enable_pipelining: true,
}
}
}
impl PipelineConfig {
#[inline]
const fn adaptive_delay(&self, queue_depth: usize) -> Duration {
if queue_depth == 0 {
self.min_delay
} else if queue_depth >= self.adaptive_threshold {
self.max_delay
} else {
let min_us = self.min_delay.as_micros() as u64;
let max_us = self.max_delay.as_micros() as u64;
let delta = max_us - min_us;
let scaled = delta * (queue_depth as u64) / (self.adaptive_threshold as u64);
Duration::from_micros(min_us + scaled)
}
}
}
pub struct PipelinedWAL {
wal: Arc<std::sync::Mutex<WAL>>,
sender: Sender<Arc<Writer>>,
receiver: Receiver<Arc<Writer>>,
leader_active: AtomicBool,
config: PipelineConfig,
batches_processed: AtomicU64,
writes_processed: AtomicU64,
}
impl PipelinedWAL {
pub fn new(wal: Arc<Mutex<WAL>>, delay: Duration, max_batch_size: usize) -> Self {
Self::with_config(
wal,
PipelineConfig {
min_delay: delay,
max_delay: delay,
max_batch_size,
enable_pipelining: true,
..Default::default()
},
)
}
pub fn with_config(wal: Arc<Mutex<WAL>>, config: PipelineConfig) -> Self {
let (sender, receiver) = bounded(config.max_batch_size * 4);
Self {
wal,
sender,
receiver,
leader_active: AtomicBool::new(false),
config,
batches_processed: AtomicU64::new(0),
writes_processed: AtomicU64::new(0),
}
}
pub fn put<F>(&self, record: Record, on_memtable: F) -> Result<u64>
where
F: Fn(&[Record]),
{
let writer = Arc::new(Writer::new(record));
if self.sender.try_send(writer.clone()).is_err() {
self.sender.send(writer.clone()).expect("channel closed");
}
let is_leader = self
.leader_active
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok();
if is_leader {
self.process_batches_pipelined(&on_memtable);
} else {
loop {
thread::park();
if writer.is_done() {
break;
}
}
}
writer.take_result()
}
pub fn sync(&self) -> Result<()> {
let wal = self.wal.lock().expect("WAL mutex poisoned");
wal.sync()
}
pub fn stats(&self) -> (u64, u64) {
(
self.batches_processed.load(Ordering::Relaxed),
self.writes_processed.load(Ordering::Relaxed),
)
}
#[allow(clippy::type_complexity)]
fn process_batches_pipelined<F>(&self, on_memtable: &F)
where
F: Fn(&[Record]),
{
let mut pending_memtable: Option<(Vec<Arc<Writer>>, Vec<Record>, Vec<u64>)> = None;
loop {
let batch_writers = self.collect_batch();
if batch_writers.is_empty() {
if let Some((writers, records, offsets)) = pending_memtable.take() {
on_memtable(&records);
for (writer, offset) in writers.iter().zip(offsets) {
writer.signal_done(Ok(offset));
}
}
self.leader_active.store(false, Ordering::Release);
if !self.receiver.is_empty()
&& self
.leader_active
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
continue;
}
return;
}
let records: Vec<Record> = batch_writers.iter().map(|w| w.record.clone()).collect();
let wal_result = {
let mut wal = self.wal.lock().expect("WAL mutex poisoned");
wal.write_batch(&records)
};
if let Some((prev_writers, prev_records, prev_offsets)) = pending_memtable.take() {
on_memtable(&prev_records);
for (writer, offset) in prev_writers.iter().zip(prev_offsets) {
writer.signal_done(Ok(offset));
}
}
match wal_result {
Ok(offsets) => {
self.batches_processed.fetch_add(1, Ordering::Relaxed);
self.writes_processed
.fetch_add(batch_writers.len() as u64, Ordering::Relaxed);
if self.config.enable_pipelining {
pending_memtable = Some((batch_writers, records, offsets));
} else {
on_memtable(&records);
for (writer, offset) in batch_writers.iter().zip(offsets) {
writer.signal_done(Ok(offset));
}
}
}
Err(e) => {
let err_str = e.to_string();
for writer in batch_writers.iter() {
let err = crate::wal::WALError::Io(std::io::Error::other(err_str.clone()));
writer.signal_done(Err(err));
}
}
}
}
}
fn collect_batch(&self) -> Vec<Arc<Writer>> {
let mut batch = Vec::with_capacity(self.config.max_batch_size);
loop {
match self.receiver.try_recv() {
Ok(writer) => {
batch.push(writer);
if batch.len() >= self.config.max_batch_size {
return batch;
}
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => return batch,
}
}
if batch.is_empty() {
let delay = self.config.adaptive_delay(0);
match self.receiver.recv_timeout(delay) {
Ok(writer) => batch.push(writer),
Err(_) => return batch, }
}
let delay = self.config.adaptive_delay(batch.len());
let deadline = Instant::now() + delay;
while batch.len() < self.config.max_batch_size && Instant::now() < deadline {
match self.receiver.try_recv() {
Ok(writer) => batch.push(writer),
Err(TryRecvError::Empty) => {
thread::sleep(Duration::from_micros(10));
}
Err(TryRecvError::Disconnected) => break,
}
}
batch
}
}