use std::collections::BTreeMap;
use std::fmt::Debug;
use std::fs::{File, OpenOptions};
use std::io::{self, BufReader, Read, Write};
use std::ops::RangeBounds;
use std::path::Path;
use std::sync::{Arc, Mutex as StdMutex};
use openraft::storage::{LogFlushed, RaftLogReader, RaftLogStorage};
use openraft::{Entry, LogId, LogState, RaftLogId, StorageError, StorageIOError, Vote};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use super::{NodeId, TypeConfig};
#[derive(Serialize, Deserialize)]
enum Record {
Vote(Vote<NodeId>),
Append(Vec<Entry<TypeConfig>>),
Truncate(u64),
Purge(LogId<NodeId>),
}
#[derive(Default)]
struct Mem {
last_purged_log_id: Option<LogId<NodeId>>,
log: BTreeMap<u64, Entry<TypeConfig>>,
committed: Option<LogId<NodeId>>,
vote: Option<Vote<NodeId>>,
}
impl Mem {
fn apply(&mut self, rec: Record) {
match rec {
Record::Vote(v) => self.vote = Some(v),
Record::Append(entries) => {
for e in entries {
self.log.insert(e.get_log_id().index, e);
}
}
Record::Truncate(index) => self.remove_range(index..),
Record::Purge(log_id) => {
self.remove_range(..=log_id.index);
self.last_purged_log_id = Some(log_id);
}
}
}
fn remove_range<R: RangeBounds<u64>>(&mut self, range: R) {
let keys: Vec<u64> = self.log.range(range).map(|(k, _)| *k).collect();
for k in keys {
self.log.remove(&k);
}
}
}
#[derive(Clone)]
pub struct DurableLogStore {
mem: Arc<Mutex<Mem>>,
file: Arc<StdMutex<File>>,
}
fn store_err(e: impl std::error::Error + 'static) -> StorageError<NodeId> {
StorageIOError::write_logs(&e).into()
}
impl DurableLogStore {
pub fn open(dir: &Path) -> io::Result<Self> {
std::fs::create_dir_all(dir)?;
let path = dir.join("log");
let mut mem = Mem::default();
if let Ok(f) = File::open(&path) {
let mut r = BufReader::new(f);
loop {
let mut len = [0u8; 4];
if r.read_exact(&mut len).is_err() {
break; }
let mut buf = vec![0u8; u32::from_le_bytes(len) as usize];
if r.read_exact(&mut buf).is_err() {
break; }
match postcard::from_bytes::<Record>(&buf) {
Ok(rec) => mem.apply(rec),
Err(_) => break, }
}
}
let file = OpenOptions::new().create(true).append(true).open(&path)?;
Ok(Self {
mem: Arc::new(Mutex::new(mem)),
file: Arc::new(StdMutex::new(file)),
})
}
async fn durable(&self, record: Record) -> Result<(), StorageError<NodeId>> {
let bytes = postcard::to_allocvec(&record).map_err(store_err)?;
let file = Arc::clone(&self.file);
tokio::task::spawn_blocking(move || -> io::Result<()> {
let mut f = file
.lock()
.map_err(|_| io::Error::other("raft log file mutex poisoned"))?;
f.write_all(&(bytes.len() as u32).to_le_bytes())?;
f.write_all(&bytes)?;
f.sync_data()?;
Ok(())
})
.await
.map_err(store_err)?
.map_err(store_err)
}
async fn append_entries(
&self,
entries: Vec<Entry<TypeConfig>>,
) -> Result<(), StorageError<NodeId>> {
self.durable(Record::Append(entries.clone())).await?;
let mut mem = self.mem.lock().await;
for e in entries {
mem.log.insert(e.get_log_id().index, e);
}
Ok(())
}
}
impl RaftLogReader<TypeConfig> for DurableLogStore {
async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug>(
&mut self,
range: RB,
) -> Result<Vec<Entry<TypeConfig>>, StorageError<NodeId>> {
let mem = self.mem.lock().await;
Ok(mem.log.range(range).map(|(_, v)| v.clone()).collect())
}
}
impl RaftLogStorage<TypeConfig> for DurableLogStore {
type LogReader = Self;
async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<NodeId>> {
let mem = self.mem.lock().await;
let last = mem.log.iter().next_back().map(|(_, e)| *e.get_log_id());
let last_purged = mem.last_purged_log_id;
Ok(LogState {
last_purged_log_id: last_purged,
last_log_id: last.or(last_purged),
})
}
async fn save_committed(
&mut self,
committed: Option<LogId<NodeId>>,
) -> Result<(), StorageError<NodeId>> {
self.mem.lock().await.committed = committed;
Ok(())
}
async fn read_committed(&mut self) -> Result<Option<LogId<NodeId>>, StorageError<NodeId>> {
Ok(self.mem.lock().await.committed)
}
async fn save_vote(&mut self, vote: &Vote<NodeId>) -> Result<(), StorageError<NodeId>> {
self.durable(Record::Vote(*vote)).await?;
self.mem.lock().await.vote = Some(*vote);
Ok(())
}
async fn read_vote(&mut self) -> Result<Option<Vote<NodeId>>, StorageError<NodeId>> {
Ok(self.mem.lock().await.vote)
}
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<TypeConfig>,
) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>>,
{
self.append_entries(entries.into_iter().collect()).await?;
callback.log_io_completed(Ok(()));
Ok(())
}
async fn truncate(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
self.durable(Record::Truncate(log_id.index)).await?;
self.mem.lock().await.remove_range(log_id.index..);
Ok(())
}
async fn purge(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
self.durable(Record::Purge(log_id)).await?;
let mut mem = self.mem.lock().await;
mem.remove_range(..=log_id.index);
mem.last_purged_log_id = Some(log_id);
Ok(())
}
async fn get_log_reader(&mut self) -> Self::LogReader {
self.clone()
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use openraft::{BasicNode, Config, ServerState, Vote};
use quiver_core::{CollectionId, WalOp};
use super::super::{ApplyOp, NoNetwork, StateMachineStore};
use super::*;
struct NoopApplier;
impl ApplyOp for NoopApplier {
async fn apply(&self, _op: WalOp) -> std::io::Result<()> {
Ok(())
}
}
fn del(i: u64) -> WalOp {
WalOp::Delete {
collection_id: CollectionId(1),
external_id: format!("e{i}"),
}
}
#[tokio::test]
async fn vote_survives_reopen() {
let dir = tempfile::tempdir().unwrap();
{
let mut s = DurableLogStore::open(dir.path()).unwrap();
s.save_vote(&Vote::new(3, 2)).await.unwrap();
}
let mut s = DurableLogStore::open(dir.path()).unwrap();
assert_eq!(s.read_vote().await.unwrap(), Some(Vote::new(3, 2)));
}
#[tokio::test]
async fn torn_tail_record_is_discarded() {
let dir = tempfile::tempdir().unwrap();
{
let mut s = DurableLogStore::open(dir.path()).unwrap();
s.save_vote(&Vote::new(1, 1)).await.unwrap();
}
{
let mut f = OpenOptions::new()
.append(true)
.open(dir.path().join("log"))
.unwrap();
f.write_all(&999u32.to_le_bytes()).unwrap();
f.write_all(b"short").unwrap();
f.sync_data().unwrap();
}
let mut s = DurableLogStore::open(dir.path()).unwrap();
assert_eq!(s.read_vote().await.unwrap(), Some(Vote::new(1, 1)));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn committed_log_survives_a_restart() {
let dir = tempfile::tempdir().unwrap();
{
let store = DurableLogStore::open(dir.path()).unwrap();
let config = Arc::new(
Config {
heartbeat_interval: 250,
election_timeout_min: 500,
election_timeout_max: 1000,
..Default::default()
}
.validate()
.unwrap(),
);
let sm = Arc::new(StateMachineStore::new(NoopApplier));
let raft = openraft::Raft::<TypeConfig>::new(1, config, NoNetwork, store, sm)
.await
.unwrap();
let mut members = BTreeMap::new();
members.insert(1, BasicNode::default());
raft.initialize(members).await.unwrap();
raft.wait(Some(Duration::from_secs(10)))
.state(ServerState::Leader, "single member becomes leader")
.await
.unwrap();
for i in 0..5 {
raft.client_write(del(i)).await.unwrap();
}
raft.shutdown().await.unwrap();
}
let mut reopened = DurableLogStore::open(dir.path()).unwrap();
let state = reopened.get_log_state().await.unwrap();
let last = state
.last_log_id
.expect("a restarted node recovers its log tail");
assert!(
last.index >= 5,
"the five committed writes survived the restart (last index {})",
last.index
);
let entries = reopened.try_get_log_entries(0..=last.index).await.unwrap();
assert!(
entries.len() >= 5,
"the durable log replays its entries after a restart"
);
}
#[tokio::test]
async fn replay_matches_a_reference_model_for_random_op_sequences() {
use std::collections::BTreeSet;
use openraft::{CommittedLeaderId, EntryPayload, LogId};
fn log_entry(index: u64) -> Entry<TypeConfig> {
Entry {
log_id: LogId::new(CommittedLeaderId::new(1, 1), index),
payload: EntryPayload::Normal(del(index)),
}
}
let dir = tempfile::tempdir().unwrap();
let mut model: BTreeSet<u64> = BTreeSet::new();
let mut model_purged: Option<u64> = None;
let mut model_vote: Option<Vote<u64>> = None;
let mut next = 1u64;
let mut seed = 0x9E37_79B9_7F4A_7C15u64;
let mut rng = move || {
seed ^= seed << 13;
seed ^= seed >> 7;
seed ^= seed << 17;
seed
};
{
let mut s = DurableLogStore::open(dir.path()).unwrap();
for _ in 0..120 {
match rng() % 5 {
0 | 1 => {
let n = (rng() % 3) + 1;
let mut batch = Vec::new();
for _ in 0..n {
batch.push(log_entry(next));
model.insert(next);
next += 1;
}
s.append_entries(batch).await.unwrap();
}
2 if next > 1 => {
let at = (rng() % next).max(1);
s.truncate(LogId::new(CommittedLeaderId::new(1, 1), at))
.await
.unwrap();
model.retain(|&i| i < at);
next = at;
}
3 => {
if let Some(&max) = model.iter().next_back() {
let upto = rng() % (max + 1);
s.purge(LogId::new(CommittedLeaderId::new(1, 1), upto))
.await
.unwrap();
model.retain(|&i| i > upto);
model_purged = Some(upto);
}
}
_ => {
let v = Vote::new(rng() % 10, rng() % 3);
s.save_vote(&v).await.unwrap();
model_vote = Some(v);
}
}
}
}
let mut s = DurableLogStore::open(dir.path()).unwrap();
assert_eq!(
s.read_vote().await.unwrap(),
model_vote,
"vote replays exactly"
);
let state = s.get_log_state().await.unwrap();
assert_eq!(
state.last_purged_log_id.map(|l| l.index),
model_purged,
"last-purged replays exactly"
);
let indices: BTreeSet<u64> = s
.try_get_log_entries(0..=u64::MAX)
.await
.unwrap()
.iter()
.map(|e| e.get_log_id().index)
.collect();
assert_eq!(
indices, model,
"the log replays to exactly the same entries"
);
}
}