1#![deny(unused_crate_dependencies)]
2#![deny(unused_qualifications)]
3
4#[cfg(test)]
5mod test;
6
7use std::collections::BTreeMap;
8use std::collections::HashMap;
9use std::fmt::Debug;
10use std::io::Cursor;
11use std::ops::RangeBounds;
12use std::sync::atomic::AtomicBool;
13use std::sync::atomic::Ordering;
14use std::sync::Arc;
15use std::sync::Mutex;
16
17use openraft::storage::LogState;
18use openraft::storage::RaftLogReader;
19use openraft::storage::RaftSnapshotBuilder;
20use openraft::storage::Snapshot;
21use openraft::Entry;
22use openraft::EntryPayload;
23use openraft::LogId;
24use openraft::OptionalSend;
25use openraft::RaftLogId;
26use openraft::RaftStorage;
27use openraft::RaftTypeConfig;
28use openraft::SnapshotMeta;
29use openraft::StorageError;
30use openraft::StorageIOError;
31use openraft::StoredMembership;
32use openraft::Vote;
33use serde::Deserialize;
34use serde::Serialize;
35use tokio::sync::RwLock;
36use tokio::time::Duration;
37
38#[derive(Serialize, Deserialize, Debug, Clone)]
43pub struct ClientRequest {
44 pub client: String,
46
47 pub serial: u64,
49
50 pub status: String,
54}
55
56pub trait IntoMemClientRequest<T> {
58 fn make_request(client_id: impl ToString, serial: u64) -> T;
59}
60
61impl IntoMemClientRequest<ClientRequest> for ClientRequest {
62 fn make_request(client_id: impl ToString, serial: u64) -> Self {
63 Self {
64 client: client_id.to_string(),
65 serial,
66 status: format!("request-{}", serial),
67 }
68 }
69}
70
71#[derive(Serialize, Deserialize, Debug, Clone)]
73pub struct ClientResponse(pub Option<String>);
74
75pub type MemNodeId = u64;
76
77openraft::declare_raft_types!(
78 pub TypeConfig:
80 D = ClientRequest,
81 R = ClientResponse,
82 Node = (),
83);
84
85#[derive(Debug)]
87pub struct MemStoreSnapshot {
88 pub meta: SnapshotMeta<MemNodeId, ()>,
89
90 pub data: Vec<u8>,
92}
93
94#[derive(Serialize, Deserialize, Debug, Default, Clone)]
96pub struct MemStoreStateMachine {
97 pub last_applied_log: Option<LogId<MemNodeId>>,
98
99 pub last_membership: StoredMembership<MemNodeId, ()>,
100
101 pub client_serial_responses: HashMap<String, (u64, Option<String>)>,
103 pub client_status: HashMap<String, String>,
105}
106
107#[derive(Debug, Clone)]
108#[derive(PartialEq, Eq)]
109#[derive(PartialOrd, Ord)]
110pub enum BlockOperation {
111 DelayBuildingSnapshot,
114 BuildSnapshot,
115 PurgeLog,
116}
117
118pub struct MemStore {
120 last_purged_log_id: RwLock<Option<LogId<MemNodeId>>>,
121
122 pub enable_saving_committed: AtomicBool,
126
127 committed: RwLock<Option<LogId<MemNodeId>>>,
128
129 log: RwLock<BTreeMap<u64, String>>,
131
132 sm: RwLock<MemStoreStateMachine>,
134
135 block: Mutex<BTreeMap<BlockOperation, Duration>>,
137
138 vote: RwLock<Option<Vote<MemNodeId>>>,
140
141 snapshot_idx: Arc<Mutex<u64>>,
142
143 current_snapshot: RwLock<Option<MemStoreSnapshot>>,
145}
146
147impl MemStore {
148 pub fn new() -> Self {
150 let log = RwLock::new(BTreeMap::new());
151 let sm = RwLock::new(MemStoreStateMachine::default());
152 let current_snapshot = RwLock::new(None);
153
154 Self {
155 last_purged_log_id: RwLock::new(None),
156 enable_saving_committed: AtomicBool::new(true),
157 committed: RwLock::new(None),
158 log,
159 sm,
160 block: Mutex::new(BTreeMap::new()),
161 vote: RwLock::new(None),
162 snapshot_idx: Arc::new(Mutex::new(0)),
163 current_snapshot,
164 }
165 }
166
167 pub async fn new_async() -> Arc<Self> {
168 Arc::new(Self::new())
169 }
170
171 pub async fn drop_snapshot(&self) {
175 let mut current = self.current_snapshot.write().await;
176 *current = None;
177 }
178
179 pub async fn get_state_machine(&self) -> MemStoreStateMachine {
181 self.sm.write().await.clone()
182 }
183
184 pub async fn clear_state_machine(&self) {
186 let mut sm = self.sm.write().await;
187 *sm = MemStoreStateMachine::default();
188 }
189
190 pub fn set_blocking(&self, block: BlockOperation, d: Duration) {
192 self.block.lock().unwrap().insert(block, d);
193 }
194
195 pub fn get_blocking(&self, block: &BlockOperation) -> Option<Duration> {
197 self.block.lock().unwrap().get(block).cloned()
198 }
199
200 pub fn clear_blocking(&mut self, block: BlockOperation) {
202 self.block.lock().unwrap().remove(&block);
203 }
204}
205
206impl Default for MemStore {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl RaftLogReader<TypeConfig> for Arc<MemStore> {
213 async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
214 &mut self,
215 range: RB,
216 ) -> Result<Vec<Entry<TypeConfig>>, StorageError<MemNodeId>> {
217 let mut entries = vec![];
218 {
219 let log = self.log.read().await;
220 for (_, serialized) in log.range(range.clone()) {
221 let ent = serde_json::from_str(serialized).map_err(|e| StorageIOError::read_logs(&e))?;
222 entries.push(ent);
223 }
224 };
225
226 Ok(entries)
227 }
228}
229
230impl RaftSnapshotBuilder<TypeConfig> for Arc<MemStore> {
231 #[tracing::instrument(level = "trace", skip(self))]
232 async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<MemNodeId>> {
233 let data;
234 let last_applied_log;
235 let last_membership;
236
237 if let Some(d) = self.get_blocking(&BlockOperation::DelayBuildingSnapshot) {
238 tracing::info!(?d, "delay snapshot build");
239 tokio::time::sleep(d).await;
240 }
241
242 {
243 let sm = self.sm.read().await;
245 data = serde_json::to_vec(&*sm).map_err(|e| StorageIOError::read_state_machine(&e))?;
246
247 last_applied_log = sm.last_applied_log;
248 last_membership = sm.last_membership.clone();
249
250 if let Some(d) = self.get_blocking(&BlockOperation::BuildSnapshot) {
251 tracing::info!(?d, "blocking snapshot build");
252 tokio::time::sleep(d).await;
253 }
254 }
255
256 let snapshot_size = data.len();
257
258 let snapshot_idx = {
259 let mut l = self.snapshot_idx.lock().unwrap();
260 *l += 1;
261 *l
262 };
263
264 let snapshot_id = if let Some(last) = last_applied_log {
265 format!("{}-{}-{}", last.leader_id, last.index, snapshot_idx)
266 } else {
267 format!("--{}", snapshot_idx)
268 };
269
270 let meta = SnapshotMeta {
271 last_log_id: last_applied_log,
272 last_membership,
273 snapshot_id,
274 };
275
276 let snapshot = MemStoreSnapshot {
277 meta: meta.clone(),
278 data: data.clone(),
279 };
280
281 {
282 let mut current_snapshot = self.current_snapshot.write().await;
283 *current_snapshot = Some(snapshot);
284 }
285
286 tracing::info!(snapshot_size, "log compaction complete");
287
288 Ok(Snapshot {
289 meta,
290 snapshot: Box::new(Cursor::new(data)),
291 })
292 }
293}
294
295impl RaftStorage<TypeConfig> for Arc<MemStore> {
296 async fn get_log_state(&mut self) -> Result<LogState<TypeConfig>, StorageError<MemNodeId>> {
297 let log = self.log.read().await;
298 let last_serialized = log.iter().next_back().map(|(_, ent)| ent);
299
300 let last = match last_serialized {
301 None => None,
302 Some(serialized) => {
303 let ent: Entry<TypeConfig> =
304 serde_json::from_str(serialized).map_err(|e| StorageIOError::read_logs(&e))?;
305 Some(*ent.get_log_id())
306 }
307 };
308
309 let last_purged = *self.last_purged_log_id.read().await;
310
311 let last = match last {
312 None => last_purged,
313 Some(x) => Some(x),
314 };
315
316 Ok(LogState {
317 last_purged_log_id: last_purged,
318 last_log_id: last,
319 })
320 }
321
322 #[tracing::instrument(level = "trace", skip(self))]
323 async fn save_vote(&mut self, vote: &Vote<MemNodeId>) -> Result<(), StorageError<MemNodeId>> {
324 tracing::debug!(?vote, "save_vote");
325 let mut h = self.vote.write().await;
326
327 *h = Some(*vote);
328 Ok(())
329 }
330
331 async fn read_vote(&mut self) -> Result<Option<Vote<MemNodeId>>, StorageError<MemNodeId>> {
332 Ok(*self.vote.read().await)
333 }
334
335 async fn save_committed(&mut self, committed: Option<LogId<MemNodeId>>) -> Result<(), StorageError<MemNodeId>> {
336 let enabled = self.enable_saving_committed.load(Ordering::Relaxed);
337 tracing::debug!(?committed, "save_committed, enabled: {}", enabled);
338 if !enabled {
339 return Ok(());
340 }
341 let mut c = self.committed.write().await;
342 *c = committed;
343 Ok(())
344 }
345
346 async fn read_committed(&mut self) -> Result<Option<LogId<MemNodeId>>, StorageError<MemNodeId>> {
347 let enabled = self.enable_saving_committed.load(Ordering::Relaxed);
348 tracing::debug!("read_committed, enabled: {}", enabled);
349 if !enabled {
350 return Ok(None);
351 }
352
353 Ok(*self.committed.read().await)
354 }
355
356 async fn last_applied_state(
357 &mut self,
358 ) -> Result<(Option<LogId<MemNodeId>>, StoredMembership<MemNodeId, ()>), StorageError<MemNodeId>> {
359 let sm = self.sm.read().await;
360 Ok((sm.last_applied_log, sm.last_membership.clone()))
361 }
362
363 #[tracing::instrument(level = "debug", skip(self))]
364 async fn delete_conflict_logs_since(&mut self, log_id: LogId<MemNodeId>) -> Result<(), StorageError<MemNodeId>> {
365 tracing::debug!("delete_log: [{:?}, +oo)", log_id);
366
367 {
368 let mut log = self.log.write().await;
369
370 let keys = log.range(log_id.index..).map(|(k, _v)| *k).collect::<Vec<_>>();
371 for key in keys {
372 log.remove(&key);
373 }
374 }
375
376 Ok(())
377 }
378
379 #[tracing::instrument(level = "debug", skip_all)]
380 async fn purge_logs_upto(&mut self, log_id: LogId<MemNodeId>) -> Result<(), StorageError<MemNodeId>> {
381 tracing::debug!("purge_log_upto: {:?}", log_id);
382
383 if let Some(d) = self.get_blocking(&BlockOperation::PurgeLog) {
384 tracing::info!(?d, "block purging log");
385 tokio::time::sleep(d).await;
386 }
387
388 {
389 let mut ld = self.last_purged_log_id.write().await;
390 assert!(*ld <= Some(log_id));
391 *ld = Some(log_id);
392 }
393
394 {
395 let mut log = self.log.write().await;
396
397 let keys = log.range(..=log_id.index).map(|(k, _v)| *k).collect::<Vec<_>>();
398 for key in keys {
399 log.remove(&key);
400 }
401 }
402
403 Ok(())
404 }
405
406 #[tracing::instrument(level = "trace", skip(self, entries))]
407 async fn append_to_log<I>(&mut self, entries: I) -> Result<(), StorageError<MemNodeId>>
408 where I: IntoIterator<Item = Entry<TypeConfig>> + OptionalSend {
409 let mut log = self.log.write().await;
410 for entry in entries {
411 let s =
412 serde_json::to_string(&entry).map_err(|e| StorageIOError::write_log_entry(*entry.get_log_id(), &e))?;
413 log.insert(entry.log_id.index, s);
414 }
415 Ok(())
416 }
417
418 #[tracing::instrument(level = "trace", skip(self, entries))]
419 async fn apply_to_state_machine(
420 &mut self,
421 entries: &[Entry<TypeConfig>],
422 ) -> Result<Vec<ClientResponse>, StorageError<MemNodeId>> {
423 let mut res = Vec::with_capacity(entries.len());
424
425 let mut sm = self.sm.write().await;
426
427 for entry in entries {
428 tracing::debug!(%entry.log_id, "replicate to sm");
429
430 sm.last_applied_log = Some(entry.log_id);
431
432 match entry.payload {
433 EntryPayload::Blank => res.push(ClientResponse(None)),
434 EntryPayload::Normal(ref data) => {
435 if let Some((serial, r)) = sm.client_serial_responses.get(&data.client) {
436 if serial == &data.serial {
437 res.push(ClientResponse(r.clone()));
438 continue;
439 }
440 }
441 let previous = sm.client_status.insert(data.client.clone(), data.status.clone());
442 sm.client_serial_responses.insert(data.client.clone(), (data.serial, previous.clone()));
443 res.push(ClientResponse(previous));
444 }
445 EntryPayload::Membership(ref mem) => {
446 sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone());
447 res.push(ClientResponse(None))
448 }
449 };
450 }
451 Ok(res)
452 }
453
454 #[tracing::instrument(level = "trace", skip(self))]
455 async fn begin_receiving_snapshot(
456 &mut self,
457 ) -> Result<Box<<TypeConfig as RaftTypeConfig>::SnapshotData>, StorageError<MemNodeId>> {
458 Ok(Box::new(Cursor::new(Vec::new())))
459 }
460
461 #[tracing::instrument(level = "trace", skip(self, snapshot))]
462 async fn install_snapshot(
463 &mut self,
464 meta: &SnapshotMeta<MemNodeId, ()>,
465 snapshot: Box<<TypeConfig as RaftTypeConfig>::SnapshotData>,
466 ) -> Result<(), StorageError<MemNodeId>> {
467 tracing::info!(
468 { snapshot_size = snapshot.get_ref().len() },
469 "decoding snapshot for installation"
470 );
471
472 let new_snapshot = MemStoreSnapshot {
473 meta: meta.clone(),
474 data: snapshot.into_inner(),
475 };
476
477 {
478 let t = &new_snapshot.data;
479 let y = std::str::from_utf8(t).unwrap();
480 tracing::debug!("SNAP META:{:?}", meta);
481 tracing::debug!("JSON SNAP DATA:{}", y);
482 }
483
484 {
486 let new_sm: MemStoreStateMachine = serde_json::from_slice(&new_snapshot.data)
487 .map_err(|e| StorageIOError::read_snapshot(Some(new_snapshot.meta.signature()), &e))?;
488 let mut sm = self.sm.write().await;
489 *sm = new_sm;
490 }
491
492 let mut current_snapshot = self.current_snapshot.write().await;
494 *current_snapshot = Some(new_snapshot);
495 Ok(())
496 }
497
498 #[tracing::instrument(level = "trace", skip(self))]
499 async fn get_current_snapshot(&mut self) -> Result<Option<Snapshot<TypeConfig>>, StorageError<MemNodeId>> {
500 match &*self.current_snapshot.read().await {
501 Some(snapshot) => {
502 let data = snapshot.data.clone();
503 Ok(Some(Snapshot {
504 meta: snapshot.meta.clone(),
505 snapshot: Box::new(Cursor::new(data)),
506 }))
507 }
508 None => Ok(None),
509 }
510 }
511
512 type LogReader = Self;
513 type SnapshotBuilder = Self;
514
515 async fn get_log_reader(&mut self) -> Self::LogReader {
516 self.clone()
517 }
518
519 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
520 self.clone()
521 }
522}