asteroid_mq/protocol/node/raft/
state_machine.rs1pub mod node;
2pub mod topic;
3
4use std::{
5 io::{self, Cursor},
6 sync::{
7 atomic::{AtomicU64, Ordering},
8 Arc,
9 },
10};
11
12use asteroid_mq_model::codec::BINCODE_CONFIG;
13use node::NodeData;
14use openraft::{
15 storage::RaftStateMachine, EntryPayload, LogId, RaftSnapshotBuilder, RaftTypeConfig, Snapshot,
16 SnapshotMeta, StorageError, StoredMembership,
17};
18use tokio::sync::RwLock;
19
20use crate::{
21 prelude::NodeId,
22 protocol::node::{raft::proposal::ProposalContext, NodeRef},
23};
24
25use super::{raft_node::TcpNode, response::RaftResponse, TypeConfig};
26#[derive(Debug)]
27pub struct StoredSnapshot {
28 pub meta: SnapshotMeta<NodeId, TcpNode>,
29
30 pub data: Vec<u8>,
32}
33#[derive(Debug, Clone, Default)]
34pub struct StateMachineData<C: RaftTypeConfig> {
35 pub last_applied_log: Option<LogId<C::NodeId>>,
36
37 pub last_membership: StoredMembership<C::NodeId, C::Node>,
38
39 pub node: NodeData,
40}
41
42#[derive(Debug)]
45pub struct StateMachineStore {
46 pub state_machine: RwLock<StateMachineData<TypeConfig>>,
48
49 snapshot_idx: AtomicU64,
55
56 current_snapshot: RwLock<Option<StoredSnapshot>>,
58 node_ref: NodeRef,
59}
60
61impl StateMachineStore {
62 pub fn new(node_ref: NodeRef) -> Self {
63 Self {
64 state_machine: RwLock::new(StateMachineData::default()),
65 snapshot_idx: AtomicU64::new(0),
66 current_snapshot: RwLock::new(None),
67 node_ref,
68 }
69 }
70 #[cfg(test)]
71 pub(crate) unsafe fn new_uninitialized() -> Self {
72 Self {
73 state_machine: RwLock::new(StateMachineData::default()),
74 snapshot_idx: AtomicU64::new(0),
75 current_snapshot: RwLock::new(None),
76 node_ref: NodeRef::default(),
77 }
78 }
79}
80impl RaftSnapshotBuilder<TypeConfig> for Arc<StateMachineStore> {
81 #[tracing::instrument(level = "trace", skip(self))]
82 async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<NodeId>> {
83 let state_machine = self.state_machine.read().await;
85 let snapshot = &state_machine.node;
86
87 let last_applied_log = state_machine.last_applied_log;
88 let last_membership = state_machine.last_membership.clone();
89
90 let mut current_snapshot = self.current_snapshot.write().await;
93
94 let snapshot_idx = self.snapshot_idx.fetch_add(1, Ordering::Relaxed) + 1;
95 let snapshot_id = if let Some(last) = last_applied_log {
96 format!("{}-{}-{}", last.leader_id, last.index, snapshot_idx)
97 } else {
98 format!("--{}", snapshot_idx)
99 };
100
101 let meta = SnapshotMeta {
102 last_log_id: last_applied_log,
103 last_membership,
104 snapshot_id,
105 };
106 let bytes = bincode::serde::encode_to_vec(snapshot, BINCODE_CONFIG).unwrap();
107 let stored = StoredSnapshot {
108 meta: meta.clone(),
109 data: bytes.clone(),
110 };
111 *current_snapshot = Some(stored);
112 drop(state_machine);
113 Ok(Snapshot {
114 meta,
115 snapshot: Box::new(Cursor::new(bytes)),
116 })
117 }
118}
119
120impl RaftStateMachine<TypeConfig> for Arc<StateMachineStore> {
121 type SnapshotBuilder = Arc<StateMachineStore>;
122 async fn applied_state(
123 &mut self,
124 ) -> Result<
125 (
126 Option<LogId<<TypeConfig as RaftTypeConfig>::NodeId>>,
127 StoredMembership<
128 <TypeConfig as RaftTypeConfig>::NodeId,
129 <TypeConfig as RaftTypeConfig>::Node,
130 >,
131 ),
132 StorageError<<TypeConfig as RaftTypeConfig>::NodeId>,
133 > {
134 let state_machine = self.state_machine.read().await;
135 Ok((
136 state_machine.last_applied_log,
137 state_machine.last_membership.clone(),
138 ))
139 }
140 #[tracing::instrument(name = "apply", skip_all)]
141 async fn apply<I>(
142 &mut self,
143 entries: I,
144 ) -> Result<
145 Vec<<TypeConfig as RaftTypeConfig>::R>,
146 StorageError<<TypeConfig as RaftTypeConfig>::NodeId>,
147 >
148 where
149 I: IntoIterator<Item = <TypeConfig as RaftTypeConfig>::Entry> + openraft::OptionalSend,
150 I::IntoIter: openraft::OptionalSend,
151 {
152 let mut sm = self.state_machine.write().await;
153 let mut res = Vec::new(); for entry in entries {
155 sm.last_applied_log = Some(entry.log_id);
156 match entry.payload {
157 EntryPayload::Blank => res.push(RaftResponse { result: Ok(()) }),
158 EntryPayload::Normal(ref proposal) => {
159 tracing::debug!(?proposal, "applying proposal to state machine");
160 let Some(node) = self.node_ref.upgrade() else {
161 res.push(RaftResponse { result: Err(()) });
162 continue;
163 };
164 let context = ProposalContext::new(node);
165 match proposal {
166 crate::protocol::node::raft::proposal::Proposal::DelegateMessage(
167 delegate_message,
168 ) => {
169 sm.node
170 .apply_delegate_message(delegate_message.clone(), context);
171 res.push(RaftResponse { result: Ok(()) })
172 }
173 crate::protocol::node::raft::proposal::Proposal::SetState(set_state) => {
174 sm.node.apply_set_state(set_state.clone(), context);
175 res.push(RaftResponse { result: Ok(()) })
176 }
177 crate::protocol::node::raft::proposal::Proposal::LoadTopic(load_topic) => {
178 sm.node.apply_load_topic(load_topic.clone(), context);
179 tracing::debug!(?load_topic, "topic loaded");
180 res.push(RaftResponse { result: Ok(()) })
181 }
182 crate::protocol::node::raft::proposal::Proposal::UnloadTopic(
183 unload_topic,
184 ) => {
185 sm.node.apply_unload_topic(unload_topic.clone());
186 res.push(RaftResponse { result: Ok(()) })
187 }
188 crate::protocol::node::raft::proposal::Proposal::EpOnline(ep_online) => {
189 sm.node.apply_ep_online(ep_online.clone(), context);
190 res.push(RaftResponse { result: Ok(()) })
191 }
192 crate::protocol::node::raft::proposal::Proposal::EpOffline(ep_offline) => {
193 sm.node.apply_ep_offline(ep_offline.clone(), context);
194 res.push(RaftResponse { result: Ok(()) })
195 }
196 crate::protocol::node::raft::proposal::Proposal::EpInterest(
197 ep_interest,
198 ) => {
199 sm.node.apply_ep_interest(ep_interest.clone(), context);
200 res.push(RaftResponse { result: Ok(()) })
201 }
202 crate::protocol::node::raft::proposal::Proposal::AckFinished(
203 ack_finished,
204 ) => {
205 sm.node.apply_ack_finished(ack_finished.clone(), context);
206 res.push(RaftResponse { result: Ok(()) })
207 }
208 }
209 }
210 EntryPayload::Membership(ref mem) => {
211 sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone());
212 res.push(RaftResponse { result: Ok(()) })
213 }
214 };
215 }
216 Ok(res)
217 }
218
219 async fn begin_receiving_snapshot(
220 &mut self,
221 ) -> Result<
222 Box<<TypeConfig as RaftTypeConfig>::SnapshotData>,
223 StorageError<<TypeConfig as RaftTypeConfig>::NodeId>,
224 > {
225 const SNAPSHOT_DEFAULT_CAPACITY: usize = 3 * (1 << 20);
227 tracing::info!("begin receiving snapshot");
228 Ok(Box::new(Cursor::new(Vec::with_capacity(
229 SNAPSHOT_DEFAULT_CAPACITY,
230 ))))
231 }
232
233 async fn get_current_snapshot(
234 &mut self,
235 ) -> Result<Option<Snapshot<TypeConfig>>, StorageError<<TypeConfig as RaftTypeConfig>::NodeId>>
236 {
237 match &*self.current_snapshot.read().await {
238 Some(snapshot) => {
239 let bytes = snapshot.data.clone();
240 Ok(Some(Snapshot {
241 meta: snapshot.meta.clone(),
242 snapshot: Box::new(Cursor::new(bytes)),
243 }))
244 }
245 None => Ok(None),
246 }
247 }
248
249 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
250 self.clone()
251 }
252
253 async fn install_snapshot(
254 &mut self,
255 meta: &SnapshotMeta<
256 <TypeConfig as RaftTypeConfig>::NodeId,
257 <TypeConfig as RaftTypeConfig>::Node,
258 >,
259 mut snapshot: Box<<TypeConfig as RaftTypeConfig>::SnapshotData>,
260 ) -> Result<(), StorageError<<TypeConfig as RaftTypeConfig>::NodeId>> {
261 let id = self.node_ref.upgrade().map(|node| node.id());
262
263 tracing::info!(
264 { snapshot_size = snapshot.get_ref().len(), ?id },
265 "decoding snapshot for installation"
266 );
267 let (new_data, size) =
268 bincode::serde::decode_from_slice::<NodeData, _>(snapshot.get_ref(), BINCODE_CONFIG)
269 .map_err(|e| {
270 StorageError::from_io_error(
271 openraft::ErrorSubject::Snapshot(None),
272 openraft::ErrorVerb::Read,
273 io::Error::new(io::ErrorKind::InvalidData, e),
274 )
275 })?;
276 snapshot.set_position(size as u64);
277 let new_snapshot = StoredSnapshot {
278 meta: meta.clone(),
279 data: snapshot.into_inner(),
280 };
281
282 let mut state_machine = self.state_machine.write().await;
285 state_machine.last_membership = new_snapshot.meta.last_membership.clone();
286 state_machine.last_applied_log = new_snapshot.meta.last_log_id;
287 state_machine.node = new_data;
288
289 if let Some(node) = self.node_ref.upgrade() {
291 tracing::info!(?id, "installed, ready to flush: {:#?}", state_machine.node);
292 for (topic_code, topic) in &mut state_machine.node.topics {
293 let mut ctx = ProposalContext::new(node.clone());
294 ctx.set_topic_code(topic_code.clone());
295 topic
296 .queue
297 .flush_ack(&mut ctx, topic.queue.pending_ack.keys().copied());
298 }
299 };
300
301 let mut current_snapshot = self.current_snapshot.write().await;
304
305 *current_snapshot = Some(new_snapshot);
307 drop(current_snapshot);
308 drop(state_machine);
309
310 Ok(())
311 }
312}