hermesmq-core 0.6.0

Core engine for hermesmq: queue state machine and Raft application types
Documentation
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::ops::{Bound, RangeBounds};
use std::sync::{Arc, Mutex};

use bytes::Bytes;
use openraft::storage::{LogFlushed, RaftLogReader, RaftLogStorage};
use openraft::{Entry, LogId, LogState, OptionalSend, StorageError, Vote};
use tokio::sync::{mpsc, oneshot};

use super::{dec, enc, sread, swrite};
use crate::raft::TypeConfig;
use crate::storage::Storage;
use crate::types::NodeId;
use crate::RedbStore;

const KEY_PURGED: &str = "log:purged";
const FLUSH_MAX_ENTRIES: usize = 512;
const FLUSH_MAX_BYTES: usize = 8 * 1024 * 1024;

enum FlushJob {
    Append(Vec<(u64, Bytes)>, LogFlushed<TypeConfig>),
    Barrier(oneshot::Sender<()>),
}

struct Shared<S> {
    db: Arc<S>,
    pending: Mutex<BTreeMap<u64, Bytes>>,
}

pub struct LogStore<S = RedbStore> {
    shared: Arc<Shared<S>>,
    jobs: mpsc::UnboundedSender<FlushJob>,
}

impl<S> Clone for LogStore<S> {
    fn clone(&self) -> Self {
        Self {
            shared: Arc::clone(&self.shared),
            jobs: self.jobs.clone(),
        }
    }
}

impl<S: Storage> LogStore<S> {
    pub fn new(db: Arc<S>) -> Self {
        let shared = Arc::new(Shared {
            db,
            pending: Mutex::new(BTreeMap::new()),
        });
        let (jobs, rx) = mpsc::unbounded_channel();
        let flusher = Arc::clone(&shared);
        std::thread::spawn(move || run_flusher(flusher, rx));
        Self { shared, jobs }
    }

    async fn barrier(&self) {
        let (tx, rx) = oneshot::channel();
        if self.jobs.send(FlushJob::Barrier(tx)).is_ok() {
            let _ = rx.await;
        }
    }

    fn pending_last(&self) -> Option<u64> {
        self.shared
            .pending
            .lock()
            .unwrap()
            .keys()
            .next_back()
            .copied()
    }
}

fn run_flusher<S: Storage>(shared: Arc<Shared<S>>, mut rx: mpsc::UnboundedReceiver<FlushJob>) {
    while let Some(first) = rx.blocking_recv() {
        let mut batch: Vec<(u64, Bytes)> = Vec::new();
        let mut callbacks = Vec::new();
        let mut barriers = Vec::new();
        let mut bytes = 0usize;
        let mut job = Some(first);
        loop {
            match job {
                Some(FlushJob::Append(entries, callback)) => {
                    bytes += entries.iter().map(|(_, b)| b.len()).sum::<usize>();
                    batch.extend(entries);
                    callbacks.push(callback);
                }
                Some(FlushJob::Barrier(done)) => {
                    barriers.push(done);
                    break;
                }
                None => break,
            }
            if batch.len() >= FLUSH_MAX_ENTRIES || bytes >= FLUSH_MAX_BYTES {
                break;
            }
            job = rx.try_recv().ok();
        }

        let result = if batch.is_empty() {
            Ok(())
        } else {
            shared.db.append_log(&batch)
        };
        match result {
            Ok(()) => {
                if !batch.is_empty() {
                    let mut pending = shared.pending.lock().unwrap();
                    for (index, _) in &batch {
                        pending.remove(index);
                    }
                }
                for callback in callbacks {
                    callback.log_io_completed(Ok(()));
                }
            }
            Err(e) => {
                let msg = e.to_string();
                for callback in callbacks {
                    callback.log_io_completed(Err(std::io::Error::other(msg.clone())));
                }
            }
        }
        for done in barriers {
            let _ = done.send(());
        }
    }
}

