1use std::net::{SocketAddr, ToSocketAddrs};
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use bincode::{deserialize, serialize};
7use bytestring::ByteString;
8use futures::channel::{mpsc, oneshot};
9use futures::future::FutureExt;
10use futures::SinkExt;
11use log::{debug, info, warn};
12use prost::Message as _;
13use tikv_raft::eraftpb::{ConfChange, ConfChangeType};
14use tokio::sync::RwLock;
15use tokio::time::timeout;
16use tonic::Request;
17
18use crate::error::{Error, Result};
19use crate::message::{Message, RaftResponse, RemoveNodeType, Status};
20use crate::raft_node::{Peer, RaftNode};
21use crate::raft_server::RaftServer;
22use crate::raft_service::connect;
23use crate::raft_service::{ConfChange as RiteraftConfChange, Empty, ResultCode};
24use crate::Config;
25
26type DashMap<K, V> = dashmap::DashMap<K, V, ahash::RandomState>;
27
28#[async_trait]
29pub trait Store: Clone + Send + Sync {
30 async fn apply(&mut self, message: &[u8]) -> Result<Vec<u8>>;
31 async fn query(&self, query: &[u8]) -> Result<Vec<u8>>;
32 async fn snapshot(&self) -> Result<Vec<u8>>;
33 async fn restore(&mut self, snapshot: &[u8]) -> Result<()>;
34}
35
36struct ProposalSender {
37 proposal: Vec<u8>,
38 client: Peer,
39}
40
41impl ProposalSender {
42 async fn send(self) -> Result<RaftResponse> {
43 match self.client.send_proposal(self.proposal).await {
44 Ok(reply) => {
45 let raft_response: RaftResponse = deserialize(&reply)?;
46 Ok(raft_response)
47 }
48 Err(e) => {
49 warn!("error sending proposal {:?}", e);
50 Err(e)
51 }
52 }
53 }
54}
55
56#[derive(Clone)]
57struct LeaderInfo {
58 leader: bool,
59 target_leader_id: u64,
60 target_leader_addr: Option<String>,
61}
62
63type LeaderInfoError = ByteString;
64
65#[derive(Clone)]
67pub struct Mailbox {
68 peers: Arc<DashMap<(u64, String), Peer>>,
69 sender: mpsc::Sender<Message>,
70 grpc_timeout: Duration,
71 grpc_concurrency_limit: usize,
72 grpc_message_size: usize,
73 grpc_breaker_threshold: u64,
74 grpc_breaker_retry_interval: i64,
75 #[allow(clippy::type_complexity)]
76 leader_info: Arc<
77 RwLock<
78 Option<(
79 Option<LeaderInfo>,
80 Option<LeaderInfoError>,
81 std::time::Instant,
82 )>,
83 >,
84 >,
85}
86
87impl Mailbox {
88 #[inline]
89 pub(crate) fn new(
90 peers: Arc<DashMap<(u64, String), Peer>>,
91 sender: mpsc::Sender<Message>,
92 grpc_timeout: Duration,
93 grpc_concurrency_limit: usize,
94 grpc_message_size: usize,
95 grpc_breaker_threshold: u64,
96 grpc_breaker_retry_interval: i64,
97 ) -> Self {
98 Self {
99 peers,
100 sender,
101 grpc_timeout,
102 grpc_concurrency_limit,
103 grpc_message_size,
104 grpc_breaker_threshold,
105 grpc_breaker_retry_interval,
106 leader_info: Arc::new(RwLock::new(None)),
107 }
108 }
109
110 #[inline]
114 pub fn pears(&self) -> Vec<(u64, Peer)> {
115 self.peers
116 .iter()
117 .map(|p| {
118 let (id, _) = p.key();
119 (*id, p.value().clone())
120 })
121 .collect::<Vec<_>>()
122 }
123
124 #[inline]
125 async fn peer(&self, leader_id: u64, leader_addr: String) -> Peer {
126 self.peers
127 .entry((leader_id, leader_addr.clone()))
128 .or_insert_with(|| {
129 Peer::new(
130 leader_addr,
131 self.grpc_timeout,
132 self.grpc_concurrency_limit,
133 self.grpc_message_size,
134 self.grpc_breaker_threshold,
135 self.grpc_breaker_retry_interval,
136 )
137 })
138 .clone()
139 }
140
141 #[inline]
142 async fn send_to_leader(
143 &self,
144 proposal: Vec<u8>,
145 leader_id: u64,
146 leader_addr: String,
147 ) -> Result<RaftResponse> {
148 let peer = self.peer(leader_id, leader_addr).await;
149 let proposal_sender = ProposalSender {
150 proposal,
151 client: peer,
152 };
153 proposal_sender.send().await
154 }
155
156 #[inline]
161 pub async fn send_proposal(&self, message: Vec<u8>) -> Result<Vec<u8>> {
162 match self.get_leader_info().await? {
163 LeaderInfo { leader: true, .. } => {
164 debug!("this node is leader");
165 let (tx, rx) = oneshot::channel();
166 let proposal = Message::Propose {
167 proposal: message.clone(),
168 chan: tx,
169 };
170 let mut sender = self.sender.clone();
171 sender
172 .send(proposal)
173 .await .map_err(|e| Error::SendError(e.to_string()))?;
175 let reply = timeout(self.grpc_timeout, rx).await;
176 let reply = reply
177 .map_err(|e| Error::RecvError(e.to_string()))?
178 .map_err(|e| Error::RecvError(e.to_string()))?;
179 match reply {
180 RaftResponse::Response { data } => return Ok(data),
181 RaftResponse::Busy => return Err(Error::Busy),
182 RaftResponse::Error(e) => return Err(Error::from(e)),
183 _ => {
184 warn!("Recv other raft response: {:?}", reply);
185 return Err(Error::Unknown);
186 }
187 }
188 }
189 LeaderInfo {
190 leader: false,
191 target_leader_id,
192 target_leader_addr,
193 ..
194 } => {
195 debug!(
196 "This node not is Leader, leader_id: {:?}, leader_addr: {:?}",
197 target_leader_id, target_leader_addr
198 );
199 if let Some(target_leader_addr) = target_leader_addr {
200 if target_leader_id != 0 {
201 return match self
202 .send_to_leader(message, target_leader_id, target_leader_addr.clone())
203 .await?
204 {
205 RaftResponse::Response { data } => return Ok(data),
206 RaftResponse::WrongLeader {
207 leader_id,
208 leader_addr,
209 } => {
210 warn!("The target node is not the Leader, target_leader_id: {}, target_leader_addr: {:?}, actual_leader_id: {}, actual_leader_addr: {:?}",
211 target_leader_id, target_leader_addr, leader_id, leader_addr);
212 return Err(Error::NotLeader);
213 }
214 RaftResponse::Busy => Err(Error::Busy),
215 RaftResponse::Error(e) => Err(Error::from(e)),
216 _ => {
217 warn!("Recv other raft response, target_leader_id: {}, target_leader_addr: {:?}", target_leader_id, target_leader_addr);
218 return Err(Error::Unknown);
219 }
220 };
221 }
222 }
223 }
224 }
225 Err(Error::LeaderNotExist)
226 }
227
228 #[inline]
230 #[deprecated]
231 pub async fn send(&self, message: Vec<u8>) -> Result<Vec<u8>> {
232 self.send_proposal(message).await
233 }
234
235 #[inline]
239 pub async fn query(&self, query: Vec<u8>) -> Result<Vec<u8>> {
240 let (tx, rx) = oneshot::channel();
241 let mut sender = self.sender.clone();
242 match sender.try_send(Message::Query { query, chan: tx }) {
243 Ok(()) => match timeout(self.grpc_timeout, rx).await {
244 Ok(Ok(RaftResponse::Response { data })) => Ok(data),
245 Ok(Ok(RaftResponse::Error(e))) => Err(Error::from(e)),
246 _ => Err(Error::Unknown),
247 },
248 Err(e) => Err(Error::SendError(e.to_string())),
249 }
250 }
251
252 #[inline]
255 pub async fn leave(&self) -> Result<()> {
256 let mut change = ConfChange::default();
257 change.set_node_id(0);
259 change.set_change_type(ConfChangeType::RemoveNode);
260 change.set_context(serialize(&RemoveNodeType::Normal)?);
261 let mut sender = self.sender.clone();
262 let (chan, rx) = oneshot::channel();
263 match sender.send(Message::ConfigChange { change, chan }).await {
264 Ok(()) => match rx.await {
265 Ok(RaftResponse::Ok) => Ok(()),
266 Ok(RaftResponse::Error(e)) => Err(Error::from(e)),
267 _ => Err(Error::Unknown),
268 },
269 Err(e) => Err(Error::SendError(e.to_string())),
270 }
271 }
272
273 #[inline]
276 pub async fn status(&self) -> Result<Status> {
277 let (tx, rx) = oneshot::channel();
278 let mut sender = self.sender.clone();
279 match sender.send(Message::Status { chan: tx }).await {
280 Ok(_) => match timeout(self.grpc_timeout, rx).await {
281 Ok(Ok(RaftResponse::Status(status))) => Ok(status),
282 Ok(Ok(RaftResponse::Error(e))) => Err(Error::from(e)),
283 _ => Err(Error::Unknown),
284 },
285 Err(e) => Err(Error::SendError(e.to_string())),
286 }
287 }
288
289 #[inline]
292 async fn _get_leader_info(&self) -> std::result::Result<LeaderInfo, LeaderInfoError> {
293 let (tx, rx) = oneshot::channel();
294 let mut sender = self.sender.clone();
295 match sender.send(Message::RequestId { chan: tx }).await {
296 Ok(_) => match timeout(self.grpc_timeout, rx).await {
297 Ok(Ok(RaftResponse::RequestId { leader_id })) => Ok(LeaderInfo {
298 leader: true,
299 target_leader_id: leader_id,
300 target_leader_addr: None,
301 }),
302 Ok(Ok(RaftResponse::WrongLeader {
303 leader_id,
304 leader_addr,
305 })) => Ok(LeaderInfo {
306 leader: false,
307 target_leader_id: leader_id,
308 target_leader_addr: leader_addr,
309 }),
310 Ok(Ok(RaftResponse::Error(e))) => Err(LeaderInfoError::from(e)),
311 _ => Err("Unknown".into()),
312 },
313 Err(e) => Err(LeaderInfoError::from(e.to_string())),
314 }
315 }
316
317 #[inline]
318 async fn get_leader_info(&self) -> Result<LeaderInfo> {
319 {
320 let leader_info = self.leader_info.read().await;
321 if let Some((leader_info, err, inst)) = leader_info.as_ref() {
322 if inst.elapsed().as_secs() < 5 {
323 if let Some(leader_info) = leader_info {
324 return Ok(leader_info.clone());
325 }
326 if let Some(err) = err {
327 return Err(err.to_string().into());
328 }
329 }
330 }
331 }
332
333 let mut write = self.leader_info.write().await;
334
335 return match self._get_leader_info().await {
336 Ok(leader_info) => {
337 write.replace((Some(leader_info.clone()), None, std::time::Instant::now()));
338 Ok(leader_info)
339 }
340 Err(e) => {
341 let err = e.to_string().into();
342 write.replace((None, Some(e), std::time::Instant::now()));
343 Err(err)
344 }
345 };
346 }
347}
348
349pub struct Raft<S: Store + 'static> {
350 store: S,
351 tx: mpsc::Sender<Message>,
352 rx: mpsc::Receiver<Message>,
353 laddr: SocketAddr,
354 logger: slog::Logger,
355 cfg: Arc<Config>,
356}
357
358impl<S: Store + Send + Sync + 'static> Raft<S> {
359 pub fn new<A: ToSocketAddrs>(
362 laddr: A,
363 store: S,
364 logger: slog::Logger,
365 cfg: Config,
366 ) -> Result<Self> {
367 let laddr = laddr
368 .to_socket_addrs()?
369 .next()
370 .ok_or_else(|| Error::from("None"))?;
371 let (tx, rx) = mpsc::channel(100_000);
372 let cfg = Arc::new(cfg);
373 Ok(Self {
374 store,
375 tx,
376 rx,
377 laddr,
378 logger,
379 cfg,
380 })
381 }
382
383 pub fn mailbox(&self) -> Mailbox {
385 Mailbox::new(
386 Arc::new(DashMap::default()),
387 self.tx.clone(),
388 self.cfg.grpc_timeout,
389 self.cfg.grpc_concurrency_limit,
390 self.cfg.grpc_message_size,
391 self.cfg.grpc_breaker_threshold,
392 self.cfg.grpc_breaker_retry_interval.as_millis() as i64,
393 )
394 }
395
396 pub async fn find_leader_info(&self, peer_addrs: Vec<String>) -> Result<Option<(u64, String)>> {
399 let mut futs = Vec::new();
400 for addr in peer_addrs {
401 let fut = async {
402 let _addr = addr.clone();
403 match self.request_leader(addr).await {
404 Ok(reply) => Ok(reply),
405 Err(e) => Err(e),
406 }
407 };
408 futs.push(fut.boxed());
409 }
410
411 let (leader_id, leader_addr) = match futures::future::select_ok(futs).await {
412 Ok((Some((leader_id, leader_addr)), _)) => (leader_id, leader_addr),
413 Ok((None, _)) => return Err(Error::LeaderNotExist),
414 Err(_e) => return Ok(None),
415 };
416
417 if leader_id == 0 {
418 Ok(None)
419 } else {
420 Ok(Some((leader_id, leader_addr)))
421 }
422 }
423
424 async fn request_leader(&self, peer_addr: String) -> Result<Option<(u64, String)>> {
427 let (leader_id, leader_addr): (u64, String) = {
428 let mut client = connect(
429 &peer_addr,
430 1,
431 self.cfg.grpc_message_size,
432 self.cfg.grpc_timeout,
433 )
434 .await?;
435 let response = client
436 .request_id(Request::new(Empty::default()))
437 .await?
438 .into_inner();
439 match response.code() {
440 ResultCode::WrongLeader => {
441 let (leader_id, addr): (u64, Option<String>) = deserialize(&response.data)?;
442 if let Some(addr) = addr {
443 (leader_id, addr)
444 } else {
445 return Ok(None);
446 }
447 }
448 ResultCode::Ok => (deserialize(&response.data)?, peer_addr),
449 ResultCode::Error => return Ok(None),
450 }
451 };
452 Ok(Some((leader_id, leader_addr)))
453 }
454
455 pub async fn lead(self, node_id: u64) -> Result<()> {
468 let node = RaftNode::new_leader(
469 self.rx,
470 self.tx.clone(),
471 node_id,
472 self.store,
473 &self.logger,
474 self.cfg.clone(),
475 )
476 .await?;
477
478 let server = RaftServer::new(self.tx, self.laddr, self.cfg.clone());
479 let server_handle = async {
480 if let Err(e) = server.run().await {
481 warn!("raft server run error: {:?}", e);
482 Err(e)
483 } else {
484 Ok(())
485 }
486 };
487 let node_handle = async {
488 if let Err(e) = node.run().await {
489 warn!("node run error: {:?}", e);
490 Err(e)
491 } else {
492 Ok(())
493 }
494 };
495
496 tokio::try_join!(server_handle, node_handle)?;
497 info!("leaving leader node");
498
499 Ok(())
500 }
501
502 pub async fn join(
517 self,
518 node_id: u64,
519 node_addr: String,
520 leader_id: Option<u64>,
521 leader_addr: String,
522 ) -> Result<()> {
523 info!("attempting to join peer cluster at {}", leader_addr);
525 let (leader_id, leader_addr): (u64, String) = if let Some(leader_id) = leader_id {
526 (leader_id, leader_addr)
527 } else {
528 self.request_leader(leader_addr)
529 .await?
530 .ok_or(Error::JoinError)?
531 };
532
533 let mut node = RaftNode::new_follower(
535 self.rx,
536 self.tx.clone(),
537 node_id,
538 self.store,
539 &self.logger,
540 self.cfg.clone(),
541 )?;
542 let peer = node.add_peer(&leader_addr, leader_id);
543 let mut client = peer.client().await?;
544 let server = RaftServer::new(self.tx, self.laddr, self.cfg.clone());
545 let server_handle = async {
546 if let Err(e) = server.run().await {
547 warn!("raft server run error: {:?}", e);
548 Err(e)
549 } else {
550 Ok(())
551 }
552 };
553
554 let node_handle = async {
555 tokio::time::sleep(Duration::from_millis(1500)).await;
556 let mut change_remove = ConfChange::default();
558 change_remove.set_node_id(node_id);
559 change_remove.set_change_type(ConfChangeType::RemoveNode);
560 change_remove.set_context(serialize(&RemoveNodeType::Stale)?);
561 let change_remove = RiteraftConfChange {
562 inner: ConfChange::encode_to_vec(&change_remove),
563 };
564
565 let raft_response = client
566 .change_config(Request::new(change_remove))
567 .await?
568 .into_inner();
569
570 info!(
571 "change_remove raft_response: {:?}",
572 deserialize::<RaftResponse>(&raft_response.inner)?
573 );
574
575 let mut change = ConfChange::default();
578 change.set_node_id(node_id);
579 change.set_change_type(ConfChangeType::AddNode);
580 change.set_context(serialize(&node_addr)?);
581 let change = RiteraftConfChange {
584 inner: ConfChange::encode_to_vec(&change),
585 };
586 let raft_response = client
587 .change_config(Request::new(change))
588 .await?
589 .into_inner();
590 if let RaftResponse::JoinSuccess {
591 assigned_id,
592 peer_addrs,
593 } = deserialize(&raft_response.inner)?
594 {
595 info!(
596 "change_config response.assigned_id: {:?}, peer_addrs: {:?}",
597 assigned_id, peer_addrs
598 );
599 for (id, addr) in peer_addrs {
600 if id != assigned_id {
601 node.add_peer(&addr, id);
602 }
603 }
604 } else {
605 return Err(Error::JoinError);
606 }
607
608 if let Err(e) = node.run().await {
609 warn!("node run error: {:?}", e);
610 Err(e)
611 } else {
612 Ok(())
613 }
614 };
615 let _ = tokio::try_join!(server_handle, node_handle)?;
616 info!("leaving follower node");
617 Ok(())
618 }
619}