asteroid_mq/protocol/node/
raft.rs1use std::sync::{Arc, OnceLock};
2
3pub mod cluster;
4pub mod log_storage;
5pub mod network;
6pub mod network_factory;
7pub mod proposal;
8pub mod raft_node;
9pub mod response;
10pub mod state_machine;
11use network_factory::{RaftNodeInfo, TcpNetworkService};
12use openraft::Raft;
13use proposal::Proposal;
14use raft_node::TcpNode;
15use response::RaftResponse;
16use tokio_util::sync::CancellationToken;
17
18use super::NodeId;
19
20#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Copy)]
21pub struct TypeConfig {
22 _private: (),
23}
24
25impl openraft::RaftTypeConfig for TypeConfig {
26 type D = Proposal;
27 type R = RaftResponse;
28 type NodeId = NodeId;
29 type Node = TcpNode;
30 type Entry = openraft::Entry<TypeConfig>;
31 type SnapshotData = std::io::Cursor<Vec<u8>>;
32 type AsyncRuntime = openraft::TokioRuntime;
33 type Responder = openraft::raft::responder::OneshotResponder<Self>;
34}
35#[derive(Clone)]
36pub struct MaybeLoadingRaft {
37 loading: Arc<OnceLock<Raft<TypeConfig>>>,
38 signal: Arc<tokio::sync::Notify>,
39}
40
41impl std::fmt::Debug for MaybeLoadingRaft {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 let loaded = self.loading.get().is_some();
44 f.debug_struct("MaybeLoadingRaft")
45 .field("loaded", &loaded)
46 .finish()
47 }
48}
49
50impl Default for MaybeLoadingRaft {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl MaybeLoadingRaft {
57 pub fn new() -> Self {
58 Self {
59 loading: Default::default(),
60 signal: tokio::sync::Notify::new().into(),
61 }
62 }
63 pub fn set(&self, raft: Raft<TypeConfig>) {
64 if self.loading.set(raft).is_ok() {
65 self.signal.notify_waiters();
66 }
67 }
68 pub async fn get(&self) -> Raft<TypeConfig> {
69 loop {
70 if let Some(raft) = self.loading.get() {
71 return raft.clone();
72 } else {
73 self.signal.notified().await;
74 }
75 }
76 }
77 pub fn get_opt(&self) -> Option<Raft<TypeConfig>> {
78 self.loading.get().cloned()
79 }
80 pub fn net_work_service(
81 &self,
82 id: NodeId,
83 node: TcpNode,
84 ct: CancellationToken,
85 ) -> TcpNetworkService {
86 TcpNetworkService::new(RaftNodeInfo { id, node }, self.clone(), ct)
87 }
88}