Skip to main content

orca_control/raft/
state_machine.rs

1//! Raft state machine backed by `ClusterStore`.
2
3use std::io::Cursor;
4use std::sync::Arc;
5
6use openraft::storage::{RaftStateMachine, Snapshot};
7use openraft::{
8    Entry, EntryPayload, ErrorSubject, ErrorVerb, LogId, RaftSnapshotBuilder, SnapshotMeta,
9    StorageError, StoredMembership,
10};
11use tokio::sync::Mutex;
12use tracing::warn;
13
14use super::type_config::OrcaTypeConfig;
15use crate::store::{ClusterStore, RaftSnapshot};
16
17type C = OrcaTypeConfig;
18
19/// Helper to build a `StorageError` from an error reading the state machine.
20fn sm_read_err(e: impl std::fmt::Display) -> StorageError<u64> {
21    StorageError::from_io_error(
22        ErrorSubject::StateMachine,
23        ErrorVerb::Read,
24        std::io::Error::other(e.to_string()),
25    )
26}
27
28/// Raft state machine wrapping the persistent `ClusterStore`.
29pub struct StateMachine {
30    store: Arc<ClusterStore>,
31    last_applied: Arc<Mutex<Option<LogId<u64>>>>,
32    last_membership: Arc<Mutex<StoredMembership<u64, openraft::BasicNode>>>,
33    snapshot: Arc<Mutex<Option<StoredSnapshot>>>,
34}
35
36/// A stored snapshot including metadata and serialized data.
37struct StoredSnapshot {
38    meta: SnapshotMeta<u64, openraft::BasicNode>,
39    data: Vec<u8>,
40}
41
42impl StateMachine {
43    pub fn new(store: Arc<ClusterStore>) -> Self {
44        Self {
45            store,
46            last_applied: Arc::new(Mutex::new(None)),
47            last_membership: Arc::new(Mutex::new(StoredMembership::default())),
48            snapshot: Arc::new(Mutex::new(None)),
49        }
50    }
51
52    async fn build_snapshot_impl(&self) -> Result<Snapshot<C>, StorageError<u64>> {
53        let snap = self.store.snapshot().map_err(|e| sm_read_err(&e))?;
54        let data = serde_json::to_vec(&snap).map_err(|e| sm_read_err(&e))?;
55
56        let last_applied = *self.last_applied.lock().await;
57        let last_membership = self.last_membership.lock().await.clone();
58
59        let meta = SnapshotMeta {
60            last_log_id: last_applied,
61            last_membership,
62            snapshot_id: format!("snap-{}", chrono::Utc::now().timestamp_millis()),
63        };
64
65        {
66            let mut s = self.snapshot.lock().await;
67            *s = Some(StoredSnapshot {
68                meta: meta.clone(),
69                data: data.clone(),
70            });
71        }
72
73        Ok(Snapshot {
74            meta,
75            snapshot: Box::new(Cursor::new(data)),
76        })
77    }
78}
79
80impl RaftSnapshotBuilder<C> for Arc<StateMachine> {
81    async fn build_snapshot(&mut self) -> Result<Snapshot<C>, StorageError<u64>> {
82        self.build_snapshot_impl().await
83    }
84}
85
86impl RaftStateMachine<C> for StateMachine {
87    type SnapshotBuilder = Arc<Self>;
88
89    async fn applied_state(
90        &mut self,
91    ) -> Result<
92        (
93            Option<LogId<u64>>,
94            StoredMembership<u64, openraft::BasicNode>,
95        ),
96        StorageError<u64>,
97    > {
98        let applied = *self.last_applied.lock().await;
99        let membership = self.last_membership.lock().await.clone();
100        Ok((applied, membership))
101    }
102
103    async fn apply<I>(&mut self, entries: I) -> Result<Vec<()>, StorageError<u64>>
104    where
105        I: IntoIterator<Item = Entry<C>> + Send,
106        I::IntoIter: Send,
107    {
108        let mut results = Vec::new();
109        for entry in entries {
110            *self.last_applied.lock().await = Some(entry.log_id);
111
112            match &entry.payload {
113                EntryPayload::Normal(raft_entry) => {
114                    if let Err(e) = self.store.apply(raft_entry) {
115                        warn!("Failed to apply entry: {e}");
116                    }
117                }
118                EntryPayload::Membership(m) => {
119                    *self.last_membership.lock().await =
120                        StoredMembership::new(Some(entry.log_id), m.clone());
121                }
122                EntryPayload::Blank => {}
123            }
124            results.push(());
125        }
126        Ok(results)
127    }
128
129    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
130        Arc::new(Self {
131            store: self.store.clone(),
132            last_applied: self.last_applied.clone(),
133            last_membership: self.last_membership.clone(),
134            snapshot: self.snapshot.clone(),
135        })
136    }
137
138    async fn begin_receiving_snapshot(
139        &mut self,
140    ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<u64>> {
141        Ok(Box::new(Cursor::new(Vec::new())))
142    }
143
144    async fn install_snapshot(
145        &mut self,
146        meta: &SnapshotMeta<u64, openraft::BasicNode>,
147        snapshot: Box<Cursor<Vec<u8>>>,
148    ) -> Result<(), StorageError<u64>> {
149        let data = snapshot.into_inner();
150        let snap: RaftSnapshot = serde_json::from_slice(&data).map_err(|e| sm_read_err(&e))?;
151
152        self.store
153            .restore_from_snapshot(&snap)
154            .map_err(|e| sm_read_err(&e))?;
155
156        *self.last_applied.lock().await = meta.last_log_id;
157        *self.last_membership.lock().await = meta.last_membership.clone();
158        *self.snapshot.lock().await = Some(StoredSnapshot {
159            meta: meta.clone(),
160            data,
161        });
162        Ok(())
163    }
164
165    async fn get_current_snapshot(&mut self) -> Result<Option<Snapshot<C>>, StorageError<u64>> {
166        let guard = self.snapshot.lock().await;
167        Ok(guard.as_ref().map(|s| Snapshot {
168            meta: s.meta.clone(),
169            snapshot: Box::new(Cursor::new(s.data.clone())),
170        }))
171    }
172}