use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use super::registry::StreamRegistry;
pub struct TcpStreamService {
registry: Arc<StreamRegistry>,
listen_port: u16,
}
impl TcpStreamService {
#[must_use]
pub fn new(registry: Arc<StreamRegistry>, listen_port: u16) -> Self {
Self {
registry,
listen_port,
}
}
#[must_use]
pub fn port(&self) -> u16 {
self.listen_port
}
#[must_use]
pub fn registry(&self) -> &Arc<StreamRegistry> {
&self.registry
}
pub async fn serve(self: Arc<Self>, listener: TcpListener) {
tracing::info!(port = self.listen_port, "TCP stream proxy listening");
loop {
let (client_stream, client_addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
tracing::warn!(
port = self.listen_port,
error = %e,
"TCP accept error, retrying"
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
continue;
}
};
let svc = Arc::clone(&self);
tokio::spawn(async move {
svc.handle_raw_connection(client_stream, client_addr).await;
});
}
}
async fn handle_raw_connection(
&self,
client_stream: tokio::net::TcpStream,
client_addr: SocketAddr,
) {
let Some(service) = self.registry.resolve_tcp(self.listen_port) else {
tracing::warn!(
port = self.listen_port,
client = %client_addr,
"No service registered for TCP port"
);
return;
};
let Some(backend) = service.select_backend() else {
tracing::warn!(
port = self.listen_port,
service = %service.name,
client = %client_addr,
"No backends available for TCP service"
);
return;
};
tracing::debug!(
port = self.listen_port,
service = %service.name,
client = %client_addr,
backend = %backend,
"Proxying TCP connection"
);
let upstream = match tokio::net::TcpStream::connect(backend).await {
Ok(stream) => stream,
Err(e) => {
tracing::warn!(
error = %e,
backend = %backend,
service = %service.name,
client = %client_addr,
"Failed to connect to TCP backend"
);
return;
}
};
Self::duplex_raw(client_stream, upstream).await;
}
async fn duplex_raw(
mut downstream: tokio::net::TcpStream,
mut upstream: tokio::net::TcpStream,
) {
match tokio::io::copy_bidirectional(&mut downstream, &mut upstream).await {
Ok((down_to_up, up_to_down)) => {
tracing::debug!(
down_to_up = down_to_up,
up_to_down = up_to_down,
"TCP tunnel closed"
);
}
Err(e) => {
tracing::debug!(error = %e, "TCP tunnel error");
}
}
}
}