pg_replica 0.2.0

Consensus-driven failover for PostgreSQL (Raft control plane)
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::io::{Cursor, Write};
use std::ops::RangeBounds;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::Mutex as StdMutex;

use openraft::storage::{
    LogFlushed, LogState, RaftLogReader, RaftLogStorage, RaftSnapshotBuilder, RaftStateMachine,
    Snapshot,
};
use openraft::{
    Entry, EntryPayload, LogId, OptionalSend, SnapshotMeta, StorageError, StorageIOError,
    StoredMembership, Vote,
};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

use crate::failover::Decision;
use crate::rtype::{Node, NodeId, TypeConfig};

#[derive(Clone, Default)]
pub struct SharedDecision(Arc<StdMutex<Option<Decision>>>);

impl SharedDecision {
    pub fn current(&self) -> Option<Decision> {
        *self.0.lock().unwrap()
    }
    fn set(&self, decision: Option<Decision>) {
        *self.0.lock().unwrap() = decision;
    }
}

#[derive(Serialize, Deserialize, Default)]
struct LogStoreState {
    vote: Option<Vote<NodeId>>,
    last_purged: Option<LogId<NodeId>>,
    committed: Option<LogId<NodeId>>,
    log: BTreeMap<u64, Entry<TypeConfig>>,
}

#[derive(Clone)]
pub struct LogStore {
    path: PathBuf,
    state: Arc<RwLock<LogStoreState>>,
}

impl LogStore {
    async fn persist(&self) -> Result<(), StorageError<NodeId>> {
        let bytes = {
            let state = self.state.read().await;
            serde_json::to_vec(&*state).map_err(|e| StorageIOError::write_logs(&e))?
        };
        atomic_write(&self.path, &bytes).map_err(|e| StorageIOError::write_logs(&e))?;
        Ok(())
    }
}

impl RaftLogReader<TypeConfig> for LogStore {
    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
        &mut self,
        range: RB,
    ) -> Result<Vec<Entry<TypeConfig>>, StorageError<NodeId>> {
        let state = self.state.read().await;
        Ok(state.log.range(range).map(|(_, entry)| entry.clone()).collect())
    }
}

impl RaftLogStorage<TypeConfig> for LogStore {
    type LogReader = Self;

    async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<NodeId>> {
        let state = self.state.read().await;
        let last = state.log.iter().next_back().map(|(_, entry)| entry.log_id);
        let last_purged = state.last_purged;
        Ok(LogState {
            last_purged_log_id: last_purged,
            last_log_id: last.or(last_purged),
        })
    }

    async fn get_log_reader(&mut self) -> Self::LogReader {
        self.clone()
    }

    async fn save_vote(&mut self, vote: &Vote<NodeId>) -> Result<(), StorageError<NodeId>> {
        {
            let mut state = self.state.write().await;
            state.vote = Some(*vote);
        }
        self.persist().await
    }

    async fn read_vote(&mut self) -> Result<Option<Vote<NodeId>>, StorageError<NodeId>> {
        Ok(self.state.read().await.vote)
    }

    async fn save_committed(
        &mut self,
        committed: Option<LogId<NodeId>>,
    ) -> Result<(), StorageError<NodeId>> {
        {
            let mut state = self.state.write().await;
            state.committed = committed;
        }
        self.persist().await
    }

    async fn read_committed(&mut self) -> Result<Option<LogId<NodeId>>, StorageError<NodeId>> {
        Ok(self.state.read().await.committed)
    }

    async fn append<I>(
        &mut self,
        entries: I,
        callback: LogFlushed<TypeConfig>,
    ) -> Result<(), StorageError<NodeId>>
    where
        I: IntoIterator<Item = Entry<TypeConfig>> + OptionalSend,
        I::IntoIter: OptionalSend,
    {
        {
            let mut state = self.state.write().await;
            for entry in entries {
                state.log.insert(entry.log_id.index, entry);
            }
        }
        self.persist().await?;
        callback.log_io_completed(Ok(()));
        Ok(())
    }

