betex 0.7.7

Betfair / Prediction Market Exchange
Documentation
use heed::{
    Database, EnvOpenOptions,
    byteorder::BigEndian,
    types::{Bytes, U64},
};

use crate::disruptor::{
    Envelope, RingSlot,
    traits::{RkyvError, RkyvToBytes},
};
use crate::error::{RecoveryError, WalError};
use core_affinity::CoreId;
use disrupt_rs::{EventPoller, GatedSequence, Polling, Sequence, wait_strategies::WaitStrategy};
use std::{
    cell::RefCell,
    path::Path,
    sync::{
        Arc,
        atomic::{AtomicBool, Ordering},
    },
    thread,
    time::Instant,
};
use tracing::{error, info, trace, warn};

type BEU64 = U64<BigEndian>;
type LogDb = Database<BEU64, Bytes>;
type MetaDb = Database<Bytes, BEU64>;

const META_SCHEMA_VERSION_KEY: &[u8] = b"schema_version";
const WAL_SCHEMA_VERSION: u64 = 1;

fn wal_map_size_bytes() -> usize {
    const DEFAULT_TEST: usize = 32 * 1024 * 1024;
    const DEFAULT: usize = 10 * 1024 * 1024 * 1024;

    if let Ok(v) = std::env::var("BTX_WAL_MAP_SIZE_BYTES")
        && let Ok(n) = v.parse::<u64>()
    {
        return n.min(usize::MAX as u64) as usize;
    }

    if cfg!(test) { DEFAULT_TEST } else { DEFAULT }
}

thread_local! {
    static SCRATCH: RefCell<rkyv::util::AlignedVec> =
        RefCell::new(rkyv::util::AlignedVec::with_capacity(1024));
}

#[derive(Debug, Clone, Copy)]
pub struct RecoveryState {
    pub last_app_seq: u64,
    pub last_tx_id: u64,
}

/// Higher-level durability + recovery handler backed by an LMDB log.
///
/// Stores archived `Envelope<T>` keyed by `Envelope::seq`.
pub struct JournalHandler {
    env: heed::Env,
    log: LogDb,
}

pub type ArchivedEnvelopeOf<T> = <Envelope<T> as rkyv::Archive>::Archived;

impl JournalHandler {
    pub fn open(path: impl AsRef<Path>, db: &str) -> anyhow::Result<Self> {
        let path = path.as_ref();
        std::fs::create_dir_all(path)?;

        let env = unsafe {
            EnvOpenOptions::new()
                .map_size(wal_map_size_bytes())
                .max_dbs(8)
                .open(path)?
        };

        let log = {
            let mut wtxn = env.write_txn()?;
            let db: LogDb = env.create_database(&mut wtxn, Some(db))?;
            let meta: MetaDb = env.create_database(&mut wtxn, Some("meta"))?;

            match meta.get(&wtxn, META_SCHEMA_VERSION_KEY)? {
                Some(v) if v == WAL_SCHEMA_VERSION => {}
                Some(v) => {
                    return Err(WalError::SchemaMismatch {
                        expected: WAL_SCHEMA_VERSION,
                        found: v,
                    }
                    .into());
                }
                None => {
                    meta.put(&mut wtxn, META_SCHEMA_VERSION_KEY, &WAL_SCHEMA_VERSION)?;
                }
            }
            wtxn.commit()?;
            db
        };

        Ok(Self { env, log })
    }

