use dashmap::DashMap;
use siphon_protocol::{ServerMessage, TunnelType};
use std::sync::Arc;
use tokio::sync::mpsc;
pub struct TunnelHandle {
pub sender: mpsc::Sender<ServerMessage>,
#[allow(dead_code)]
pub client_id: String,
#[allow(dead_code)]
pub tunnel_type: TunnelType,
pub dns_record_id: Option<String>,
}
pub struct Router {
routes: DashMap<String, TunnelHandle>,
tcp_ports: DashMap<u16, String>,
}
impl Router {
pub fn new() -> Arc<Self> {
Arc::new(Self {
routes: DashMap::new(),
tcp_ports: DashMap::new(),
})
}
pub fn register(
&self,
subdomain: String,
handle: TunnelHandle,
tcp_port: Option<u16>,
) -> Result<(), RouterError> {
if self.routes.contains_key(&subdomain) {
return Err(RouterError::SubdomainTaken(subdomain));
}
if let Some(port) = tcp_port {
self.tcp_ports.insert(port, subdomain.clone());
}
self.routes.insert(subdomain, handle);
Ok(())
}
pub fn unregister(&self, subdomain: &str) -> Option<TunnelHandle> {
if let Some((_, handle)) = self.routes.remove(subdomain) {
self.tcp_ports.retain(|_, v| v != subdomain);
Some(handle)
} else {
None
}
}
pub fn get_sender(&self, subdomain: &str) -> Option<mpsc::Sender<ServerMessage>> {
self.routes.get(subdomain).map(|h| h.sender.clone())
}
#[allow(dead_code)]
pub fn get_subdomain_for_port(&self, port: u16) -> Option<String> {
self.tcp_ports.get(&port).map(|s| s.clone())
}
pub fn is_available(&self, subdomain: &str) -> bool {
!self.routes.contains_key(subdomain)
}
#[allow(dead_code)]
pub fn list_subdomains(&self) -> Vec<String> {
self.routes.iter().map(|r| r.key().clone()).collect()
}
}
impl Default for Router {
fn default() -> Self {
Self {
routes: DashMap::new(),
tcp_ports: DashMap::new(),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum RouterError {
#[error("Subdomain already taken: {0}")]
SubdomainTaken(String),
}