use id::PublicId;
use itertools::Itertools;
use message_filter::MessageFilter;
use std::collections::{BTreeSet, HashMap};
use std::collections::hash_map::Entry;
use std::time::Duration;
const MAX_TUNNEL_CLIENT_PAIRS: usize = 40;
pub struct Tunnels {
tunnels: HashMap<PublicId, PublicId>,
new_clients: MessageFilter<(PublicId, PublicId)>,
clients: BTreeSet<(PublicId, PublicId)>,
}
impl Tunnels {
pub fn has_clients(&self, src_id: PublicId, dst_id: PublicId) -> bool {
if src_id < dst_id {
self.clients.contains(&(src_id, dst_id))
} else {
self.clients.contains(&(dst_id, src_id))
}
}
pub fn consider_clients(
&mut self,
src_id: PublicId,
dst_id: PublicId,
) -> Option<(PublicId, PublicId)> {
if self.clients.len() >= MAX_TUNNEL_CLIENT_PAIRS || self.tunnels.contains_key(&src_id) ||
self.tunnels.contains_key(&dst_id)
{
return None;
}
let (id0, id1) = if src_id < dst_id {
(src_id, dst_id)
} else {
(dst_id, src_id)
};
let _ = self.new_clients.insert(&(id0, id1));
Some((id0, id1))
}
pub fn accept_clients(&mut self, src_id: PublicId, dst_id: PublicId) -> bool {
let pair = (src_id, dst_id);
if self.new_clients.contains(&pair) {
self.new_clients.remove(&pair);
let _ = self.clients.insert(pair);
true
} else {
false
}
}
pub fn drop_client(&mut self, pub_id: &PublicId) -> Vec<PublicId> {
let pairs = self.clients
.iter()
.filter(|pair| pair.0 == *pub_id || pair.1 == *pub_id)
.cloned()
.collect_vec();
pairs
.into_iter()
.map(|pair| {
let _ = self.clients.remove(&pair);
if pair.0 == *pub_id { pair.1 } else { pair.0 }
})
.collect()
}
pub fn drop_client_pair(&mut self, src_id: PublicId, dst_id: PublicId) -> bool {
let (id0, id1) = if src_id < dst_id {
(src_id, dst_id)
} else {
(dst_id, src_id)
};
self.clients.remove(&(id0, id1))
}
pub fn add(&mut self, dst_id: PublicId, tunnel_id: PublicId) -> bool {
match self.tunnels.entry(dst_id) {
Entry::Occupied(_) => false,
Entry::Vacant(entry) => {
let _ = entry.insert(tunnel_id);
true
}
}
}
pub fn remove(&mut self, dst_id: PublicId, tunnel_id: PublicId) -> bool {
if let Entry::Occupied(entry) = self.tunnels.entry(dst_id) {
if entry.get() == &tunnel_id {
let _ = entry.remove();
return true;
}
}
false
}
pub fn remove_tunnel_for(&mut self, dst_id: &PublicId) -> Option<PublicId> {
self.tunnels.remove(dst_id)
}
pub fn is_tunnel_node(&self, tunnel_id: &PublicId) -> bool {
self.tunnels.values().any(|id| id == tunnel_id)
}
pub fn remove_tunnel(&mut self, tunnel_id: &PublicId) -> Vec<PublicId> {
let dst_ids = self.tunnels
.iter()
.filter(|&(_, id)| id == tunnel_id)
.map(|(&dst_id, _)| dst_id)
.collect_vec();
for dst_id in &dst_ids {
let _ = self.tunnels.remove(dst_id);
}
dst_ids
}
pub fn tunnel_for(&self, dst_id: &PublicId) -> Option<&PublicId> {
self.tunnels.get(dst_id)
}
pub fn client_count(&self) -> usize {
self.clients.len()
}
pub fn tunnel_count(&self) -> usize {
self.tunnels.len()
}
}
impl Default for Tunnels {
fn default() -> Tunnels {
Tunnels {
tunnels: HashMap::new(),
new_clients: MessageFilter::with_expiry_duration(Duration::from_secs(60)),
clients: BTreeSet::new(),
}
}
}
#[cfg(all(test, feature = "use-mock-crust"))]
mod tests {
use super::*;
use id::FullId;
use itertools::Itertools;
#[test]
fn tunnel_nodes_test() {
let our_id = *FullId::new().public_id();
let their_id = *FullId::new().public_id();
let mut tunnels: Tunnels = Default::default();
assert_eq!(None, tunnels.tunnel_for(&our_id));
let _ = tunnels.add(our_id, their_id);
assert_eq!(Some(&their_id), tunnels.tunnel_for(&our_id));
assert_eq!(None, tunnels.tunnel_for(&their_id));
let _ = tunnels.remove(our_id, their_id);
assert_eq!(None, tunnels.tunnel_for(&our_id));
}
#[test]
fn remove_tunnel_test() {
let mut sorted_ids = vec![];
for _ in 0..5 {
sorted_ids.push(*FullId::new().public_id());
}
sorted_ids.sort();
let mut tunnels: Tunnels = Default::default();
let _ = tunnels.add(sorted_ids[1], sorted_ids[0]);
let _ = tunnels.add(sorted_ids[2], sorted_ids[0]);
let _ = tunnels.add(sorted_ids[3], sorted_ids[4]);
let removed_peers = tunnels.remove_tunnel(&sorted_ids[0]).into_iter().sorted();
assert_eq!(&[sorted_ids[1], sorted_ids[2]], &*removed_peers);
assert_eq!(None, tunnels.tunnel_for(&sorted_ids[1]));
assert_eq!(None, tunnels.tunnel_for(&sorted_ids[2]));
assert_eq!(Some(&sorted_ids[4]), tunnels.tunnel_for(&sorted_ids[3]));
}
#[test]
fn clients_test() {
let mut sorted_ids = vec![];
for _ in 0..6 {
sorted_ids.push(*FullId::new().public_id());
}
sorted_ids.sort();
let mut tunnels: Tunnels = Default::default();
let _ = tunnels.add(sorted_ids[0], sorted_ids[1]);
assert!(!tunnels.accept_clients(sorted_ids[1], sorted_ids[2]));
assert!(!tunnels.accept_clients(sorted_ids[3], sorted_ids[4]));
assert_eq!(None, tunnels.consider_clients(sorted_ids[5], sorted_ids[0]));
assert_eq!(
Some((sorted_ids[1], sorted_ids[2])),
tunnels.consider_clients(sorted_ids[1], sorted_ids[2])
);
assert_eq!(
Some((sorted_ids[3], sorted_ids[4])),
tunnels.consider_clients(sorted_ids[4], sorted_ids[3])
);
assert!(tunnels.accept_clients(sorted_ids[1], sorted_ids[2]));
assert!(tunnels.accept_clients(sorted_ids[3], sorted_ids[4]));
assert!(tunnels.has_clients(sorted_ids[2], sorted_ids[1]));
assert!(tunnels.has_clients(sorted_ids[3], sorted_ids[4]));
}
}