    pub fn write_txn(&self) -> anyhow::Result<heed::RwTxn<'_>> {
        Ok(self.env.write_txn()?)
    }

    pub fn put_envelope<T>(
        &self,
        wtxn: &mut heed::RwTxn<'_>,
        envelope: &Envelope<T>,
    ) -> anyhow::Result<()>
    where
        T: RkyvToBytes,
    {
        SCRATCH.with(|scratch| -> anyhow::Result<()> {
            let mut bytes = scratch.borrow_mut();
            let mut writer = std::mem::take(&mut *bytes);
            writer.clear();
            let writer = rkyv::api::high::to_bytes_in::<_, RkyvError>(envelope, writer)?;
            self.log.put(wtxn, &envelope.seq, writer.as_slice())?;
            *bytes = writer;
            Ok(())
        })?;
        Ok(())
    }

    pub fn recover_from<T, F>(&self, after_seq: Option<u64>, mut on_event: F) -> anyhow::Result<()>
    where
        T: rkyv::Archive,
        ArchivedEnvelopeOf<T>:
            for<'a> rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, RkyvError>>,
        F: for<'a> FnMut(&'a ArchivedEnvelopeOf<T>, bool) -> anyhow::Result<()>,
    {
        let start_seq = after_seq.map(|s| s.saturating_add(1)).unwrap_or(0);
        info!(after_seq = ?after_seq, start_seq, "starting wal recovery");

        let rtxn = self.env.read_txn()?;
        let mut iter = self.log.range(&rtxn, &(start_seq..))?;

        struct TxInProgress<'a, AE> {
            tx_id: u64,
            tx_len: u16,
            events: Vec<&'a AE>,
        }

        let mut current_tx: Option<TxInProgress<'_, ArchivedEnvelopeOf<T>>> = None;
        let mut last_seen_seq: Option<u64> = None;
        let mut pending_last: Option<&ArchivedEnvelopeOf<T>> = None;

        while let Some(entry) = iter.next().transpose()? {
            let (app_seq, val_bytes) = entry;
            last_seen_seq = Some(app_seq);

            let archived =
                match rkyv::api::high::access::<ArchivedEnvelopeOf<T>, RkyvError>(val_bytes) {
                    Ok(a) => a,
                    Err(e) => {
                        error!(app_seq, error = %e, "corrupt WAL record");
                        return Err(WalError::CorruptRecord {
                            app_seq,
                            details: e.to_string(),
                        }
                        .into());
                    }
                };

            let tx_id: u64 = archived.tx_id.into();
            let tx_len: u16 = archived.tx_len.into();
            let tx_ix: u16 = archived.tx_ix.into();

            if tx_len == 0 || tx_ix >= tx_len {
                error!(app_seq, tx_id, tx_len, tx_ix, "invalid wal tx framing");
                return Err(WalError::InvalidTxFraming {
                    app_seq,
                    tx_id,
                    tx_len,
                    tx_ix,
                }
                .into());
            }

            match &mut current_tx {
                None => {
                    // Resync by only starting on the first event of a tx.
                    if tx_ix != 0 {
                        continue;
                    }
                    current_tx = Some(TxInProgress {
                        tx_id,
                        tx_len,
                        events: vec![archived],
                    });
                }
                Some(tx) => {
                    let expected_ix: u16 = tx.events.len().try_into().unwrap_or(u16::MAX);
                    let framing_mismatch =
                        tx.tx_id != tx_id || tx.tx_len != tx_len || tx_ix != expected_ix;

                    if framing_mismatch {
                        if tx_ix != 0 {
                            current_tx = None;
                            continue;
                        }
                        *tx = TxInProgress {
                            tx_id,
                            tx_len,
                            events: vec![archived],
                        };
                    } else {
                        tx.events.push(archived);
                    }
                }
            }

            if let Some(tx) = &mut current_tx
                && tx.events.len() == tx.tx_len as usize
            {
                for e in tx.events.drain(..) {
                    if let Some(prev) = pending_last.take() {
                        on_event(prev, false)?;
                    }
                    pending_last = Some(e);
                }
                current_tx = None;
            }
        }

        // Trim incomplete tail.
        if let Some(tx) = current_tx.take() {
            let trimmed_count = tx.events.len();
            if trimmed_count != 0 {
                let Some(first) = tx.events.first() else {
                    return Err(RecoveryError::WalTailMissingFirst.into());
                };
                let Some(last) = tx.events.last() else {
                    return Err(RecoveryError::WalTailMissingLast.into());
                };
                let trimmed_start: u64 = first.seq.into();
                let trimmed_end: u64 = last.seq.into();
                let trim_to = trimmed_start.saturating_sub(1);
                warn!(
                    trimmed_start,
                    trimmed_end, trimmed_count, trim_to, "dropping incomplete wal tail"
                );
            }
        }

        if let Some(last) = pending_last.take() {
            on_event(last, true)?;
        }

        let last_seq = last_seen_seq.or(after_seq).unwrap_or(0);
        info!(last_seq, "wal recovery complete");

        Ok(())
    }

    pub fn into_poller<T, W, B>(
        self,
        poller: EventPoller<RingSlot<T>, B, W::Notifier>,
        wait_strategy: W,
        gate: GatedSequence<W::Notifier>,
        shutdown: Arc<AtomicBool>,
    ) -> JournalPoller<T, W, B>
    where
        W: WaitStrategy,
    {
        JournalPoller::<T, W, B>::new(self, poller, wait_strategy, gate, shutdown)
    }
}

