use crate::error::{Error, Result};
use crate::message::{Message, RaftResponse};
use crate::raft_node::RaftNode;
use crate::raft_server::RaftServer;
use crate::raft_service::raft_service_client::RaftServiceClient;
use crate::raft_service::{RequestIdArgs, ResultCode};
use async_trait::async_trait;
use bincode::{deserialize, serialize};
use log::{info, warn};
use raft::eraftpb::{ConfChange, ConfChangeType};
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
use tonic::Request;
use std::collections::HashMap;
use std::time::Duration;
#[async_trait]
pub trait Store {
async fn apply(&mut self, message: &[u8]) -> Result<Vec<u8>>;
async fn snapshot(&self) -> Result<Vec<u8>>;
async fn restore(&mut self, snapshot: &[u8]) -> Result<()>;
}
#[derive(Clone)]
pub struct Mailbox(mpsc::Sender<Message>);
impl Mailbox {
pub async fn send(&self, message: Vec<u8>) -> Result<Vec<u8>> {
let (tx, rx) = oneshot::channel();
let proposal = Message::Propose {
proposal: message,
chan: tx,
};
let sender = self.0.clone();
match sender.send(proposal).await {
Ok(_) => match timeout(Duration::from_secs(2), rx).await {
Ok(Ok(RaftResponse::Response { data })) => Ok(data),
_ => Err(Error::Unknown),
},
_ => Err(Error::Unknown),
}
}
pub async fn leave(&self) -> Result<()> {
let mut change = ConfChange::default();
change.set_node_id(0);
change.set_change_type(ConfChangeType::RemoveNode);
let sender = self.0.clone();
let (chan, rx) = oneshot::channel();
match sender.send(Message::ConfigChange { change, chan }).await {
Ok(_) => match rx.await {
Ok(RaftResponse::Ok) => Ok(()),
_ => Err(Error::Unknown),
},
_ => Err(Error::Unknown),
}
}
}
pub struct Raft<S: Store + 'static> {
store: S,
tx: mpsc::Sender<Message>,
rx: mpsc::Receiver<Message>,
addr: String,
logger: slog::Logger,
}
impl<S: Store + Send + Sync + 'static> Raft<S> {
pub fn new(addr: String, store: S, logger: slog::Logger) -> Self {
let (tx, rx) = mpsc::channel(100);
Self {
store,
tx,
rx,
addr,
logger,
}
}
pub fn mailbox(&self) -> Mailbox {
Mailbox(self.tx.clone())
}
pub async fn lead(self) -> Result<()> {
let addr = self.addr.clone();
let node = RaftNode::new_leader(self.rx, self.tx.clone(), self.store, &self.logger);
let server = RaftServer::new(self.tx, addr);
let _server_handle = tokio::spawn(server.run());
let node_handle = tokio::spawn(node.run());
let _ = tokio::try_join!(node_handle);
warn!("leaving leader node");
Ok(())
}
pub async fn join(self, addr: String) -> Result<()> {
info!("attempting to join peer cluster at {}", addr);
let mut leader_addr = addr.to_string();
let (leader_id, node_id, peer_addrs): (u64, u64, HashMap<u64, String>) = loop {
let mut client = RaftServiceClient::connect(format!("http://{}", leader_addr)).await?;
let response = client
.request_id(Request::new(RequestIdArgs {
addr: self.addr.clone(),
}))
.await?
.into_inner();
match response.code() {
ResultCode::WrongLeader => {
info!("this is the wrong leader");
let (_leader_id, addr): (u64, String) = deserialize(&response.data)?;
leader_addr = addr;
info!("Wrong leader, retrying with leader at {}", leader_addr);
continue;
}
ResultCode::Ok => {
break deserialize(&response.data)?;
}
ResultCode::Error => return Err(Error::JoinError),
}
};
info!("obtained ID from leader: {}", node_id);
let addr = self.addr.clone();
let mut node =
RaftNode::new_follower(self.rx, self.tx.clone(), node_id, self.store, &self.logger)?;
for (id, peer_addr) in peer_addrs.iter() {
node.add_peer(peer_addr, id.to_owned()).await?;
}
node.add_peer(&leader_addr, leader_id).await?;
let mut client = node.peer_mut(leader_id).unwrap().clone();
let server = RaftServer::new(self.tx, addr);
let _server_handle = tokio::spawn(server.run());
let node_handle = tokio::spawn(node.run());
let mut change = ConfChange::default();
change.set_node_id(node_id);
change.set_change_type(ConfChangeType::AddNode);
change.set_context(serialize(&self.addr)?);
client.change_config(Request::new(change)).await?;
let _ = tokio::try_join!(node_handle);
Ok(())
}
}