use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use sha2::{Digest, Sha256};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::time::{interval, timeout};
use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage, WebSocketStream};
use uuid::Uuid;
use crate::overlay::DynTunnelDnsRegistrar;
use crate::{
ControlMessage, Message, Result, ServiceProtocol, TunnelError, TunnelRegistry,
TunnelServerConfig,
};
pub type TokenValidator = Arc<dyn Fn(&str) -> Result<()> + Send + Sync>;
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn current_timestamp_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0)
.min(u128::from(u64::MAX)) as u64
}
pub struct ControlHandler {
registry: Arc<TunnelRegistry>,
config: TunnelServerConfig,
token_validator: TokenValidator,
dns_registrar: Option<DynTunnelDnsRegistrar>,
local_overlay_ip: Option<Ipv4Addr>,
}
impl ControlHandler {
#[must_use]
pub fn new(
registry: Arc<TunnelRegistry>,
config: TunnelServerConfig,
token_validator: TokenValidator,
) -> Self {
Self {
registry,
config,
token_validator,
dns_registrar: None,
local_overlay_ip: None,
}
}
#[must_use]
pub fn with_dns_registrar(mut self, registrar: DynTunnelDnsRegistrar) -> Self {
self.dns_registrar = Some(registrar);
self
}
#[must_use]
pub fn with_local_overlay_ip(mut self, ip: Ipv4Addr) -> Self {
self.local_overlay_ip = Some(ip);
self
}
pub async fn handle_connection(
&self,
stream: TcpStream,
client_addr: SocketAddr,
) -> Result<()> {
let ws_stream = accept_async(stream)
.await
.map_err(TunnelError::connection)?;
let (mut ws_sink, mut ws_stream) = ws_stream.split();
let auth_timeout = Duration::from_secs(10);
let auth_msg = timeout(auth_timeout, async {
while let Some(msg) = ws_stream.next().await {
match msg {
Ok(WsMessage::Binary(data)) => {
return Message::decode(&data).map(|(m, _)| m);
}
Ok(WsMessage::Close(_)) => {
return Err(TunnelError::connection_msg("Client closed connection"));
}
Ok(_) => {} Err(e) => return Err(TunnelError::connection(e)),
}
}
Err(TunnelError::connection_msg("Connection closed before auth"))
})
.await
.map_err(|_| TunnelError::timeout())??;
let Message::Auth {
token,
client_id: _,
} = auth_msg
else {
let fail = Message::AuthFail {
reason: "Expected AUTH message".to_string(),
};
let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
return Err(TunnelError::auth("Expected AUTH message"));
};
if let Err(e) = (self.token_validator)(&token) {
let fail = Message::AuthFail {
reason: e.to_string(),
};
let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
return Err(e);
}
let token_hash = hash_token(&token);
if self.registry.token_exists(&token_hash) {
let fail = Message::AuthFail {
reason: "Token already in use".to_string(),
};
let _ = ws_sink.send(WsMessage::Binary(fail.encode().into())).await;
return Err(TunnelError::auth("Token already in use"));
}
let (control_tx, mut control_rx) = mpsc::channel::<ControlMessage>(256);
let tunnel = self.registry.register_tunnel(
token_hash.clone(),
None, control_tx,
Some(client_addr),
)?;
let tunnel_id = tunnel.id;
let auth_ok = Message::AuthOk { tunnel_id };
ws_sink
.send(WsMessage::Binary(auth_ok.encode().into()))
.await
.map_err(TunnelError::connection)?;
tracing::info!(
tunnel_id = %tunnel_id,
client_addr = %client_addr,
"Tunnel authenticated"
);
let result = self
.run_message_loop(tunnel_id, &mut ws_sink, &mut ws_stream, &mut control_rx)
.await;
self.registry.unregister_tunnel(tunnel_id);
tracing::info!(tunnel_id = %tunnel_id, "Tunnel disconnected");
result
}
async fn run_message_loop(
&self,
tunnel_id: Uuid,
ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
ws_stream: &mut futures_util::stream::SplitStream<WebSocketStream<TcpStream>>,
control_rx: &mut mpsc::Receiver<ControlMessage>,
) -> Result<()> {
let mut heartbeat_interval = interval(self.config.heartbeat_interval);
let heartbeat_timeout = self.config.heartbeat_timeout;
let mut last_heartbeat_ack = std::time::Instant::now();
loop {
tokio::select! {
_ = heartbeat_interval.tick() => {
if last_heartbeat_ack.elapsed() > heartbeat_timeout {
tracing::warn!(tunnel_id = %tunnel_id, "Heartbeat timeout");
return Err(TunnelError::timeout());
}
let timestamp = current_timestamp_ms();
let hb = Message::Heartbeat { timestamp };
ws_sink
.send(WsMessage::Binary(hb.encode().into()))
.await
.map_err(TunnelError::connection)?;
}
Some(ctrl_msg) = control_rx.recv() => {
let msg = match ctrl_msg {
ControlMessage::Connect {
service_id,
connection_id,
client_addr,
} => Message::Connect {
service_id,
connection_id,
client_addr: client_addr.to_string(),
},
ControlMessage::Heartbeat { timestamp } => {
Message::Heartbeat { timestamp }
}
ControlMessage::Disconnect { reason } => {
let _ = ws_sink
.send(WsMessage::Binary(
Message::Disconnect { reason }.encode().into(),
))
.await;
return Ok(());
}
};
ws_sink
.send(WsMessage::Binary(msg.encode().into()))
.await
.map_err(TunnelError::connection)?;
}
Some(msg_result) = ws_stream.next() => {
match msg_result {
Ok(WsMessage::Binary(data)) => {
let (msg, _) = Message::decode(&data)?;
if matches!(msg, Message::HeartbeatAck { .. }) {
last_heartbeat_ack = std::time::Instant::now();
}
self.handle_client_message(tunnel_id, msg, ws_sink).await?;
self.registry.touch_tunnel(tunnel_id);
}
Ok(WsMessage::Close(_)) => {
return Ok(());
}
Ok(WsMessage::Ping(data)) => {
ws_sink
.send(WsMessage::Pong(data))
.await
.map_err(TunnelError::connection)?;
}
Ok(_) => {} Err(e) => {
return Err(TunnelError::connection(e));
}
}
}
else => break,
}
}
Ok(())
}
#[allow(clippy::too_many_lines)]
async fn handle_client_message(
&self,
tunnel_id: Uuid,
msg: Message,
ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
) -> Result<()> {
match msg {
Message::Register {
name,
protocol,
local_port,
remote_port,
} => {
self.handle_register(tunnel_id, &name, protocol, local_port, remote_port, ws_sink)
.await?;
}
Message::Unregister { service_id } => {
if let Err(e) = self.registry.remove_service(tunnel_id, service_id) {
tracing::warn!(
tunnel_id = %tunnel_id,
service_id = %service_id,
error = %e,
"Service unregistration failed"
);
} else {
tracing::info!(
tunnel_id = %tunnel_id,
service_id = %service_id,
"Service unregistered"
);
}
}
Message::ConnectAck { connection_id } => {
tracing::debug!(
tunnel_id = %tunnel_id,
connection_id = %connection_id,
"Connection acknowledged"
);
}
Message::ConnectFail {
connection_id,
reason,
} => {
tracing::warn!(
tunnel_id = %tunnel_id,
connection_id = %connection_id,
reason = %reason,
"Connection failed"
);
}
Message::HeartbeatAck { timestamp } => {
let now = current_timestamp_ms();
let latency_ms = now.saturating_sub(timestamp);
tracing::trace!(
tunnel_id = %tunnel_id,
latency_ms = latency_ms,
"Heartbeat ack received"
);
}
Message::Auth { .. }
| Message::AuthOk { .. }
| Message::AuthFail { .. }
| Message::RegisterOk { .. }
| Message::RegisterFail { .. }
| Message::Connect { .. }
| Message::Heartbeat { .. }
| Message::Disconnect { .. } => {
tracing::warn!(
tunnel_id = %tunnel_id,
msg_type = ?msg.message_type(),
"Unexpected message from client"
);
}
}
Ok(())
}
async fn handle_register(
&self,
tunnel_id: Uuid,
name: &str,
protocol: ServiceProtocol,
local_port: u16,
remote_port: u16,
ws_sink: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, WsMessage>,
) -> Result<()> {
let result = self
.registry
.add_service(tunnel_id, name, protocol, local_port, remote_port);
let response = match result {
Ok(service) => {
let assigned_port = service.assigned_port.unwrap_or(remote_port);
tracing::info!(
tunnel_id = %tunnel_id,
service_name = %name,
local_port = local_port,
remote_port = assigned_port,
"Service registered"
);
if let (Some(ref registrar), Some(overlay_ip)) =
(&self.dns_registrar, self.local_overlay_ip)
{
let dns_name = format!("tun-{name}");
if let Err(e) = registrar
.register_service(&dns_name, overlay_ip, assigned_port)
.await
{
tracing::warn!(
service_name = %name,
dns_name = %dns_name,
error = %e,
"Failed to register service in overlay DNS"
);
} else {
tracing::debug!(
dns_name = %dns_name,
overlay_ip = %overlay_ip,
port = assigned_port,
"Registered service in overlay DNS"
);
}
}
Message::RegisterOk {
service_id: service.id,
}
}
Err(e) => {
tracing::warn!(
tunnel_id = %tunnel_id,
service_name = %name,
error = %e,
"Service registration failed"
);
Message::RegisterFail {
reason: e.to_string(),
}
}
};
ws_sink
.send(WsMessage::Binary(response.encode().into()))
.await
.map_err(TunnelError::connection)?;
Ok(())
}
}
#[must_use]
pub fn hash_token(token: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
hex::encode(hasher.finalize())
}
pub fn accept_all_tokens(token: &str) -> Result<()> {
if token.is_empty() {
return Err(TunnelError::auth("Token cannot be empty"));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_token_consistent() {
let token = "my-secret-token";
let hash1 = hash_token(token);
let hash2 = hash_token(token);
assert_eq!(hash1, hash2);
assert_eq!(hash1.len(), 64); }
#[test]
fn test_hash_token_different_tokens() {
let hash1 = hash_token("token1");
let hash2 = hash_token("token2");
assert_ne!(hash1, hash2);
}
#[test]
fn test_hash_token_empty() {
let hash = hash_token("");
assert_eq!(hash.len(), 64);
}
#[test]
fn test_hash_token_known_value() {
let hash = hash_token("test");
assert_eq!(
hash,
"9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
);
}
#[test]
fn test_accept_all_tokens_valid() {
assert!(accept_all_tokens("valid-token").is_ok());
assert!(accept_all_tokens("a").is_ok());
assert!(accept_all_tokens("very-long-token-with-many-characters").is_ok());
}
#[test]
fn test_accept_all_tokens_empty() {
let result = accept_all_tokens("");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_control_handler_creation() {
let registry = Arc::new(TunnelRegistry::default());
let config = TunnelServerConfig::default();
let validator = Arc::new(accept_all_tokens);
let handler = ControlHandler::new(registry.clone(), config, validator);
assert!(Arc::strong_count(&handler.registry) >= 1);
}
#[test]
fn test_hash_token_unicode() {
let hash = hash_token("token-with-unicode-\u{1F600}");
assert_eq!(hash.len(), 64);
}
#[test]
fn test_hash_token_special_chars() {
let hash = hash_token("token!@#$%^&*()");
assert_eq!(hash.len(), 64);
}
}