openraft_memstore/
lib.rs

1#![deny(unused_crate_dependencies)]
2#![deny(unused_qualifications)]
3
4#[cfg(test)]
5mod test;
6
7use std::collections::BTreeMap;
8use std::collections::HashMap;
9use std::fmt::Debug;
10use std::io::Cursor;
11use std::ops::RangeBounds;
12use std::sync::atomic::AtomicBool;
13use std::sync::atomic::Ordering;
14use std::sync::Arc;
15use std::sync::Mutex;
16
17use openraft::storage::LogState;
18use openraft::storage::RaftLogReader;
19use openraft::storage::RaftSnapshotBuilder;
20use openraft::storage::Snapshot;
21use openraft::Entry;
22use openraft::EntryPayload;
23use openraft::LogId;
24use openraft::OptionalSend;
25use openraft::RaftLogId;
26use openraft::RaftStorage;
27use openraft::RaftTypeConfig;
28use openraft::SnapshotMeta;
29use openraft::StorageError;
30use openraft::StorageIOError;
31use openraft::StoredMembership;
32use openraft::Vote;
33use serde::Deserialize;
34use serde::Serialize;
35use tokio::sync::RwLock;
36use tokio::time::Duration;
37
38/// The application data request type which the `MemStore` works with.
39///
40/// Conceptually, for demo purposes, this represents an update to a client's status info,
41/// returning the previously recorded status.
42#[derive(Serialize, Deserialize, Debug, Clone)]
43pub struct ClientRequest {
44    /// The ID of the client which has sent the request.
45    pub client: String,
46
47    /// The serial number of this request.
48    pub serial: u64,
49
50    /// A string describing the status of the client. For a real application, this should probably
51    /// be an enum representing all of the various types of requests / operations which a client
52    /// can perform.
53    pub status: String,
54}
55
56/// Helper trait to build `ClientRequest` for `MemStore` in generic test code.
57pub trait IntoMemClientRequest<T> {
58    fn make_request(client_id: impl ToString, serial: u64) -> T;
59}
60
61impl IntoMemClientRequest<ClientRequest> for ClientRequest {
62    fn make_request(client_id: impl ToString, serial: u64) -> Self {
63        Self {
64            client: client_id.to_string(),
65            serial,
66            status: format!("request-{}", serial),
67        }
68    }
69}
70
71/// The application data response type which the `MemStore` works with.
72#[derive(Serialize, Deserialize, Debug, Clone)]
73pub struct ClientResponse(pub Option<String>);
74
75pub type MemNodeId = u64;
76
77openraft::declare_raft_types!(
78    /// Declare the type configuration for `MemStore`.
79    pub TypeConfig:
80        D = ClientRequest,
81        R = ClientResponse,
82        Node = (),
83);
84
85/// The application snapshot type which the `MemStore` works with.
86#[derive(Debug)]
87pub struct MemStoreSnapshot {
88    pub meta: SnapshotMeta<MemNodeId, ()>,
89
90    /// The data of the state machine at the time of this snapshot.
91    pub data: Vec<u8>,
92}
93
94/// The state machine of the `MemStore`.
95#[derive(Serialize, Deserialize, Debug, Default, Clone)]
96pub struct MemStoreStateMachine {
97    pub last_applied_log: Option<LogId<MemNodeId>>,
98
99    pub last_membership: StoredMembership<MemNodeId, ()>,
100
101    /// A mapping of client IDs to their state info.
102    pub client_serial_responses: HashMap<String, (u64, Option<String>)>,
103    /// The current status of a client by ID.
104    pub client_status: HashMap<String, String>,
105}
106
107#[derive(Debug, Clone)]
108#[derive(PartialEq, Eq)]
109#[derive(PartialOrd, Ord)]
110pub enum BlockOperation {
111    /// Block building a snapshot but does not hold a lock on the state machine.
112    /// This will prevent building snapshot returning but should not block applying entries.
113    DelayBuildingSnapshot,
114    BuildSnapshot,
115    PurgeLog,
116}
117
118/// An in-memory storage system implementing the `RaftStorage` trait.
119pub struct MemStore {
120    last_purged_log_id: RwLock<Option<LogId<MemNodeId>>>,
121
122    /// Saving committed log id is optional in Openraft.
123    ///
124    /// This flag switches on the saving for testing purposes.
125    pub enable_saving_committed: AtomicBool,
126
127    committed: RwLock<Option<LogId<MemNodeId>>>,
128
129    /// The Raft log. Logs are stored in serialized json.
130    log: RwLock<BTreeMap<u64, String>>,
131
132    /// The Raft state machine.
133    sm: RwLock<MemStoreStateMachine>,
134
135    /// Block operations for testing purposes.
136    block: Mutex<BTreeMap<BlockOperation, Duration>>,
137
138    /// The current hard state.
139    vote: RwLock<Option<Vote<MemNodeId>>>,
140
141    snapshot_idx: Arc<Mutex<u64>>,
142
143    /// The current snapshot.
144    current_snapshot: RwLock<Option<MemStoreSnapshot>>,
145}
146
147impl MemStore {
148    /// Create a new `MemStore` instance.
149    pub fn new() -> Self {
150        let log = RwLock::new(BTreeMap::new());
151        let sm = RwLock::new(MemStoreStateMachine::default());
152        let current_snapshot = RwLock::new(None);
153
154        Self {
155            last_purged_log_id: RwLock::new(None),
156            enable_saving_committed: AtomicBool::new(true),
157            committed: RwLock::new(None),
158            log,
159            sm,
160            block: Mutex::new(BTreeMap::new()),
161            vote: RwLock::new(None),
162            snapshot_idx: Arc::new(Mutex::new(0)),
163            current_snapshot,
164        }
165    }
166
167    pub async fn new_async() -> Arc<Self> {
168        Arc::new(Self::new())
169    }
170
171    /// Remove the current snapshot.
172    ///
173    /// This method is only used for testing purposes.
174    pub async fn drop_snapshot(&self) {
175        let mut current = self.current_snapshot.write().await;
176        *current = None;
177    }
178
179    /// Get a handle to the state machine for testing purposes.
180    pub async fn get_state_machine(&self) -> MemStoreStateMachine {
181        self.sm.write().await.clone()
182    }
183
184    /// Clear the state machine for testing purposes.
185    pub async fn clear_state_machine(&self) {
186        let mut sm = self.sm.write().await;
187        *sm = MemStoreStateMachine::default();
188    }
189
190    /// Block an operation for testing purposes.
191    pub fn set_blocking(&self, block: BlockOperation, d: Duration) {
192        self.block.lock().unwrap().insert(block, d);
193    }
194
195    /// Get the blocking flag for an operation.
196    pub fn get_blocking(&self, block: &BlockOperation) -> Option<Duration> {
197        self.block.lock().unwrap().get(block).cloned()
198    }
199
200    /// Clear a blocking flag for an operation.
201    pub fn clear_blocking(&mut self, block: BlockOperation) {
202        self.block.lock().unwrap().remove(&block);
203    }
204}
205
206impl Default for MemStore {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl RaftLogReader<TypeConfig> for Arc<MemStore> {
213    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
214        &mut self,
215        range: RB,
216    ) -> Result<Vec<Entry<TypeConfig>>, StorageError<MemNodeId>> {
217        let mut entries = vec![];
218        {
219            let log = self.log.read().await;
220            for (_, serialized) in log.range(range.clone()) {
221                let ent = serde_json::from_str(serialized).map_err(|e| StorageIOError::read_logs(&e))?;
222                entries.push(ent);
223            }
224        };
225
226        Ok(entries)
227    }
228}
229
230impl RaftSnapshotBuilder<TypeConfig> for Arc<MemStore> {
231    #[tracing::instrument(level = "trace", skip(self))]
232    async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<MemNodeId>> {
233        let data;
234        let last_applied_log;
235        let last_membership;
236
237        if let Some(d) = self.get_blocking(&BlockOperation::DelayBuildingSnapshot) {
238            tracing::info!(?d, "delay snapshot build");
239            tokio::time::sleep(d).await;
240        }
241
242        {
243            // Serialize the data of the state machine.
244            let sm = self.sm.read().await;
245            data = serde_json::to_vec(&*sm).map_err(|e| StorageIOError::read_state_machine(&e))?;
246
247            last_applied_log = sm.last_applied_log;
248            last_membership = sm.last_membership.clone();
249
250            if let Some(d) = self.get_blocking(&BlockOperation::BuildSnapshot) {
251                tracing::info!(?d, "blocking snapshot build");
252                tokio::time::sleep(d).await;
253            }
254        }
255
256        let snapshot_size = data.len();
257
258        let snapshot_idx = {
259            let mut l = self.snapshot_idx.lock().unwrap();
260            *l += 1;
261            *l
262        };
263
264        let snapshot_id = if let Some(last) = last_applied_log {
265            format!("{}-{}-{}", last.leader_id, last.index, snapshot_idx)
266        } else {
267            format!("--{}", snapshot_idx)
268        };
269
270        let meta = SnapshotMeta {
271            last_log_id: last_applied_log,
272            last_membership,
273            snapshot_id,
274        };
275
276        let snapshot = MemStoreSnapshot {
277            meta: meta.clone(),
278            data: data.clone(),
279        };
280
281        {
282            let mut current_snapshot = self.current_snapshot.write().await;
283            *current_snapshot = Some(snapshot);
284        }
285
286        tracing::info!(snapshot_size, "log compaction complete");
287
288        Ok(Snapshot {
289            meta,
290            snapshot: Box::new(Cursor::new(data)),
291        })
292    }
293}
294
295impl RaftStorage<TypeConfig> for Arc<MemStore> {
296    async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<MemNodeId>> {
297        let log = self.log.read().await;
298        let last_serialized = log.iter().next_back().map(|(_, ent)| ent);
299
300        let last = match last_serialized {
301            None => None,
302            Some(serialized) => {
303                let ent: Entry<TypeConfig> =
304                    serde_json::from_str(serialized).map_err(|e| StorageIOError::read_logs(&e))?;
305                Some(*ent.get_log_id())
306            }
307        };
308
309        let last_purged = *self.last_purged_log_id.read().await;
310
311        let last = match last {
312            None => last_purged,
313            Some(x) => Some(x),
314        };
315
316        Ok(LogState {
317            last_purged_log_id: last_purged,
318            last_log_id: last,
319        })
320    }
321
322    #[tracing::instrument(level = "trace", skip(self))]
323    async fn save_vote(&mut self, vote: &Vote<MemNodeId>) -> Result<(), StorageError<MemNodeId>> {
324        tracing::debug!(?vote, "save_vote");
325        let mut h = self.vote.write().await;
326
327        *h = Some(*vote);
328        Ok(())
329    }
330
331    async fn read_vote(&mut self) -> Result<Option<Vote<MemNodeId>>, StorageError<MemNodeId>> {
332        Ok(*self.vote.read().await)
333    }
334
335    async fn save_committed(&mut self, committed: Option<LogId<MemNodeId>>) -> Result<(), StorageError<MemNodeId>> {
336        let enabled = self.enable_saving_committed.load(Ordering::Relaxed);
337        tracing::debug!(?committed, "save_committed, enabled: {}", enabled);
338        if !enabled {
339            return Ok(());
340        }
341        let mut c = self.committed.write().await;
342        *c = committed;
343        Ok(())
344    }
345
346    async fn read_committed(&mut self) -> Result<Option<LogId<MemNodeId>>, StorageError<MemNodeId>> {
347        let enabled = self.enable_saving_committed.load(Ordering::Relaxed);
348        tracing::debug!("read_committed, enabled: {}", enabled);
349        if !enabled {
350            return Ok(None);
351        }
352
353        Ok(*self.committed.read().await)
354    }
355
356    async fn last_applied_state(
357        &mut self,
358    ) -> Result<(Option<LogId<MemNodeId>>, StoredMembership<MemNodeId, ()>), StorageError<MemNodeId>> {
359        let sm = self.sm.read().await;
360        Ok((sm.last_applied_log, sm.last_membership.clone()))
361    }
362
363    #[tracing::instrument(level = "debug", skip(self))]
364    async fn delete_conflict_logs_since(&mut self, log_id: LogId<MemNodeId>) -> Result<(), StorageError<MemNodeId>> {
365        tracing::debug!("delete_log: [{:?}, +oo)", log_id);
366
367        {
368            let mut log = self.log.write().await;
369
370            let keys = log.range(log_id.index..).map(|(k, _v)| *k).collect::<Vec<_>>();
371            for key in keys {
372                log.remove(&key);
373            }
374        }
375
376        Ok(())
377    }
378
379    #[tracing::instrument(level = "debug", skip_all)]
380    async fn purge_logs_upto(&mut self, log_id: LogId<MemNodeId>) -> Result<(), StorageError<MemNodeId>> {
381        tracing::debug!("purge_log_upto: {:?}", log_id);
382
383        if let Some(d) = self.get_blocking(&BlockOperation::PurgeLog) {
384            tracing::info!(?d, "block purging log");
385            tokio::time::sleep(d).await;
386        }
387
388        {
389            let mut ld = self.last_purged_log_id.write().await;
390            assert!(*ld <= Some(log_id));
391            *ld = Some(log_id);
392        }
393
394        {
395            let mut log = self.log.write().await;
396
397            let keys = log.range(..=log_id.index).map(|(k, _v)| *k).collect::<Vec<_>>();
398            for key in keys {
399                log.remove(&key);
400            }
401        }
402
403        Ok(())
404    }
405
406    #[tracing::instrument(level = "trace", skip(self, entries))]
407    async fn append_to_log<I>(&mut self, entries: I) -> Result<(), StorageError<MemNodeId>>
408    where I: IntoIterator<Item = Entry<TypeConfig>> + OptionalSend {
409        let mut log = self.log.write().await;
410        for entry in entries {
411            let s =
412                serde_json::to_string(&entry).map_err(|e| StorageIOError::write_log_entry(*entry.get_log_id(), &e))?;
413            log.insert(entry.log_id.index, s);
414        }
415        Ok(())
416    }
417
418    #[tracing::instrument(level = "trace", skip(self, entries))]
419    async fn apply_to_state_machine(
420        &mut self,
421        entries: &[Entry<TypeConfig>],
422    ) -> Result<Vec<ClientResponse>, StorageError<MemNodeId>> {
423        let mut res = Vec::with_capacity(entries.len());
424
425        let mut sm = self.sm.write().await;
426
427        for entry in entries {
428            tracing::debug!(%entry.log_id, "replicate to sm");
429
430            sm.last_applied_log = Some(entry.log_id);
431
432            match entry.payload {
433                EntryPayload::Blank => res.push(ClientResponse(None)),
434                EntryPayload::Normal(ref data) => {
435                    if let Some((serial, r)) = sm.client_serial_responses.get(&data.client) {
436                        if serial == &data.serial {
437                            res.push(ClientResponse(r.clone()));
438                            continue;
439                        }
440                    }
441                    let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
442                    sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
443                    res.push(ClientResponse(previous));
444                }
445                EntryPayload::Membership(ref mem) => {
446                    sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone());
447                    res.push(ClientResponse(None))
448                }
449            };
450        }
451        Ok(res)
452    }
453
454    #[tracing::instrument(level = "trace", skip(self))]
455    async fn begin_receiving_snapshot(
456        &mut self,
457    ) -> Result<Box<<TypeConfig as RaftTypeConfig>::SnapshotData>, StorageError<MemNodeId>> {
458        Ok(Box::new(Cursor::new(Vec::new())))
459    }
460
461    #[tracing::instrument(level = "trace", skip(self, snapshot))]
462    async fn install_snapshot(
463        &mut self,
464        meta: &SnapshotMeta<MemNodeId, ()>,
465        snapshot: Box<<TypeConfig as RaftTypeConfig>::SnapshotData>,
466    ) -> Result<(), StorageError<MemNodeId>> {
467        tracing::info!(
468            { snapshot_size = snapshot.get_ref().len() },
469            "decoding snapshot for installation"
470        );
471
472        let new_snapshot = MemStoreSnapshot {
473            meta: meta.clone(),
474            data: snapshot.into_inner(),
475        };
476
477        {
478            let t = &new_snapshot.data;
479            let y = std::str::from_utf8(t).unwrap();
480            tracing::debug!("SNAP META:{:?}", meta);
481            tracing::debug!("JSON SNAP DATA:{}", y);
482        }
483
484        // Update the state machine.
485        {
486            let new_sm: MemStoreStateMachine = serde_json::from_slice(&new_snapshot.data)
487                .map_err(|e| StorageIOError::read_snapshot(Some(new_snapshot.meta.signature()), &e))?;
488            let mut sm = self.sm.write().await;
489            *sm = new_sm;
490        }
491
492        // Update current snapshot.
493        let mut current_snapshot = self.current_snapshot.write().await;
494        *current_snapshot = Some(new_snapshot);
495        Ok(())
496    }
497
498    #[tracing::instrument(level = "trace", skip(self))]
499    async fn get_current_snapshot(&mut self) -> Result<Option<Snapshot<TypeConfig>>, StorageError<MemNodeId>> {
500        match &*self.current_snapshot.read().await {
501            Some(snapshot) => {
502                let data = snapshot.data.clone();
503                Ok(Some(Snapshot {
504                    meta: snapshot.meta.clone(),
505                    snapshot: Box::new(Cursor::new(data)),
506                }))
507            }
508            None => Ok(None),
509        }
510    }
511
512    type LogReader = Self;
513    type SnapshotBuilder = Self;
514
515    async fn get_log_reader(&mut self) -> Self::LogReader {
516        self.clone()
517    }
518
519    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
520        self.clone()
521    }
522}