impl<S: Storage> RaftLogReader<TypeConfig> for LogStore<S> {
    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
        &mut self,
        range: RB,
    ) -> Result<Vec<Entry<TypeConfig>>, StorageError<NodeId>> {
        let start = match range.start_bound() {
            Bound::Included(x) => *x,
            Bound::Excluded(x) => *x + 1,
            Bound::Unbounded => 0,
        };
        let end = match range.end_bound() {
            Bound::Included(x) => *x + 1,
            Bound::Excluded(x) => *x,
            Bound::Unbounded => {
                let db_last = self.shared.db.last_log_index().map_err(sread)?;
                db_last
                    .into_iter()
                    .chain(self.pending_last())
                    .max()
                    .map(|i| i + 1)
                    .unwrap_or(0)
            }
        };

        let mut merged: BTreeMap<u64, Bytes> = BTreeMap::new();
        for (index, bytes) in self.shared.db.read_log(start, end).map_err(sread)? {
            merged.insert(index, Bytes::from(bytes));
        }
        {
            let pending = self.shared.pending.lock().unwrap();
            for (index, bytes) in pending.range(start..end) {
                merged.insert(*index, bytes.clone());
            }
        }
        let mut out = Vec::with_capacity(merged.len());
        for bytes in merged.values() {
            out.push(dec::<Entry<TypeConfig>>(bytes)?);
        }
        Ok(out)
    }
}

impl<S: Storage> RaftLogStorage<TypeConfig> for LogStore<S> {
    type LogReader = Self;

    async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<NodeId>> {
        let last_purged: Option<LogId<NodeId>> = match self.shared.db.get(KEY_PURGED).map_err(sread)? {
            Some(b) => Some(dec(&b)?),
            None => None,
        };

        let db_last = self.shared.db.last_log_index().map_err(sread)?;
        let last_index = db_last.into_iter().chain(self.pending_last()).max();
        let last_log_id = match last_index {
            Some(index) => {
                let bytes = {
                    let pending = self.shared.pending.lock().unwrap();
                    pending.get(&index).cloned()
                };
                let bytes = match bytes {
                    Some(b) => Some(b),
                    None => self
                        .shared
                        .db
                        .read_log(index, index + 1)
                        .map_err(sread)?
                        .into_iter()
                        .next()
                        .map(|(_, b)| Bytes::from(b)),
                };
                match bytes {
                    Some(b) => Some(dec::<Entry<TypeConfig>>(&b)?.log_id),
                    None => last_purged,
                }
            }
            None => last_purged,
        };

        Ok(LogState {
            last_purged_log_id: last_purged,
            last_log_id,
        })
    }

    async fn get_log_reader(&mut self) -> Self::LogReader {
        self.clone()
    }

    async fn save_vote(&mut self, vote: &Vote<NodeId>) -> Result<(), StorageError<NodeId>> {
        let bytes = enc(vote)?;
        self.shared.db.save_vote(&bytes).map_err(swrite)
    }

    async fn read_vote(&mut self) -> Result<Option<Vote<NodeId>>, StorageError<NodeId>> {
        match self.shared.db.read_vote().map_err(sread)? {
            Some(b) => Ok(Some(dec(&b)?)),
            None => Ok(None),
        }
    }

    async fn save_committed(
        &mut self,
        committed: Option<LogId<NodeId>>,
    ) -> Result<(), StorageError<NodeId>> {
        let bytes = enc(&committed)?;
        self.shared.db.save_committed(&bytes).map_err(swrite)
    }

    async fn read_committed(&mut self) -> Result<Option<LogId<NodeId>>, StorageError<NodeId>> {
        match self.shared.db.read_committed().map_err(sread)? {
            Some(b) => dec::<Option<LogId<NodeId>>>(&b),
            None => Ok(None),
        }
    }

    async fn append<I>(
        &mut self,
        entries: I,
        callback: LogFlushed<TypeConfig>,
    ) -> Result<(), StorageError<NodeId>>
    where
        I: IntoIterator<Item = Entry<TypeConfig>> + OptionalSend,
        I::IntoIter: OptionalSend,
    {
        let mut batch = Vec::new();
        for entry in entries {
            batch.push((entry.log_id.index, Bytes::from(enc(&entry)?)));
        }
        {
            let mut pending = self.shared.pending.lock().unwrap();
            for (index, bytes) in &batch {
                pending.insert(*index, bytes.clone());
            }
        }
        if self.jobs.send(FlushJob::Append(batch, callback)).is_err() {
            return Err(swrite("log flusher thread is gone"));
        }
        Ok(())
    }

    async fn truncate(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
        self.barrier().await;
        self.shared
            .pending
            .lock()
            .unwrap()
            .split_off(&log_id.index);
        self.shared
            .db
            .truncate_log_from(log_id.index)
            .map_err(swrite)
    }

    async fn purge(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
        let bytes = enc(&log_id)?;
        self.shared.db.put(KEY_PURGED, &bytes).map_err(swrite)?;
        self.shared.db.purge_log_upto(log_id.index).map_err(swrite)
    }
}