pub struct JournalPoller<T, W: WaitStrategy, B> {
    handler: JournalHandler,
    poller: EventPoller<RingSlot<T>, B, W::Notifier>,
    wait_strategy: W,
    gate: GatedSequence<W::Notifier>,
    shutdown: Arc<AtomicBool>,
}

impl<T, W: WaitStrategy, B> JournalPoller<T, W, B> {
    pub fn new(
        handler: JournalHandler,
        poller: EventPoller<RingSlot<T>, B, W::Notifier>,
        wait_strategy: W,
        gate: GatedSequence<W::Notifier>,
        shutdown: Arc<AtomicBool>,
    ) -> Self {
        Self {
            handler,
            poller,
            wait_strategy,
            gate,
            shutdown,
        }
    }
}

impl<T, W, B> JournalPoller<T, W, B>
where
    W: WaitStrategy + Send + 'static,
    W::Notifier: Send + 'static,
    T: RkyvToBytes + Send + Sync + 'static,
    B: disrupt_rs::Barrier + 'static,
{
    pub fn poll(self) -> thread::JoinHandle<anyhow::Result<()>> {
        self.spawn(None)
    }

    pub fn poll_pinned(self, core_id: Option<usize>) -> thread::JoinHandle<anyhow::Result<()>> {
        self.spawn(core_id)
    }

    fn spawn(self, core_id: Option<usize>) -> thread::JoinHandle<anyhow::Result<()>> {
        let mut this = self;
        thread::spawn(move || {
            trace!(core_id = ?core_id, "starting wal poller");
            if let Some(id) = core_id
                && !core_affinity::set_for_current(CoreId { id })
            {
                warn!(
                    core_id = id,
                    "failed to pin wal poller thread to requested core"
                );
            }

            crate::metrics_stage!("wal_poller");
            let mut waiter = this.wait_strategy.new_waiter();
            let res = (|| {
                let mut wtxn = this.handler.env.write_txn()?;
                let mut last_completed_tx_end: Sequence = 0;

                'poll: loop {
                    match this.poller.poll_wait(&mut waiter) {
                        Ok(mut events) => {
                            let t0 = Instant::now();
                            let events_len = events.len();
                            crate::metric!({ inflight: 1, in_total: events_len as u64 });

                            for (s, e) in &mut events {
                                let envelope: &Envelope<T> = e;
                                this.handler.put_envelope(&mut wtxn, envelope)?;
                                if envelope.is_tx_end() {
                                    last_completed_tx_end = s;
                                }
                            }

                            wtxn.commit()?;
                            this.gate.set(last_completed_tx_end);
                            wtxn = this.handler.env.write_txn()?;

                            crate::metric!({
                                out_total: 1, // 1 commit per batch
                                inflight: 0,
                                duration_ns: t0.elapsed().as_nanos() as u64
                            });
                        }
                        Err(Polling::NoEvents) => {
                            // The disruptor poller doesn't consult our shutdown flag, so ensure we
                            // can exit cleanly even if no new events are arriving.
                            if this.shutdown.load(Ordering::Acquire) {
                                break 'poll;
                            }
                            continue;
                        }
                        Err(Polling::Shutdown) => break 'poll,
                    }
                }
                Ok::<_, anyhow::Error>(())
            })();

            if let Err(e) = res {
                error!(error = ?e, "journal poller error");
                crate::metric!({ err_total: 1 });
                this.shutdown.store(true, Ordering::Relaxed);
                return Err(e);
            }

            Ok(())
        })
    }
}