use std::sync::Arc;
use std::time::Duration;
use noxu_sync::Mutex;
use crate::error::Result;
use crate::net::channel::{Channel, LocalChannel, LocalChannelPair};
pub struct InMemoryEndpoint {
inner: Arc<LocalChannel>,
}
impl InMemoryEndpoint {
fn new(inner: LocalChannel) -> Self {
Self { inner: Arc::new(inner) }
}
pub fn channel_handle(&self) -> Arc<dyn Channel> {
Arc::clone(&self.inner) as Arc<dyn Channel>
}
}
impl Channel for InMemoryEndpoint {
fn send(&self, data: &[u8]) -> Result<()> {
self.inner.send(data)
}
fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
self.inner.receive(timeout)
}
fn close(&self) -> Result<()> {
self.inner.close()
}
fn is_open(&self) -> bool {
self.inner.is_open()
}
}
pub struct InMemoryTransport;
impl InMemoryTransport {
pub fn new_pair() -> (InMemoryEndpoint, InMemoryEndpoint) {
let pair = LocalChannelPair::new();
(
InMemoryEndpoint::new(pair.channel_a),
InMemoryEndpoint::new(pair.channel_b),
)
}
pub fn new_group(n: usize) -> InMemoryGroup {
InMemoryGroup::new(n)
}
}
pub struct InMemoryGroup {
n: usize,
endpoints: Vec<Vec<Mutex<Option<InMemoryEndpoint>>>>,
}
impl InMemoryGroup {
fn new(n: usize) -> Self {
assert!(n > 0, "InMemoryGroup requires at least one node");
let endpoints: Vec<Vec<Mutex<Option<InMemoryEndpoint>>>> = (0..n)
.map(|_| (0..n).map(|_| Mutex::new(None)).collect())
.collect();
#[allow(clippy::needless_range_loop)]
for i in 0..n {
for j in (i + 1)..n {
let pair = LocalChannelPair::new();
*endpoints[i][j].lock() =
Some(InMemoryEndpoint::new(pair.channel_a));
*endpoints[j][i].lock() =
Some(InMemoryEndpoint::new(pair.channel_b));
}
}
Self { n, endpoints }
}
pub fn size(&self) -> usize {
self.n
}
pub fn channel(&self, from: usize, to: usize) -> Arc<dyn Channel> {
assert!(from < self.n, "from index {from} out of range (n={})", self.n);
assert!(to < self.n, "to index {to} out of range (n={})", self.n);
assert!(from != to, "in-memory mesh has no self-loop channel");
let slot = self.endpoints[from][to].lock();
slot.as_ref()
.unwrap_or_else(|| {
panic!(
"in-memory channel {from}→{to} is closed; \
call reconnect({from}) before reuse"
)
})
.channel_handle()
}
pub fn try_channel(
&self,
from: usize,
to: usize,
) -> Option<Arc<dyn Channel>> {
assert!(from < self.n, "from index {from} out of range (n={})", self.n);
assert!(to < self.n, "to index {to} out of range (n={})", self.n);
assert!(from != to, "in-memory mesh has no self-loop channel");
let slot = self.endpoints[from][to].lock();
slot.as_ref().map(|e| e.channel_handle())
}
pub fn simulate_crash(&self, node: usize) {
assert!(node < self.n, "node index {node} out of range (n={})", self.n);
for peer in 0..self.n {
if peer == node {
continue;
}
let mut out = self.endpoints[node][peer].lock();
if let Some(ep) = out.take() {
let _ = ep.inner.close();
}
drop(out);
let mut inn = self.endpoints[peer][node].lock();
if let Some(ep) = inn.take() {
let _ = ep.inner.close();
}
}
}
pub fn reconnect(&self, node: usize) {
assert!(node < self.n, "node index {node} out of range (n={})", self.n);
for peer in 0..self.n {
if peer == node {
continue;
}
let (lo, hi) =
if node < peer { (node, peer) } else { (peer, node) };
let mut a = self.endpoints[lo][hi].lock();
let mut b = self.endpoints[hi][lo].lock();
if a.is_some() || b.is_some() {
continue;
}
let pair = LocalChannelPair::new();
*a = Some(InMemoryEndpoint::new(pair.channel_a));
*b = Some(InMemoryEndpoint::new(pair.channel_b));
}
}
pub fn is_node_live(&self, node: usize) -> bool {
if node >= self.n {
return false;
}
for peer in 0..self.n {
if peer == node {
continue;
}
if self.endpoints[node][peer].lock().is_none() {
return false;
}
if self.endpoints[peer][node].lock().is_none() {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pair_round_trip() {
let (a, b) = InMemoryTransport::new_pair();
a.send(b"hello").unwrap();
let got = b.receive(Duration::from_millis(50)).unwrap();
assert_eq!(got, Some(b"hello".to_vec()));
b.send(b"world").unwrap();
let got = a.receive(Duration::from_millis(50)).unwrap();
assert_eq!(got, Some(b"world".to_vec()));
}
#[test]
fn group_3node_mesh_is_fully_connected() {
let group = InMemoryTransport::new_group(3);
assert_eq!(group.size(), 3);
for i in 0..3 {
for j in 0..3 {
if i == j {
continue;
}
let _ = group.channel(i, j);
}
}
group.channel(0, 1).send(b"01").unwrap();
let got =
group.channel(1, 0).receive(Duration::from_millis(50)).unwrap();
assert_eq!(got, Some(b"01".to_vec()));
}
#[test]
fn group_independent_pairs_do_not_cross_talk() {
let group = InMemoryTransport::new_group(4);
group.channel(0, 1).send(b"to-1").unwrap();
group.channel(0, 2).send(b"to-2").unwrap();
let g10 =
group.channel(1, 0).receive(Duration::from_millis(50)).unwrap();
let g20 =
group.channel(2, 0).receive(Duration::from_millis(50)).unwrap();
let g30 =
group.channel(3, 0).receive(Duration::from_millis(50)).unwrap();
assert_eq!(g10, Some(b"to-1".to_vec()));
assert_eq!(g20, Some(b"to-2".to_vec()));
assert_eq!(g30, None, "node 3 must not see node 1's traffic");
}
#[test]
fn simulate_crash_closes_all_channels_for_node() {
let group = InMemoryTransport::new_group(3);
let zero_to_one = group.channel(0, 1);
let one_to_zero = group.channel(1, 0);
group.simulate_crash(0);
assert!(zero_to_one.send(b"after-crash").is_err());
let r = one_to_zero.receive(Duration::from_millis(20));
assert!(r.is_err(), "post-crash receive must surface error");
assert!(group.try_channel(0, 1).is_none());
assert!(group.try_channel(1, 0).is_none());
assert!(group.try_channel(1, 2).is_some());
group.channel(1, 2).send(b"alive").unwrap();
let got =
group.channel(2, 1).receive(Duration::from_millis(50)).unwrap();
assert_eq!(got, Some(b"alive".to_vec()));
}
#[test]
fn simulate_crash_is_idempotent() {
let group = InMemoryTransport::new_group(3);
group.simulate_crash(2);
group.simulate_crash(2);
assert!(!group.is_node_live(2));
assert!(group.try_channel(0, 1).is_some());
assert!(group.try_channel(1, 0).is_some());
group.channel(0, 1).send(b"alive").unwrap();
let got =
group.channel(1, 0).receive(Duration::from_millis(50)).unwrap();
assert_eq!(got, Some(b"alive".to_vec()));
}
#[test]
fn reconnect_after_crash_restores_traffic() {
let group = InMemoryTransport::new_group(3);
group.simulate_crash(0);
assert!(!group.is_node_live(0));
group.reconnect(0);
assert!(group.is_node_live(0));
group.channel(0, 1).send(b"reborn").unwrap();
let got =
group.channel(1, 0).receive(Duration::from_millis(50)).unwrap();
assert_eq!(got, Some(b"reborn".to_vec()));
}
#[test]
#[should_panic(expected = "out of range")]
fn channel_out_of_range_panics() {
let group = InMemoryTransport::new_group(2);
let _ = group.channel(5, 0);
}
#[test]
#[should_panic(expected = "no self-loop")]
fn channel_self_loop_panics() {
let group = InMemoryTransport::new_group(2);
let _ = group.channel(0, 0);
}
#[test]
#[should_panic(expected = "at least one node")]
fn empty_group_panics() {
let _ = InMemoryTransport::new_group(0);
}
#[test]
fn one_node_group_has_no_channels() {
let group = InMemoryTransport::new_group(1);
assert_eq!(group.size(), 1);
assert!(group.is_node_live(0));
}
#[test]
fn channel_handle_outlives_borrow_of_group() {
let handle: Arc<dyn Channel> = {
let group = InMemoryTransport::new_group(2);
group.channel(0, 1)
};
let _ = handle.is_open();
}
}