use heed::{
Database, EnvOpenOptions,
byteorder::BigEndian,
types::{Bytes, U64},
};
use crate::disruptor::{
Envelope, RingSlot,
traits::{RkyvError, RkyvToBytes},
};
use crate::error::{RecoveryError, WalError};
use core_affinity::CoreId;
use disrupt_rs::{EventPoller, GatedSequence, Polling, Sequence, wait_strategies::WaitStrategy};
use std::{
cell::RefCell,
path::Path,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
thread,
time::Instant,
};
use tracing::{error, info, trace, warn};
type BEU64 = U64<BigEndian>;
type LogDb = Database<BEU64, Bytes>;
type MetaDb = Database<Bytes, BEU64>;
const META_SCHEMA_VERSION_KEY: &[u8] = b"schema_version";
const WAL_SCHEMA_VERSION: u64 = 1;
fn wal_map_size_bytes() -> usize {
const DEFAULT_TEST: usize = 32 * 1024 * 1024;
const DEFAULT: usize = 10 * 1024 * 1024 * 1024;
if let Ok(v) = std::env::var("BTX_WAL_MAP_SIZE_BYTES")
&& let Ok(n) = v.parse::<u64>()
{
return n.min(usize::MAX as u64) as usize;
}
if cfg!(test) { DEFAULT_TEST } else { DEFAULT }
}
thread_local! {
static SCRATCH: RefCell<rkyv::util::AlignedVec> =
RefCell::new(rkyv::util::AlignedVec::with_capacity(1024));
}
#[derive(Debug, Clone, Copy)]
pub struct RecoveryState {
pub last_app_seq: u64,
pub last_tx_id: u64,
}
pub struct JournalHandler {
env: heed::Env,
log: LogDb,
}
pub type ArchivedEnvelopeOf<T> = <Envelope<T> as rkyv::Archive>::Archived;
impl JournalHandler {
pub fn open(path: impl AsRef<Path>, db: &str) -> anyhow::Result<Self> {
let path = path.as_ref();
std::fs::create_dir_all(path)?;
let env = unsafe {
EnvOpenOptions::new()
.map_size(wal_map_size_bytes())
.max_dbs(8)
.open(path)?
};
let log = {
let mut wtxn = env.write_txn()?;
let db: LogDb = env.create_database(&mut wtxn, Some(db))?;
let meta: MetaDb = env.create_database(&mut wtxn, Some("meta"))?;
match meta.get(&wtxn, META_SCHEMA_VERSION_KEY)? {
Some(v) if v == WAL_SCHEMA_VERSION => {}
Some(v) => {
return Err(WalError::SchemaMismatch {
expected: WAL_SCHEMA_VERSION,
found: v,
}
.into());
}
None => {
meta.put(&mut wtxn, META_SCHEMA_VERSION_KEY, &WAL_SCHEMA_VERSION)?;
}
}
wtxn.commit()?;
db
};
Ok(Self { env, log })
}
pub fn write_txn(&self) -> anyhow::Result<heed::RwTxn<'_>> {
Ok(self.env.write_txn()?)
}
pub fn put_envelope<T>(
&self,
wtxn: &mut heed::RwTxn<'_>,
envelope: &Envelope<T>,
) -> anyhow::Result<()>
where
T: RkyvToBytes,
{
SCRATCH.with(|scratch| -> anyhow::Result<()> {
let mut bytes = scratch.borrow_mut();
let mut writer = std::mem::take(&mut *bytes);
writer.clear();
let writer = rkyv::api::high::to_bytes_in::<_, RkyvError>(envelope, writer)?;
self.log.put(wtxn, &envelope.seq, writer.as_slice())?;
*bytes = writer;
Ok(())
})?;
Ok(())
}
pub fn recover_from<T, F>(&self, after_seq: Option<u64>, mut on_event: F) -> anyhow::Result<()>
where
T: rkyv::Archive,
ArchivedEnvelopeOf<T>:
for<'a> rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, RkyvError>>,
F: for<'a> FnMut(&'a ArchivedEnvelopeOf<T>, bool) -> anyhow::Result<()>,
{
let start_seq = after_seq.map(|s| s.saturating_add(1)).unwrap_or(0);
info!(after_seq = ?after_seq, start_seq, "starting wal recovery");
let rtxn = self.env.read_txn()?;
let mut iter = self.log.range(&rtxn, &(start_seq..))?;
struct TxInProgress<'a, AE> {
tx_id: u64,
tx_len: u16,
events: Vec<&'a AE>,
}
let mut current_tx: Option<TxInProgress<'_, ArchivedEnvelopeOf<T>>> = None;
let mut last_seen_seq: Option<u64> = None;
let mut pending_last: Option<&ArchivedEnvelopeOf<T>> = None;
while let Some(entry) = iter.next().transpose()? {
let (app_seq, val_bytes) = entry;
last_seen_seq = Some(app_seq);
let archived =
match rkyv::api::high::access::<ArchivedEnvelopeOf<T>, RkyvError>(val_bytes) {
Ok(a) => a,
Err(e) => {
error!(app_seq, error = %e, "corrupt WAL record");
return Err(WalError::CorruptRecord {
app_seq,
details: e.to_string(),
}
.into());
}
};
let tx_id: u64 = archived.tx_id.into();
let tx_len: u16 = archived.tx_len.into();
let tx_ix: u16 = archived.tx_ix.into();
if tx_len == 0 || tx_ix >= tx_len {
error!(app_seq, tx_id, tx_len, tx_ix, "invalid wal tx framing");
return Err(WalError::InvalidTxFraming {
app_seq,
tx_id,
tx_len,
tx_ix,
}
.into());
}
match &mut current_tx {
None => {
if tx_ix != 0 {
continue;
}
current_tx = Some(TxInProgress {
tx_id,
tx_len,
events: vec![archived],
});
}
Some(tx) => {
let expected_ix: u16 = tx.events.len().try_into().unwrap_or(u16::MAX);
let framing_mismatch =
tx.tx_id != tx_id || tx.tx_len != tx_len || tx_ix != expected_ix;
if framing_mismatch {
if tx_ix != 0 {
current_tx = None;
continue;
}
*tx = TxInProgress {
tx_id,
tx_len,
events: vec![archived],
};
} else {
tx.events.push(archived);
}
}
}
if let Some(tx) = &mut current_tx
&& tx.events.len() == tx.tx_len as usize
{
for e in tx.events.drain(..) {
if let Some(prev) = pending_last.take() {
on_event(prev, false)?;
}
pending_last = Some(e);
}
current_tx = None;
}
}
if let Some(tx) = current_tx.take() {
let trimmed_count = tx.events.len();
if trimmed_count != 0 {
let Some(first) = tx.events.first() else {
return Err(RecoveryError::WalTailMissingFirst.into());
};
let Some(last) = tx.events.last() else {
return Err(RecoveryError::WalTailMissingLast.into());
};
let trimmed_start: u64 = first.seq.into();
let trimmed_end: u64 = last.seq.into();
let trim_to = trimmed_start.saturating_sub(1);
warn!(
trimmed_start,
trimmed_end, trimmed_count, trim_to, "dropping incomplete wal tail"
);
}
}
if let Some(last) = pending_last.take() {
on_event(last, true)?;
}
let last_seq = last_seen_seq.or(after_seq).unwrap_or(0);
info!(last_seq, "wal recovery complete");
Ok(())
}
pub fn into_poller<T, W, B>(
self,
poller: EventPoller<RingSlot<T>, B, W::Notifier>,
wait_strategy: W,
gate: GatedSequence<W::Notifier>,
shutdown: Arc<AtomicBool>,
) -> JournalPoller<T, W, B>
where
W: WaitStrategy,
{
JournalPoller::<T, W, B>::new(self, poller, wait_strategy, gate, shutdown)
}
}
pub struct JournalPoller<T, W: WaitStrategy, B> {
handler: JournalHandler,
poller: EventPoller<RingSlot<T>, B, W::Notifier>,
wait_strategy: W,
gate: GatedSequence<W::Notifier>,
shutdown: Arc<AtomicBool>,
}
impl<T, W: WaitStrategy, B> JournalPoller<T, W, B> {
pub fn new(
handler: JournalHandler,
poller: EventPoller<RingSlot<T>, B, W::Notifier>,
wait_strategy: W,
gate: GatedSequence<W::Notifier>,
shutdown: Arc<AtomicBool>,
) -> Self {
Self {
handler,
poller,
wait_strategy,
gate,
shutdown,
}
}
}
impl<T, W, B> JournalPoller<T, W, B>
where
W: WaitStrategy + Send + 'static,
W::Notifier: Send + 'static,
T: RkyvToBytes + Send + Sync + 'static,
B: disrupt_rs::Barrier + 'static,
{
pub fn poll(self) -> thread::JoinHandle<anyhow::Result<()>> {
self.spawn(None)
}
pub fn poll_pinned(self, core_id: Option<usize>) -> thread::JoinHandle<anyhow::Result<()>> {
self.spawn(core_id)
}
fn spawn(self, core_id: Option<usize>) -> thread::JoinHandle<anyhow::Result<()>> {
let mut this = self;
thread::spawn(move || {
trace!(core_id = ?core_id, "starting wal poller");
if let Some(id) = core_id
&& !core_affinity::set_for_current(CoreId { id })
{
warn!(core_id = id, "failed to pin wal poller thread to requested core");
}
crate::metrics_stage!("wal_poller");
let mut waiter = this.wait_strategy.new_waiter();
let res = (|| {
let mut wtxn = this.handler.env.write_txn()?;
let mut last_completed_tx_end: Sequence = 0;
'poll: loop {
match this.poller.poll_wait(&mut waiter) {
Ok(mut events) => {
let t0 = Instant::now();
let events_len = events.len();
crate::metric!({ inflight: 1, in_total: events_len as u64 });
for (s, e) in &mut events {
let envelope: &Envelope<T> = e;
this.handler.put_envelope(&mut wtxn, envelope)?;
if envelope.is_tx_end() {
last_completed_tx_end = s;
}
}
wtxn.commit()?;
this.gate.set(last_completed_tx_end);
wtxn = this.handler.env.write_txn()?;
crate::metric!({
out_total: 1, inflight: 0,
duration_ns: t0.elapsed().as_nanos() as u64
});
}
Err(Polling::NoEvents) => {
if this.shutdown.load(Ordering::Acquire) {
break 'poll;
}
continue;
}
Err(Polling::Shutdown) => break 'poll,
}
}
Ok::<_, anyhow::Error>(())
})();
if let Err(e) = res {
error!(error = ?e, "journal poller error");
crate::metric!({ err_total: 1 });
this.shutdown.store(true, Ordering::Relaxed);
return Err(e);
}
Ok(())
})
}
}