    async fn truncate(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
        {
            let mut state = self.state.write().await;
            let keys: Vec<u64> = state.log.range(log_id.index..).map(|(k, _)| *k).collect();
            for key in keys {
                state.log.remove(&key);
            }
        }
        self.persist().await
    }

    async fn purge(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
        {
            let mut state = self.state.write().await;
            state.last_purged = Some(log_id);
            let keys: Vec<u64> = state.log.range(..=log_id.index).map(|(k, _)| *k).collect();
            for key in keys {
                state.log.remove(&key);
            }
        }
        self.persist().await
    }
}

#[derive(Serialize, Deserialize, Default, Clone)]
struct SmState {
    last_applied: Option<LogId<NodeId>>,
    last_membership: StoredMembership<NodeId, Node>,
    decision: Option<Decision>,
}

struct StoredSnapshot {
    meta: SnapshotMeta<NodeId, Node>,
    data: Vec<u8>,
}

#[derive(Clone)]
pub struct StateMachine {
    path: PathBuf,
    sm: Arc<RwLock<SmState>>,
    snapshot: Arc<RwLock<Option<Arc<StoredSnapshot>>>>,
    snapshot_idx: Arc<StdMutex<u64>>,
    shared: SharedDecision,
}

impl StateMachine {
    async fn persist(&self) -> Result<(), StorageError<NodeId>> {
        let bytes = {
            let sm = self.sm.read().await;
            serde_json::to_vec(&*sm).map_err(|e| StorageIOError::write_state_machine(&e))?
        };
        atomic_write(&self.path, &bytes).map_err(|e| StorageIOError::write_state_machine(&e))?;
        Ok(())
    }
}

impl RaftSnapshotBuilder<TypeConfig> for StateMachine {
    async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<NodeId>> {
        let (data, last_applied, last_membership) = {
            let sm = self.sm.read().await;
            let data =
                serde_json::to_vec(&*sm).map_err(|e| StorageIOError::read_state_machine(&e))?;
            (data, sm.last_applied, sm.last_membership.clone())
        };
        let idx = {
            let mut guard = self.snapshot_idx.lock().unwrap();
            *guard += 1;
            *guard
        };
        let snapshot_id = match last_applied {
            Some(log_id) => format!("{}-{}-{}", log_id.leader_id, log_id.index, idx),
            None => format!("--{}", idx),
        };
        let meta = SnapshotMeta {
            last_log_id: last_applied,
            last_membership,
            snapshot_id,
        };
        {
            let mut current = self.snapshot.write().await;
            *current = Some(Arc::new(StoredSnapshot {
                meta: meta.clone(),
                data: data.clone(),
            }));
        }
        Ok(Snapshot {
            meta,
            snapshot: Box::new(Cursor::new(data)),
        })
    }
}

impl RaftStateMachine<TypeConfig> for StateMachine {
    type SnapshotBuilder = Self;

    async fn applied_state(
        &mut self,
    ) -> Result<(Option<LogId<NodeId>>, StoredMembership<NodeId, Node>), StorageError<NodeId>> {
        let sm = self.sm.read().await;
        Ok((sm.last_applied, sm.last_membership.clone()))
    }

