use std::fmt::Debug;
use log::{debug, error, info, trace};
use rand::Rng;
use rand_chacha::rand_core::SeedableRng;
use serde::{Deserialize, Serialize};
use dcs::communication::messages::{UpdateClusterVec, Header, Package, PackageBuilder};
use dcs::communication::service::CommunicationService;
use dcs::coordination::{CoordinationPackage, CoordinationService, Stopwatch};
use dcs::heapless;
use dcs::heapless::LinearMap;
use dcs::nodes::SystemNodeId;
use dcs::properties::CLUSTER_NODE_COUNT;
use dcs::rules::measurements::{Measurement, SystemState};
use dcs::rules::strategy::Rule;
use crate::messages::RaftMessage::{AppendLog, ReadRequestReply, WriteRequestReply};
use crate::messages::*;
use crate::metadata::RaftMetadata;
use crate::server::ElectionVote::Abstained;
mod candidate;
mod follower;
mod leader;
pub trait NoOp {
fn noop() -> Self;
}
pub trait Merge {
fn merge(self, rhs: Self) -> Self;
}
pub trait LogData:
Clone + Debug + NoOp + Merge + Serialize + From<(SystemNodeId, Measurement)> + Into<SystemState>
{
}
impl<T> LogData for T where
T: Clone
+ Debug
+ NoOp
+ Merge
+ Serialize
+ From<(SystemNodeId, Measurement)>
+ Into<SystemState>
{
}
struct IO<T: Stopwatch> {
timer: T,
timeout_secs: u64,
}
impl<T: Stopwatch> IO<T> {
fn to_millis(secs: u64) -> u64 {
secs * 1000
}
pub fn set_heartbeat_timeout(&mut self) {
self.timer = Stopwatch::from_millis(Self::to_millis(self.timeout_secs) / 5);
}
pub fn set_follower_timeout(&mut self) {
self.timer = Stopwatch::from_millis(Self::to_millis(self.timeout_secs));
}
pub fn set_candidate_timeout(&mut self) {
self.timer = Stopwatch::from_millis((Self::to_millis(self.timeout_secs) * 3) / 4);
}
pub fn get_leasing_time_secs(&mut self) -> u64 {
self.timeout_secs
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum MemberState {
Leader,
Follower,
Candidate,
}
trait LeaderBehavior<T: Stopwatch, L: LogData> {
fn parse_message(
&mut self,
package: RaftPackage<L>,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) -> MemberState;
fn after_tick(&mut self, communication_service: &mut dyn CommunicationService<RaftPackage<L>>);
}
trait CandidateBehavior<T: Stopwatch, L: LogData> {
fn parse_message(
&mut self,
package: RaftPackage<L>,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) -> MemberState;
fn after_tick(&mut self, communication_service: &mut dyn CommunicationService<RaftPackage<L>>);
}
trait FollowerBehavior<T: Stopwatch, L: LogData> {
fn parse_message(
&mut self,
package: RaftPackage<L>,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) -> MemberState;
fn after_tick(&mut self, communication_service: &mut dyn CommunicationService<RaftPackage<L>>);
}
pub type RaftPackage<T> = Package<SystemNodeId, RaftMessage<T>>;
pub struct RaftService<T: Stopwatch, L: LogData> {
io: IO<T>,
term: Term,
id: SystemNodeId,
leader_id: Option<SystemNodeId>,
commit_index: LogIndex,
last_applied: LogIndex,
current_role: MemberState,
log: Log<L>,
cluster: LinearMap<SystemNodeId, ClusterMember, CLUSTER_NODE_COUNT>,
pub members_already_setup: bool,
next_config: Option<UpdateClusterVec>,
}
impl<T: Stopwatch, L: LogData>
CoordinationService<
T,
RaftMessage<L>,
LinearMap<SystemNodeId, RaftMetadata, CLUSTER_NODE_COUNT>,
RaftMetadata,
> for RaftService<T, L>
{
fn new(id: SystemNodeId, metadata: RaftMetadata) -> Self {
debug!(
"Intializaing Raft Coordination Service. Node timeout: {:?}s",
metadata.timeout
);
let timer = Stopwatch::from_millis(metadata.timeout * 1000);
Self::new(timer, None, id.into(), metadata.timeout)
}
fn leader(&self) -> Option<SystemNodeId> {
self.leader_id
}
fn get_state(&self) -> SystemState {
debug!("Raft state: {:?}", self.log);
let mut state = SystemState::default();
for entry in self.log.iter().filter_map(|entry| entry.data.clone()) {
let state_entry : SystemState = entry.into();
state.extend(state_entry)
}
return state
}
fn get_current_rule(&self) -> Option<Rule> {
self.log.iter().filter_map(|entry| entry.rule).last()
}
fn update_rule(
&mut self,
communication_service: &mut dyn CommunicationService<Package<SystemNodeId, RaftMessage<L>>>,
new_rule: Rule
) {
let message = RaftMessage::WriteRequest(WriteRequestArgs::with_rule(self.id, new_rule));
let package = package(message).from(self.id).to(self.id).build().unwrap();
self.process(communication_service, Some(package), None)
}
fn update_members(
&mut self,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
new_config: UpdateClusterVec,
) {
let message = RaftMessage::ConfigChange(new_config);
let package = package(message).from(self.id).to(self.id).build().unwrap();
self.process(communication_service, Some(package), None)
}
fn update_state(
&mut self,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
measurement: Measurement,
) {
if !self.is_leader() && self.leader_id.is_some() {
let measurement = RaftMessage::WriteRequest(WriteRequestArgs::with_measurement(self.id.into(), measurement));
let package = self
.build_message(measurement, self.leader_id.unwrap().into())
.unwrap();
communication_service.push(package);
} else {
let message = RaftMessage::WriteRequest(WriteRequestArgs::with_measurement(self.id, measurement));
let package = self.build_message(message, self.id).unwrap();
self.process(communication_service, Some(package), None)
}
}
fn process(
&mut self,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
package: Option<CoordinationPackage<RaftMessage<L>>>,
members: LinearMap<SystemNodeId, RaftMetadata, CLUSTER_NODE_COUNT>,
) {
self.process(communication_service, package, Some(members))
}
}
impl<T: Stopwatch, L: LogData> RaftService<T, L> {
pub(crate) fn new(
timer: T,
cluster: Option<LinearMap<SystemNodeId, ClusterMember, CLUSTER_NODE_COUNT>>,
id: SystemNodeId,
timeout_secs: u64,
) -> Self {
RaftService {
id,
term: 0,
leader_id: None,
commit_index: 0,
last_applied: 0,
io: IO { timer, timeout_secs, },
log: Default::default(),
members_already_setup: cluster.is_some(),
current_role: MemberState::Follower,
cluster: cluster.unwrap_or_default(),
next_config: None,
}
}
pub fn is_leader(&self) -> bool {
self.current_role == MemberState::Leader
}
pub fn process(
&mut self,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
message: Option<Package<SystemNodeId, RaftMessage<L>>>,
members: Option<LinearMap<SystemNodeId, RaftMetadata, CLUSTER_NODE_COUNT>>,
) {
if !self.members_already_setup {
self.set_up_members_for_the_first_time(members);
}
if let Some(msg) = message.clone() {
if msg.header.from != msg.header.to {
debug!("[RAFT] ({}) ==|{}|==> ({})", msg.header.from, msg.body, msg.header.to);
}
}
if self.members_already_setup {
trace!("Current log status: {:?}", self.log);
self.io.timer.update();
if let Some(package) = message {
self.parse_message(package, communication_service);
} else {
trace!("No message received, ticking!");
}
self.after_tick(communication_service);
}
}
fn set_up_members_for_the_first_time(
&mut self,
members: Option<LinearMap<SystemNodeId, RaftMetadata, CLUSTER_NODE_COUNT>>,
) {
if let Some(members) = members {
debug!("Setting up members for the first time, cluster: {:?}", members);
self.cluster = members
.iter()
.filter(|(&id, _)| id != self.id)
.map(|(&id, _metadata)| {
(
id,
ClusterMember {
id: id.into(),
vote_granted: ElectionVote::Abstained,
next_idx: 1,
match_idx: 0,
last_successful_heartbeat: 0,
},
)
})
.collect();
self.members_already_setup = true;
}
}
pub fn tick(&mut self, communication_service: &mut dyn CommunicationService<RaftPackage<L>>) {
let message = communication_service.pop();
self.process(communication_service, message, None)
}
fn parse_message(
&mut self,
package: RaftPackage<L>,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
debug!("Processing package {:?} as {:?}.", package, self.current_role);
self.current_role = match self.current_role {
MemberState::Leader => (self as &mut dyn LeaderBehavior<T, L>)
.parse_message(package, communication_service),
MemberState::Follower => (self as &mut dyn FollowerBehavior<T, L>)
.parse_message(package, communication_service),
MemberState::Candidate => (self as &mut dyn CandidateBehavior<T, L>)
.parse_message(package, communication_service),
};
}
fn update_ttl(&mut self, header: &Header<SystemNodeId>) {
if let Some(node) = self.cluster.get_mut(&header.from) {
node.last_successful_heartbeat = self.io.timer.current_time_as_secs();
}
}
fn start_election(
&mut self,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
self.term += 1;
self.current_role = MemberState::Candidate;
self.clean_state_from_previous_election();
self.send_election_messages(communication_service);
self.io.set_candidate_timeout();
}
fn clean_state_from_previous_election(&mut self) {
self.cluster
.values_mut()
.for_each(|member| member.vote_granted = Abstained);
}
fn send_election_messages(
&mut self,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
info!("Node #{} starting leader election.", self.id);
debug!("Election started in cluster: {:?}.", self.cluster);
let cluster = &self.cluster;
let term = self.term;
let id = self.id;
cluster
.values()
.into_iter()
.filter(|member: &&ClusterMember| member.id != id)
.for_each(|member| {
let msg = RaftPackageBuilder::default()
.from(self.id)
.to(member.id)
.with_message(RaftMessage::RequestVote(RequestVoteArgs {
term,
prev_log_index: self.get_last_log_index(),
prev_log_term: self.get_last_log_term(),
}))
.build()
.ok();
if let Some(pkg) = msg {
communication_service.push(pkg)
}
});
}
pub(crate) fn build_message(
&self,
msg: RaftMessage<L>,
to: SystemNodeId,
) -> Option<Package<SystemNodeId, RaftMessage<L>>> {
package(msg)
.from(self.id)
.to(to)
.build()
.ok()
}
pub(crate) fn broadcast_message(
&mut self,
message: RaftMessage<L>,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
let my_id = self.id;
let member_ids = self
.cluster
.values()
.map(|m| m.id)
.filter(|member_id| member_id != &my_id);
for member in member_ids {
self.send_message_to(message.clone(), member, communication_service)
}
}
fn send_message_to(
&self,
message: RaftMessage<L>,
member: SystemNodeId,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
if let Some(msg) = self.build_message(message, member.into()) {
communication_service.push(msg)
}
}
pub(crate) fn reset_next_index(&mut self) {
let next_log_idx = self.get_last_log_index() + 1;
self.cluster
.iter_mut()
.for_each(|(_id, member)| member.next_idx = next_log_idx)
}
pub(crate) fn reset_match_index(&mut self) {
self.cluster
.iter_mut()
.for_each(|(_id, member)| member.match_idx = 0)
}
pub(crate) fn commit_noop(
&mut self,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
let prev_log_index = self.get_last_log_index();
let prev_log_term = self.get_last_log_term();
let noop: LogEntry<L> = LogEntry::with_data(self.term, Some(L::noop()));
let _ = self.log.push(noop.clone());
let mut entries = Log::new();
let _ = entries.push(noop);
self.broadcast_message(
AppendLog(AppendLogArgs {
term: self.term,
prev_log_index,
prev_log_term,
entries,
leader_commit: self.commit_index,
}),
communication_service,
);
}
pub fn reject_read_request(
&mut self,
header: Header<SystemNodeId>,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
let reject_read_request = ReadRequestReply(ReadRequestReplyArgs::fail(self.leader_id));
communication_service.push(
self.build_message(reject_read_request, header.from)
.unwrap(),
)
}
pub fn reject_write_request(
&mut self,
header: Header<SystemNodeId>,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
let msg = self
.build_message(WriteRequestReply(false, self.leader_id), header.from)
.unwrap();
communication_service.push(msg)
}
pub(crate) fn get_last_log_term(&self) -> Term {
if let Some(entry) = self.log.last() {
entry.term
} else {
0
}
}
pub(crate) fn get_last_log_index(&self) -> LogIndex {
self.log.last_index() as LogIndex
}
fn after_tick(&mut self, communication_service: &mut dyn CommunicationService<RaftPackage<L>>) {
if self.log.capacity() < 0.25 {
info!("Creating snapshot with commit-idx {}.", self.commit_index);
debug!("Log to snapshot: {:?}", self.log);
self.log.snapshot(self.commit_index);
debug!("Snapshot result: {:?}", self.log);
}
match self.current_role {
MemberState::Leader => (self as &mut dyn LeaderBehavior<T, L>).after_tick(communication_service),
MemberState::Follower => (self as &mut dyn FollowerBehavior<T, L>).after_tick(communication_service),
MemberState::Candidate => (self as &mut dyn CandidateBehavior<T, L>).after_tick(communication_service),
}
}
fn send_heartbeat_to_followers(
&mut self,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
let message = self.make_heartbeat();
let my_id = self.id;
let heartbeat_ttl = self.io.get_leasing_time_secs() / 2;
let now = self.io.timer.current_time_as_secs();
let member_ids = self
.cluster
.values()
.filter(|m| now - m.last_successful_heartbeat > heartbeat_ttl)
.map(|m| m.id)
.filter(|member_id| member_id != &my_id);
for member in member_ids {
self.send_message_to(message.clone(), member, communication_service)
}
}
fn make_heartbeat(&mut self) -> RaftMessage<L> {
let prev_log_index = self.get_last_log_index();
let prev_log_term = self.get_last_log_term();
RaftMessage::AppendLog(AppendLogArgs {
term: self.term,
prev_log_index,
prev_log_term,
entries: Log::new(),
leader_commit: self.commit_index,
})
}
fn send_install_snapshot_to(
&self,
member: &ClusterMember,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
let snapshot = self.log.get(1).unwrap();
let args = InstallSnapshotArgs {
term: self.term,
leader_id: self.id,
last_included_index: self.log.last_included_index(),
last_included_term: snapshot.term,
data: snapshot.clone(),
};
self.send_message_to(
RaftMessage::InstallSnapshot(args),
member.id,
communication_service,
);
}
fn send_next_entry_to(
&self,
member: &ClusterMember,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
let (prev_log_index, prev_log_term) = if member.next_idx > 1 {
let idx = member.next_idx - 1;
let term = self
.log
.get(idx as LogIndex)
.map(|entry| entry.term)
.unwrap_or_default();
(idx as LogIndex, term)
} else {
(0, 0)
};
trace!("Send next entry to {member:?} with previous-log-idx {prev_log_index} and term {prev_log_term}");
if let Some(current_entry) = self.log.get(member.next_idx) {
let mut entries = Log::new();
let _ = entries.push(current_entry.clone());
let append_log_msg = AppendLog(AppendLogArgs {
term: self.term,
prev_log_index,
prev_log_term,
entries,
leader_commit: self.commit_index,
});
trace!("Sending next entry to {member:?} entry: {append_log_msg:?}");
self.send_message_to(append_log_msg, member.id, communication_service);
}
}
pub fn current_config(&self) -> UpdateClusterVec {
let mut current_config = UpdateClusterVec::from_iter(self.cluster.values().map(|v| v.id.into()));
current_config.push(self.id.into());
current_config.sort();
current_config
}
pub fn update_config(&mut self, new_cluster: &UpdateClusterVec) {
for member_id in new_cluster {
if !self.cluster.contains_key(&member_id) && *member_id != self.id {
let new_member = ClusterMember {
id: *member_id,
vote_granted: ElectionVote::Abstained,
next_idx: 1,
match_idx: 0,
last_successful_heartbeat: 0,
};
debug!("Adding new member to cluster {new_member:?}");
if self.cluster.insert(*member_id, new_member).is_err() {
error!("Couldn't insert node #{} into cluster data.", member_id)
}
}
}
let mut members_to_delete = heapless::Vec::<SystemNodeId, CLUSTER_NODE_COUNT>::new();
for system_id in self.cluster.keys() {
if !new_cluster.contains(&system_id) {
members_to_delete.push(*system_id);
}
}
members_to_delete.into_iter().for_each(|member| {
self.cluster.remove(&member);
});
}
pub fn update_commit_index(&mut self) {
let mut commit_indexes: Vec<LogIndex> = self
.cluster
.values()
.map(|member| member.next_idx.checked_sub(1).unwrap_or_default())
.collect();
commit_indexes.sort_by(|x, y| x.cmp(y).reverse());
let mut latest_entry_shared = commit_indexes.first().cloned();
for index in commit_indexes.iter() {
let mut count = commit_indexes
.iter()
.filter(|other_index| other_index <= &index)
.count();
if self.commit_index <= *index {
count += 1
}
if count >= (self.cluster.len() + 1) / 2 {
latest_entry_shared = Some(*index);
break;
}
}
if let Some(index) = latest_entry_shared {
if index >= self.commit_index {
self.commit_index = index;
}
}
}
pub fn replicate_entry(
&mut self,
entry: LogEntry<L>,
communication_service: &mut dyn CommunicationService<RaftPackage<L>>,
) {
let mut entries = Log::new();
let _ = entries.push(entry);
let append_log_msg = RaftMessage::AppendLog(AppendLogArgs {
term: self.term,
prev_log_index: self.get_last_log_index(),
prev_log_term: self.get_last_log_term(),
entries,
leader_commit: self.commit_index,
});
self.broadcast_message(append_log_msg, communication_service);
}
fn become_follower_of(&mut self, leader_id: SystemNodeId, new_term: Term) {
self.io.set_follower_timeout();
self.term = new_term;
debug!("Updating leader_id to: {}", leader_id);
self.leader_id = Some(leader_id);
}
}
fn package<T: LogData>(msg: RaftMessage<T>) -> RaftPackageBuilder<T> {
RaftPackageBuilder::default().with_message(msg)
}
#[derive(Serialize, Deserialize, Copy, Clone, Eq, PartialEq, Debug)]
pub enum ElectionVote {
Granted,
Against,
Abstained,
}
#[derive(Serialize, Deserialize, Copy, Clone, Eq, PartialEq, Debug)]
pub struct ClusterMember {
pub id: SystemNodeId,
pub vote_granted: ElectionVote,
pub next_idx: LogIndex,
pub match_idx: LogIndex,
pub last_successful_heartbeat: u64,
}
#[cfg(test)]
#[allow(non_snake_case)]
mod test_utils {
extern crate log;
extern crate std;
use core::fmt::Debug;
use std::println;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc;
use std::sync::mpsc::{Receiver, Sender};
use serde::{Deserialize, Serialize};
use dcs::communication::connection::{InMsgQueue, OutMsgQueue, Readable, Writable};
use dcs::communication::messages::{PackageBuilder};
use dcs::communication::service::CommunicationService;
use dcs::coordination::Stopwatch;
use dcs::nodes::SystemNodeId;
use dcs::rules::measurements::{Measurement, SystemState};
use crate::messages::RaftMessage::RequestVote;
use crate::messages::{RaftMessage, WriteRequestArgs};
use crate::server::{package, LogData, Merge, NoOp, RaftPackage, RaftService};
use crate::{CANDIATE_ELECTION_TIMEOUT, ELECTION_TIMEOUT};
static CLOCK: AtomicUsize = AtomicUsize::new(0);
pub struct FakeTimer {
clock: u32,
timeout: u32,
}
impl FakeTimer {
fn update_timer(&mut self) {
CLOCK.fetch_add(1, Ordering::Relaxed);
self.clock += 1;
if self.timeout > 0 {
self.timeout -= 1;
}
}
pub fn new() -> Self
where
Self: Sized,
{
FakeTimer {
clock: 0,
timeout: 0,
}
}
}
impl Stopwatch for FakeTimer {
fn from_millis(secs: u64) -> Self {
Self {
clock: 0 as u32,
timeout: secs as u32,
}
}
fn restart(&mut self) {}
fn is_timeout(&self) -> bool {
self.timeout == 0
}
fn current_time_as_secs(&mut self) -> u64 {
CLOCK.load(Ordering::Relaxed) as u64
}
fn as_secs(&self) -> u64 { unreachable!() }
fn update(&mut self) {
self.update_timer()
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Default, Serialize, Deserialize)]
pub struct TestState(pub u32, pub u32);
impl From<(SystemNodeId, Measurement)> for TestState {
fn from((_, m): (SystemNodeId, Measurement)) -> Self {
TestState(m.value as u32, 0)
}
}
impl From<TestState> for SystemState {
fn from(_: TestState) -> Self {
unreachable!()
}
}
impl NoOp for TestState {
fn noop() -> Self {
Self(0, 0)
}
}
impl Merge for TestState {
fn merge(self, rhs: Self) -> Self {
Self(self.0 + rhs.0, self.1 + rhs.1)
}
}
pub fn assert_message_is_vote_request<T: LogData + Debug>(actual_msg: &RaftMessage<T>) {
match actual_msg {
RequestVote { .. } => success(),
_ => fail(),
}
}
pub fn fail() {
unreachable!()
}
pub fn success() {
assert!(true)
}
pub fn timeout_candidate<T: LogData + Debug>(
candidate: &mut RaftService<FakeTimer, T>,
communication_service: &mut dyn CommunicationService<RaftPackage<T>>,
) {
tick_times(CANDIATE_ELECTION_TIMEOUT, candidate, communication_service);
}
pub fn almost_timeout_follower<T: LogData + Debug>(
candidate: &mut RaftService<FakeTimer, T>,
communication_service: &mut dyn CommunicationService<RaftPackage<T>>,
) {
tick_times(ELECTION_TIMEOUT - 1, candidate, communication_service);
}
pub fn timeout_follower<T: LogData + Debug>(
candidate: &mut RaftService<FakeTimer, T>,
communication_service: &mut dyn CommunicationService<RaftPackage<T>>,
) {
tick_times(ELECTION_TIMEOUT, candidate, communication_service);
}
fn tick_times<T: LogData + Debug>(
times: u64,
candidate: &mut RaftService<FakeTimer, T>,
communication_service: &mut dyn CommunicationService<RaftPackage<T>>,
) {
for _i in 0..times {
candidate.tick(communication_service);
}
}
pub fn read_msg_sent_by_server<T: LogData + Debug>(
in_queue_test: &mut dyn InMsgQueue<RaftPackage<T>>,
) -> RaftMessage<T> {
let package = in_queue_test
.pop()
.expect("Message wasn't received by server");
package.body
}
pub fn get_destination_of_msg<T: LogData + Debug>(
in_queue_test: &mut dyn InMsgQueue<RaftPackage<T>>,
) -> SystemNodeId {
in_queue_test.pop().unwrap().header.to
}
pub struct FakeDcsMsgQueue;
impl FakeDcsMsgQueue {
pub fn build_pair() -> (
FakeDcsInMsgQueue<RaftPackage<TestState>>,
FakeDcsOutMsgQueue<RaftPackage<TestState>>,
) {
let (tx, rx) = mpsc::channel::<RaftPackage<TestState>>();
(
FakeDcsInMsgQueue { in_channel: rx },
FakeDcsOutMsgQueue { out_channel: tx },
)
}
}
pub struct FakeDcsOutMsgQueue<T: Writable + Debug> {
out_channel: Sender<T>,
}
#[derive(Debug)]
pub struct FakeDcsInMsgQueue<T: Readable + Debug> {
in_channel: Receiver<T>,
}
impl<T: Writable + Send + Debug> OutMsgQueue<T> for FakeDcsOutMsgQueue<T> {
fn push(&mut self, msg: T) {
println!("PUSH {:?}", msg);
let _ = self.out_channel.send(msg);
}
}
impl<T: Readable + Debug> InMsgQueue<T> for FakeDcsInMsgQueue<T> {
fn pop(&mut self) -> Option<T> {
match self.in_channel.try_recv() {
Ok(msg) => {
println!("POP {:?}", msg);
Some(msg)
}
Err(_) => None,
}
}
}
pub struct FakeCommunicationService<'a> {
out_queue: &'a mut dyn OutMsgQueue<RaftPackage<TestState>>,
in_queue: &'a mut dyn InMsgQueue<RaftPackage<TestState>>,
}
impl<'a> FakeCommunicationService<'a> {
pub fn new(
out_queue: &'a mut dyn OutMsgQueue<RaftPackage<TestState>>,
in_queue: &'a mut dyn InMsgQueue<RaftPackage<TestState>>,
) -> Self {
Self {
out_queue,
in_queue,
}
}
}
impl<'a> OutMsgQueue<RaftPackage<TestState>> for FakeCommunicationService<'a> {
fn push(&mut self, msg: RaftPackage<TestState>) {
self.out_queue.push(msg);
}
}
impl<'a> InMsgQueue<RaftPackage<TestState>> for FakeCommunicationService<'a> {
fn pop(&mut self) -> Option<RaftPackage<TestState>> {
self.in_queue.pop()
}
}
impl<'a> CommunicationService<RaftPackage<TestState>> for FakeCommunicationService<'a> {}
pub fn configure_read_only_request<T: LogData + Debug>(id: SystemNodeId) -> RaftPackage<T> {
package(RaftMessage::ReadRequest)
.from(0.into())
.to(id)
.build()
.unwrap()
}
pub fn send_read_request(
server: &mut RaftService<FakeTimer, TestState>,
out_queue_test: &mut dyn OutMsgQueue<RaftPackage<TestState>>,
communication_service: &mut dyn CommunicationService<RaftPackage<TestState>>,
) {
let client_request = configure_read_only_request(server.id);
out_queue_test.push(client_request);
server.tick(communication_service);
}
pub fn send_write_request(
server: &mut RaftService<FakeTimer, TestState>,
out_queue_test: &mut dyn OutMsgQueue<RaftPackage<TestState>>,
data: Measurement,
communication_service: &mut dyn CommunicationService<RaftPackage<TestState>>,
) {
let server_id = server.id;
let client_request = configure_write_request(server_id, data);
out_queue_test.push(client_request);
server.tick(communication_service);
}
pub fn configure_write_request(
id: SystemNodeId,
data: Measurement,
) -> RaftPackage<TestState> {
let current_log = RaftMessage::WriteRequest(WriteRequestArgs::with_measurement(id, data));
package(current_log)
.from(2.into())
.to(id)
.build()
.unwrap()
}
}
#[cfg(test)]
pub mod test_server_builder {
extern crate log;
extern crate std;
use std::borrow::BorrowMut;
use std::collections::HashMap;
use std::iter::FromFn;
use std::ops::DerefMut;
use dcs::communication::connection::{InMsgQueue, OutMsgQueue};
use dcs::communication::messages::PackageBuilder;
use dcs::communication::service::CommunicationService;
use dcs::coordination::Stopwatch;
use dcs::heapless;
use dcs::heapless::LinearMap;
use dcs::nodes::SystemNodeId;
use dcs::rules::measurements::Measurement;
use dcs::rules::measurements::ClusterType::TEMPERATURE;
use crate::messages::*;
use crate::server::test_utils::*;
use crate::server::ElectionVote::Abstained;
use crate::server::*;
use crate::server::{package, CLUSTER_NODE_COUNT};
use crate::state::RaftState;
use crate::RaftMessage::*;
use crate::{RaftPackageBuilder, CANDIATE_ELECTION_TIMEOUT, ELECTION_TIMEOUT};
pub struct TestContext<'a> {
in_queue_test: FakeDcsInMsgQueue<RaftPackage<TestState>>,
out_queue_test: FakeDcsOutMsgQueue<RaftPackage<TestState>>,
server: RaftService<FakeTimer, TestState>,
cluster: LinearMap<SystemNodeId, ClusterMember, CLUSTER_NODE_COUNT>,
comm_service: &'a mut FakeCommunicationService<'a>,
}
impl<'a> TestContext<'a> {
pub fn get_node_mut(&mut self) -> &mut RaftService<FakeTimer, TestState> {
&mut self.server
}
pub fn get_node(&self) -> &RaftService<FakeTimer, TestState> {
&self.server
}
pub fn term(&self) -> Term {
self.server.term
}
pub fn tick(&mut self) {
let server = &mut self.server;
let mut comm_service = self.comm_service.deref_mut();
server.tick(comm_service)
}
pub fn cluster_size(&self) -> usize {
self.server.cluster.len() + 1
}
pub fn cluster(&self) -> LinearMap<SystemNodeId, ClusterMember, CLUSTER_NODE_COUNT> {
self.server.cluster.clone()
}
pub fn id(&self) -> SystemNodeId {
self.server.id
}
pub fn commit_idx(&self) -> LogIndex {
self.server.commit_index
}
pub fn send_read_request(&mut self) {
let server_id = self.server.id;
let client_request = configure_read_only_request(server_id);
self.send(client_request);
}
pub fn send_read_request_and_tick(&mut self) {
self.send_read_request();
self.tick();
}
pub fn send_write_request(&mut self, data: Measurement) {
let server_id = self.server.id;
let client_request = configure_write_request(server_id, data);
self.send(client_request);
}
pub fn send_write_request_and_tick(&mut self, data: Measurement) {
self.send_write_request(data);
self.tick();
}
pub fn empty_queue(&mut self) {
let _ = self.messages();
}
pub fn messages(&mut self) -> std::vec::Vec<RaftPackage<TestState>> {
let mut messages = vec![];
let mut pkg = self.recv();
while pkg.is_some() {
messages.push(pkg.unwrap());
pkg = self.recv();
}
messages
}
pub fn append_log_requests(&mut self) -> Vec<AppendLogArgs<TestState>> {
let mut messages = self.messages();
let messages: Vec<_> = messages
.into_iter()
.filter_map(|pkg| {
if let RaftMessage::AppendLog(args) = pkg.body {
Some(args)
} else {
None
}
})
.collect();
messages
}
pub fn count_heartbeats(&mut self) -> usize {
let queue = &mut self.in_queue_test;
let iter = std::iter::from_fn(|| queue.pop());
iter.map(|pkg| pkg.body)
.filter(|msg| matches!(msg, RaftMessage::AppendLog(args)))
.count()
}
pub fn entry_was_broadcasted(&mut self, entry: LogEntry<TestState>) -> bool {
let mut appendlog_messages: Vec<RaftPackage<TestState>> = self
.messages()
.into_iter()
.filter(|pkg| matches!(pkg.get_message(), RaftMessage::AppendLog(args)))
.collect();
let mut expected_entries = Log::new();
expected_entries.push(entry);
let mut count = 0;
for member in self.cluster.keys() {
let broadcasted = appendlog_messages
.iter()
.filter(|pkg| pkg.header.to == *member)
.filter(|pkg| {
if let AppendLog(args) = pkg.get_message() {
args.entries == expected_entries
} else {
false
}
})
.count();
if broadcasted != 0 {
count += 1
}
}
count == self.cluster.len()
}
pub fn get_read_response(&mut self) -> Option<ReadRequestReplyArgs<TestState>> {
self.messages()
.into_iter()
.map(|pkg| pkg.body)
.find_map(|msg| {
if let ReadRequestReply(args) = msg {
Some(args)
} else {
None
}
})
}
pub fn send(&mut self, pkg: RaftPackage<TestState>) {
self.out_queue_test.push(pkg);
}
pub fn send_and_tick(&mut self, pkg: RaftPackage<TestState>) {
self.out_queue_test.push(pkg);
self.tick();
}
pub fn package(&self, msg: RaftMessage<TestState>) -> RaftPackage<TestState> {
package(msg)
.from(SystemNodeId::default())
.to(self.server.id.into())
.build()
.unwrap()
}
pub fn recv(&mut self) -> Option<RaftPackage<TestState>> {
self.in_queue_test.pop()
}
pub fn recv_message(&mut self) -> Option<RaftMessage<TestState>> {
self.recv().map(|p| p.body)
}
pub fn reject_append_log(&mut self, id: SystemNodeId) {
let append_log_rejection =
RaftMessage::AppendLogResponse::<TestState>(AppendLogResponseResult {
term: self.server.term,
success: false,
});
let pkg = RaftPackageBuilder::default()
.from(id)
.to(self.server.id.into())
.with_message(append_log_rejection)
.build()
.unwrap();
self.send(pkg);
}
pub fn accept_append_log(&mut self, id: SystemNodeId) {
self.answer_append_log(id, true)
}
pub fn answer_append_log(&mut self, id: SystemNodeId, success: bool) {
let append_log_rejection =
RaftMessage::AppendLogResponse::<TestState>(AppendLogResponseResult {
term: self.server.term,
success,
});
let pkg = RaftPackageBuilder::default()
.from(id)
.to(self.server.id.into())
.with_message(append_log_rejection)
.build()
.unwrap();
self.send(pkg);
}
pub fn followers_answers_append_log_request(&mut self, success: bool) {
self.empty_queue();
self.followers().into_iter().for_each(|follower| {
self.answer_append_log(follower, success);
self.tick();
});
}
pub fn followers_accept_append_log_request(&mut self) {
self.followers_answers_append_log_request(true)
}
pub fn followers(&mut self) -> Vec<SystemNodeId> {
self.server.cluster.keys().cloned().collect()
}
pub fn current_config(&self) -> UpdateClusterVec {
let mut config: UpdateClusterVec =
self.server.cluster.values().map(|node| node.id.into()).collect();
config.push(self.server.id.into());
config.sort();
config
}
pub fn send_append_log_and_tick(
&mut self,
prev_log_index: LogIndex,
prev_log_term: Term,
data: Option<TestState>,
) {
let mut entries: Log<TestState> = Default::default();
let _ = entries.push(LogEntry::with_data(prev_log_term, data));
let append_log = AppendLog(AppendLogArgs {
term: 1,
prev_log_index,
prev_log_term,
entries,
leader_commit: 0,
});
self.send_and_tick(self.package(append_log))
}
}
pub struct ContextBuilder<'a> {
server_builder: ServerBuilder<'a>,
}
impl<'a> ContextBuilder<'a> {
pub fn new() -> Self {
Self {
server_builder: ServerBuilder::new(),
}
}
pub fn follower<'b: 'a>(&'b mut self) -> TestContext {
let (mut in_queue_test, mut out_queue_test, mut server, cluster, comm_service) =
self.server_builder.follower();
TestContext {
in_queue_test,
out_queue_test,
server,
cluster,
comm_service,
}
}
pub fn leader<'b: 'a>(&'b mut self) -> TestContext {
let (mut in_queue_test, mut out_queue_test, mut server, cluster, comm_service) =
self.server_builder.leader();
TestContext {
in_queue_test,
out_queue_test,
server,
cluster,
comm_service,
}
}
}
pub type BuilderOutput<'b> = (
FakeDcsInMsgQueue<RaftPackage<TestState>>,
FakeDcsOutMsgQueue<RaftPackage<TestState>>,
RaftService<FakeTimer, TestState>,
LinearMap<SystemNodeId, ClusterMember, CLUSTER_NODE_COUNT>,
&'b mut FakeCommunicationService<'b>,
);
pub struct ServerBuilder<'a> {
test_rx: Option<FakeDcsInMsgQueue<RaftPackage<TestState>>>,
comm_service_tx: FakeDcsOutMsgQueue<RaftPackage<TestState>>,
test_tx: Option<FakeDcsOutMsgQueue<RaftPackage<TestState>>>,
comm_service_rx: FakeDcsInMsgQueue<RaftPackage<TestState>>,
communication_service: Option<FakeCommunicationService<'a>>,
}
impl<'a> ServerBuilder<'a> {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let (out_rx, out_tx) = FakeDcsMsgQueue::build_pair();
let (in_rx, in_tx) = FakeDcsMsgQueue::build_pair();
Self {
test_rx: Some(out_rx),
comm_service_tx: out_tx,
test_tx: Some(in_tx),
comm_service_rx: in_rx,
communication_service: None,
}
}
pub fn follower<'b: 'a>(&'b mut self) -> BuilderOutput {
let mut timer: FakeTimer = FakeTimer::new();
self.communication_service = Some(FakeCommunicationService::new(
&mut self.comm_service_tx,
&mut self.comm_service_rx,
));
timer = Stopwatch::from_millis(ELECTION_TIMEOUT);
let mut cluster: LinearMap<SystemNodeId, ClusterMember, CLUSTER_NODE_COUNT> =
LinearMap::new();
let _ = cluster.insert(
0.into(),
ClusterMember {
id: SystemNodeId::from(0),
vote_granted: Abstained,
match_idx: 0,
next_idx: 1,
last_successful_heartbeat: 0,
},
);
let _ = cluster.insert(
2.into(),
ClusterMember {
id: SystemNodeId::from(2),
vote_granted: Abstained,
match_idx: 0,
next_idx: 1,
last_successful_heartbeat: 0,
},
);
let mut server = RaftService::new(
timer, Some(cluster.clone()), SystemNodeId::from(1), ELECTION_TIMEOUT
);
server.leader_id = Some(SystemNodeId::from(0));
(
self.test_rx.take().unwrap(),
self.test_tx.take().unwrap(),
server,
cluster,
self.communication_service.as_mut().unwrap(),
)
}
pub fn server_in_cluster<'b: 'a>(&'b mut self) -> BuilderOutput {
self.communication_service = Some(FakeCommunicationService::new(
&mut self.comm_service_tx,
&mut self.comm_service_rx,
));
let mut timer: FakeTimer = FakeTimer::new();
let cluster = <ServerBuilder<'a>>::initialize_cluster_with_members();
timer = Stopwatch::from_millis(ELECTION_TIMEOUT);
let server = RaftService::new(
timer, Some(cluster.clone()), SystemNodeId::from(0), ELECTION_TIMEOUT
);
(
self.test_rx.take().unwrap(),
self.test_tx.take().unwrap(),
server,
cluster,
self.communication_service.as_mut().unwrap(),
)
}
pub fn candidate<'b: 'a>(&'b mut self) -> BuilderOutput {
self.communication_service = Some(FakeCommunicationService::new(
&mut self.comm_service_tx,
&mut self.comm_service_rx,
));
let timer: FakeTimer = FakeTimer::new();
let cluster = <ServerBuilder<'a>>::initialize_cluster_with_members();
let mut server = RaftService::new(
timer, Some(cluster.clone()), SystemNodeId::from(0), ELECTION_TIMEOUT
);
let communication_service = self.communication_service.as_mut().unwrap();
server.tick(communication_service); server.io.timer = Stopwatch::from_millis(CANDIATE_ELECTION_TIMEOUT);
self.test_rx.as_mut().unwrap().pop().unwrap(); self.test_rx.as_mut().unwrap().pop().unwrap(); (
self.test_rx.take().unwrap(),
self.test_tx.take().unwrap(),
server,
cluster,
communication_service,
)
}
pub fn leader<'b: 'a>(&'b mut self) -> BuilderOutput {
self.communication_service = Some(FakeCommunicationService::new(
&mut self.comm_service_tx,
&mut self.comm_service_rx,
));
let mut cluster = <ServerBuilder<'a>>::initialize_cluster_with_members();
let timer: FakeTimer = FakeTimer::new();
let mut server = RaftService::new(
timer, Some(cluster.clone()), SystemNodeId::from(0), ELECTION_TIMEOUT
);
let communication_service = self.communication_service.as_mut().unwrap();
server.tick(communication_service); self.test_rx.as_mut().unwrap().pop().unwrap(); self.test_rx.as_mut().unwrap().pop().unwrap(); let server_id = server.id.clone();
for member in cluster.values().filter(|member| member.id != server_id) {
let grant_vote = RaftMessage::RequestVoteResponse(RequestVoteResponseResult {
term: 1,
granted: true,
});
let msg = package(grant_vote)
.from(member.id)
.to(server.id)
.build()
.unwrap();
self.test_tx.as_mut().unwrap().push(msg);
server.tick(communication_service);
}
for member in cluster.values().filter(|member| member.id != server_id) {
let accept_appendlog = RaftMessage::AppendLogResponse(AppendLogResponseResult {
term: 1,
success: true,
});
let msg = package(accept_appendlog)
.from(member.id)
.to(server.id)
.build()
.unwrap();
self.test_tx.as_mut().unwrap().push(msg);
server.tick(communication_service);
}
for _ in cluster.values().filter(|member| member.id != server.id) {
self.test_rx.as_mut().unwrap().pop(); }
(
self.test_rx.take().unwrap(),
self.test_tx.take().unwrap(),
server,
cluster,
communication_service,
)
}
fn initialize_cluster_with_members() -> LinearMap<SystemNodeId, ClusterMember, CLUSTER_NODE_COUNT> {
let mut cluster: LinearMap<SystemNodeId, ClusterMember, CLUSTER_NODE_COUNT> = LinearMap::new();
let _ = cluster.insert(
SystemNodeId::from(1),
ClusterMember {
id: SystemNodeId::from(1),
vote_granted: Abstained,
match_idx: 0,
next_idx: 1,
last_successful_heartbeat: 0,
},
);
let _ = cluster.insert(
SystemNodeId::from(2),
ClusterMember {
id: SystemNodeId::from(2),
vote_granted: Abstained,
match_idx: 0,
next_idx: 1,
last_successful_heartbeat: 0,
},
);
cluster
}
}
}
#[macro_export]
macro_rules! assert_message {
($field:ident in $message:ident == $value:expr) => {{
match $message {
RequestVote(RequestVoteArgs { $field, .. }) => assert_eq!($value, $field),
_ => fail(),
}
}};
}