use std::collections::BTreeMap;
use std::fmt::Debug;
use std::io::Cursor;
use std::ops::RangeBounds;
use std::sync::Arc;
use openraft::storage::LogState;
use openraft::{
Entry, EntryPayload, LogId, RaftLogReader, RaftSnapshotBuilder, RaftStorage, Snapshot,
SnapshotMeta, StorageError, StoredMembership, Vote,
};
use tracing::info;
use super::state_machine::{apply_command, CoordinatorState};
use super::{ClusterResponse, NodeId, RaftNode, TypeConfig};
pub type SharedCoordinatorState = Arc<std::sync::RwLock<CoordinatorState>>;
type SharedLog = Arc<std::sync::RwLock<BTreeMap<u64, Entry<TypeConfig>>>>;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct StateMachineSnapshot {
last_applied_log: Option<LogId<NodeId>>,
last_membership: StoredMembership<NodeId, RaftNode>,
state: CoordinatorState,
}
#[derive(Debug)]
pub struct MemStore {
vote: Option<Vote<NodeId>>,
log: SharedLog,
last_purged: Option<LogId<NodeId>>,
last_applied_log: Option<LogId<NodeId>>,
last_membership: StoredMembership<NodeId, RaftNode>,
pub state: CoordinatorState,
shared_state: Option<SharedCoordinatorState>,
}
impl Default for MemStore {
fn default() -> Self {
Self::new()
}
}
impl MemStore {
pub fn new() -> Self {
Self {
vote: None,
log: Arc::new(std::sync::RwLock::new(BTreeMap::new())),
last_purged: None,
last_applied_log: None,
last_membership: StoredMembership::default(),
state: CoordinatorState::default(),
shared_state: None,
}
}
pub fn with_shared_state() -> (Self, SharedCoordinatorState) {
let shared = Arc::new(std::sync::RwLock::new(CoordinatorState::default()));
let store = Self {
vote: None,
log: Arc::new(std::sync::RwLock::new(BTreeMap::new())),
last_purged: None,
last_applied_log: None,
last_membership: StoredMembership::default(),
state: CoordinatorState::default(),
shared_state: Some(shared.clone()),
};
(store, shared)
}
fn publish_state(&self) {
if let Some(ref shared) = self.shared_state {
if let Ok(mut state) = shared.write() {
*state = self.state.clone();
}
}
}
}
#[derive(Debug)]
pub struct LogReader {
log: SharedLog,
}
impl RaftLogReader<TypeConfig> for LogReader {
async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + Send>(
&mut self,
range: RB,
) -> Result<Vec<Entry<TypeConfig>>, StorageError<NodeId>> {
let log = self.log.read().unwrap_or_else(|e| e.into_inner());
let entries: Vec<_> = log.range(range).map(|(_, entry)| entry.clone()).collect();
Ok(entries)
}
}
impl RaftLogReader<TypeConfig> for MemStore {
async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + Send>(
&mut self,
range: RB,
) -> Result<Vec<Entry<TypeConfig>>, StorageError<NodeId>> {
let log = self.log.read().unwrap_or_else(|e| e.into_inner());
let entries: Vec<_> = log.range(range).map(|(_, entry)| entry.clone()).collect();
Ok(entries)
}
}
#[derive(Debug)]
pub struct SnapshotBuilder {
snapshot: StateMachineSnapshot,
}
impl RaftSnapshotBuilder<TypeConfig> for SnapshotBuilder {
async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<NodeId>> {
let data = serde_json::to_vec(&self.snapshot).map_err(|e| StorageError::IO {
source: openraft::StorageIOError::read_state_machine(&e),
})?;
let last_applied = self.snapshot.last_applied_log;
let membership = self.snapshot.last_membership.clone();
let snapshot_id = last_applied
.map(|id| format!("{}-{}", id.leader_id, id.index))
.unwrap_or_else(|| "0-0".to_string());
let meta = SnapshotMeta {
last_log_id: last_applied,
last_membership: membership,
snapshot_id,
};
Ok(Snapshot {
meta,
snapshot: Box::new(Cursor::new(data)),
})
}
}
impl RaftStorage<TypeConfig> for MemStore {
type LogReader = LogReader;
type SnapshotBuilder = SnapshotBuilder;
async fn save_vote(&mut self, vote: &Vote<NodeId>) -> Result<(), StorageError<NodeId>> {
self.vote = Some(*vote);
Ok(())
}
async fn read_vote(&mut self) -> Result<Option<Vote<NodeId>>, StorageError<NodeId>> {
Ok(self.vote)
}
async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<NodeId>> {
use openraft::RaftLogId;
let log = self.log.read().unwrap_or_else(|e| e.into_inner());
let last = log.values().last().map(|e| *e.get_log_id());
Ok(LogState {
last_purged_log_id: self.last_purged,
last_log_id: last,
})
}
async fn get_log_reader(&mut self) -> Self::LogReader {
LogReader {
log: self.log.clone(),
}
}
async fn append_to_log<I>(&mut self, entries: I) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
{
use openraft::RaftLogId;
let mut log = self.log.write().unwrap_or_else(|e| e.into_inner());
for entry in entries {
let idx = entry.get_log_id().index;
log.insert(idx, entry);
}
Ok(())
}
async fn delete_conflict_logs_since(
&mut self,
log_id: LogId<NodeId>,
) -> Result<(), StorageError<NodeId>> {
let mut log = self.log.write().unwrap_or_else(|e| e.into_inner());
let to_remove: Vec<u64> = log.range(log_id.index..).map(|(k, _)| *k).collect();
for k in to_remove {
log.remove(&k);
}
Ok(())
}
async fn purge_logs_upto(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
let mut log = self.log.write().unwrap_or_else(|e| e.into_inner());
let to_remove: Vec<u64> = log.range(..=log_id.index).map(|(k, _)| *k).collect();
for k in to_remove {
log.remove(&k);
}
drop(log);
self.last_purged = Some(log_id);
Ok(())
}
async fn last_applied_state(
&mut self,
) -> Result<(Option<LogId<NodeId>>, StoredMembership<NodeId, RaftNode>), StorageError<NodeId>>
{
Ok((self.last_applied_log, self.last_membership.clone()))
}
async fn apply_to_state_machine(
&mut self,
entries: &[Entry<TypeConfig>],
) -> Result<Vec<ClusterResponse>, StorageError<NodeId>> {
use openraft::RaftLogId;
let mut responses = Vec::with_capacity(entries.len());
for entry in entries {
self.last_applied_log = Some(*entry.get_log_id());
match &entry.payload {
EntryPayload::Blank => {
responses.push(ClusterResponse::Ok);
}
EntryPayload::Normal(cmd) => {
let resp = apply_command(&mut self.state, cmd.clone());
responses.push(resp);
}
EntryPayload::Membership(mem) => {
self.last_membership =
StoredMembership::new(Some(*entry.get_log_id()), mem.clone());
responses.push(ClusterResponse::Ok);
}
}
}
self.publish_state();
Ok(responses)
}
async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
SnapshotBuilder {
snapshot: StateMachineSnapshot {
last_applied_log: self.last_applied_log,
last_membership: self.last_membership.clone(),
state: self.state.clone(),
},
}
}
async fn begin_receiving_snapshot(
&mut self,
) -> Result<Box<Cursor<Vec<u8>>>, StorageError<NodeId>> {
Ok(Box::new(Cursor::new(Vec::new())))
}
async fn install_snapshot(
&mut self,
meta: &SnapshotMeta<NodeId, RaftNode>,
snapshot: Box<Cursor<Vec<u8>>>,
) -> Result<(), StorageError<NodeId>> {
let data = snapshot.into_inner();
let snap: StateMachineSnapshot =
serde_json::from_slice(&data).map_err(|e| StorageError::IO {
source: openraft::StorageIOError::read_state_machine(&e),
})?;
self.last_applied_log = meta.last_log_id;
self.last_membership = meta.last_membership.clone();
self.state = snap.state;
self.publish_state();
info!("Installed Raft snapshot at {:?}", meta.last_log_id);
Ok(())
}
async fn get_current_snapshot(
&mut self,
) -> Result<Option<Snapshot<TypeConfig>>, StorageError<NodeId>> {
let mut builder = self.get_snapshot_builder().await;
let snapshot = builder.build_snapshot().await?;
Ok(Some(snapshot))
}
}