raftify/
raft_server.rs

1use bincode::serialize;
2use std::{
3    net::{SocketAddr, ToSocketAddrs},
4    sync::Arc,
5    time::Duration,
6};
7use tokio::{
8    sync::{
9        mpsc,
10        oneshot::{self, Receiver},
11    },
12    time::timeout,
13};
14use tonic::{transport::Server, Request, Response, Status};
15
16#[cfg(feature = "tls")]
17use tonic::transport::{Certificate, Identity, ServerTlsConfig};
18
19use super::{
20    macro_utils::function_name,
21    raft_service::{
22        self,
23        raft_service_server::{RaftService, RaftServiceServer},
24    },
25    Config, Error,
26};
27use crate::{
28    create_client,
29    raft::{eraftpb::Message as RaftMessage, logger::Logger},
30    raft_service::ProposeArgs,
31    request::{
32        common::confchange_request::ConfChangeRequest, server_request_message::ServerRequestMsg,
33    },
34    response::server_response_message::{
35        ConfChangeResponseResult, RequestIdResponseResult, ResponseResult, ServerResponseMsg,
36    },
37    AbstractLogEntry, AbstractStateMachine, StableStorage,
38};
39
40#[derive(Clone)]
41pub struct RaftServer<
42    LogEntry: AbstractLogEntry,
43    LogStorage: StableStorage + 'static,
44    FSM: AbstractStateMachine,
45> {
46    tx: mpsc::Sender<ServerRequestMsg<LogEntry, LogStorage, FSM>>,
47    raft_addr: SocketAddr,
48    config: Config,
49    logger: Arc<dyn Logger>,
50}
51
52impl<
53        LogEntry: AbstractLogEntry + 'static,
54        LogStorage: StableStorage + Send + Sync + 'static,
55        FSM: AbstractStateMachine + 'static,
56    > RaftServer<LogEntry, LogStorage, FSM>
57{
58    pub fn new<A: ToSocketAddrs>(
59        tx: mpsc::Sender<ServerRequestMsg<LogEntry, LogStorage, FSM>>,
60        raft_addr: A,
61        config: Config,
62        logger: Arc<dyn Logger>,
63    ) -> Self {
64        let raft_addr = raft_addr.to_socket_addrs().unwrap().next().unwrap();
65        RaftServer {
66            tx,
67            raft_addr,
68            config,
69            logger,
70        }
71    }
72
73    pub(crate) async fn run(self, rx_quit_signal: Receiver<()>) -> Result<(), Error> {
74        let raft_addr = self.raft_addr;
75        let logger = self.logger.clone();
76        logger.debug(&format!(
77            "RaftServer starts to listen gRPC requests on \"{}\"...",
78            raft_addr
79        ));
80
81        let quit_signal = async {
82            rx_quit_signal.await.ok();
83        };
84
85        let mut server_builder = Server::builder();
86
87        #[cfg(feature = "tls")]
88        if let Some(tls_cfg) = &self.config.server_tls_config {
89            logger.debug("TLS enabled.");
90            let cert_path = tls_cfg
91                .cert_path
92                .as_ref()
93                .expect("Server requires cert_path");
94            let cert = tokio::fs::read(cert_path).await?;
95            let key_path = tls_cfg.key_path.as_ref().expect("Server requires key_path");
96            let key = tokio::fs::read(key_path).await?;
97            let identity = Identity::from_pem(cert, key);
98
99            let mut tls_config = ServerTlsConfig::new().identity(identity);
100
101            // mTLS
102            if let Some(ca_cert_path) = &tls_cfg.ca_cert_path {
103                let ca_cert = tokio::fs::read(ca_cert_path).await?;
104                let ca_cert = Certificate::from_pem(ca_cert);
105                tls_config = tls_config.client_ca_root(ca_cert);
106            }
107
108            server_builder = server_builder.tls_config(tls_config)?;
109        }
110
111        server_builder
112            .add_service(RaftServiceServer::new(self))
113            .serve_with_shutdown(raft_addr, quit_signal)
114            .await?;
115
116        Ok(())
117    }
118}
119
120impl<
121        LogEntry: AbstractLogEntry + 'static,
122        LogStorage: StableStorage + 'static,
123        FSM: AbstractStateMachine + 'static,
124    > RaftServer<LogEntry, LogStorage, FSM>
125{
126    fn print_send_error(&self, function_name: &str) {
127        self.logger.error(&format!(
128            "Error occurred in sending message ('RaftServer --> RaftNode'). Function: '{}'",
129            function_name
130        ));
131    }
132}
133
134#[tonic::async_trait]
135impl<
136        LogEntry: AbstractLogEntry + 'static,
137        LogStorage: StableStorage + Sync + Send + 'static,
138        FSM: AbstractStateMachine + 'static,
139    > RaftService for RaftServer<LogEntry, LogStorage, FSM>
140{
141    async fn request_id(
142        &self,
143        request: Request<raft_service::RequestIdArgs>,
144    ) -> Result<Response<raft_service::RequestIdResponse>, Status> {
145        let request_args = request.into_inner();
146        let sender = self.tx.clone();
147        let (tx_msg, rx_msg) = oneshot::channel();
148        sender
149            .send(ServerRequestMsg::RequestId {
150                raft_addr: request_args.raft_addr.clone(),
151                tx_msg,
152            })
153            .await
154            .unwrap();
155        let response = rx_msg.await.unwrap();
156
157        match response {
158            ServerResponseMsg::RequestId { result } => match result {
159                RequestIdResponseResult::Success {
160                    reserved_id,
161                    leader_id,
162                    peers,
163                } => Ok(Response::new(raft_service::RequestIdResponse {
164                    code: raft_service::ResultCode::Ok as i32,
165                    leader_id,
166                    reserved_id,
167                    leader_addr: self.raft_addr.to_string(),
168                    peers: serialize(&peers).unwrap(),
169                    ..Default::default()
170                })),
171                RequestIdResponseResult::Error(e) => {
172                    Ok(Response::new(raft_service::RequestIdResponse {
173                        code: raft_service::ResultCode::Error as i32,
174                        error: e.to_string().as_bytes().to_vec(),
175                        ..Default::default()
176                    }))
177                }
178                RequestIdResponseResult::WrongLeader { leader_addr, .. } => {
179                    let mut client =
180                        create_client(leader_addr, self.config.client_tls_config.clone())
181                            .await
182                            .unwrap();
183                    let reply = client.request_id(request_args).await?.into_inner();
184
185                    Ok(Response::new(reply))
186                }
187            },
188            _ => unreachable!(),
189        }
190    }
191
192    async fn change_config(
193        &self,
194        request: Request<raft_service::ChangeConfigArgs>,
195    ) -> Result<Response<raft_service::ChangeConfigResponse>, Status> {
196        let request_args = request.into_inner();
197        let sender = self.tx.clone();
198        let (tx_msg, rx_msg) = oneshot::channel();
199
200        let conf_change_request: ConfChangeRequest = request_args.clone().into();
201
202        let message = ServerRequestMsg::ChangeConfig {
203            conf_change: conf_change_request,
204            tx_msg,
205        };
206
207        // TODO: Handle this kind of errors
208        match sender.send(message).await {
209            Ok(_) => {}
210            Err(_) => {
211                self.print_send_error(function_name!());
212            }
213        }
214
215        let mut reply = raft_service::ChangeConfigResponse::default();
216
217        match timeout(
218            Duration::from_secs_f32(self.config.conf_change_request_timeout),
219            rx_msg,
220        )
221        .await
222        {
223            Ok(Ok(raft_response)) => {
224                match raft_response {
225                    ServerResponseMsg::ConfigChange { result } => match result {
226                        ConfChangeResponseResult::JoinSuccess {
227                            assigned_ids,
228                            peers,
229                        } => {
230                            reply.result_type =
231                                raft_service::ChangeConfigResultType::ChangeConfigSuccess as i32;
232                            reply.assigned_ids = assigned_ids;
233                            reply.peers = serialize(&peers).unwrap();
234                        }
235                        ConfChangeResponseResult::RemoveSuccess {} => {
236                            reply.result_type =
237                                raft_service::ChangeConfigResultType::ChangeConfigSuccess as i32;
238                        }
239                        ConfChangeResponseResult::Error(e) => {
240                            reply.result_type =
241                                raft_service::ChangeConfigResultType::ChangeConfigUnknownError
242                                    as i32;
243                            reply.error = e.to_string().as_bytes().to_vec();
244                        }
245                        ConfChangeResponseResult::WrongLeader { leader_addr, .. } => {
246                            reply.result_type =
247                                raft_service::ChangeConfigResultType::ChangeConfigWrongLeader
248                                    as i32;
249
250                            let mut client =
251                                create_client(leader_addr, self.config.client_tls_config.clone())
252                                    .await
253                                    .unwrap();
254                            reply = client.change_config(request_args).await?.into_inner();
255                        }
256                    },
257                    _ => unreachable!(),
258                }
259                reply.result_type =
260                    raft_service::ChangeConfigResultType::ChangeConfigSuccess as i32;
261            }
262            Ok(Err(e)) => {
263                reply.result_type =
264                    raft_service::ChangeConfigResultType::ChangeConfigUnknownError as i32;
265                reply.error = e.to_string().as_bytes().to_vec();
266            }
267            Err(e) => {
268                reply.result_type =
269                    raft_service::ChangeConfigResultType::ChangeConfigTimeoutError as i32;
270                reply.error = e.to_string().as_bytes().to_vec();
271                self.logger.error(&format!(
272                    "Confchange request timeout! (\"conf_change_request_timeout\" = {})",
273                    self.config.conf_change_request_timeout
274                ));
275            }
276        }
277
278        Ok(Response::new(reply))
279    }
280
281    async fn send_message(
282        &self,
283        request: Request<RaftMessage>,
284    ) -> Result<Response<raft_service::Empty>, Status> {
285        let request_args = request.into_inner();
286        let sender = self.tx.clone();
287        match sender
288            .send(ServerRequestMsg::SendMessage {
289                message: Box::new(request_args),
290            })
291            .await
292        {
293            Ok(_) => (),
294            Err(_) => self.print_send_error(function_name!()),
295        }
296
297        Ok(Response::new(raft_service::Empty {}))
298    }
299
300    async fn propose(
301        &self,
302        request: Request<raft_service::ProposeArgs>,
303    ) -> Result<Response<raft_service::ProposeResponse>, Status> {
304        let request_args = request.into_inner();
305        let sender = self.tx.clone();
306
307        let (tx_msg, rx_msg) = oneshot::channel();
308        match sender
309            .send(ServerRequestMsg::Propose {
310                proposal: request_args.msg.clone(),
311                tx_msg,
312            })
313            .await
314        {
315            Ok(_) => (),
316            Err(_) => self.print_send_error(function_name!()),
317        }
318
319        let response = rx_msg.await.unwrap();
320        match response {
321            ServerResponseMsg::Propose { result } => {
322                match result {
323                    ResponseResult::Success => Ok(Response::new(raft_service::ProposeResponse {
324                        ..Default::default()
325                    })),
326                    ResponseResult::Error(error) => {
327                        Ok(Response::new(raft_service::ProposeResponse {
328                            error: error.to_string().as_bytes().to_vec(),
329                        }))
330                    }
331                    ResponseResult::WrongLeader { leader_addr, .. } => {
332                        // TODO: Handle this kind of errors
333                        let mut client =
334                            create_client(leader_addr, self.config.client_tls_config.clone())
335                                .await
336                                .unwrap();
337                        let _ = client
338                            .propose(ProposeArgs {
339                                msg: request_args.msg,
340                            })
341                            .await?;
342
343                        Ok(Response::new(raft_service::ProposeResponse {
344                            ..Default::default()
345                        }))
346                    }
347                }
348            }
349            _ => unreachable!(),
350        }
351    }
352
353    async fn debug_node(
354        &self,
355        request: Request<raft_service::Empty>,
356    ) -> Result<Response<raft_service::DebugNodeResponse>, Status> {
357        let _request_args = request.into_inner();
358        let sender = self.tx.clone();
359        let (tx_msg, rx_msg) = oneshot::channel();
360
361        match sender.send(ServerRequestMsg::DebugNode { tx_msg }).await {
362            Ok(_) => (),
363            Err(_) => self.print_send_error(function_name!()),
364        }
365
366        let response = rx_msg.await.unwrap();
367        match response {
368            ServerResponseMsg::DebugNode { result_json } => {
369                Ok(Response::new(raft_service::DebugNodeResponse {
370                    result_json,
371                }))
372            }
373            _ => unreachable!(),
374        }
375    }
376
377    async fn get_peers(
378        &self,
379        request: Request<raft_service::Empty>,
380    ) -> Result<Response<raft_service::GetPeersResponse>, Status> {
381        let _request_args = request.into_inner();
382        let (tx_msg, rx_msg) = oneshot::channel();
383        let sender = self.tx.clone();
384        match sender.send(ServerRequestMsg::GetPeers { tx_msg }).await {
385            Ok(_) => (),
386            Err(_) => self.print_send_error(function_name!()),
387        }
388        let response = rx_msg.await.unwrap();
389
390        match response {
391            ServerResponseMsg::GetPeers { peers } => {
392                Ok(Response::new(raft_service::GetPeersResponse {
393                    peers_json: peers.to_json(),
394                }))
395            }
396            _ => unreachable!(),
397        }
398    }
399
400    async fn leave_joint(
401        &self,
402        request: Request<raft_service::Empty>,
403    ) -> Result<Response<raft_service::Empty>, Status> {
404        let _request_args = request.into_inner();
405        let (tx_msg, rx_msg) = oneshot::channel();
406        let sender = self.tx.clone();
407        match sender.send(ServerRequestMsg::LeaveJoint { tx_msg }).await {
408            Ok(_) => (),
409            Err(_) => self.print_send_error(function_name!()),
410        }
411        let response = rx_msg.await.unwrap();
412
413        match response {
414            ServerResponseMsg::LeaveJoint {} => Ok(Response::new(raft_service::Empty {})),
415            _ => unreachable!(),
416        }
417    }
418
419    async fn set_peers(
420        &self,
421        request: Request<raft_service::Peers>,
422    ) -> Result<Response<raft_service::Empty>, Status> {
423        let request_args = request.into_inner();
424        let peers = request_args.into();
425
426        let (tx_msg, rx_msg) = oneshot::channel();
427        let sender = self.tx.clone();
428        match sender
429            .send(ServerRequestMsg::SetPeers { peers, tx_msg })
430            .await
431        {
432            Ok(_) => (),
433            Err(_) => self.print_send_error(function_name!()),
434        }
435        let response = rx_msg.await.unwrap();
436
437        match response {
438            ServerResponseMsg::SetPeers {} => Ok(Response::new(raft_service::Empty {})),
439            _ => unreachable!(),
440        }
441    }
442
443    async fn create_snapshot(
444        &self,
445        request: Request<raft_service::Empty>,
446    ) -> Result<Response<raft_service::Empty>, Status> {
447        let _request_args = request.into_inner();
448        let (tx_msg, rx_msg) = oneshot::channel();
449        let sender = self.tx.clone();
450        match sender
451            .send(ServerRequestMsg::CreateSnapshot { tx_msg })
452            .await
453        {
454            Ok(_) => (),
455            Err(_) => self.print_send_error(function_name!()),
456        }
457        let response = rx_msg.await.unwrap();
458
459        match response {
460            ServerResponseMsg::CreateSnapshot {} => Ok(Response::new(raft_service::Empty {})),
461            _ => unreachable!(),
462        }
463    }
464}