1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use bytestring::ByteString;
5use serde::de::{self, Deserializer};
6use serde::ser::Serializer;
7use serde::{Deserialize, Serialize};
8
9use futures::channel::oneshot::Sender;
10use tikv_raft::eraftpb::{ConfChange, Message as RaftMessage};
11use tikv_raft::prelude::Snapshot;
12use tikv_raft::StateRole;
13
14#[derive(Serialize, Deserialize, Debug)]
16pub enum RaftResponse {
17 WrongLeader {
19 leader_id: u64,
20 leader_addr: Option<String>,
21 },
22 JoinSuccess {
24 assigned_id: u64,
25 peer_addrs: HashMap<u64, String>,
26 },
27 RequestId { leader_id: u64 },
29 Error(String),
31 Busy,
33 Response { data: Vec<u8> },
35 Status(Status),
37 Ok,
39}
40
41#[allow(dead_code)]
43pub enum Message {
44 Propose {
46 proposal: Vec<u8>,
47 chan: Sender<RaftResponse>,
48 },
49 Query {
51 query: Vec<u8>,
52 chan: Sender<RaftResponse>,
53 },
54 ConfigChange {
56 change: ConfChange,
57 chan: Sender<RaftResponse>,
58 },
59 RequestId { chan: Sender<RaftResponse> },
61 ReportUnreachable { node_id: u64 },
63 Raft(Box<RaftMessage>),
65 Status { chan: Sender<RaftResponse> },
67 Snapshot { snapshot: Snapshot },
69}
70
71#[derive(Serialize, Deserialize, Debug, Clone)]
72pub struct PeerState {
73 pub addr: ByteString,
74 pub available: bool,
75}
76
77#[derive(Serialize, Deserialize, Debug, Clone)]
79pub struct Status {
80 pub id: u64,
81 pub leader_id: u64,
82 pub uncommitteds: usize,
83 pub merger_proposals: usize,
84 pub sending_raft_messages: isize,
85 pub timeout_max: isize,
86 pub timeout_recent_count: isize,
87 pub propose_count: isize,
88 pub propose_rate: f64,
89 pub peers: HashMap<u64, Option<PeerState>>,
90 #[serde(
91 serialize_with = "Status::serialize_role",
92 deserialize_with = "Status::deserialize_role"
93 )]
94 pub role: StateRole,
95}
96
97impl Status {
98 #[inline]
99 pub fn available(&self) -> bool {
100 if matches!(self.role, StateRole::Leader) {
101 let (all_count, available_count) = self.get_count();
103 let available = available_count >= ((all_count / 2) + (all_count % 2));
104 log::debug!(
105 "is Leader, all_count: {}, available_count: {} {}",
106 all_count,
107 available_count,
108 available
109 );
110 available
111 } else if self.leader_id > 0 {
112 let available = self
114 .peers
115 .get(&self.leader_id)
116 .and_then(|p| p.as_ref().map(|p| p.available))
117 .unwrap_or_default();
118 log::debug!("has Leader, available: {}", available);
119 available
120 } else {
121 let (all_count, available_count) = self.get_count();
124 let available = available_count >= ((all_count / 2) + (all_count % 2));
125 log::debug!(
126 "no Leader, all_count: {}, available_count: {} {}",
127 all_count,
128 available_count,
129 available
130 );
131 available
132 }
133 }
134
135 #[inline]
136 fn get_count(&self) -> (usize, usize) {
137 let available_count = self
138 .peers
139 .iter()
140 .filter(|(_, p)| if let Some(p) = p { p.available } else { false })
141 .count();
142 if self.peers.contains_key(&self.id) {
143 (self.peers.len() - 1, available_count - 1)
144 } else {
145 (self.peers.len(), available_count)
146 }
147 }
148
149 #[inline]
151 pub fn is_started(&self) -> bool {
152 self.leader_id > 0
153 }
154
155 #[inline]
157 pub fn is_leader(&self) -> bool {
158 self.leader_id == self.id && matches!(self.role, StateRole::Leader)
159 }
160
161 #[inline]
162 pub fn deserialize_role<'de, D>(deserializer: D) -> Result<StateRole, D::Error>
163 where
164 D: Deserializer<'de>,
165 {
166 let role = match u8::deserialize(deserializer)? {
167 1 => StateRole::Follower,
168 2 => StateRole::Candidate,
169 3 => StateRole::Leader,
170 4 => StateRole::PreCandidate,
171 _ => return Err(de::Error::missing_field("role")),
172 };
173 Ok(role)
174 }
175
176 #[inline]
177 pub fn serialize_role<S>(role: &StateRole, s: S) -> std::result::Result<S::Ok, S::Error>
178 where
179 S: Serializer,
180 {
181 match role {
182 StateRole::Follower => 1u8,
183 StateRole::Candidate => 2u8,
184 StateRole::Leader => 3u8,
185 StateRole::PreCandidate => 4u8,
186 }
187 .serialize(s)
188 }
189}
190
191#[derive(Clone, Serialize, Deserialize, Debug)]
192pub(crate) enum RemoveNodeType {
193 Normal,
194 Stale,
195}
196
197pub(crate) enum ReplyChan {
199 One((Sender<RaftResponse>, Instant)),
201 More(Vec<(Sender<RaftResponse>, Instant)>),
203}
204
205#[derive(Serialize, Deserialize)]
207pub(crate) enum Proposals {
208 One(Vec<u8>),
210 More(Vec<Vec<u8>>),
212}
213
214pub(crate) struct Merger {
216 proposals: Vec<Vec<u8>>,
217 chans: Vec<(Sender<RaftResponse>, Instant)>,
218 start_collection_time: i64,
219 proposal_batch_size: usize,
220 proposal_batch_timeout: i64,
221}
222
223impl Merger {
224 pub fn new(proposal_batch_size: usize, proposal_batch_timeout: Duration) -> Self {
233 Self {
234 proposals: Vec::new(),
235 chans: Vec::new(),
236 start_collection_time: 0,
237 proposal_batch_size,
238 proposal_batch_timeout: proposal_batch_timeout.as_millis() as i64,
239 }
240 }
241
242 #[inline]
248 pub fn add(&mut self, proposal: Vec<u8>, chan: Sender<RaftResponse>) {
249 self.proposals.push(proposal);
250 self.chans.push((chan, Instant::now()));
251 }
252
253 #[inline]
258 pub fn len(&self) -> usize {
259 self.proposals.len()
260 }
261
262 #[inline]
267 pub fn take(&mut self) -> Option<(Proposals, ReplyChan)> {
268 let max = self.proposal_batch_size;
269 let len = self.len();
270 let len = if len > max { max } else { len };
271 if len > 0 && (len == max || self.timeout()) {
272 let data = if len == 1 {
273 match (self.proposals.pop(), self.chans.pop()) {
274 (Some(proposal), Some(chan)) => {
275 Some((Proposals::One(proposal), ReplyChan::One(chan)))
276 }
277 _ => unreachable!(),
278 }
279 } else {
280 let mut proposals = self.proposals.drain(0..len).collect::<Vec<_>>();
281 let mut chans = self.chans.drain(0..len).collect::<Vec<_>>();
282 proposals.reverse();
283 chans.reverse();
284 Some((Proposals::More(proposals), ReplyChan::More(chans)))
285 };
286 self.start_collection_time = chrono::Local::now().timestamp_millis();
287 data
288 } else {
289 None
290 }
291 }
292
293 #[inline]
294 fn timeout(&self) -> bool {
295 chrono::Local::now().timestamp_millis()
296 > (self.start_collection_time + self.proposal_batch_timeout)
297 }
298}
299
300#[tokio::test]
301async fn test_merger() -> std::result::Result<(), Box<dyn std::error::Error>> {
302 let mut merger = Merger::new(50, Duration::from_millis(200));
303 use futures::channel::oneshot::channel;
304 use std::time::Duration;
305
306 let add = |merger: &mut Merger| {
307 let (tx, rx) = channel();
308 merger.add(vec![1, 2, 3], tx);
309 rx
310 };
311
312 use std::sync::atomic::{AtomicI64, Ordering};
313 use std::sync::Arc;
314 const MAX: i64 = 111;
315 let count = Arc::new(AtomicI64::new(0));
316 let mut futs = Vec::new();
317 for _ in 0..MAX {
318 let rx = add(&mut merger);
319 let count1 = count.clone();
320 let fut = async move {
321 let r = tokio::time::timeout(Duration::from_secs(3), rx).await;
322 match r {
323 Ok(_) => {}
324 Err(_) => {
325 println!("timeout ...");
326 }
327 }
328 count1.fetch_add(1, Ordering::SeqCst);
329 };
330
331 futs.push(fut);
332 }
333
334 let sends = async {
335 loop {
336 if let Some((_data, chan)) = merger.take() {
337 match chan {
338 ReplyChan::One((tx, _)) => {
339 let _ = tx.send(RaftResponse::Ok);
340 }
341 ReplyChan::More(txs) => {
342 for (tx, _) in txs {
343 let _ = tx.send(RaftResponse::Ok);
344 }
345 }
346 }
347 }
348 tokio::time::sleep(Duration::from_millis(100)).await;
349 if merger.len() == 0 {
350 break;
351 }
352 }
353 };
354
355 let count_p = count.clone();
356 let count_print = async move {
357 loop {
358 tokio::time::sleep(Duration::from_secs(2)).await;
359 println!("count_p: {}", count_p.load(Ordering::SeqCst));
360 if count_p.load(Ordering::SeqCst) >= MAX {
361 break;
362 }
363 }
364 };
365 println!("futs: {}", futs.len());
366 futures::future::join3(futures::future::join_all(futs), sends, count_print).await;
367
368 Ok(())
369}