use std::{
collections::BTreeMap,
net::{IpAddr, SocketAddr},
sync::Arc,
};
use tiny_keccak::{Hasher, Sha3};
use tokio::sync::RwLock;
use xor_name::XorName;
#[derive(Clone)]
pub(crate) struct ConnectionPool<I: ConnId> {
store: Arc<RwLock<Store<I>>>,
}
impl<I: ConnId> ConnectionPool<I> {
pub(crate) fn new() -> Self {
Self {
store: Arc::new(RwLock::new(Store::default())),
}
}
pub(crate) async fn insert(
&self,
id: I,
addr: SocketAddr,
conn: quinn::Connection,
) -> ConnectionRemover<I> {
let mut store = self.store.write().await;
let key = Key {
addr,
id: store.id_gen.next(),
};
let _ = store.id_map.insert(id, (conn.clone(), key));
let _ = store.key_map.insert(key, (conn, id));
ConnectionRemover {
store: self.store.clone(),
id,
key,
}
}
#[allow(unused)]
pub(crate) async fn has_id(&self, id: &I) -> bool {
let store = self.store.read().await;
store.id_map.contains_key(id)
}
pub(crate) async fn remove(&self, addr: &SocketAddr) -> Vec<quinn::Connection> {
let mut store = self.store.write().await;
let keys_to_remove = store
.key_map
.range_mut(Key::min(*addr)..=Key::max(*addr))
.into_iter()
.map(|(key, _)| *key)
.collect::<Vec<_>>();
keys_to_remove
.iter()
.filter_map(|key| store.key_map.remove(key).map(|entry| entry.0))
.collect::<Vec<_>>()
}
pub(crate) async fn get_by_id(
&self,
addr: &I,
) -> Option<(quinn::Connection, ConnectionRemover<I>)> {
let store = self.store.read().await;
let (conn, key) = store.id_map.get(addr)?;
let remover = ConnectionRemover {
store: self.store.clone(),
key: *key,
id: *addr,
};
Some((conn.clone(), remover))
}
pub(crate) async fn get_by_addr(
&self,
addr: &SocketAddr,
) -> Option<(quinn::Connection, ConnectionRemover<I>)> {
let store = self.store.read().await;
let (key, entry) = store
.key_map
.range(Key::min(*addr)..=Key::max(*addr))
.next()?;
let conn = entry.clone().0;
let remover = ConnectionRemover {
store: self.store.clone(),
key: *key,
id: entry.1,
};
Some((conn, remover))
}
}
#[derive(Clone)]
pub(crate) struct ConnectionRemover<I: ConnId> {
store: Arc<RwLock<Store<I>>>,
key: Key,
id: I,
}
impl<I: ConnId> ConnectionRemover<I> {
pub(crate) async fn remove(&self) {
let mut store = self.store.write().await;
let _ = store.key_map.remove(&self.key);
let _ = store.id_map.remove(&self.id);
}
pub(crate) fn remote_addr(&self) -> &SocketAddr {
&self.key.addr
}
pub(crate) fn id(&self) -> I {
self.id
}
}
#[derive(Default)]
struct Store<I: ConnId> {
id_map: BTreeMap<I, (quinn::Connection, Key)>,
key_map: BTreeMap<Key, (quinn::Connection, I)>,
id_gen: IdGen,
}
pub trait ConnId:
Clone + Copy + Eq + PartialEq + Ord + PartialOrd + Default + Send + Sync + 'static
{
fn generate(socket_addr: &SocketAddr) -> Self;
}
impl ConnId for XorName {
fn generate(addr: &SocketAddr) -> Self {
let mut hasher = Sha3::v256();
let mut output = [0u8; 32];
match addr.ip() {
IpAddr::V4(addr) => hasher.update(&addr.octets()),
IpAddr::V6(addr) => hasher.update(&addr.octets()),
}
hasher.update(&addr.port().to_be_bytes());
hasher.finalize(&mut output);
XorName(output)
}
}
#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
struct Key {
addr: SocketAddr,
id: u64,
}
impl Key {
fn min(addr: SocketAddr) -> Self {
Self { addr, id: u64::MIN }
}
fn max(addr: SocketAddr) -> Self {
Self { addr, id: u64::MAX }
}
}
#[derive(Default)]
struct IdGen(u64);
impl IdGen {
fn next(&mut self) -> u64 {
let id = self.0;
self.0 = self.0.wrapping_add(1);
id
}
}