use std::collections::BTreeMap;
use std::fmt::Debug;
use std::io;
use std::ops::RangeBounds;
use std::sync::Arc;
use openraft::storage::{IOFlushed, RaftLogStorage};
use openraft::{LogId, LogState, OptionalSend, RaftLogReader, RaftTypeConfig};
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct MemLogStore<C: RaftTypeConfig> {
inner: Arc<RwLock<MemLogStoreInner<C>>>,
}
struct MemLogStoreInner<C: RaftTypeConfig> {
vote: Option<openraft::vote::Vote<C>>,
log: BTreeMap<u64, openraft::Entry<C>>,
committed: Option<LogId<C>>,
last_purged: Option<LogId<C>>,
}
impl<C: RaftTypeConfig> Default for MemLogStore<C> {
fn default() -> Self {
Self {
inner: Arc::new(RwLock::new(MemLogStoreInner {
vote: None,
log: BTreeMap::new(),
committed: None,
last_purged: None,
})),
}
}
}
impl<C: RaftTypeConfig> MemLogStore<C> {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Clone)]
pub struct MemLogReader<C: RaftTypeConfig> {
inner: Arc<RwLock<MemLogStoreInner<C>>>,
}
impl<C> RaftLogReader<C> for MemLogReader<C>
where
C: RaftTypeConfig<Entry = openraft::Entry<C>, Vote = openraft::vote::Vote<C>>,
openraft::Entry<C>: Clone,
{
async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
&mut self,
range: RB,
) -> Result<Vec<C::Entry>, io::Error> {
let inner = self.inner.read().await;
let entries: Vec<C::Entry> = inner.log.range(range).map(|(_, e)| e.clone()).collect();
Ok(entries)
}
async fn read_vote(&mut self) -> Result<Option<C::Vote>, io::Error> {
let inner = self.inner.read().await;
Ok(inner.vote.clone())
}
}
impl<C> RaftLogStorage<C> for MemLogStore<C>
where
C: RaftTypeConfig<Entry = openraft::Entry<C>, Vote = openraft::vote::Vote<C>>,
openraft::Entry<C>: Clone,
{
type LogReader = MemLogReader<C>;
async fn get_log_state(&mut self) -> Result<LogState<C>, io::Error> {
let inner = self.inner.read().await;
let last = inner.log.values().last().map(|e| e.log_id.clone());
Ok(LogState {
last_purged_log_id: inner.last_purged.clone(),
last_log_id: last,
})
}
async fn get_log_reader(&mut self) -> Self::LogReader {
MemLogReader {
inner: Arc::clone(&self.inner),
}
}
async fn save_vote(&mut self, vote: &C::Vote) -> Result<(), io::Error> {
let mut inner = self.inner.write().await;
inner.vote = Some(vote.clone());
Ok(())
}
async fn append<I>(&mut self, entries: I, callback: IOFlushed<C>) -> Result<(), io::Error>
where
I: IntoIterator<Item = C::Entry> + OptionalSend,
I::IntoIter: OptionalSend,
{
let mut inner = self.inner.write().await;
for entry in entries {
let index = entry.log_id.index;
inner.log.insert(index, entry);
}
callback.io_completed(Ok(()));
Ok(())
}
async fn truncate_after(&mut self, last_log_id: Option<LogId<C>>) -> Result<(), io::Error> {
let mut inner = self.inner.write().await;
match last_log_id {
Some(id) => {
let keys_to_remove: Vec<u64> =
inner.log.range((id.index + 1)..).map(|(k, _)| *k).collect();
for k in keys_to_remove {
inner.log.remove(&k);
}
}
None => {
inner.log.clear();
}
}
Ok(())
}
async fn purge(&mut self, log_id: LogId<C>) -> Result<(), io::Error> {
let mut inner = self.inner.write().await;
let keys_to_remove: Vec<u64> = inner.log.range(..=log_id.index).map(|(k, _)| *k).collect();
for k in keys_to_remove {
inner.log.remove(&k);
}
inner.last_purged = Some(log_id);
Ok(())
}
async fn save_committed(&mut self, committed: Option<LogId<C>>) -> Result<(), io::Error> {
let mut inner = self.inner.write().await;
inner.committed = committed;
Ok(())
}
async fn read_committed(&mut self) -> Result<Option<LogId<C>>, io::Error> {
let inner = self.inner.read().await;
Ok(inner.committed.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_types::TestTypeConfig;
#[tokio::test]
async fn initial_state_is_empty() {
let mut store = MemLogStore::<TestTypeConfig>::new();
let state = store.get_log_state().await.unwrap();
assert!(state.last_log_id.is_none());
assert!(state.last_purged_log_id.is_none());
}
#[tokio::test]
async fn save_and_read_vote() {
let mut store = MemLogStore::<TestTypeConfig>::new();
let vote = openraft::vote::Vote::new(1, 1);
store.save_vote(&vote).await.unwrap();
let mut reader = store.get_log_reader().await;
let read = reader.read_vote().await.unwrap();
assert_eq!(read.unwrap(), vote);
}
use openraft::vote::RaftLeaderId;
#[tokio::test]
async fn truncate_after_none_on_empty() {
let mut store = MemLogStore::<TestTypeConfig>::new();
store.truncate_after(None).await.unwrap();
let state = store.get_log_state().await.unwrap();
assert!(state.last_log_id.is_none());
}
#[tokio::test]
async fn truncate_after_some_on_empty() {
use openraft::vote::leader_id_adv::CommittedLeaderId;
let mut store = MemLogStore::<TestTypeConfig>::new();
let log_id = openraft::LogId::new(CommittedLeaderId::new(1, 1), 5);
store.truncate_after(Some(log_id)).await.unwrap();
let state = store.get_log_state().await.unwrap();
assert!(state.last_log_id.is_none());
}
#[tokio::test]
async fn purge_on_empty_sets_last_purged() {
use openraft::vote::leader_id_adv::CommittedLeaderId;
let mut store = MemLogStore::<TestTypeConfig>::new();
let log_id = openraft::LogId::new(CommittedLeaderId::new(1, 1), 3);
store.purge(log_id).await.unwrap();
let state = store.get_log_state().await.unwrap();
assert_eq!(state.last_purged_log_id.unwrap().index, 3);
}
#[tokio::test]
async fn save_and_read_committed() {
use openraft::vote::leader_id_adv::CommittedLeaderId;
let mut store = MemLogStore::<TestTypeConfig>::new();
let committed = store.read_committed().await.unwrap();
assert!(committed.is_none());
let log_id = openraft::LogId::new(CommittedLeaderId::new(1, 1), 42);
store.save_committed(Some(log_id)).await.unwrap();
let committed = store.read_committed().await.unwrap();
assert_eq!(committed.unwrap().index, 42);
store.save_committed(None).await.unwrap();
let committed = store.read_committed().await.unwrap();
assert!(committed.is_none());
}
#[tokio::test]
async fn get_log_reader_shares_state() {
let mut store = MemLogStore::<TestTypeConfig>::new();
let vote = openraft::vote::Vote::new(5, 5);
store.save_vote(&vote).await.unwrap();
let mut reader = store.get_log_reader().await;
let read = reader.read_vote().await.unwrap();
assert_eq!(read.unwrap(), vote);
}
#[tokio::test]
async fn read_entries_on_empty() {
let mut store = MemLogStore::<TestTypeConfig>::new();
let mut reader = store.get_log_reader().await;
let entries = reader.try_get_log_entries(0..10).await.unwrap();
assert!(entries.is_empty());
}
}