1use bincode::serialize;
2use std::{
3 net::{SocketAddr, ToSocketAddrs},
4 sync::Arc,
5 time::Duration,
6};
7use tokio::{
8 sync::{
9 mpsc,
10 oneshot::{self, Receiver},
11 },
12 time::timeout,
13};
14use tonic::{transport::Server, Request, Response, Status};
15
16#[cfg(feature = "tls")]
17use tonic::transport::{Certificate, Identity, ServerTlsConfig};
18
19use super::{
20 macro_utils::function_name,
21 raft_service::{
22 self,
23 raft_service_server::{RaftService, RaftServiceServer},
24 },
25 Config, Error,
26};
27use crate::{
28 create_client,
29 raft::{eraftpb::Message as RaftMessage, logger::Logger},
30 raft_service::ProposeArgs,
31 request::{
32 common::confchange_request::ConfChangeRequest, server_request_message::ServerRequestMsg,
33 },
34 response::server_response_message::{
35 ConfChangeResponseResult, RequestIdResponseResult, ResponseResult, ServerResponseMsg,
36 },
37 AbstractLogEntry, AbstractStateMachine, StableStorage,
38};
39
40#[derive(Clone)]
41pub struct RaftServer<
42 LogEntry: AbstractLogEntry,
43 LogStorage: StableStorage + 'static,
44 FSM: AbstractStateMachine,
45> {
46 tx: mpsc::Sender<ServerRequestMsg<LogEntry, LogStorage, FSM>>,
47 raft_addr: SocketAddr,
48 config: Config,
49 logger: Arc<dyn Logger>,
50}
51
52impl<
53 LogEntry: AbstractLogEntry + 'static,
54 LogStorage: StableStorage + Send + Sync + 'static,
55 FSM: AbstractStateMachine + 'static,
56 > RaftServer<LogEntry, LogStorage, FSM>
57{
58 pub fn new<A: ToSocketAddrs>(
59 tx: mpsc::Sender<ServerRequestMsg<LogEntry, LogStorage, FSM>>,
60 raft_addr: A,
61 config: Config,
62 logger: Arc<dyn Logger>,
63 ) -> Self {
64 let raft_addr = raft_addr.to_socket_addrs().unwrap().next().unwrap();
65 RaftServer {
66 tx,
67 raft_addr,
68 config,
69 logger,
70 }
71 }
72
73 pub(crate) async fn run(self, rx_quit_signal: Receiver<()>) -> Result<(), Error> {
74 let raft_addr = self.raft_addr;
75 let logger = self.logger.clone();
76 logger.debug(&format!(
77 "RaftServer starts to listen gRPC requests on \"{}\"...",
78 raft_addr
79 ));
80
81 let quit_signal = async {
82 rx_quit_signal.await.ok();
83 };
84
85 let mut server_builder = Server::builder();
86
87 #[cfg(feature = "tls")]
88 if let Some(tls_cfg) = &self.config.server_tls_config {
89 logger.debug("TLS enabled.");
90 let cert_path = tls_cfg
91 .cert_path
92 .as_ref()
93 .expect("Server requires cert_path");
94 let cert = tokio::fs::read(cert_path).await?;
95 let key_path = tls_cfg.key_path.as_ref().expect("Server requires key_path");
96 let key = tokio::fs::read(key_path).await?;
97 let identity = Identity::from_pem(cert, key);
98
99 let mut tls_config = ServerTlsConfig::new().identity(identity);
100
101 if let Some(ca_cert_path) = &tls_cfg.ca_cert_path {
103 let ca_cert = tokio::fs::read(ca_cert_path).await?;
104 let ca_cert = Certificate::from_pem(ca_cert);
105 tls_config = tls_config.client_ca_root(ca_cert);
106 }
107
108 server_builder = server_builder.tls_config(tls_config)?;
109 }
110
111 server_builder
112 .add_service(RaftServiceServer::new(self))
113 .serve_with_shutdown(raft_addr, quit_signal)
114 .await?;
115
116 Ok(())
117 }
118}
119
120impl<
121 LogEntry: AbstractLogEntry + 'static,
122 LogStorage: StableStorage + 'static,
123 FSM: AbstractStateMachine + 'static,
124 > RaftServer<LogEntry, LogStorage, FSM>
125{
126 fn print_send_error(&self, function_name: &str) {
127 self.logger.error(&format!(
128 "Error occurred in sending message ('RaftServer --> RaftNode'). Function: '{}'",
129 function_name
130 ));
131 }
132}
133
134#[tonic::async_trait]
135impl<
136 LogEntry: AbstractLogEntry + 'static,
137 LogStorage: StableStorage + Sync + Send + 'static,
138 FSM: AbstractStateMachine + 'static,
139 > RaftService for RaftServer<LogEntry, LogStorage, FSM>
140{
141 async fn request_id(
142 &self,
143 request: Request<raft_service::RequestIdArgs>,
144 ) -> Result<Response<raft_service::RequestIdResponse>, Status> {
145 let request_args = request.into_inner();
146 let sender = self.tx.clone();
147 let (tx_msg, rx_msg) = oneshot::channel();
148 sender
149 .send(ServerRequestMsg::RequestId {
150 raft_addr: request_args.raft_addr.clone(),
151 tx_msg,
152 })
153 .await
154 .unwrap();
155 let response = rx_msg.await.unwrap();
156
157 match response {
158 ServerResponseMsg::RequestId { result } => match result {
159 RequestIdResponseResult::Success {
160 reserved_id,
161 leader_id,
162 peers,
163 } => Ok(Response::new(raft_service::RequestIdResponse {
164 code: raft_service::ResultCode::Ok as i32,
165 leader_id,
166 reserved_id,
167 leader_addr: self.raft_addr.to_string(),
168 peers: serialize(&peers).unwrap(),
169 ..Default::default()
170 })),
171 RequestIdResponseResult::Error(e) => {
172 Ok(Response::new(raft_service::RequestIdResponse {
173 code: raft_service::ResultCode::Error as i32,
174 error: e.to_string().as_bytes().to_vec(),
175 ..Default::default()
176 }))
177 }
178 RequestIdResponseResult::WrongLeader { leader_addr, .. } => {
179 let mut client =
180 create_client(leader_addr, self.config.client_tls_config.clone())
181 .await
182 .unwrap();
183 let reply = client.request_id(request_args).await?.into_inner();
184
185 Ok(Response::new(reply))
186 }
187 },
188 _ => unreachable!(),
189 }
190 }
191
192 async fn change_config(
193 &self,
194 request: Request<raft_service::ChangeConfigArgs>,
195 ) -> Result<Response<raft_service::ChangeConfigResponse>, Status> {
196 let request_args = request.into_inner();
197 let sender = self.tx.clone();
198 let (tx_msg, rx_msg) = oneshot::channel();
199
200 let conf_change_request: ConfChangeRequest = request_args.clone().into();
201
202 let message = ServerRequestMsg::ChangeConfig {
203 conf_change: conf_change_request,
204 tx_msg,
205 };
206
207 match sender.send(message).await {
209 Ok(_) => {}
210 Err(_) => {
211 self.print_send_error(function_name!());
212 }
213 }
214
215 let mut reply = raft_service::ChangeConfigResponse::default();
216
217 match timeout(
218 Duration::from_secs_f32(self.config.conf_change_request_timeout),
219 rx_msg,
220 )
221 .await
222 {
223 Ok(Ok(raft_response)) => {
224 match raft_response {
225 ServerResponseMsg::ConfigChange { result } => match result {
226 ConfChangeResponseResult::JoinSuccess {
227 assigned_ids,
228 peers,
229 } => {
230 reply.result_type =
231 raft_service::ChangeConfigResultType::ChangeConfigSuccess as i32;
232 reply.assigned_ids = assigned_ids;
233 reply.peers = serialize(&peers).unwrap();
234 }
235 ConfChangeResponseResult::RemoveSuccess {} => {
236 reply.result_type =
237 raft_service::ChangeConfigResultType::ChangeConfigSuccess as i32;
238 }
239 ConfChangeResponseResult::Error(e) => {
240 reply.result_type =
241 raft_service::ChangeConfigResultType::ChangeConfigUnknownError
242 as i32;
243 reply.error = e.to_string().as_bytes().to_vec();
244 }
245 ConfChangeResponseResult::WrongLeader { leader_addr, .. } => {
246 reply.result_type =
247 raft_service::ChangeConfigResultType::ChangeConfigWrongLeader
248 as i32;
249
250 let mut client =
251 create_client(leader_addr, self.config.client_tls_config.clone())
252 .await
253 .unwrap();
254 reply = client.change_config(request_args).await?.into_inner();
255 }
256 },
257 _ => unreachable!(),
258 }
259 reply.result_type =
260 raft_service::ChangeConfigResultType::ChangeConfigSuccess as i32;
261 }
262 Ok(Err(e)) => {
263 reply.result_type =
264 raft_service::ChangeConfigResultType::ChangeConfigUnknownError as i32;
265 reply.error = e.to_string().as_bytes().to_vec();
266 }
267 Err(e) => {
268 reply.result_type =
269 raft_service::ChangeConfigResultType::ChangeConfigTimeoutError as i32;
270 reply.error = e.to_string().as_bytes().to_vec();
271 self.logger.error(&format!(
272 "Confchange request timeout! (\"conf_change_request_timeout\" = {})",
273 self.config.conf_change_request_timeout
274 ));
275 }
276 }
277
278 Ok(Response::new(reply))
279 }
280
281 async fn send_message(
282 &self,
283 request: Request<RaftMessage>,
284 ) -> Result<Response<raft_service::Empty>, Status> {
285 let request_args = request.into_inner();
286 let sender = self.tx.clone();
287 match sender
288 .send(ServerRequestMsg::SendMessage {
289 message: Box::new(request_args),
290 })
291 .await
292 {
293 Ok(_) => (),
294 Err(_) => self.print_send_error(function_name!()),
295 }
296
297 Ok(Response::new(raft_service::Empty {}))
298 }
299
300 async fn propose(
301 &self,
302 request: Request<raft_service::ProposeArgs>,
303 ) -> Result<Response<raft_service::ProposeResponse>, Status> {
304 let request_args = request.into_inner();
305 let sender = self.tx.clone();
306
307 let (tx_msg, rx_msg) = oneshot::channel();
308 match sender
309 .send(ServerRequestMsg::Propose {
310 proposal: request_args.msg.clone(),
311 tx_msg,
312 })
313 .await
314 {
315 Ok(_) => (),
316 Err(_) => self.print_send_error(function_name!()),
317 }
318
319 let response = rx_msg.await.unwrap();
320 match response {
321 ServerResponseMsg::Propose { result } => {
322 match result {
323 ResponseResult::Success => Ok(Response::new(raft_service::ProposeResponse {
324 ..Default::default()
325 })),
326 ResponseResult::Error(error) => {
327 Ok(Response::new(raft_service::ProposeResponse {
328 error: error.to_string().as_bytes().to_vec(),
329 }))
330 }
331 ResponseResult::WrongLeader { leader_addr, .. } => {
332 let mut client =
334 create_client(leader_addr, self.config.client_tls_config.clone())
335 .await
336 .unwrap();
337 let _ = client
338 .propose(ProposeArgs {
339 msg: request_args.msg,
340 })
341 .await?;
342
343 Ok(Response::new(raft_service::ProposeResponse {
344 ..Default::default()
345 }))
346 }
347 }
348 }
349 _ => unreachable!(),
350 }
351 }
352
353 async fn debug_node(
354 &self,
355 request: Request<raft_service::Empty>,
356 ) -> Result<Response<raft_service::DebugNodeResponse>, Status> {
357 let _request_args = request.into_inner();
358 let sender = self.tx.clone();
359 let (tx_msg, rx_msg) = oneshot::channel();
360
361 match sender.send(ServerRequestMsg::DebugNode { tx_msg }).await {
362 Ok(_) => (),
363 Err(_) => self.print_send_error(function_name!()),
364 }
365
366 let response = rx_msg.await.unwrap();
367 match response {
368 ServerResponseMsg::DebugNode { result_json } => {
369 Ok(Response::new(raft_service::DebugNodeResponse {
370 result_json,
371 }))
372 }
373 _ => unreachable!(),
374 }
375 }
376
377 async fn get_peers(
378 &self,
379 request: Request<raft_service::Empty>,
380 ) -> Result<Response<raft_service::GetPeersResponse>, Status> {
381 let _request_args = request.into_inner();
382 let (tx_msg, rx_msg) = oneshot::channel();
383 let sender = self.tx.clone();
384 match sender.send(ServerRequestMsg::GetPeers { tx_msg }).await {
385 Ok(_) => (),
386 Err(_) => self.print_send_error(function_name!()),
387 }
388 let response = rx_msg.await.unwrap();
389
390 match response {
391 ServerResponseMsg::GetPeers { peers } => {
392 Ok(Response::new(raft_service::GetPeersResponse {
393 peers_json: peers.to_json(),
394 }))
395 }
396 _ => unreachable!(),
397 }
398 }
399
400 async fn leave_joint(
401 &self,
402 request: Request<raft_service::Empty>,
403 ) -> Result<Response<raft_service::Empty>, Status> {
404 let _request_args = request.into_inner();
405 let (tx_msg, rx_msg) = oneshot::channel();
406 let sender = self.tx.clone();
407 match sender.send(ServerRequestMsg::LeaveJoint { tx_msg }).await {
408 Ok(_) => (),
409 Err(_) => self.print_send_error(function_name!()),
410 }
411 let response = rx_msg.await.unwrap();
412
413 match response {
414 ServerResponseMsg::LeaveJoint {} => Ok(Response::new(raft_service::Empty {})),
415 _ => unreachable!(),
416 }
417 }
418
419 async fn set_peers(
420 &self,
421 request: Request<raft_service::Peers>,
422 ) -> Result<Response<raft_service::Empty>, Status> {
423 let request_args = request.into_inner();
424 let peers = request_args.into();
425
426 let (tx_msg, rx_msg) = oneshot::channel();
427 let sender = self.tx.clone();
428 match sender
429 .send(ServerRequestMsg::SetPeers { peers, tx_msg })
430 .await
431 {
432 Ok(_) => (),
433 Err(_) => self.print_send_error(function_name!()),
434 }
435 let response = rx_msg.await.unwrap();
436
437 match response {
438 ServerResponseMsg::SetPeers {} => Ok(Response::new(raft_service::Empty {})),
439 _ => unreachable!(),
440 }
441 }
442
443 async fn create_snapshot(
444 &self,
445 request: Request<raft_service::Empty>,
446 ) -> Result<Response<raft_service::Empty>, Status> {
447 let _request_args = request.into_inner();
448 let (tx_msg, rx_msg) = oneshot::channel();
449 let sender = self.tx.clone();
450 match sender
451 .send(ServerRequestMsg::CreateSnapshot { tx_msg })
452 .await
453 {
454 Ok(_) => (),
455 Err(_) => self.print_send_error(function_name!()),
456 }
457 let response = rx_msg.await.unwrap();
458
459 match response {
460 ServerResponseMsg::CreateSnapshot {} => Ok(Response::new(raft_service::Empty {})),
461 _ => unreachable!(),
462 }
463 }
464}