#[cfg(not(feature = "use-mock-crust"))]
use crust::PeerId;
use itertools::Itertools;
use message_filter::MessageFilter;
#[cfg(feature = "use-mock-crust")]
use mock_crust::crust::PeerId;
use std::collections::{HashMap, HashSet};
use std::collections::hash_map::Entry;
use std::time::Duration;
const MAX_TUNNEL_CLIENT_PAIRS: usize = 40;
pub struct Tunnels {
tunnels: HashMap<PeerId, PeerId>,
new_clients: MessageFilter<(PeerId, PeerId)>,
clients: HashSet<(PeerId, PeerId)>,
}
impl Tunnels {
pub fn has_clients(&self, src_id: PeerId, dst_id: PeerId) -> bool {
if src_id < dst_id {
self.clients.contains(&(src_id, dst_id))
} else {
self.clients.contains(&(dst_id, src_id))
}
}
pub fn consider_clients(&mut self, src_id: PeerId, dst_id: PeerId) -> Option<(PeerId, PeerId)> {
if self.clients.len() >= MAX_TUNNEL_CLIENT_PAIRS || self.tunnels.contains_key(&src_id) ||
self.tunnels.contains_key(&dst_id) {
return None;
}
let (id0, id1) = if src_id < dst_id {
(src_id, dst_id)
} else {
(dst_id, src_id)
};
let _ = self.new_clients.insert(&(id0, id1));
Some((id0, id1))
}
pub fn accept_clients(&mut self, src_id: PeerId, dst_id: PeerId) -> bool {
let pair = (src_id, dst_id);
if self.new_clients.contains(&pair) {
self.new_clients.remove(&pair);
self.clients.insert(pair);
true
} else {
false
}
}
pub fn drop_client(&mut self, peer_id: &PeerId) -> Vec<PeerId> {
let pairs = self.clients
.iter()
.filter(|pair| pair.0 == *peer_id || pair.1 == *peer_id)
.cloned()
.collect_vec();
pairs.into_iter()
.map(|pair| {
self.clients.remove(&pair);
if pair.0 == *peer_id { pair.1 } else { pair.0 }
})
.collect()
}
pub fn drop_client_pair(&mut self, src_id: PeerId, dst_id: PeerId) -> bool {
let (id0, id1) = if src_id < dst_id {
(src_id, dst_id)
} else {
(dst_id, src_id)
};
self.clients.remove(&(id0, id1))
}
pub fn add(&mut self, dst_id: PeerId, tunnel_id: PeerId) -> bool {
match self.tunnels.entry(dst_id) {
Entry::Occupied(_) => false,
Entry::Vacant(entry) => {
let _ = entry.insert(tunnel_id);
true
}
}
}
pub fn remove(&mut self, dst_id: PeerId, tunnel_id: PeerId) -> bool {
if let Entry::Occupied(entry) = self.tunnels.entry(dst_id) {
if entry.get() == &tunnel_id {
let _ = entry.remove();
return true;
}
}
false
}
pub fn remove_tunnel_for(&mut self, dst_id: &PeerId) -> Option<PeerId> {
self.tunnels.remove(dst_id)
}
pub fn is_tunnel_node(&self, tunnel_id: &PeerId) -> bool {
self.tunnels.values().any(|id| id == tunnel_id)
}
pub fn remove_tunnel(&mut self, tunnel_id: &PeerId) -> Vec<PeerId> {
let dst_ids = self.tunnels
.iter()
.filter(|&(_, id)| id == tunnel_id)
.map(|(&dst_id, _)| dst_id)
.collect_vec();
for dst_id in &dst_ids {
let _ = self.tunnels.remove(dst_id);
}
dst_ids
}
pub fn tunnel_for(&self, dst_id: &PeerId) -> Option<&PeerId> {
self.tunnels.get(dst_id)
}
pub fn client_count(&self) -> usize {
self.clients.len()
}
pub fn tunnel_count(&self) -> usize {
self.tunnels.len()
}
#[cfg(feature = "use-mock-crust")]
pub fn clear_new_clients(&mut self) {
self.new_clients.clear();
}
}
impl Default for Tunnels {
fn default() -> Tunnels {
Tunnels {
tunnels: HashMap::new(),
new_clients: MessageFilter::with_expiry_duration(Duration::from_secs(60)),
clients: HashSet::new(),
}
}
}
#[cfg(all(test, feature = "use-mock-crust"))]
mod tests {
use super::*;
use itertools::Itertools;
use mock_crust::crust::PeerId;
fn id(i: usize) -> PeerId {
PeerId(i)
}
#[test]
fn tunnel_nodes_test() {
let mut tunnels: Tunnels = Default::default();
assert_eq!(None, tunnels.tunnel_for(&id(0)));
tunnels.add(id(0), id(1));
assert_eq!(Some(&id(1)), tunnels.tunnel_for(&id(0)));
assert_eq!(None, tunnels.tunnel_for(&id(1)));
tunnels.remove(id(0), id(1));
assert_eq!(None, tunnels.tunnel_for(&id(0)));
}
#[test]
fn remove_tunnel_test() {
let mut tunnels: Tunnels = Default::default();
tunnels.add(id(1), id(0));
tunnels.add(id(2), id(0));
tunnels.add(id(3), id(4));
let removed_peers = tunnels.remove_tunnel(&id(0)).into_iter().sorted();
assert_eq!(&[id(1), id(2)], &*removed_peers);
assert_eq!(None, tunnels.tunnel_for(&id(1)));
assert_eq!(None, tunnels.tunnel_for(&id(2)));
assert_eq!(Some(&id(4)), tunnels.tunnel_for(&id(3)));
}
#[test]
fn clients_test() {
let mut tunnels: Tunnels = Default::default();
tunnels.add(id(0), id(1));
assert!(!tunnels.accept_clients(id(1), id(2)));
assert!(!tunnels.accept_clients(id(3), id(4)));
assert_eq!(None, tunnels.consider_clients(id(5), id(0)));
assert_eq!(Some((id(1), id(2))), tunnels.consider_clients(id(1), id(2)));
assert_eq!(Some((id(3), id(4))), tunnels.consider_clients(id(4), id(3)));
assert!(tunnels.accept_clients(id(1), id(2)));
assert!(tunnels.accept_clients(id(3), id(4)));
assert!(tunnels.has_clients(id(2), id(1)));
assert!(tunnels.has_clients(id(3), id(4)));
}
}