1use std::io::Cursor;
2use std::sync::{Arc, Mutex};
3
4use openraft::storage::{RaftSnapshotBuilder, RaftStateMachine};
5use openraft::{
6 BasicNode, Entry, EntryPayload, LogId, OptionalSend, Snapshot, SnapshotMeta, StorageError,
7 StoredMembership,
8};
9use serde::{Deserialize, Serialize};
10
11use super::{dec, enc, sread, swrite};
12use crate::queue::Queue;
13use crate::raft::TypeConfig;
14use crate::storage::Storage;
15use crate::types::NodeId;
16use crate::{AppResponse, RedbStore};
17
18const KEY_SNAPSHOT: &str = "sm:snapshot";
19
20#[derive(Default, Clone, Serialize, Deserialize)]
21struct StateMachineData {
22 last_applied: Option<LogId<NodeId>>,
23 last_membership: StoredMembership<NodeId, BasicNode>,
24 queue: Queue,
25}
26
27#[derive(Clone, Serialize, Deserialize)]
28struct StoredSnapshot {
29 meta: SnapshotMeta<NodeId, BasicNode>,
30 data: Vec<u8>,
31}
32
33struct Inner {
34 data: StateMachineData,
35 snapshot_seq: u64,
36}
37
38pub struct StateMachineStore<S = RedbStore> {
39 inner: Arc<Mutex<Inner>>,
40 db: Arc<S>,
41}
42
43impl<S> Clone for StateMachineStore<S> {
44 fn clone(&self) -> Self {
45 Self {
46 inner: Arc::clone(&self.inner),
47 db: Arc::clone(&self.db),
48 }
49 }
50}
51
52impl<S: Storage> StateMachineStore<S> {
53 pub fn new(db: Arc<S>) -> Result<Self, StorageError<NodeId>> {
54 let data = match db.get(KEY_SNAPSHOT).map_err(sread)? {
55 Some(bytes) => {
56 let stored: StoredSnapshot = dec(&bytes)?;
57 dec(&stored.data)?
58 }
59 None => StateMachineData::default(),
60 };
61 Ok(Self {
62 inner: Arc::new(Mutex::new(Inner {
63 data,
64 snapshot_seq: 0,
65 })),
66 db,
67 })
68 }
69
70 pub fn rate_config(&self, topic: &str) -> Option<(u64, u32)> {
71 self.inner.lock().unwrap().data.queue.rate_config(topic)
72 }
73
74 pub fn metrics(&self) -> crate::queue::QueueMetrics {
75 self.inner.lock().unwrap().data.queue.metrics()
76 }
77
78 pub fn has_deliverable(&self, topic: &str, group: &str, now_ms: u64) -> bool {
79 self.inner
80 .lock()
81 .unwrap()
82 .data
83 .queue
84 .has_deliverable(topic, group, now_ms)
85 }
86}
87
88fn persist_snapshot<S: Storage>(
89 db: &S,
90 stored: &StoredSnapshot,
91) -> Result<(), StorageError<NodeId>> {
92 let bytes = enc(stored)?;
93 db.put(KEY_SNAPSHOT, &bytes).map_err(swrite)
94}
95
96impl<S: Storage> RaftStateMachine<TypeConfig> for StateMachineStore<S> {
97 type SnapshotBuilder = SnapshotBuilder<S>;
98
99 async fn applied_state(
100 &mut self,
101 ) -> Result<(Option<LogId<NodeId>>, StoredMembership<NodeId, BasicNode>), StorageError<NodeId>>
102 {
103 let inner = self.inner.lock().unwrap();
104 Ok((inner.data.last_applied, inner.data.last_membership.clone()))
105 }
106
107 async fn apply<I>(&mut self, entries: I) -> Result<Vec<AppResponse>, StorageError<NodeId>>
108 where
109 I: IntoIterator<Item = Entry<TypeConfig>> + OptionalSend,
110 I::IntoIter: OptionalSend,
111 {
112 let mut inner = self.inner.lock().unwrap();
113 let mut responses = Vec::new();
114 for entry in entries {
115 let log_id = entry.log_id;
116 inner.data.last_applied = Some(log_id);
117 let response = match entry.payload {
118 EntryPayload::Blank => AppResponse::NoOp,
119 EntryPayload::Membership(m) => {
120 inner.data.last_membership = StoredMembership::new(Some(log_id), m);
121 AppResponse::NoOp
122 }
123 EntryPayload::Normal(req) => inner.data.queue.apply(req),
124 };
125 responses.push(response);
126 }
127 Ok(responses)
128 }
129
130 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
131 SnapshotBuilder {
132 inner: Arc::clone(&self.inner),
133 db: Arc::clone(&self.db),
134 }
135 }
136
137 async fn begin_receiving_snapshot(
138 &mut self,
139 ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<NodeId>> {
140 Ok(Box::new(Cursor::new(Vec::new())))
141 }
142
143 async fn install_snapshot(
144 &mut self,
145 meta: &SnapshotMeta<NodeId, BasicNode>,
146 snapshot: Box<Cursor<Vec<u8>>>,
147 ) -> Result<(), StorageError<NodeId>> {
148 let bytes = (*snapshot).into_inner();
149 let data: StateMachineData = dec(&bytes)?;
150 let stored = StoredSnapshot {
151 meta: meta.clone(),
152 data: bytes,
153 };
154 persist_snapshot(self.db.as_ref(), &stored)?;
155 self.inner.lock().unwrap().data = data;
156 Ok(())
157 }
158
159 async fn get_current_snapshot(
160 &mut self,
161 ) -> Result<Option<Snapshot<TypeConfig>>, StorageError<NodeId>> {
162 match self.db.get(KEY_SNAPSHOT).map_err(sread)? {
163 Some(bytes) => {
164 let stored: StoredSnapshot = dec(&bytes)?;
165 Ok(Some(Snapshot {
166 meta: stored.meta,
167 snapshot: Box::new(Cursor::new(stored.data)),
168 }))
169 }
170 None => Ok(None),
171 }
172 }
173}
174
175pub struct SnapshotBuilder<S = RedbStore> {
176 inner: Arc<Mutex<Inner>>,
177 db: Arc<S>,
178}
179
180impl<S: Storage> RaftSnapshotBuilder<TypeConfig> for SnapshotBuilder<S> {
181 async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<NodeId>> {
182 let (meta, bytes) = {
183 let mut inner = self.inner.lock().unwrap();
184 let bytes = enc(&inner.data)?;
185 inner.snapshot_seq += 1;
186 let last = inner.data.last_applied;
187 let snapshot_id = match &last {
188 Some(log_id) => format!("{}-{}", log_id.index, inner.snapshot_seq),
189 None => format!("none-{}", inner.snapshot_seq),
190 };
191 let meta = SnapshotMeta {
192 last_log_id: last,
193 last_membership: inner.data.last_membership.clone(),
194 snapshot_id,
195 };
196 (meta, bytes)
197 };
198 let stored = StoredSnapshot { meta, data: bytes };
199 persist_snapshot(self.db.as_ref(), &stored)?;
200 Ok(Snapshot {
201 meta: stored.meta,
202 snapshot: Box::new(Cursor::new(stored.data)),
203 })
204 }
205}