use crate::{log::LogEntries, message::Message, node::NodeId};
use std::collections::{BTreeMap, BTreeSet};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Action {
SetElectionTimeout,
SaveCurrentTerm,
SaveVotedFor,
BroadcastMessage(Message),
AppendLogEntries(LogEntries),
SendMessage(NodeId, Message),
InstallSnapshot(NodeId),
}
#[derive(Debug, Default, Clone)]
pub struct Actions {
pub set_election_timeout: bool,
pub save_current_term: bool,
pub save_voted_for: bool,
pub broadcast_message: Option<Message>,
pub append_log_entries: Option<LogEntries>,
pub send_messages: BTreeMap<NodeId, Message>,
pub install_snapshots: BTreeSet<NodeId>,
}
impl Actions {
pub(crate) fn set(&mut self, action: Action) {
match action {
Action::SetElectionTimeout => self.set_election_timeout = true,
Action::SaveCurrentTerm => self.save_current_term = true,
Action::SaveVotedFor => self.save_voted_for = true,
Action::AppendLogEntries(log_entries) => {
if let Some(existing) = &mut self.append_log_entries {
existing.append(&log_entries);
} else {
self.append_log_entries = Some(log_entries);
}
}
Action::BroadcastMessage(message) => {
if let Some(existing) = &mut self.broadcast_message {
existing.merge(message);
} else {
self.broadcast_message = Some(message);
}
}
Action::SendMessage(node_id, message) => {
if let Some(existing) = self.send_messages.get_mut(&node_id) {
existing.merge(message);
} else {
self.send_messages.insert(node_id, message);
}
}
Action::InstallSnapshot(node_id) => {
self.install_snapshots.insert(node_id);
}
}
}
pub fn is_empty(&self) -> bool {
!self.set_election_timeout
&& !self.save_current_term
&& !self.save_voted_for
&& self.append_log_entries.is_none()
&& self.broadcast_message.is_none()
&& self.send_messages.is_empty()
&& self.install_snapshots.is_empty()
}
}
impl Iterator for Actions {
type Item = Action;
fn next(&mut self) -> Option<Self::Item> {
if self.set_election_timeout {
self.set_election_timeout = false;
return Some(Action::SetElectionTimeout);
}
if self.save_current_term {
self.save_current_term = false;
return Some(Action::SaveCurrentTerm);
}
if self.save_voted_for {
self.save_voted_for = false;
return Some(Action::SaveVotedFor);
}
if let Some(broadcast_message) = self.broadcast_message.take() {
return Some(Action::BroadcastMessage(broadcast_message));
}
if let Some(log_entries) = self.append_log_entries.take() {
return Some(Action::AppendLogEntries(log_entries));
}
if let Some((node_id, message)) = self.send_messages.pop_first() {
return Some(Action::SendMessage(node_id, message));
}
if let Some(node_id) = self.install_snapshots.pop_first() {
return Some(Action::InstallSnapshot(node_id));
}
None
}
}
#[cfg(test)]
mod tests {
use crate::{LogEntry, LogIndex, LogPosition, MessageSeqNo, Term};
use super::*;
#[test]
fn actions_set() {
let mut actions = Actions::default();
assert_eq!(actions.next(), None);
actions.set(Action::SetElectionTimeout);
actions.set(Action::SetElectionTimeout);
assert_eq!(actions.next(), Some(Action::SetElectionTimeout));
assert_eq!(actions.next(), None);
actions.set(Action::SaveCurrentTerm);
actions.set(Action::SaveCurrentTerm);
assert_eq!(actions.next(), Some(Action::SaveCurrentTerm));
assert_eq!(actions.next(), None);
actions.set(Action::SaveVotedFor);
actions.set(Action::SaveVotedFor);
assert_eq!(actions.next(), Some(Action::SaveVotedFor));
assert_eq!(actions.next(), None);
actions.set(Action::BroadcastMessage(Message::request_vote_call(
Term::new(2),
NodeId::new(3),
MessageSeqNo::new(10),
pos(2, 8),
)));
actions.set(Action::BroadcastMessage(Message::append_entries_call(
Term::new(2),
NodeId::new(3),
LogIndex::new(10),
MessageSeqNo::new(30),
LogEntries::new(pos(2, 10)),
)));
assert!(matches!(
actions.next(),
Some(Action::BroadcastMessage(Message::AppendEntriesCall { .. }))
));
assert_eq!(actions.next(), None);
actions.set(Action::AppendLogEntries(LogEntries::from_iter(
pos(2, 3),
std::iter::once(LogEntry::Command),
)));
actions.set(Action::AppendLogEntries(LogEntries::from_iter(
pos(2, 4),
std::iter::once(LogEntry::Command),
)));
assert_eq!(
actions.next(),
Some(Action::AppendLogEntries(LogEntries::from_iter(
pos(2, 3),
[LogEntry::Command, LogEntry::Command].into_iter()
)))
);
assert_eq!(actions.next(), None);
actions.set(Action::SendMessage(
NodeId::new(4),
Message::request_vote_call(
Term::new(2),
NodeId::new(3),
MessageSeqNo::new(3),
pos(2, 8),
),
));
actions.set(Action::SendMessage(
NodeId::new(2),
Message::append_entries_call(
Term::new(2),
NodeId::new(3),
LogIndex::new(10),
MessageSeqNo::new(30),
LogEntries::new(pos(2, 10)),
),
));
assert!(matches!(
actions.next(),
Some(Action::SendMessage(_, Message::AppendEntriesCall { .. }))
));
assert!(matches!(
actions.next(),
Some(Action::SendMessage(_, Message::RequestVoteCall { .. }))
));
assert_eq!(actions.next(), None);
actions.set(Action::InstallSnapshot(NodeId::new(3)));
actions.set(Action::InstallSnapshot(NodeId::new(2)));
actions.set(Action::InstallSnapshot(NodeId::new(3)));
assert_eq!(
actions.next(),
Some(Action::InstallSnapshot(NodeId::new(2)))
);
assert_eq!(
actions.next(),
Some(Action::InstallSnapshot(NodeId::new(3)))
);
assert_eq!(actions.next(), None);
}
fn pos(term: u64, index: u64) -> LogPosition {
let term = Term::new(term);
let index = LogIndex::new(index);
LogPosition { term, index }
}
}