    async fn apply<I>(&mut self, entries: I) -> Result<Vec<()>, StorageError<NodeId>>
    where
        I: IntoIterator<Item = Entry<TypeConfig>> + OptionalSend,
        I::IntoIter: OptionalSend,
    {
        let mut responses = Vec::new();
        let mut changed = false;
        let decision;
        {
            let mut sm = self.sm.write().await;
            for entry in entries {
                sm.last_applied = Some(entry.log_id);
                match entry.payload {
                    EntryPayload::Blank => {}
                    EntryPayload::Normal(value) => {
                        sm.decision = Some(value);
                        changed = true;
                    }
                    EntryPayload::Membership(membership) => {
                        sm.last_membership =
                            StoredMembership::new(Some(entry.log_id), membership);
                    }
                }
                responses.push(());
            }
            decision = sm.decision;
        }
        self.persist().await?;
        if changed {
            self.shared.set(decision);
        }
        Ok(responses)
    }

    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
        self.clone()
    }

    async fn begin_receiving_snapshot(
        &mut self,
    ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<NodeId>> {
        Ok(Box::new(Cursor::new(Vec::new())))
    }

    async fn install_snapshot(
        &mut self,
        meta: &SnapshotMeta<NodeId, Node>,
        snapshot: Box<Cursor<Vec<u8>>>,
    ) -> Result<(), StorageError<NodeId>> {
        let data = snapshot.into_inner();
        let new_sm: SmState = serde_json::from_slice(&data)
            .map_err(|e| StorageIOError::read_snapshot(Some(meta.signature()), &e))?;
        let decision = new_sm.decision;
        {
            let mut sm = self.sm.write().await;
            *sm = new_sm;
        }
        {
            let mut current = self.snapshot.write().await;
            *current = Some(Arc::new(StoredSnapshot {
                meta: meta.clone(),
                data,
            }));
        }
        self.persist().await?;
        self.shared.set(decision);
        Ok(())
    }

    async fn get_current_snapshot(
        &mut self,
    ) -> Result<Option<Snapshot<TypeConfig>>, StorageError<NodeId>> {
        match &*self.snapshot.read().await {
            Some(snapshot) => Ok(Some(Snapshot {
                meta: snapshot.meta.clone(),
                snapshot: Box::new(Cursor::new(snapshot.data.clone())),
            })),
            None => Ok(None),
        }
    }
}

pub fn open(dir: &Path, node_id: NodeId) -> (LogStore, StateMachine, SharedDecision) {
    let log_path = dir.join(format!("raft_log_{}.json", node_id));
    let sm_path = dir.join(format!("raft_sm_{}.json", node_id));
    let log_state: LogStoreState = load(&log_path);
    let sm_state: SmState = load(&sm_path);
    let shared = SharedDecision(Arc::new(StdMutex::new(sm_state.decision)));
    let log_store = LogStore {
        path: log_path,
        state: Arc::new(RwLock::new(log_state)),
    };
    let state_machine = StateMachine {
        path: sm_path,
        sm: Arc::new(RwLock::new(sm_state)),
        snapshot: Arc::new(RwLock::new(None)),
        snapshot_idx: Arc::new(StdMutex::new(0)),
        shared: shared.clone(),
    };
    (log_store, state_machine, shared)
}

fn load<T>(path: &Path) -> T
where
    T: Default + for<'de> Deserialize<'de>,
{
    std::fs::read(path)
        .ok()
        .and_then(|bytes| serde_json::from_slice(&bytes).ok())
        .unwrap_or_default()
}

fn atomic_write(path: &Path, bytes: &[u8]) -> std::io::Result<()> {
    let tmp = path.with_extension("tmp");
    {
        let mut file = std::fs::File::create(&tmp)?;
        file.write_all(bytes)?;
        file.sync_all()?;
    }
    std::fs::rename(&tmp, path)
}

#[cfg(test)]
mod tests {
    use super::*;
    use openraft::testing::{StoreBuilder, Suite};

    struct TestDir(PathBuf);

    impl Drop for TestDir {
        fn drop(&mut self) {
            let _ = std::fs::remove_dir_all(&self.0);
        }
    }

    struct Builder;

    impl StoreBuilder<TypeConfig, LogStore, StateMachine, TestDir> for Builder {
        async fn build(&self) -> Result<(TestDir, LogStore, StateMachine), StorageError<NodeId>> {
            let nanos = std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap()
                .as_nanos();
            let dir = std::env::temp_dir().join(format!("pgr-store-test-{}-{}", std::process::id(), nanos));
            std::fs::create_dir_all(&dir).unwrap();
            let (log_store, state_machine, _shared) = open(&dir, 1);
            Ok((TestDir(dir), log_store, state_machine))
        }
    }

    #[test]
    fn storage_conformance_suite() -> Result<(), StorageError<NodeId>> {
        Suite::test_all(Builder)
    }
}