use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use parking_lot::RwLock;
use tokio::sync::mpsc;
use tokio::time::{interval, timeout};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use uuid::Uuid;
use crate::client::connector::OverlayAwareConnector;
use crate::overlay::DynOverlayResolver;
use crate::{Message, Result, ServiceConfig, ServiceProtocol, TunnelClientConfig, TunnelError};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum AgentState {
#[default]
Disconnected,
Connecting,
Connected {
tunnel_id: Uuid,
},
Reconnecting {
attempt: u32,
},
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum ServiceStatus {
#[default]
Pending,
Registered,
Failed(String),
}
#[derive(Debug, Clone)]
pub struct RegisteredService {
pub config: ServiceConfig,
pub service_id: Option<Uuid>,
pub status: ServiceStatus,
}
impl RegisteredService {
#[must_use]
pub fn new(config: ServiceConfig) -> Self {
Self {
config,
service_id: None,
status: ServiceStatus::Pending,
}
}
#[must_use]
pub fn is_registered(&self) -> bool {
matches!(self.status, ServiceStatus::Registered)
}
}
#[derive(Debug, Clone)]
pub enum ControlEvent {
Authenticated {
tunnel_id: Uuid,
},
ServiceRegistered {
name: String,
service_id: Uuid,
},
ServiceFailed {
name: String,
reason: String,
},
IncomingConnection {
service_id: Uuid,
connection_id: Uuid,
client_addr: String,
},
Heartbeat {
timestamp: u64,
},
Disconnected {
reason: String,
},
Error {
message: String,
},
}
#[derive(Debug, Clone)]
pub enum ControlCommand {
Register {
name: String,
protocol: ServiceProtocol,
local_port: u16,
remote_port: u16,
},
Unregister {
service_id: Uuid,
},
ConnectAck {
connection_id: Uuid,
},
ConnectFail {
connection_id: Uuid,
reason: String,
},
Disconnect,
}
pub type ConnectionCallback = Arc<dyn Fn(Uuid, Uuid, String) -> bool + Send + Sync>;
pub struct TunnelAgent {
config: TunnelClientConfig,
state: Arc<RwLock<AgentState>>,
services: Arc<RwLock<HashMap<String, RegisteredService>>>,
connection_callback: Option<ConnectionCallback>,
command_tx: Option<mpsc::Sender<ControlCommand>>,
event_tx: Option<mpsc::Sender<ControlEvent>>,
overlay_resolver: Option<DynOverlayResolver>,
}
impl TunnelAgent {
#[must_use]
pub fn new(config: TunnelClientConfig) -> Self {
let services: HashMap<String, RegisteredService> = config
.services
.iter()
.map(|s| (s.name.clone(), RegisteredService::new(s.clone())))
.collect();
Self {
config,
state: Arc::new(RwLock::new(AgentState::Disconnected)),
services: Arc::new(RwLock::new(services)),
connection_callback: None,
command_tx: None,
event_tx: None,
overlay_resolver: None,
}
}
#[must_use]
pub fn on_connection(mut self, callback: ConnectionCallback) -> Self {
self.connection_callback = Some(callback);
self
}
#[must_use]
pub fn with_event_channel(mut self, tx: mpsc::Sender<ControlEvent>) -> Self {
self.event_tx = Some(tx);
self
}
#[must_use]
pub fn with_overlay_resolver(mut self, resolver: DynOverlayResolver) -> Self {
self.overlay_resolver = Some(resolver);
self
}
#[must_use]
pub fn state(&self) -> AgentState {
self.state.read().clone()
}
#[must_use]
pub fn get_service(&self, name: &str) -> Option<RegisteredService> {
self.services.read().get(name).cloned()
}
#[must_use]
pub fn services(&self) -> Vec<RegisteredService> {
self.services.read().values().cloned().collect()
}
#[must_use]
pub fn is_connected(&self) -> bool {
matches!(*self.state.read(), AgentState::Connected { .. })
}
#[must_use]
pub fn tunnel_id(&self) -> Option<Uuid> {
match *self.state.read() {
AgentState::Connected { tunnel_id } => Some(tunnel_id),
_ => None,
}
}
pub async fn send_command(&self, command: ControlCommand) -> Result<()> {
let tx = self
.command_tx
.as_ref()
.ok_or_else(|| TunnelError::connection_msg("agent not running"))?;
tx.send(command)
.await
.map_err(|_| TunnelError::connection_msg("command channel closed"))
}
pub async fn run(&self) -> Result<()> {
self.config.validate().map_err(TunnelError::config)?;
let mut current_interval = self.config.reconnect_interval;
let mut attempt = 0u32;
loop {
attempt += 1;
*self.state.write() = AgentState::Reconnecting { attempt };
tracing::info!(
attempt = attempt,
interval_ms = current_interval.as_millis(),
"attempting to connect"
);
match self.run_once().await {
Ok(()) => {
tracing::info!("agent shutting down");
return Ok(());
}
Err(TunnelError::Shutdown) => {
tracing::info!("agent received shutdown signal");
return Ok(());
}
Err(e) => {
tracing::warn!(error = %e, "connection failed, will retry");
if let Some(ref tx) = self.event_tx {
let _ = tx
.send(ControlEvent::Disconnected {
reason: e.to_string(),
})
.await;
}
}
}
{
let mut services = self.services.write();
for service in services.values_mut() {
service.service_id = None;
service.status = ServiceStatus::Pending;
}
}
tokio::time::sleep(current_interval).await;
current_interval = std::cmp::min(
current_interval.saturating_mul(2),
self.config.max_reconnect_interval,
);
}
}
pub async fn run_once(&self) -> Result<()> {
*self.state.write() = AgentState::Connecting;
tracing::debug!(url = %self.config.server_url, "connecting to server");
let connector = OverlayAwareConnector::new(
&self.config.server_url,
self.config.overlay_server_url.as_deref(),
self.config.routing_mode,
self.overlay_resolver.clone(),
);
let (ws_stream, _response) = connector.connect().await?;
let (mut ws_sink, mut ws_stream) = ws_stream.split();
let client_id = Uuid::new_v4();
let auth_msg = Message::Auth {
token: self.config.token.clone(),
client_id,
};
ws_sink
.send(WsMessage::Binary(auth_msg.encode().into()))
.await
.map_err(TunnelError::connection)?;
let auth_timeout = Duration::from_secs(10);
let auth_response = timeout(auth_timeout, async {
while let Some(msg) = ws_stream.next().await {
match msg {
Ok(WsMessage::Binary(data)) => {
return Message::decode(&data).map(|(m, _)| m);
}
Ok(WsMessage::Close(frame)) => {
let reason = frame.map_or_else(
|| "connection closed".to_string(),
|f| f.reason.to_string(),
);
return Err(TunnelError::connection_msg(reason));
}
Ok(_) => {} Err(e) => return Err(TunnelError::connection(e)),
}
}
Err(TunnelError::connection_msg("connection closed before auth"))
})
.await
.map_err(|_| TunnelError::timeout())??;
let tunnel_id = match auth_response {
Message::AuthOk { tunnel_id } => tunnel_id,
Message::AuthFail { reason } => {
return Err(TunnelError::auth(reason));
}
other => {
return Err(TunnelError::protocol(format!(
"expected AuthOk or AuthFail, got {:?}",
other.message_type()
)));
}
};
*self.state.write() = AgentState::Connected { tunnel_id };
tracing::info!(
tunnel_id = %tunnel_id,
client_id = %client_id,
"authenticated with server"
);
if let Some(ref tx) = self.event_tx {
let _ = tx.send(ControlEvent::Authenticated { tunnel_id }).await;
}
self.register_services(&mut ws_sink).await?;
self.run_message_loop(tunnel_id, &mut ws_sink, &mut ws_stream)
.await
}
async fn register_services<S>(&self, ws_sink: &mut S) -> Result<()>
where
S: SinkExt<WsMessage> + Unpin,
S::Error: std::error::Error,
{
let services: Vec<ServiceConfig> = {
self.services
.read()
.values()
.map(|s| s.config.clone())
.collect()
};
for service in services {
let register_msg = Message::Register {
name: service.name.clone(),
protocol: service.protocol,
local_port: service.local_port,
remote_port: service.remote_port,
};
tracing::debug!(
service_name = %service.name,
local_port = service.local_port,
"registering service"
);
ws_sink
.send(WsMessage::Binary(register_msg.encode().into()))
.await
.map_err(|e| TunnelError::connection_msg(e.to_string()))?;
}
Ok(())
}
async fn run_message_loop<Sink, Stream>(
&self,
tunnel_id: Uuid,
ws_sink: &mut Sink,
ws_stream: &mut Stream,
) -> Result<()>
where
Sink: SinkExt<WsMessage> + Unpin,
Sink::Error: std::error::Error,
Stream: StreamExt<Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
+ Unpin,
{
let (_command_tx, mut command_rx) = mpsc::channel::<ControlCommand>(256);
let mut pending_services: Vec<String> = { self.services.read().keys().cloned().collect() };
let mut check_interval = interval(Duration::from_secs(5));
loop {
tokio::select! {
_ = check_interval.tick() => {
}
Some(command) = command_rx.recv() => {
match command {
ControlCommand::Register { name, protocol, local_port, remote_port } => {
let msg = Message::Register {
name: name.clone(),
protocol,
local_port,
remote_port,
};
ws_sink
.send(WsMessage::Binary(msg.encode().into()))
.await
.map_err(|e| TunnelError::connection_msg(e.to_string()))?;
pending_services.push(name);
}
ControlCommand::Unregister { service_id } => {
let msg = Message::Unregister { service_id };
ws_sink
.send(WsMessage::Binary(msg.encode().into()))
.await
.map_err(|e| TunnelError::connection_msg(e.to_string()))?;
}
ControlCommand::ConnectAck { connection_id } => {
let msg = Message::ConnectAck { connection_id };
ws_sink
.send(WsMessage::Binary(msg.encode().into()))
.await
.map_err(|e| TunnelError::connection_msg(e.to_string()))?;
}
ControlCommand::ConnectFail { connection_id, reason } => {
let msg = Message::ConnectFail { connection_id, reason };
ws_sink
.send(WsMessage::Binary(msg.encode().into()))
.await
.map_err(|e| TunnelError::connection_msg(e.to_string()))?;
}
ControlCommand::Disconnect => {
tracing::info!("disconnect command received");
return Ok(());
}
}
}
Some(msg_result) = ws_stream.next() => {
match msg_result {
Ok(WsMessage::Binary(data)) => {
let (msg, _) = Message::decode(&data)?;
self.handle_server_message(
tunnel_id,
msg,
ws_sink,
&mut pending_services,
).await?;
}
Ok(WsMessage::Close(frame)) => {
let reason = frame.map_or_else(
|| "server closed connection".to_string(),
|f| f.reason.to_string(),
);
tracing::info!(reason = %reason, "server closed connection");
return Err(TunnelError::connection_msg(reason));
}
Ok(WsMessage::Ping(data)) => {
ws_sink
.send(WsMessage::Pong(data))
.await
.map_err(|e| TunnelError::connection_msg(e.to_string()))?;
}
Ok(_) => {} Err(e) => {
return Err(TunnelError::connection(e));
}
}
}
else => {
break;
}
}
}
Ok(())
}
async fn handle_server_message<S>(
&self,
tunnel_id: Uuid,
msg: Message,
ws_sink: &mut S,
pending_services: &mut Vec<String>,
) -> Result<()>
where
S: SinkExt<WsMessage> + Unpin,
S::Error: std::error::Error,
{
match msg {
Message::RegisterOk { service_id } => {
self.handle_register_ok(service_id, pending_services).await;
}
Message::RegisterFail { reason } => {
self.handle_register_fail(reason, pending_services).await;
}
Message::Connect {
service_id,
connection_id,
client_addr,
} => {
self.handle_connect(service_id, connection_id, client_addr, ws_sink)
.await?;
}
Message::Heartbeat { timestamp } => {
self.handle_heartbeat(timestamp, ws_sink).await?;
}
Message::Disconnect { reason } => {
return self.handle_disconnect(reason).await;
}
Message::Auth { .. }
| Message::AuthOk { .. }
| Message::AuthFail { .. }
| Message::Register { .. }
| Message::Unregister { .. }
| Message::ConnectAck { .. }
| Message::ConnectFail { .. }
| Message::HeartbeatAck { .. } => {
tracing::warn!(
tunnel_id = %tunnel_id,
msg_type = ?msg.message_type(),
"unexpected message from server"
);
}
}
Ok(())
}
async fn handle_register_ok(&self, service_id: Uuid, pending_services: &mut Vec<String>) {
let name = match pending_services.first().cloned() {
Some(n) => {
pending_services.remove(0);
n
}
None => return,
};
{
let mut services = self.services.write();
if let Some(service) = services.get_mut(&name) {
service.service_id = Some(service_id);
service.status = ServiceStatus::Registered;
}
}
tracing::info!(
service_name = %name,
service_id = %service_id,
"service registered"
);
if let Some(ref tx) = self.event_tx {
let _ = tx
.send(ControlEvent::ServiceRegistered { name, service_id })
.await;
}
}
async fn handle_register_fail(&self, reason: String, pending_services: &mut Vec<String>) {
let name = match pending_services.first().cloned() {
Some(n) => {
pending_services.remove(0);
n
}
None => return,
};
{
let mut services = self.services.write();
if let Some(service) = services.get_mut(&name) {
service.status = ServiceStatus::Failed(reason.clone());
}
}
tracing::warn!(
service_name = %name,
reason = %reason,
"service registration failed"
);
if let Some(ref tx) = self.event_tx {
let _ = tx.send(ControlEvent::ServiceFailed { name, reason }).await;
}
}
async fn handle_connect<S>(
&self,
service_id: Uuid,
connection_id: Uuid,
client_addr: String,
ws_sink: &mut S,
) -> Result<()>
where
S: SinkExt<WsMessage> + Unpin,
S::Error: std::error::Error,
{
tracing::debug!(
service_id = %service_id,
connection_id = %connection_id,
client_addr = %client_addr,
"incoming connection"
);
if let Some(ref tx) = self.event_tx {
let _ = tx
.send(ControlEvent::IncomingConnection {
service_id,
connection_id,
client_addr: client_addr.clone(),
})
.await;
}
let accepted = self
.connection_callback
.as_ref()
.is_none_or(|cb| cb(service_id, connection_id, client_addr.clone()));
let response = if accepted {
Message::ConnectAck { connection_id }
} else {
Message::ConnectFail {
connection_id,
reason: "connection rejected by client".to_string(),
}
};
ws_sink
.send(WsMessage::Binary(response.encode().into()))
.await
.map_err(|e| TunnelError::connection_msg(e.to_string()))?;
Ok(())
}
async fn handle_heartbeat<S>(&self, timestamp: u64, ws_sink: &mut S) -> Result<()>
where
S: SinkExt<WsMessage> + Unpin,
S::Error: std::error::Error,
{
tracing::trace!(timestamp = timestamp, "heartbeat received");
let ack = Message::HeartbeatAck { timestamp };
ws_sink
.send(WsMessage::Binary(ack.encode().into()))
.await
.map_err(|e| TunnelError::connection_msg(e.to_string()))?;
if let Some(ref tx) = self.event_tx {
let _ = tx.send(ControlEvent::Heartbeat { timestamp }).await;
}
Ok(())
}
async fn handle_disconnect(&self, reason: String) -> Result<()> {
tracing::info!(reason = %reason, "server requested disconnect");
if let Some(ref tx) = self.event_tx {
let _ = tx
.send(ControlEvent::Disconnected {
reason: reason.clone(),
})
.await;
}
Err(TunnelError::connection_msg(reason))
}
pub fn disconnect(&self) {
*self.state.write() = AgentState::Disconnected;
if let Some(ref tx) = self.command_tx {
let _ = tx.try_send(ControlCommand::Disconnect);
}
}
}
impl Clone for TunnelAgent {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
state: Arc::clone(&self.state),
services: Arc::clone(&self.services),
connection_callback: self.connection_callback.clone(),
command_tx: self.command_tx.clone(),
event_tx: self.event_tx.clone(),
overlay_resolver: self.overlay_resolver.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config() -> TunnelClientConfig {
TunnelClientConfig::new("ws://localhost:8080/tunnel/v1", "test-token")
.with_service(ServiceConfig::tcp("ssh", 22).with_remote_port(2222))
.with_service(ServiceConfig::udp("game", 27015))
}
#[test]
fn test_agent_state_default() {
let state = AgentState::default();
assert_eq!(state, AgentState::Disconnected);
}
#[test]
fn test_agent_state_variants() {
let disconnected = AgentState::Disconnected;
let connecting = AgentState::Connecting;
let connected = AgentState::Connected {
tunnel_id: Uuid::new_v4(),
};
let reconnecting = AgentState::Reconnecting { attempt: 3 };
assert_ne!(disconnected, connecting);
assert_ne!(connecting, connected);
assert_ne!(connected, reconnecting);
}
#[test]
fn test_service_status_default() {
let status = ServiceStatus::default();
assert_eq!(status, ServiceStatus::Pending);
}
#[test]
fn test_service_status_variants() {
assert_eq!(ServiceStatus::Pending, ServiceStatus::Pending);
assert_eq!(ServiceStatus::Registered, ServiceStatus::Registered);
assert_eq!(
ServiceStatus::Failed("error".to_string()),
ServiceStatus::Failed("error".to_string())
);
assert_ne!(
ServiceStatus::Failed("error1".to_string()),
ServiceStatus::Failed("error2".to_string())
);
}
#[test]
fn test_registered_service_new() {
let config = ServiceConfig::tcp("ssh", 22);
let service = RegisteredService::new(config.clone());
assert_eq!(service.config.name, "ssh");
assert!(service.service_id.is_none());
assert_eq!(service.status, ServiceStatus::Pending);
assert!(!service.is_registered());
}
#[test]
fn test_registered_service_is_registered() {
let config = ServiceConfig::tcp("ssh", 22);
let mut service = RegisteredService::new(config);
assert!(!service.is_registered());
service.status = ServiceStatus::Registered;
assert!(service.is_registered());
service.status = ServiceStatus::Failed("error".to_string());
assert!(!service.is_registered());
}
#[test]
fn test_tunnel_agent_new() {
let config = create_test_config();
let agent = TunnelAgent::new(config);
assert_eq!(agent.state(), AgentState::Disconnected);
assert!(!agent.is_connected());
assert!(agent.tunnel_id().is_none());
let services = agent.services();
assert_eq!(services.len(), 2);
}
#[test]
fn test_tunnel_agent_get_service() {
let config = create_test_config();
let agent = TunnelAgent::new(config);
let ssh = agent.get_service("ssh");
assert!(ssh.is_some());
assert_eq!(ssh.unwrap().config.local_port, 22);
let game = agent.get_service("game");
assert!(game.is_some());
assert_eq!(game.unwrap().config.protocol, ServiceProtocol::Udp);
let nonexistent = agent.get_service("nonexistent");
assert!(nonexistent.is_none());
}
#[test]
fn test_tunnel_agent_on_connection() {
let config = create_test_config();
let callback_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let callback_called_clone = Arc::clone(&callback_called);
let callback: ConnectionCallback = Arc::new(move |_service_id, _conn_id, _addr| {
callback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
true
});
let agent = TunnelAgent::new(config).on_connection(callback);
assert!(!callback_called.load(std::sync::atomic::Ordering::SeqCst));
assert!(agent.connection_callback.is_some());
}
#[test]
fn test_tunnel_agent_clone() {
let config = create_test_config();
let agent = TunnelAgent::new(config);
let cloned = agent.clone();
assert_eq!(agent.state(), cloned.state());
assert_eq!(agent.services().len(), cloned.services().len());
}
#[test]
fn test_tunnel_agent_disconnect() {
let config = create_test_config();
let agent = TunnelAgent::new(config);
*agent.state.write() = AgentState::Connected {
tunnel_id: Uuid::new_v4(),
};
assert!(agent.is_connected());
agent.disconnect();
assert_eq!(agent.state(), AgentState::Disconnected);
assert!(!agent.is_connected());
}
#[test]
fn test_control_event_variants() {
let _auth = ControlEvent::Authenticated {
tunnel_id: Uuid::new_v4(),
};
let _registered = ControlEvent::ServiceRegistered {
name: "ssh".to_string(),
service_id: Uuid::new_v4(),
};
let _failed = ControlEvent::ServiceFailed {
name: "ssh".to_string(),
reason: "error".to_string(),
};
let _incoming = ControlEvent::IncomingConnection {
service_id: Uuid::new_v4(),
connection_id: Uuid::new_v4(),
client_addr: "127.0.0.1:12345".to_string(),
};
let heartbeat = ControlEvent::Heartbeat { timestamp: 12345 };
assert!(matches!(heartbeat, ControlEvent::Heartbeat { .. }));
let _disconnected = ControlEvent::Disconnected {
reason: "test".to_string(),
};
let _error = ControlEvent::Error {
message: "test error".to_string(),
};
}
#[test]
fn test_control_command_variants() {
let _register = ControlCommand::Register {
name: "ssh".to_string(),
protocol: ServiceProtocol::Tcp,
local_port: 22,
remote_port: 2222,
};
let _unregister = ControlCommand::Unregister {
service_id: Uuid::new_v4(),
};
let _ack = ControlCommand::ConnectAck {
connection_id: Uuid::new_v4(),
};
let _fail = ControlCommand::ConnectFail {
connection_id: Uuid::new_v4(),
reason: "error".to_string(),
};
let disconnect = ControlCommand::Disconnect;
assert!(matches!(disconnect, ControlCommand::Disconnect));
}
#[test]
fn test_tunnel_agent_with_event_channel() {
let config = create_test_config();
let (tx, _rx) = mpsc::channel(16);
let agent = TunnelAgent::new(config).with_event_channel(tx);
assert!(agent.event_tx.is_some());
}
#[tokio::test]
async fn test_send_command_not_running() {
let config = create_test_config();
let agent = TunnelAgent::new(config);
let result = agent.send_command(ControlCommand::Disconnect).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("agent not running"));
}
}