use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use uuid::Uuid;
use crate::{Result, ServiceConfig, TunnelError};
pub struct LocalProxy {
services: DashMap<Uuid, ServiceConfig>,
connections: Arc<DashMap<Uuid, ConnectionHandle>>,
connect_timeout: Duration,
}
struct ConnectionHandle {
abort_handle: tokio::task::AbortHandle,
}
impl LocalProxy {
#[must_use]
pub fn new(connect_timeout: Duration) -> Self {
Self {
services: DashMap::new(),
connections: Arc::new(DashMap::new()),
connect_timeout,
}
}
pub fn register_service(&self, service_id: Uuid, config: ServiceConfig) {
self.services.insert(service_id, config);
}
pub fn unregister_service(&self, service_id: Uuid) {
self.services.remove(&service_id);
}
pub async fn handle_connection(
&self,
service_id: Uuid,
connection_id: Uuid,
) -> Result<TcpStream> {
let config = self
.services
.get(&service_id)
.ok_or_else(|| TunnelError::registry(format!("Unknown service: {service_id}")))?
.clone();
let addr = format!("127.0.0.1:{}", config.local_port);
let local_stream = tokio::time::timeout(self.connect_timeout, TcpStream::connect(&addr))
.await
.map_err(|_| TunnelError::timeout())?
.map_err(|e| TunnelError::Connection { source: e })?;
tracing::debug!(
service_id = %service_id,
connection_id = %connection_id,
local_addr = %addr,
"Connected to local service"
);
Ok(local_stream)
}
pub async fn proxy_streams(local: TcpStream, remote: TcpStream) -> Result<(u64, u64)> {
let (mut local_read, mut local_write) = local.into_split();
let (mut remote_read, mut remote_write) = remote.into_split();
let local_to_remote = async {
let mut buf = vec![0u8; 8192];
let mut total = 0u64;
loop {
match local_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
if remote_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
}
}
}
let _ = remote_write.shutdown().await;
total
};
let remote_to_local = async {
let mut buf = vec![0u8; 8192];
let mut total = 0u64;
loop {
match remote_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
if local_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
}
}
}
let _ = local_write.shutdown().await;
total
};
let (sent, received) = tokio::join!(local_to_remote, remote_to_local);
Ok((sent, received))
}
pub fn start_proxy(
&self,
connection_id: Uuid,
local: TcpStream,
remote: TcpStream,
) -> tokio::task::JoinHandle<Result<(u64, u64)>> {
let connections = Arc::clone(&self.connections);
let handle = tokio::spawn(async move {
let result = Self::proxy_streams(local, remote).await;
connections.remove(&connection_id);
result
});
self.connections.insert(
connection_id,
ConnectionHandle {
abort_handle: handle.abort_handle(),
},
);
handle
}
pub fn cancel_connection(&self, connection_id: Uuid) {
if let Some((_, handle)) = self.connections.remove(&connection_id) {
handle.abort_handle.abort();
}
}
#[must_use]
pub fn connection_count(&self) -> usize {
self.connections.len()
}
#[must_use]
pub fn service_count(&self) -> usize {
self.services.len()
}
pub fn shutdown(&self) {
for item in self.connections.iter() {
item.abort_handle.abort();
}
self.connections.clear();
}
}
impl Drop for LocalProxy {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_and_unregister_services() {
let proxy = LocalProxy::new(Duration::from_secs(5));
let service_id1 = Uuid::new_v4();
let service_id2 = Uuid::new_v4();
let config1 = ServiceConfig::tcp("ssh", 22);
let config2 = ServiceConfig::tcp("postgres", 5432);
proxy.register_service(service_id1, config1);
proxy.register_service(service_id2, config2);
assert_eq!(proxy.service_count(), 2);
assert!(proxy.services.contains_key(&service_id1));
assert!(proxy.services.contains_key(&service_id2));
proxy.unregister_service(service_id1);
assert_eq!(proxy.service_count(), 1);
assert!(!proxy.services.contains_key(&service_id1));
assert!(proxy.services.contains_key(&service_id2));
proxy.unregister_service(service_id2);
assert_eq!(proxy.service_count(), 0);
}
#[test]
fn test_service_and_connection_counts() {
let proxy = LocalProxy::new(Duration::from_secs(5));
assert_eq!(proxy.service_count(), 0);
assert_eq!(proxy.connection_count(), 0);
let service_id = Uuid::new_v4();
proxy.register_service(service_id, ServiceConfig::tcp("test", 8080));
assert_eq!(proxy.service_count(), 1);
assert_eq!(proxy.connection_count(), 0);
proxy.unregister_service(service_id);
assert_eq!(proxy.service_count(), 0);
}
#[test]
fn test_cancel_connection() {
let proxy = LocalProxy::new(Duration::from_secs(5));
let connection_id = Uuid::new_v4();
proxy.cancel_connection(connection_id);
assert_eq!(proxy.connection_count(), 0);
}
#[test]
fn test_shutdown_clears_all() {
let proxy = LocalProxy::new(Duration::from_secs(5));
for _ in 0..5 {
proxy.register_service(Uuid::new_v4(), ServiceConfig::tcp("test", 8080));
}
assert_eq!(proxy.service_count(), 5);
proxy.shutdown();
assert_eq!(proxy.connection_count(), 0);
assert_eq!(proxy.service_count(), 5);
}
#[tokio::test]
async fn test_handle_connection_unknown_service() {
let proxy = LocalProxy::new(Duration::from_secs(5));
let service_id = Uuid::new_v4();
let connection_id = Uuid::new_v4();
let result = proxy.handle_connection(service_id, connection_id).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, TunnelError::Registry { .. }));
}
#[tokio::test]
async fn test_handle_connection_timeout() {
let proxy = LocalProxy::new(Duration::from_millis(100));
let service_id = Uuid::new_v4();
let connection_id = Uuid::new_v4();
proxy.register_service(service_id, ServiceConfig::tcp("test", 65432));
let result = proxy.handle_connection(service_id, connection_id).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(
err,
TunnelError::Timeout | TunnelError::Connection { .. }
));
}
#[test]
fn test_unregister_nonexistent_service() {
let proxy = LocalProxy::new(Duration::from_secs(5));
let service_id = Uuid::new_v4();
proxy.unregister_service(service_id);
assert_eq!(proxy.service_count(), 0);
}
#[test]
fn test_register_overwrites_existing() {
let proxy = LocalProxy::new(Duration::from_secs(5));
let service_id = Uuid::new_v4();
proxy.register_service(service_id, ServiceConfig::tcp("first", 8080));
assert_eq!(proxy.service_count(), 1);
proxy.register_service(service_id, ServiceConfig::tcp("second", 9090));
assert_eq!(proxy.service_count(), 1);
let config = proxy.services.get(&service_id).unwrap();
assert_eq!(config.name, "second");
assert_eq!(config.local_port, 9090);
}
}