rmqtt_raft/
message.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use bytestring::ByteString;
5use serde::de::{self, Deserializer};
6use serde::ser::Serializer;
7use serde::{Deserialize, Serialize};
8
9use futures::channel::oneshot::Sender;
10use tikv_raft::eraftpb::{ConfChange, Message as RaftMessage};
11use tikv_raft::prelude::Snapshot;
12use tikv_raft::StateRole;
13
14/// Enumeration representing various types of responses that can be sent back to clients.
15#[derive(Serialize, Deserialize, Debug)]
16pub enum RaftResponse {
17    /// Indicates that the request was sent to the wrong leader.
18    WrongLeader {
19        leader_id: u64,
20        leader_addr: Option<String>,
21    },
22    /// Indicates that a join request was successful.
23    JoinSuccess {
24        assigned_id: u64,
25        peer_addrs: HashMap<u64, String>,
26    },
27    /// Contains the leader ID in response to a request for ID.
28    RequestId { leader_id: u64 },
29    /// Represents an error with a message.
30    Error(String),
31    /// Too busy
32    Busy,
33    /// Contains arbitrary response data.
34    Response { data: Vec<u8> },
35    /// Represents the status of the system.
36    Status(Status),
37    /// Represents a successful operation.
38    Ok,
39}
40
41/// Enumeration representing different types of messages that can be sent within the system.
42#[allow(dead_code)]
43pub enum Message {
44    /// A proposal message to be processed.
45    Propose {
46        proposal: Vec<u8>,
47        chan: Sender<RaftResponse>,
48    },
49    /// A query message to be processed.
50    Query {
51        query: Vec<u8>,
52        chan: Sender<RaftResponse>,
53    },
54    /// A configuration change message to be processed.
55    ConfigChange {
56        change: ConfChange,
57        chan: Sender<RaftResponse>,
58    },
59    /// A request for the leader's ID.
60    RequestId { chan: Sender<RaftResponse> },
61    /// Report that a node is unreachable.
62    ReportUnreachable { node_id: u64 },
63    /// A Raft message to be processed.
64    Raft(Box<RaftMessage>),
65    /// A request for the status of the system.
66    Status { chan: Sender<RaftResponse> },
67    /// Snapshot
68    Snapshot { snapshot: Snapshot },
69}
70
71#[derive(Serialize, Deserialize, Debug, Clone)]
72pub struct PeerState {
73    pub addr: ByteString,
74    pub available: bool,
75}
76
77/// Struct representing the status of the system.
78#[derive(Serialize, Deserialize, Debug, Clone)]
79pub struct Status {
80    pub id: u64,
81    pub leader_id: u64,
82    pub uncommitteds: usize,
83    pub merger_proposals: usize,
84    pub sending_raft_messages: isize,
85    pub timeout_max: isize,
86    pub timeout_recent_count: isize,
87    pub propose_count: isize,
88    pub propose_rate: f64,
89    pub peers: HashMap<u64, Option<PeerState>>,
90    #[serde(
91        serialize_with = "Status::serialize_role",
92        deserialize_with = "Status::deserialize_role"
93    )]
94    pub role: StateRole,
95}
96
97impl Status {
98    #[inline]
99    pub fn available(&self) -> bool {
100        if matches!(self.role, StateRole::Leader) {
101            //Check if the number of available nodes is greater than or equal to half of the total nodes.
102            let (all_count, available_count) = self.get_count();
103            let available = available_count >= ((all_count / 2) + (all_count % 2));
104            log::debug!(
105                "is Leader, all_count: {}, available_count: {} {}",
106                all_count,
107                available_count,
108                available
109            );
110            available
111        } else if self.leader_id > 0 {
112            //As long as a leader exists and is available, the system considers itself in a normal state.
113            let available = self
114                .peers
115                .get(&self.leader_id)
116                .and_then(|p| p.as_ref().map(|p| p.available))
117                .unwrap_or_default();
118            log::debug!("has Leader, available: {}", available);
119            available
120        } else {
121            //If there is no Leader, it's still necessary to check whether the number of all other
122            // available nodes is greater than or equal to half.
123            let (all_count, available_count) = self.get_count();
124            let available = available_count >= ((all_count / 2) + (all_count % 2));
125            log::debug!(
126                "no Leader, all_count: {}, available_count: {} {}",
127                all_count,
128                available_count,
129                available
130            );
131            available
132        }
133    }
134
135    #[inline]
136    fn get_count(&self) -> (usize, usize) {
137        let available_count = self
138            .peers
139            .iter()
140            .filter(|(_, p)| if let Some(p) = p { p.available } else { false })
141            .count();
142        if self.peers.contains_key(&self.id) {
143            (self.peers.len() - 1, available_count - 1)
144        } else {
145            (self.peers.len(), available_count)
146        }
147    }
148
149    /// Checks if the node has started.
150    #[inline]
151    pub fn is_started(&self) -> bool {
152        self.leader_id > 0
153    }
154
155    /// Checks if this node is the leader.
156    #[inline]
157    pub fn is_leader(&self) -> bool {
158        self.leader_id == self.id && matches!(self.role, StateRole::Leader)
159    }
160
161    #[inline]
162    pub fn deserialize_role<'de, D>(deserializer: D) -> Result<StateRole, D::Error>
163    where
164        D: Deserializer<'de>,
165    {
166        let role = match u8::deserialize(deserializer)? {
167            1 => StateRole::Follower,
168            2 => StateRole::Candidate,
169            3 => StateRole::Leader,
170            4 => StateRole::PreCandidate,
171            _ => return Err(de::Error::missing_field("role")),
172        };
173        Ok(role)
174    }
175
176    #[inline]
177    pub fn serialize_role<S>(role: &StateRole, s: S) -> std::result::Result<S::Ok, S::Error>
178    where
179        S: Serializer,
180    {
181        match role {
182            StateRole::Follower => 1u8,
183            StateRole::Candidate => 2u8,
184            StateRole::Leader => 3u8,
185            StateRole::PreCandidate => 4u8,
186        }
187        .serialize(s)
188    }
189}
190
191#[derive(Clone, Serialize, Deserialize, Debug)]
192pub(crate) enum RemoveNodeType {
193    Normal,
194    Stale,
195}
196
197/// Enumeration for reply channels which could be single or multiple.
198pub(crate) enum ReplyChan {
199    /// Single reply channel with its timestamp.
200    One((Sender<RaftResponse>, Instant)),
201    /// Multiple reply channels with their timestamps.
202    More(Vec<(Sender<RaftResponse>, Instant)>),
203}
204
205/// Enumeration for proposals which could be a single proposal or multiple proposals.
206#[derive(Serialize, Deserialize)]
207pub(crate) enum Proposals {
208    /// A single proposal.
209    One(Vec<u8>),
210    /// Multiple proposals.
211    More(Vec<Vec<u8>>),
212}
213
214/// A struct to manage proposal batching and sending.
215pub(crate) struct Merger {
216    proposals: Vec<Vec<u8>>,
217    chans: Vec<(Sender<RaftResponse>, Instant)>,
218    start_collection_time: i64,
219    proposal_batch_size: usize,
220    proposal_batch_timeout: i64,
221}
222
223impl Merger {
224    /// Creates a new `Merger` instance with the specified batch size and timeout.
225    ///
226    /// # Parameters
227    /// - `proposal_batch_size`: The maximum number of proposals to include in a batch.
228    /// - `proposal_batch_timeout`: The timeout duration for collecting proposals.
229    ///
230    /// # Returns
231    /// A new `Merger` instance.
232    pub fn new(proposal_batch_size: usize, proposal_batch_timeout: Duration) -> Self {
233        Self {
234            proposals: Vec::new(),
235            chans: Vec::new(),
236            start_collection_time: 0,
237            proposal_batch_size,
238            proposal_batch_timeout: proposal_batch_timeout.as_millis() as i64,
239        }
240    }
241
242    /// Adds a new proposal and its corresponding reply channel to the merger.
243    ///
244    /// # Parameters
245    /// - `proposal`: The proposal data to be added.
246    /// - `chan`: The reply channel for the proposal.
247    #[inline]
248    pub fn add(&mut self, proposal: Vec<u8>, chan: Sender<RaftResponse>) {
249        self.proposals.push(proposal);
250        self.chans.push((chan, Instant::now()));
251    }
252
253    /// Returns the number of proposals currently held by the merger.
254    ///
255    /// # Returns
256    /// The number of proposals.
257    #[inline]
258    pub fn len(&self) -> usize {
259        self.proposals.len()
260    }
261
262    /// Retrieves a batch of proposals and their corresponding reply channels if the batch size or timeout criteria are met.
263    ///
264    /// # Returns
265    /// An `Option` containing the proposals and reply channels, or `None` if no batch is ready.
266    #[inline]
267    pub fn take(&mut self) -> Option<(Proposals, ReplyChan)> {
268        let max = self.proposal_batch_size;
269        let len = self.len();
270        let len = if len > max { max } else { len };
271        if len > 0 && (len == max || self.timeout()) {
272            let data = if len == 1 {
273                match (self.proposals.pop(), self.chans.pop()) {
274                    (Some(proposal), Some(chan)) => {
275                        Some((Proposals::One(proposal), ReplyChan::One(chan)))
276                    }
277                    _ => unreachable!(),
278                }
279            } else {
280                let mut proposals = self.proposals.drain(0..len).collect::<Vec<_>>();
281                let mut chans = self.chans.drain(0..len).collect::<Vec<_>>();
282                proposals.reverse();
283                chans.reverse();
284                Some((Proposals::More(proposals), ReplyChan::More(chans)))
285            };
286            self.start_collection_time = chrono::Local::now().timestamp_millis();
287            data
288        } else {
289            None
290        }
291    }
292
293    #[inline]
294    fn timeout(&self) -> bool {
295        chrono::Local::now().timestamp_millis()
296            > (self.start_collection_time + self.proposal_batch_timeout)
297    }
298}
299
300#[tokio::test]
301async fn test_merger() -> std::result::Result<(), Box<dyn std::error::Error>> {
302    let mut merger = Merger::new(50, Duration::from_millis(200));
303    use futures::channel::oneshot::channel;
304    use std::time::Duration;
305
306    let add = |merger: &mut Merger| {
307        let (tx, rx) = channel();
308        merger.add(vec![1, 2, 3], tx);
309        rx
310    };
311
312    use std::sync::atomic::{AtomicI64, Ordering};
313    use std::sync::Arc;
314    const MAX: i64 = 111;
315    let count = Arc::new(AtomicI64::new(0));
316    let mut futs = Vec::new();
317    for _ in 0..MAX {
318        let rx = add(&mut merger);
319        let count1 = count.clone();
320        let fut = async move {
321            let r = tokio::time::timeout(Duration::from_secs(3), rx).await;
322            match r {
323                Ok(_) => {}
324                Err(_) => {
325                    println!("timeout ...");
326                }
327            }
328            count1.fetch_add(1, Ordering::SeqCst);
329        };
330
331        futs.push(fut);
332    }
333
334    let sends = async {
335        loop {
336            if let Some((_data, chan)) = merger.take() {
337                match chan {
338                    ReplyChan::One((tx, _)) => {
339                        let _ = tx.send(RaftResponse::Ok);
340                    }
341                    ReplyChan::More(txs) => {
342                        for (tx, _) in txs {
343                            let _ = tx.send(RaftResponse::Ok);
344                        }
345                    }
346                }
347            }
348            tokio::time::sleep(Duration::from_millis(100)).await;
349            if merger.len() == 0 {
350                break;
351            }
352        }
353    };
354
355    let count_p = count.clone();
356    let count_print = async move {
357        loop {
358            tokio::time::sleep(Duration::from_secs(2)).await;
359            println!("count_p: {}", count_p.load(Ordering::SeqCst));
360            if count_p.load(Ordering::SeqCst) >= MAX {
361                break;
362            }
363        }
364    };
365    println!("futs: {}", futs.len());
366    futures::future::join3(futures::future::join_all(futs), sends, count_print).await;
367
368    Ok(())
369}