rmqtt_raft/
raft.rs

1use std::net::{SocketAddr, ToSocketAddrs};
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use bincode::{deserialize, serialize};
7use bytestring::ByteString;
8use futures::channel::{mpsc, oneshot};
9use futures::future::FutureExt;
10use futures::SinkExt;
11use log::{debug, info, warn};
12use prost::Message as _;
13use tikv_raft::eraftpb::{ConfChange, ConfChangeType};
14use tokio::sync::RwLock;
15use tokio::time::timeout;
16use tonic::Request;
17
18use crate::error::{Error, Result};
19use crate::message::{Message, RaftResponse, RemoveNodeType, Status};
20use crate::raft_node::{Peer, RaftNode};
21use crate::raft_server::RaftServer;
22use crate::raft_service::connect;
23use crate::raft_service::{ConfChange as RiteraftConfChange, Empty, ResultCode};
24use crate::Config;
25
26type DashMap<K, V> = dashmap::DashMap<K, V, ahash::RandomState>;
27
28#[async_trait]
29pub trait Store: Clone + Send + Sync {
30    async fn apply(&mut self, message: &[u8]) -> Result<Vec<u8>>;
31    async fn query(&self, query: &[u8]) -> Result<Vec<u8>>;
32    async fn snapshot(&self) -> Result<Vec<u8>>;
33    async fn restore(&mut self, snapshot: &[u8]) -> Result<()>;
34}
35
36struct ProposalSender {
37    proposal: Vec<u8>,
38    client: Peer,
39}
40
41impl ProposalSender {
42    async fn send(self) -> Result<RaftResponse> {
43        match self.client.send_proposal(self.proposal).await {
44            Ok(reply) => {
45                let raft_response: RaftResponse = deserialize(&reply)?;
46                Ok(raft_response)
47            }
48            Err(e) => {
49                warn!("error sending proposal {:?}", e);
50                Err(e)
51            }
52        }
53    }
54}
55
56#[derive(Clone)]
57struct LeaderInfo {
58    leader: bool,
59    target_leader_id: u64,
60    target_leader_addr: Option<String>,
61}
62
63type LeaderInfoError = ByteString;
64
65/// A mailbox to send messages to a running raft node.
66#[derive(Clone)]
67pub struct Mailbox {
68    peers: Arc<DashMap<(u64, String), Peer>>,
69    sender: mpsc::Sender<Message>,
70    grpc_timeout: Duration,
71    grpc_concurrency_limit: usize,
72    grpc_message_size: usize,
73    grpc_breaker_threshold: u64,
74    grpc_breaker_retry_interval: i64,
75    #[allow(clippy::type_complexity)]
76    leader_info: Arc<
77        RwLock<
78            Option<(
79                Option<LeaderInfo>,
80                Option<LeaderInfoError>,
81                std::time::Instant,
82            )>,
83        >,
84    >,
85}
86
87impl Mailbox {
88    #[inline]
89    pub(crate) fn new(
90        peers: Arc<DashMap<(u64, String), Peer>>,
91        sender: mpsc::Sender<Message>,
92        grpc_timeout: Duration,
93        grpc_concurrency_limit: usize,
94        grpc_message_size: usize,
95        grpc_breaker_threshold: u64,
96        grpc_breaker_retry_interval: i64,
97    ) -> Self {
98        Self {
99            peers,
100            sender,
101            grpc_timeout,
102            grpc_concurrency_limit,
103            grpc_message_size,
104            grpc_breaker_threshold,
105            grpc_breaker_retry_interval,
106            leader_info: Arc::new(RwLock::new(None)),
107        }
108    }
109
110    /// Retrieves a list of peers with their IDs.
111    /// This method returns a vector containing tuples of peer IDs and their respective `Peer` objects.
112    /// It iterates over the internal `peers` map and collects the IDs and cloned `Peer` instances.
113    #[inline]
114    pub fn pears(&self) -> Vec<(u64, Peer)> {
115        self.peers
116            .iter()
117            .map(|p| {
118                let (id, _) = p.key();
119                (*id, p.value().clone())
120            })
121            .collect::<Vec<_>>()
122    }
123
124    #[inline]
125    async fn peer(&self, leader_id: u64, leader_addr: String) -> Peer {
126        self.peers
127            .entry((leader_id, leader_addr.clone()))
128            .or_insert_with(|| {
129                Peer::new(
130                    leader_addr,
131                    self.grpc_timeout,
132                    self.grpc_concurrency_limit,
133                    self.grpc_message_size,
134                    self.grpc_breaker_threshold,
135                    self.grpc_breaker_retry_interval,
136                )
137            })
138            .clone()
139    }
140
141    #[inline]
142    async fn send_to_leader(
143        &self,
144        proposal: Vec<u8>,
145        leader_id: u64,
146        leader_addr: String,
147    ) -> Result<RaftResponse> {
148        let peer = self.peer(leader_id, leader_addr).await;
149        let proposal_sender = ProposalSender {
150            proposal,
151            client: peer,
152        };
153        proposal_sender.send().await
154    }
155
156    /// Sends a proposal to the leader node.
157    /// This method first attempts to send the proposal to the local node if it is the leader.
158    /// If the node is not the leader, it retrieves the leader's address and sends the proposal to the leader node.
159    /// If the proposal is successfully handled, the method returns a `RaftResponse::Response` with the resulting data.
160    #[inline]
161    pub async fn send_proposal(&self, message: Vec<u8>) -> Result<Vec<u8>> {
162        match self.get_leader_info().await? {
163            LeaderInfo { leader: true, .. } => {
164                debug!("this node is leader");
165                let (tx, rx) = oneshot::channel();
166                let proposal = Message::Propose {
167                    proposal: message.clone(),
168                    chan: tx,
169                };
170                let mut sender = self.sender.clone();
171                sender
172                    .send(proposal)
173                    .await //.try_send(proposal)
174                    .map_err(|e| Error::SendError(e.to_string()))?;
175                let reply = timeout(self.grpc_timeout, rx).await;
176                let reply = reply
177                    .map_err(|e| Error::RecvError(e.to_string()))?
178                    .map_err(|e| Error::RecvError(e.to_string()))?;
179                match reply {
180                    RaftResponse::Response { data } => return Ok(data),
181                    RaftResponse::Busy => return Err(Error::Busy),
182                    RaftResponse::Error(e) => return Err(Error::from(e)),
183                    _ => {
184                        warn!("Recv other raft response: {:?}", reply);
185                        return Err(Error::Unknown);
186                    }
187                }
188            }
189            LeaderInfo {
190                leader: false,
191                target_leader_id,
192                target_leader_addr,
193                ..
194            } => {
195                debug!(
196                    "This node not is Leader, leader_id: {:?}, leader_addr: {:?}",
197                    target_leader_id, target_leader_addr
198                );
199                if let Some(target_leader_addr) = target_leader_addr {
200                    if target_leader_id != 0 {
201                        return match self
202                            .send_to_leader(message, target_leader_id, target_leader_addr.clone())
203                            .await?
204                        {
205                            RaftResponse::Response { data } => return Ok(data),
206                            RaftResponse::WrongLeader {
207                                leader_id,
208                                leader_addr,
209                            } => {
210                                warn!("The target node is not the Leader, target_leader_id: {}, target_leader_addr: {:?}, actual_leader_id: {}, actual_leader_addr: {:?}",
211                            target_leader_id, target_leader_addr, leader_id, leader_addr);
212                                return Err(Error::NotLeader);
213                            }
214                            RaftResponse::Busy => Err(Error::Busy),
215                            RaftResponse::Error(e) => Err(Error::from(e)),
216                            _ => {
217                                warn!("Recv other raft response, target_leader_id: {}, target_leader_addr: {:?}", target_leader_id, target_leader_addr);
218                                return Err(Error::Unknown);
219                            }
220                        };
221                    }
222                }
223            }
224        }
225        Err(Error::LeaderNotExist)
226    }
227
228    /// Deprecated method to send a message, internally calls `send_proposal`.
229    #[inline]
230    #[deprecated]
231    pub async fn send(&self, message: Vec<u8>) -> Result<Vec<u8>> {
232        self.send_proposal(message).await
233    }
234
235    /// Sends a query to the Raft node and returns the response data.
236    /// It sends a `Message::Query` containing the query bytes and waits for a response.
237    /// On success, it returns the data wrapped in `RaftResponse::Response`.
238    #[inline]
239    pub async fn query(&self, query: Vec<u8>) -> Result<Vec<u8>> {
240        let (tx, rx) = oneshot::channel();
241        let mut sender = self.sender.clone();
242        match sender.try_send(Message::Query { query, chan: tx }) {
243            Ok(()) => match timeout(self.grpc_timeout, rx).await {
244                Ok(Ok(RaftResponse::Response { data })) => Ok(data),
245                Ok(Ok(RaftResponse::Error(e))) => Err(Error::from(e)),
246                _ => Err(Error::Unknown),
247            },
248            Err(e) => Err(Error::SendError(e.to_string())),
249        }
250    }
251
252    /// Sends a request to leave the Raft cluster.
253    /// It initiates a `ConfigChange` to remove the node from the cluster and waits for a response.
254    #[inline]
255    pub async fn leave(&self) -> Result<()> {
256        let mut change = ConfChange::default();
257        // set node id to 0, the node will set it to self when it receives it.
258        change.set_node_id(0);
259        change.set_change_type(ConfChangeType::RemoveNode);
260        change.set_context(serialize(&RemoveNodeType::Normal)?);
261        let mut sender = self.sender.clone();
262        let (chan, rx) = oneshot::channel();
263        match sender.send(Message::ConfigChange { change, chan }).await {
264            Ok(()) => match rx.await {
265                Ok(RaftResponse::Ok) => Ok(()),
266                Ok(RaftResponse::Error(e)) => Err(Error::from(e)),
267                _ => Err(Error::Unknown),
268            },
269            Err(e) => Err(Error::SendError(e.to_string())),
270        }
271    }
272
273    /// Retrieves the current status of the Raft node.
274    /// Sends a `Message::Status` request and waits for a `RaftResponse::Status` reply, which contains the node's status.
275    #[inline]
276    pub async fn status(&self) -> Result<Status> {
277        let (tx, rx) = oneshot::channel();
278        let mut sender = self.sender.clone();
279        match sender.send(Message::Status { chan: tx }).await {
280            Ok(_) => match timeout(self.grpc_timeout, rx).await {
281                Ok(Ok(RaftResponse::Status(status))) => Ok(status),
282                Ok(Ok(RaftResponse::Error(e))) => Err(Error::from(e)),
283                _ => Err(Error::Unknown),
284            },
285            Err(e) => Err(Error::SendError(e.to_string())),
286        }
287    }
288
289    /// Retrieves leader information, including whether the current node is the leader, the leader ID, and its address.
290    /// This method sends a `Message::RequestId` and waits for a response with the leader's ID and address.
291    #[inline]
292    async fn _get_leader_info(&self) -> std::result::Result<LeaderInfo, LeaderInfoError> {
293        let (tx, rx) = oneshot::channel();
294        let mut sender = self.sender.clone();
295        match sender.send(Message::RequestId { chan: tx }).await {
296            Ok(_) => match timeout(self.grpc_timeout, rx).await {
297                Ok(Ok(RaftResponse::RequestId { leader_id })) => Ok(LeaderInfo {
298                    leader: true,
299                    target_leader_id: leader_id,
300                    target_leader_addr: None,
301                }),
302                Ok(Ok(RaftResponse::WrongLeader {
303                    leader_id,
304                    leader_addr,
305                })) => Ok(LeaderInfo {
306                    leader: false,
307                    target_leader_id: leader_id,
308                    target_leader_addr: leader_addr,
309                }),
310                Ok(Ok(RaftResponse::Error(e))) => Err(LeaderInfoError::from(e)),
311                _ => Err("Unknown".into()),
312            },
313            Err(e) => Err(LeaderInfoError::from(e.to_string())),
314        }
315    }
316
317    #[inline]
318    async fn get_leader_info(&self) -> Result<LeaderInfo> {
319        {
320            let leader_info = self.leader_info.read().await;
321            if let Some((leader_info, err, inst)) = leader_info.as_ref() {
322                if inst.elapsed().as_secs() < 5 {
323                    if let Some(leader_info) = leader_info {
324                        return Ok(leader_info.clone());
325                    }
326                    if let Some(err) = err {
327                        return Err(err.to_string().into());
328                    }
329                }
330            }
331        }
332
333        let mut write = self.leader_info.write().await;
334
335        return match self._get_leader_info().await {
336            Ok(leader_info) => {
337                write.replace((Some(leader_info.clone()), None, std::time::Instant::now()));
338                Ok(leader_info)
339            }
340            Err(e) => {
341                let err = e.to_string().into();
342                write.replace((None, Some(e), std::time::Instant::now()));
343                Err(err)
344            }
345        };
346    }
347}
348
349pub struct Raft<S: Store + 'static> {
350    store: S,
351    tx: mpsc::Sender<Message>,
352    rx: mpsc::Receiver<Message>,
353    laddr: SocketAddr,
354    logger: slog::Logger,
355    cfg: Arc<Config>,
356}
357
358impl<S: Store + Send + Sync + 'static> Raft<S> {
359    /// Creates a new Raft node with the provided address, store, logger, and configuration.
360    /// The node communicates with other peers using a mailbox.
361    pub fn new<A: ToSocketAddrs>(
362        laddr: A,
363        store: S,
364        logger: slog::Logger,
365        cfg: Config,
366    ) -> Result<Self> {
367        let laddr = laddr
368            .to_socket_addrs()?
369            .next()
370            .ok_or_else(|| Error::from("None"))?;
371        let (tx, rx) = mpsc::channel(100_000);
372        let cfg = Arc::new(cfg);
373        Ok(Self {
374            store,
375            tx,
376            rx,
377            laddr,
378            logger,
379            cfg,
380        })
381    }
382
383    /// Returns a `Mailbox` for the Raft node, which facilitates communication with peers.
384    pub fn mailbox(&self) -> Mailbox {
385        Mailbox::new(
386            Arc::new(DashMap::default()),
387            self.tx.clone(),
388            self.cfg.grpc_timeout,
389            self.cfg.grpc_concurrency_limit,
390            self.cfg.grpc_message_size,
391            self.cfg.grpc_breaker_threshold,
392            self.cfg.grpc_breaker_retry_interval.as_millis() as i64,
393        )
394    }
395
396    /// Finds leader information by querying a list of peer addresses.
397    /// Returns the leader ID and its address if found.
398    pub async fn find_leader_info(&self, peer_addrs: Vec<String>) -> Result<Option<(u64, String)>> {
399        let mut futs = Vec::new();
400        for addr in peer_addrs {
401            let fut = async {
402                let _addr = addr.clone();
403                match self.request_leader(addr).await {
404                    Ok(reply) => Ok(reply),
405                    Err(e) => Err(e),
406                }
407            };
408            futs.push(fut.boxed());
409        }
410
411        let (leader_id, leader_addr) = match futures::future::select_ok(futs).await {
412            Ok((Some((leader_id, leader_addr)), _)) => (leader_id, leader_addr),
413            Ok((None, _)) => return Err(Error::LeaderNotExist),
414            Err(_e) => return Ok(None),
415        };
416
417        if leader_id == 0 {
418            Ok(None)
419        } else {
420            Ok(Some((leader_id, leader_addr)))
421        }
422    }
423
424    /// Requests the leader information from a specific peer.
425    /// Sends a `Message::RequestId` to the peer and waits for the response.
426    async fn request_leader(&self, peer_addr: String) -> Result<Option<(u64, String)>> {
427        let (leader_id, leader_addr): (u64, String) = {
428            let mut client = connect(
429                &peer_addr,
430                1,
431                self.cfg.grpc_message_size,
432                self.cfg.grpc_timeout,
433            )
434            .await?;
435            let response = client
436                .request_id(Request::new(Empty::default()))
437                .await?
438                .into_inner();
439            match response.code() {
440                ResultCode::WrongLeader => {
441                    let (leader_id, addr): (u64, Option<String>) = deserialize(&response.data)?;
442                    if let Some(addr) = addr {
443                        (leader_id, addr)
444                    } else {
445                        return Ok(None);
446                    }
447                }
448                ResultCode::Ok => (deserialize(&response.data)?, peer_addr),
449                ResultCode::Error => return Ok(None),
450            }
451        };
452        Ok(Some((leader_id, leader_addr)))
453    }
454
455    /// The `lead` function transitions the current node to the leader role in a Raft cluster.
456    /// It initializes the leader node and runs both the Raft server and the node concurrently.
457    /// The function will return once the server or node experiences an error, or when the leader
458    /// role is relinquished.
459    ///
460    /// # Arguments
461    ///
462    /// * `node_id` - The unique identifier for the node.
463    ///
464    /// # Returns
465    ///
466    /// A `Result<()>` indicating success or failure during the process.
467    pub async fn lead(self, node_id: u64) -> Result<()> {
468        let node = RaftNode::new_leader(
469            self.rx,
470            self.tx.clone(),
471            node_id,
472            self.store,
473            &self.logger,
474            self.cfg.clone(),
475        )
476        .await?;
477
478        let server = RaftServer::new(self.tx, self.laddr, self.cfg.clone());
479        let server_handle = async {
480            if let Err(e) = server.run().await {
481                warn!("raft server run error: {:?}", e);
482                Err(e)
483            } else {
484                Ok(())
485            }
486        };
487        let node_handle = async {
488            if let Err(e) = node.run().await {
489                warn!("node run error: {:?}", e);
490                Err(e)
491            } else {
492                Ok(())
493            }
494        };
495
496        tokio::try_join!(server_handle, node_handle)?;
497        info!("leaving leader node");
498
499        Ok(())
500    }
501
502    /// The `join` function is used to make the current node join an existing Raft cluster.
503    /// It tries to discover the current leader, communicates with the leader to join the cluster,
504    /// and configures the node as a follower.
505    ///
506    /// # Arguments
507    ///
508    /// * `node_id` - The unique identifier for the current node.
509    /// * `node_addr` - The address of the current node.
510    /// * `leader_id` - The optional leader node's identifier (if already known).
511    /// * `leader_addr` - The address of the leader node.
512    ///
513    /// # Returns
514    ///
515    /// A `Result<()>` indicating success or failure during the joining process.
516    pub async fn join(
517        self,
518        node_id: u64,
519        node_addr: String,
520        leader_id: Option<u64>,
521        leader_addr: String,
522    ) -> Result<()> {
523        // 1. try to discover the leader and obtain an id from it, if leader_id is None.
524        info!("attempting to join peer cluster at {}", leader_addr);
525        let (leader_id, leader_addr): (u64, String) = if let Some(leader_id) = leader_id {
526            (leader_id, leader_addr)
527        } else {
528            self.request_leader(leader_addr)
529                .await?
530                .ok_or(Error::JoinError)?
531        };
532
533        // 2. run server and node to prepare for joining
534        let mut node = RaftNode::new_follower(
535            self.rx,
536            self.tx.clone(),
537            node_id,
538            self.store,
539            &self.logger,
540            self.cfg.clone(),
541        )?;
542        let peer = node.add_peer(&leader_addr, leader_id);
543        let mut client = peer.client().await?;
544        let server = RaftServer::new(self.tx, self.laddr, self.cfg.clone());
545        let server_handle = async {
546            if let Err(e) = server.run().await {
547                warn!("raft server run error: {:?}", e);
548                Err(e)
549            } else {
550                Ok(())
551            }
552        };
553
554        let node_handle = async {
555            tokio::time::sleep(Duration::from_millis(1500)).await;
556            //try remove from the cluster
557            let mut change_remove = ConfChange::default();
558            change_remove.set_node_id(node_id);
559            change_remove.set_change_type(ConfChangeType::RemoveNode);
560            change_remove.set_context(serialize(&RemoveNodeType::Stale)?);
561            let change_remove = RiteraftConfChange {
562                inner: ConfChange::encode_to_vec(&change_remove),
563            };
564
565            let raft_response = client
566                .change_config(Request::new(change_remove))
567                .await?
568                .into_inner();
569
570            info!(
571                "change_remove raft_response: {:?}",
572                deserialize::<RaftResponse>(&raft_response.inner)?
573            );
574
575            // 3. Join the cluster
576            // TODO: handle wrong leader
577            let mut change = ConfChange::default();
578            change.set_node_id(node_id);
579            change.set_change_type(ConfChangeType::AddNode);
580            change.set_context(serialize(&node_addr)?);
581            // change.set_context(serialize(&node_addr)?);
582
583            let change = RiteraftConfChange {
584                inner: ConfChange::encode_to_vec(&change),
585            };
586            let raft_response = client
587                .change_config(Request::new(change))
588                .await?
589                .into_inner();
590            if let RaftResponse::JoinSuccess {
591                assigned_id,
592                peer_addrs,
593            } = deserialize(&raft_response.inner)?
594            {
595                info!(
596                    "change_config response.assigned_id: {:?}, peer_addrs: {:?}",
597                    assigned_id, peer_addrs
598                );
599                for (id, addr) in peer_addrs {
600                    if id != assigned_id {
601                        node.add_peer(&addr, id);
602                    }
603                }
604            } else {
605                return Err(Error::JoinError);
606            }
607
608            if let Err(e) = node.run().await {
609                warn!("node run error: {:?}", e);
610                Err(e)
611            } else {
612                Ok(())
613            }
614        };
615        let _ = tokio::try_join!(server_handle, node_handle)?;
616        info!("leaving follower node");
617        Ok(())
618    }
619}