1use crate::{
2 config::TlsConfig, raft::logger::Logger, request::server_request_message::ServerRequestMsg,
3 ClusterJoinTicket, InitialRole, Peers, StableStorage,
4};
5use bincode::deserialize;
6use std::{net::ToSocketAddrs, ops::Deref, sync::Arc};
7use tokio::{
8 signal,
9 sync::{mpsc, oneshot},
10};
11
12use super::{
13 create_client,
14 error::{Error, Result},
15 raft_node::RaftNode,
16 raft_server::RaftServer,
17 raft_service::{self, ResultCode},
18 AbstractLogEntry, AbstractStateMachine, Config,
19};
20
21#[derive(Clone)]
25pub struct Raft<
26 LogEntry: AbstractLogEntry + 'static,
27 LogStorage: StableStorage + Send + Clone + 'static,
28 FSM: AbstractStateMachine + Clone + 'static,
29> {
30 pub raft_node: RaftNode<LogEntry, LogStorage, FSM>,
31 pub raft_server: RaftServer<LogEntry, LogStorage, FSM>,
32 pub tx_server: mpsc::Sender<ServerRequestMsg<LogEntry, LogStorage, FSM>>,
33 pub logger: Arc<dyn Logger>,
34}
35
36impl<
37 LogEntry: AbstractLogEntry + 'static,
38 LogStorage: StableStorage + Send + Clone + 'static,
39 FSM: AbstractStateMachine + Clone + 'static,
40 > Deref for Raft<LogEntry, LogStorage, FSM>
41{
42 type Target = RaftNode<LogEntry, LogStorage, FSM>;
43
44 fn deref(&self) -> &Self::Target {
45 &self.raft_node
46 }
47}
48
49impl<
50 LogEntry: AbstractLogEntry,
51 LogStorage: StableStorage + Send + Sync + Clone + 'static,
52 FSM: AbstractStateMachine + Send + Sync + Clone + 'static,
53 > Raft<LogEntry, LogStorage, FSM>
54{
55 pub fn bootstrap<A: ToSocketAddrs>(
59 node_id: u64,
60 raft_addr: A,
61 log_storage: LogStorage,
62 fsm: FSM,
63 config: Config,
64 logger: Arc<dyn Logger>,
65 ) -> Result<Self> {
66 logger.info(&format!("RaftNode bootstrapped. {:?}", config));
67
68 let raft_addr = raft_addr.to_socket_addrs()?.next().unwrap();
69 let mut should_be_leader = config.initial_peers.is_none();
70
71 if config.initial_peers.is_some() {
72 let leaders = config
73 .initial_peers
74 .clone()
75 .unwrap()
76 .inner
77 .into_iter()
78 .filter(|(_, peer)| peer.initial_role == InitialRole::Leader)
79 .map(|(key, _)| key)
80 .collect::<Vec<_>>();
81
82 assert!(leaders.len() < 2);
83 should_be_leader = leaders.contains(&node_id);
84 }
85
86 let (tx_server, rx_server) = mpsc::channel(100);
87 let raft_node = RaftNode::bootstrap(
88 node_id,
89 should_be_leader,
90 log_storage,
91 fsm,
92 config.clone(),
93 raft_addr,
94 logger.clone(),
95 tx_server.clone(),
96 rx_server,
97 )?;
98
99 let raft_server =
100 RaftServer::new(tx_server.clone(), raft_addr, config.clone(), logger.clone());
101
102 Ok(Self {
103 tx_server: tx_server.clone(),
104 raft_node,
105 raft_server,
106 logger,
107 })
108 }
109
110 pub async fn run(self) -> Result<()> {
112 let (tx_quit_signal, rx_quit_signal) = oneshot::channel::<()>();
113
114 let raft_node = self.raft_node.clone();
115 let raft_node_handle = tokio::spawn(raft_node.run());
116 let raft_server = self.raft_server.clone();
117 let raft_server_handle = tokio::spawn(raft_server.run(rx_quit_signal));
118
119 tokio::select! {
120 _ = signal::ctrl_c() => {
121 self.logger.info("Ctrl+C signal detected. Shutting down...");
122 Ok(())
123 }
124 result = raft_node_handle => {
125 tx_quit_signal.send(()).unwrap();
126
127 match result {
128 Ok(raft_node_result) => {
129 match raft_node_result {
130 Ok(_) => {
131 self.logger.info("RaftNode quitted. Shutting down...");
132 Ok(())
133 },
134 Err(err) => {
135 self.logger.error(&format!("RaftNode quitted with the error. Shutting down... {:?}", err));
136 Err(Error::Other(Box::new(err)))
137 }
138 }
139 },
140 Err(err) => {
141 self.logger.error(&format!("RaftNode quitted with the error. Shutting down... {:?}", err));
142 Err(Error::Unknown)
143 }
144 }
145 }
146 result = raft_server_handle => {
147 match result {
148 Ok(raft_server_result) => {
149 match raft_server_result {
150 Ok(_) => {
151 self.logger.info("RaftServer quitted. Shutting down...");
152 Ok(())
153 },
154 Err(err) => {
155 self.logger.error(&format!("RaftServer quitted with error. Shutting down... {:?}", err));
156 Err(Error::Other(Box::new(err)))
157 }
158 }
159 },
160 Err(err) => {
161 self.logger.error(&format!("RaftServer quitted with the error. Shutting down... {:?}", err));
162 Err(Error::Unknown)
163 }
164 }
165 }
166 }
167 }
168
169 pub async fn request_id<A: ToSocketAddrs>(
172 raft_addr: A,
173 peer_addr: String,
174 tls_config: Option<TlsConfig>,
175 ) -> Result<ClusterJoinTicket> {
176 let raft_addr = raft_addr
177 .to_socket_addrs()
178 .unwrap()
179 .next()
180 .unwrap()
181 .to_string();
182
183 let mut client = create_client(&peer_addr, tls_config).await?;
184 let response = client
185 .request_id(raft_service::RequestIdArgs {
186 raft_addr: raft_addr.to_string(),
187 })
188 .await?
189 .into_inner();
190
191 let peers: Peers = deserialize(&response.peers)?;
192 match response.code() {
193 ResultCode::Ok => Ok(ClusterJoinTicket {
194 raft_addr,
195 reserved_id: response.reserved_id,
196 leader_addr: response.leader_addr,
197 peers: peers.into(),
198 }),
199 ResultCode::Error => Err(Error::JoinError),
200 ResultCode::WrongLeader => {
201 unreachable!();
202 }
203 }
204 }
205}