Skip to main content

asteroid_mq/protocol/node/raft/
state_machine.rs

1pub mod node;
2pub mod topic;
3
4use std::{
5    io::{self, Cursor},
6    sync::{
7        atomic::{AtomicU64, Ordering},
8        Arc,
9    },
10};
11
12use asteroid_mq_model::codec::BINCODE_CONFIG;
13use node::NodeData;
14use openraft::{
15    storage::RaftStateMachine, EntryPayload, LogId, RaftSnapshotBuilder, RaftTypeConfig, Snapshot,
16    SnapshotMeta, StorageError, StoredMembership,
17};
18use tokio::sync::RwLock;
19
20use crate::{
21    prelude::NodeId,
22    protocol::node::{raft::proposal::ProposalContext, NodeRef},
23};
24
25use super::{raft_node::TcpNode, response::RaftResponse, TypeConfig};
26#[derive(Debug)]
27pub struct StoredSnapshot {
28    pub meta: SnapshotMeta<NodeId, TcpNode>,
29
30    /// The data of the state machine at the time of this snapshot.
31    pub data: Vec<u8>,
32}
33#[derive(Debug, Clone, Default)]
34pub struct StateMachineData<C: RaftTypeConfig> {
35    pub last_applied_log: Option<LogId<C::NodeId>>,
36
37    pub last_membership: StoredMembership<C::NodeId, C::Node>,
38
39    pub node: NodeData,
40}
41
42/// Defines a state machine for the Raft cluster. This state machine represents a copy of the
43/// data for this node. Additionally, it is responsible for storing the last snapshot of the data.
44#[derive(Debug)]
45pub struct StateMachineStore {
46    /// The Raft state machine.
47    pub state_machine: RwLock<StateMachineData<TypeConfig>>,
48
49    /// Used in identifier for snapshot.
50    ///
51    /// Note that concurrently created snapshots and snapshots created on different nodes
52    /// are not guaranteed to have sequential `snapshot_idx` values, but this does not matter for
53    /// correctness.
54    snapshot_idx: AtomicU64,
55
56    /// The last received snapshot.
57    current_snapshot: RwLock<Option<StoredSnapshot>>,
58    node_ref: NodeRef,
59}
60
61impl StateMachineStore {
62    pub fn new(node_ref: NodeRef) -> Self {
63        Self {
64            state_machine: RwLock::new(StateMachineData::default()),
65            snapshot_idx: AtomicU64::new(0),
66            current_snapshot: RwLock::new(None),
67            node_ref,
68        }
69    }
70    #[cfg(test)]
71    pub(crate) unsafe fn new_uninitialized() -> Self {
72        Self {
73            state_machine: RwLock::new(StateMachineData::default()),
74            snapshot_idx: AtomicU64::new(0),
75            current_snapshot: RwLock::new(None),
76            node_ref: NodeRef::default(),
77        }
78    }
79}
80impl RaftSnapshotBuilder<TypeConfig> for Arc<StateMachineStore> {
81    #[tracing::instrument(level = "trace", skip(self))]
82    async fn build_snapshot(&mut self) -> Result<Snapshot<TypeConfig>, StorageError<NodeId>> {
83        // Serialize the data of the state machine.
84        let state_machine = self.state_machine.read().await;
85        let snapshot = &state_machine.node;
86
87        let last_applied_log = state_machine.last_applied_log;
88        let last_membership = state_machine.last_membership.clone();
89
90        // Lock the current snapshot before releasing the lock on the state machine, to avoid a race
91        // condition on the written snapshot
92        let mut current_snapshot = self.current_snapshot.write().await;
93
94        let snapshot_idx = self.snapshot_idx.fetch_add(1, Ordering::Relaxed) + 1;
95        let snapshot_id = if let Some(last) = last_applied_log {
96            format!("{}-{}-{}", last.leader_id, last.index, snapshot_idx)
97        } else {
98            format!("--{}", snapshot_idx)
99        };
100
101        let meta = SnapshotMeta {
102            last_log_id: last_applied_log,
103            last_membership,
104            snapshot_id,
105        };
106        let bytes = bincode::serde::encode_to_vec(snapshot, BINCODE_CONFIG).unwrap();
107        let stored = StoredSnapshot {
108            meta: meta.clone(),
109            data: bytes.clone(),
110        };
111        *current_snapshot = Some(stored);
112        drop(state_machine);
113        Ok(Snapshot {
114            meta,
115            snapshot: Box::new(Cursor::new(bytes)),
116        })
117    }
118}
119
120impl RaftStateMachine<TypeConfig> for Arc<StateMachineStore> {
121    type SnapshotBuilder = Arc<StateMachineStore>;
122    async fn applied_state(
123        &mut self,
124    ) -> Result<
125        (
126            Option<LogId<<TypeConfig as RaftTypeConfig>::NodeId>>,
127            StoredMembership<
128                <TypeConfig as RaftTypeConfig>::NodeId,
129                <TypeConfig as RaftTypeConfig>::Node,
130            >,
131        ),
132        StorageError<<TypeConfig as RaftTypeConfig>::NodeId>,
133    > {
134        let state_machine = self.state_machine.read().await;
135        Ok((
136            state_machine.last_applied_log,
137            state_machine.last_membership.clone(),
138        ))
139    }
140    #[tracing::instrument(name = "apply", skip_all)]
141    async fn apply<I>(
142        &mut self,
143        entries: I,
144    ) -> Result<
145        Vec<<TypeConfig as RaftTypeConfig>::R>,
146        StorageError<<TypeConfig as RaftTypeConfig>::NodeId>,
147    >
148    where
149        I: IntoIterator<Item = <TypeConfig as RaftTypeConfig>::Entry> + openraft::OptionalSend,
150        I::IntoIter: openraft::OptionalSend,
151    {
152        let mut sm = self.state_machine.write().await;
153        let mut res = Vec::new(); //No `with_capacity`; do not know `len` of iterator
154        for entry in entries {
155            sm.last_applied_log = Some(entry.log_id);
156            match entry.payload {
157                EntryPayload::Blank => res.push(RaftResponse { result: Ok(()) }),
158                EntryPayload::Normal(ref proposal) => {
159                    tracing::debug!(?proposal, "applying proposal to state machine");
160                    let Some(node) = self.node_ref.upgrade() else {
161                        res.push(RaftResponse { result: Err(()) });
162                        continue;
163                    };
164                    let context = ProposalContext::new(node);
165                    match proposal {
166                        crate::protocol::node::raft::proposal::Proposal::DelegateMessage(
167                            delegate_message,
168                        ) => {
169                            sm.node
170                                .apply_delegate_message(delegate_message.clone(), context);
171                            res.push(RaftResponse { result: Ok(()) })
172                        }
173                        crate::protocol::node::raft::proposal::Proposal::SetState(set_state) => {
174                            sm.node.apply_set_state(set_state.clone(), context);
175                            res.push(RaftResponse { result: Ok(()) })
176                        }
177                        crate::protocol::node::raft::proposal::Proposal::LoadTopic(load_topic) => {
178                            sm.node.apply_load_topic(load_topic.clone(), context);
179                            tracing::debug!(?load_topic, "topic loaded");
180                            res.push(RaftResponse { result: Ok(()) })
181                        }
182                        crate::protocol::node::raft::proposal::Proposal::UnloadTopic(
183                            unload_topic,
184                        ) => {
185                            sm.node.apply_unload_topic(unload_topic.clone());
186                            res.push(RaftResponse { result: Ok(()) })
187                        }
188                        crate::protocol::node::raft::proposal::Proposal::EpOnline(ep_online) => {
189                            sm.node.apply_ep_online(ep_online.clone(), context);
190                            res.push(RaftResponse { result: Ok(()) })
191                        }
192                        crate::protocol::node::raft::proposal::Proposal::EpOffline(ep_offline) => {
193                            sm.node.apply_ep_offline(ep_offline.clone(), context);
194                            res.push(RaftResponse { result: Ok(()) })
195                        }
196                        crate::protocol::node::raft::proposal::Proposal::EpInterest(
197                            ep_interest,
198                        ) => {
199                            sm.node.apply_ep_interest(ep_interest.clone(), context);
200                            res.push(RaftResponse { result: Ok(()) })
201                        }
202                        crate::protocol::node::raft::proposal::Proposal::AckFinished(
203                            ack_finished,
204                        ) => {
205                            sm.node.apply_ack_finished(ack_finished.clone(), context);
206                            res.push(RaftResponse { result: Ok(()) })
207                        }
208                    }
209                }
210                EntryPayload::Membership(ref mem) => {
211                    sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone());
212                    res.push(RaftResponse { result: Ok(()) })
213                }
214            };
215        }
216        Ok(res)
217    }
218
219    async fn begin_receiving_snapshot(
220        &mut self,
221    ) -> Result<
222        Box<<TypeConfig as RaftTypeConfig>::SnapshotData>,
223        StorageError<<TypeConfig as RaftTypeConfig>::NodeId>,
224    > {
225        // 3 Mb
226        const SNAPSHOT_DEFAULT_CAPACITY: usize = 3 * (1 << 20);
227        tracing::info!("begin receiving snapshot");
228        Ok(Box::new(Cursor::new(Vec::with_capacity(
229            SNAPSHOT_DEFAULT_CAPACITY,
230        ))))
231    }
232
233    async fn get_current_snapshot(
234        &mut self,
235    ) -> Result<Option<Snapshot<TypeConfig>>, StorageError<<TypeConfig as RaftTypeConfig>::NodeId>>
236    {
237        match &*self.current_snapshot.read().await {
238            Some(snapshot) => {
239                let bytes = snapshot.data.clone();
240                Ok(Some(Snapshot {
241                    meta: snapshot.meta.clone(),
242                    snapshot: Box::new(Cursor::new(bytes)),
243                }))
244            }
245            None => Ok(None),
246        }
247    }
248
249    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
250        self.clone()
251    }
252
253    async fn install_snapshot(
254        &mut self,
255        meta: &SnapshotMeta<
256            <TypeConfig as RaftTypeConfig>::NodeId,
257            <TypeConfig as RaftTypeConfig>::Node,
258        >,
259        mut snapshot: Box<<TypeConfig as RaftTypeConfig>::SnapshotData>,
260    ) -> Result<(), StorageError<<TypeConfig as RaftTypeConfig>::NodeId>> {
261        let id = self.node_ref.upgrade().map(|node| node.id());
262
263        tracing::info!(
264            { snapshot_size = snapshot.get_ref().len(), ?id },
265            "decoding snapshot for installation"
266        );
267        let (new_data, size) =
268            bincode::serde::decode_from_slice::<NodeData, _>(snapshot.get_ref(), BINCODE_CONFIG)
269                .map_err(|e| {
270                    StorageError::from_io_error(
271                        openraft::ErrorSubject::Snapshot(None),
272                        openraft::ErrorVerb::Read,
273                        io::Error::new(io::ErrorKind::InvalidData, e),
274                    )
275                })?;
276        snapshot.set_position(size as u64);
277        let new_snapshot = StoredSnapshot {
278            meta: meta.clone(),
279            data: snapshot.into_inner(),
280        };
281
282        // Update the state machine.
283
284        let mut state_machine = self.state_machine.write().await;
285        state_machine.last_membership = new_snapshot.meta.last_membership.clone();
286        state_machine.last_applied_log = new_snapshot.meta.last_log_id;
287        state_machine.node = new_data;
288
289        // flush, we can do this for it's immutable
290        if let Some(node) = self.node_ref.upgrade() {
291            tracing::info!(?id, "installed, ready to flush: {:#?}", state_machine.node);
292            for (topic_code, topic) in &mut state_machine.node.topics {
293                let mut ctx = ProposalContext::new(node.clone());
294                ctx.set_topic_code(topic_code.clone());
295                topic
296                    .queue
297                    .flush_ack(&mut ctx, topic.queue.pending_ack.keys().copied());
298            }
299        };
300
301        // Lock the current snapshot before releasing the lock on the state machine, to avoid a race
302        // condition on the written snapshot
303        let mut current_snapshot = self.current_snapshot.write().await;
304
305        // Update current snapshot.
306        *current_snapshot = Some(new_snapshot);
307        drop(current_snapshot);
308        drop(state_machine);
309
310        Ok(())
311    }
312}