use std::{
collections::BTreeMap,
net::SocketAddr,
ops::Deref,
sync::{Arc, RwLock},
};
use scion_proto::address::EndhostAddr;
use serde::{Deserialize, Serialize};
use snap_tun::server::SnapTunToken;
pub mod dto;
#[derive(Debug, Default, Clone)]
pub struct SharedTunnelGatewayState<T: SnapTunToken>(
Arc<RwLock<BTreeMap<EndhostAddr, Arc<snap_tun::server::Sender<T>>>>>,
);
impl<T: SnapTunToken> SharedTunnelGatewayState<T> {
pub fn new() -> Self {
Self(Arc::new(RwLock::new(BTreeMap::new())))
}
}
impl<T: SnapTunToken> Deref for SharedTunnelGatewayState<T> {
type Target = Arc<RwLock<BTreeMap<EndhostAddr, Arc<snap_tun::server::Sender<T>>>>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T: SnapTunToken> SharedTunnelGatewayState<T> {
pub(crate) fn add_tunnel_mapping(
&self,
addr: EndhostAddr,
tunnel: Arc<snap_tun::server::Sender<T>>,
) {
let mut tunnels = self.write().expect("no fail");
match tunnels.insert(addr, tunnel) {
Some(e) => {
tracing::warn!(%addr, existing_remote=%e.remote_underlay_address(), "Overwriting existing snaptun connection mapping");
}
None => {
tracing::debug!(%addr, "Adding snaptun connection mapping");
}
}
}
pub(crate) fn remove_tunnel_mapping_if_same(
&self,
addr: EndhostAddr,
should_contain: &Arc<snap_tun::server::Sender<T>>,
) {
let mut tunnels = self.write().expect("no fail");
match tunnels.entry(addr) {
std::collections::btree_map::Entry::Vacant(_) => {
tracing::debug!(%addr, "No snaptun connection mapping found to remove");
}
std::collections::btree_map::Entry::Occupied(occupied_entry) => {
let tunnel = occupied_entry.get();
if Arc::ptr_eq(tunnel, should_contain) {
tracing::debug!(%addr, "Removing snaptun connection mapping");
occupied_entry.remove();
} else {
tracing::warn!(%addr, "Not removing snaptun connection mapping, is mapped to a different tunnel");
}
}
}
}
pub(crate) fn get_mapped_tunnel(
&self,
addr: EndhostAddr,
) -> Option<Arc<snap_tun::server::Sender<T>>> {
let tunnels = self.read().expect("no fail");
tunnels.get(&addr).cloned()
}
}
#[derive(Debug, PartialEq, Clone, Default, Serialize, Deserialize)]
pub struct TunnelGatewayIoConfig {
pub listen_addr: Option<SocketAddr>,
}
impl TunnelGatewayIoConfig {
pub fn new(listen_addr: SocketAddr) -> Self {
Self {
listen_addr: Some(listen_addr),
}
}
}