use std::collections::HashMap;
use std::sync::Arc;
use omnipaxos::messages::Message;
use omnipaxos::storage::Entry;
use parking_lot::Mutex;
use tokio::sync::mpsc;
use crate::test_fakes::partition::PartitionController;
const CHANNEL_CAPACITY: usize = 1024;
pub struct MemNetwork<T: Entry> {
senders: Mutex<HashMap<u64, mpsc::Sender<Message<T>>>>,
partition: Arc<PartitionController>,
}
impl<T: Entry + Send + 'static> MemNetwork<T> {
#[must_use]
pub fn new() -> Self {
Self {
senders: Mutex::new(HashMap::new()),
partition: Arc::new(PartitionController::new()),
}
}
#[must_use]
pub fn partition(&self) -> Arc<PartitionController> {
self.partition.clone()
}
pub fn register(&self, node_id: u64) -> mpsc::Receiver<Message<T>> {
let (sender, receiver) = mpsc::channel(CHANNEL_CAPACITY);
self.senders.lock().insert(node_id, sender);
receiver
}
pub async fn deliver(&self, message: Message<T>) {
let (from, to) = endpoints(&message);
if self.partition.is_blocked(from, to) {
return;
}
let sender = {
let guard = self.senders.lock();
guard.get(&to).cloned()
};
if let Some(sender) = sender {
let _ = sender.try_send(message);
}
}
}
impl<T: Entry + Send + 'static> Default for MemNetwork<T> {
fn default() -> Self {
Self::new()
}
}
fn endpoints<T: Entry>(message: &Message<T>) -> (u64, u64) {
match message {
Message::SequencePaxos(paxos) => (paxos.from, paxos.to),
Message::BLE(ble) => (ble.from, ble.to),
}
}
#[cfg(test)]
mod tests {
use super::*;
use omnipaxos::messages::Message;
use omnipaxos::messages::sequence_paxos::{PaxosMessage, PaxosMsg};
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct Cmd(u64);
impl omnipaxos::storage::Entry for Cmd {
type Snapshot = Snap;
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct Snap;
impl omnipaxos::storage::Snapshot<Cmd> for Snap {
fn create(_: &[Cmd]) -> Self {
Self
}
fn merge(&mut self, _: Self) {}
fn use_snapshots() -> bool {
false
}
}
fn proposal_forward(from: u64, to: u64) -> Message<Cmd> {
Message::SequencePaxos(PaxosMessage {
from,
to,
msg: PaxosMsg::ProposalForward(vec![Cmd(7)]),
})
}
#[tokio::test]
async fn message_routes_to_registered_destination() {
let network: MemNetwork<Cmd> = MemNetwork::new();
let mut inbox = network.register(2);
network.deliver(proposal_forward(1, 2)).await;
let received = inbox.recv().await.expect("recv");
match received {
Message::SequencePaxos(p) => {
assert_eq!(p.from, 1);
assert_eq!(p.to, 2);
}
_ => panic!("unexpected variant"),
}
}
#[tokio::test]
async fn message_to_unregistered_node_is_silently_dropped() {
let network: MemNetwork<Cmd> = MemNetwork::new();
network.deliver(proposal_forward(1, 99)).await;
}
#[tokio::test]
async fn partitioned_endpoint_drops_message() {
let network: MemNetwork<Cmd> = MemNetwork::new();
let mut inbox = network.register(2);
network.partition().isolate(2);
network.deliver(proposal_forward(1, 2)).await;
let result = tokio::time::timeout(std::time::Duration::from_millis(10), inbox.recv()).await;
assert!(result.is_err(), "isolated node must not receive messages");
}
#[tokio::test]
async fn healed_partition_resumes_routing() {
let network: MemNetwork<Cmd> = MemNetwork::new();
let mut inbox = network.register(2);
network.partition().isolate(2);
network.deliver(proposal_forward(1, 2)).await;
network.partition().heal();
network.deliver(proposal_forward(1, 2)).await;
let received = inbox.recv().await.expect("recv");
assert!(matches!(received, Message::SequencePaxos(_)));
}
}