Skip to main content

asteroid_mq/protocol/node/
raft.rs

1use std::sync::{Arc, OnceLock};
2
3pub mod cluster;
4pub mod log_storage;
5pub mod network;
6pub mod network_factory;
7pub mod proposal;
8pub mod raft_node;
9pub mod response;
10pub mod state_machine;
11use network_factory::{RaftNodeInfo, TcpNetworkService};
12use openraft::Raft;
13use proposal::Proposal;
14use raft_node::TcpNode;
15use response::RaftResponse;
16use tokio_util::sync::CancellationToken;
17
18use super::NodeId;
19
20#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Copy)]
21pub struct TypeConfig {
22    _private: (),
23}
24
25impl openraft::RaftTypeConfig for TypeConfig {
26    type D = Proposal;
27    type R = RaftResponse;
28    type NodeId = NodeId;
29    type Node = TcpNode;
30    type Entry = openraft::Entry<TypeConfig>;
31    type SnapshotData = std::io::Cursor<Vec<u8>>;
32    type AsyncRuntime = openraft::TokioRuntime;
33    type Responder = openraft::raft::responder::OneshotResponder<Self>;
34}
35#[derive(Clone)]
36pub struct MaybeLoadingRaft {
37    loading: Arc<OnceLock<Raft<TypeConfig>>>,
38    signal: Arc<tokio::sync::Notify>,
39}
40
41impl std::fmt::Debug for MaybeLoadingRaft {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        let loaded = self.loading.get().is_some();
44        f.debug_struct("MaybeLoadingRaft")
45            .field("loaded", &loaded)
46            .finish()
47    }
48}
49
50impl Default for MaybeLoadingRaft {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl MaybeLoadingRaft {
57    pub fn new() -> Self {
58        Self {
59            loading: Default::default(),
60            signal: tokio::sync::Notify::new().into(),
61        }
62    }
63    pub fn set(&self, raft: Raft<TypeConfig>) {
64        if self.loading.set(raft).is_ok() {
65            self.signal.notify_waiters();
66        }
67    }
68    pub async fn get(&self) -> Raft<TypeConfig> {
69        loop {
70            if let Some(raft) = self.loading.get() {
71                return raft.clone();
72            } else {
73                self.signal.notified().await;
74            }
75        }
76    }
77    pub fn get_opt(&self) -> Option<Raft<TypeConfig>> {
78        self.loading.get().cloned()
79    }
80    pub fn net_work_service(
81        &self,
82        id: NodeId,
83        node: TcpNode,
84        ct: CancellationToken,
85    ) -> TcpNetworkService {
86        TcpNetworkService::new(RaftNodeInfo { id, node }, self.clone(), ct)
87    }
88}