use std::{collections::HashSet, io};
use crate::Pea2Pea;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum Topology {
Line,
Ring,
Mesh,
Star,
Grid {
width: usize,
height: usize,
},
Tree,
Random {
degree: usize,
seed: u64,
},
}
impl Topology {
pub fn num_expected_connections(&self, num_nodes: usize) -> usize {
if num_nodes == 0 {
return 0;
}
match self {
Self::Line => (num_nodes - 1) * 2,
Self::Ring => num_nodes * 2,
Self::Mesh => (num_nodes - 1) * num_nodes,
Self::Star => (num_nodes - 1) * 2,
Self::Grid { width, height } => ((width * height) * 2 - width - height) * 2,
Self::Tree => (num_nodes - 1) * 2,
Self::Random { degree, seed: _ } => num_nodes * degree * 2,
}
}
}
pub async fn connect_nodes<T: Pea2Pea>(nodes: &[T], topology: Topology) -> io::Result<()> {
let count = nodes.len();
if count < 2 {
return Err(io::ErrorKind::InvalidInput.into());
}
match topology {
Topology::Line | Topology::Ring => {
for i in 0..(count - 1) {
let addr = nodes[i + 1].node().listening_addr().await?;
nodes[i].node().connect(addr).await?;
}
if topology == Topology::Ring {
let addr = nodes[0].node().listening_addr().await?;
nodes[count - 1].node().connect(addr).await?;
}
}
Topology::Mesh => {
let mut connected_pairs = HashSet::with_capacity(count * (count - 1));
for i in 0..count {
for (j, peer) in nodes.iter().enumerate() {
if i != j && connected_pairs.insert((i, j)) && connected_pairs.insert((j, i)) {
let addr = peer.node().listening_addr().await?;
nodes[i].node().connect(addr).await?;
}
}
}
}
Topology::Star => {
let hub_addr = nodes[0].node().listening_addr().await?;
for node in nodes.iter().skip(1) {
node.node().connect(hub_addr).await?;
}
}
Topology::Grid { width, height } => {
if width * height != count {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Grid topology dimensions ({width}x{height} = {}) do not match the number of nodes ({count})",
width * height
),
));
}
for row in 0..height {
for col in 0..width {
let i = row * width + col;
if col + 1 < width {
let target = i + 1;
let addr = nodes[target].node().listening_addr().await?;
nodes[i].node().connect(addr).await?;
}
if row + 1 < height {
let target = i + width;
let addr = nodes[target].node().listening_addr().await?;
nodes[i].node().connect(addr).await?;
}
}
}
}
Topology::Tree => {
for i in 0..count {
let left = 2 * i + 1;
if left < count {
let addr = nodes[left].node().listening_addr().await?;
nodes[i].node().connect(addr).await?;
}
let right = 2 * i + 2;
if right < count {
let addr = nodes[right].node().listening_addr().await?;
nodes[i].node().connect(addr).await?;
}
}
}
Topology::Random { degree, seed } => {
if degree >= count {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Random topology degree cannot exceed N-1",
));
}
for i in 0..count {
let mut chosen_targets = HashSet::with_capacity(degree);
let mut attempt = 0u64;
while chosen_targets.len() < degree {
attempt += 1;
let mut x = (i as u64).wrapping_add(seed).wrapping_add(attempt);
x = x.wrapping_mul(0x517cc1b727220a95); x ^= x >> 32;
let target = (x as usize) % count;
if target != i && chosen_targets.insert(target) {
let addr = nodes[target].node().listening_addr().await?;
let _ = nodes[i].node().connect(addr).await;
}
}
}
}
}
Ok(())
}