use crate::config::TargetConfig;
use crate::connection_manager::ConnectionManager;
use crate::logging::log_resource_cleanup;
use anyhow::{Context, Result};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Semaphore;
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
pub struct ProxyServer {
listen_host: String,
listen_port: u16,
target_name: String,
target_config: TargetConfig,
connection_manager: Arc<ConnectionManager>,
connection_semaphore: Arc<Semaphore>,
connection_management_config: crate::config::ConnectionManagementConfig,
}
impl ProxyServer {
pub fn new(
listen_host: String,
listen_port: u16,
target_name: String,
target_config: TargetConfig,
connection_management_config: crate::config::ConnectionManagementConfig,
) -> Self {
let connection_manager = Arc::new(ConnectionManager::new());
Self {
listen_host,
listen_port,
target_name,
target_config,
connection_manager,
connection_semaphore: Arc::new(Semaphore::new(1000)),
connection_management_config,
}
}
pub async fn run(&self) -> Result<()> {
info!(
target_name = %self.target_name,
listen_host = %self.listen_host,
listen_port = self.listen_port,
"Initializing proxy server"
);
self.connection_manager
.initialize_target(self.target_name.clone(), self.target_config.clone())
.await
.with_context(|| format!("Failed to initialize target '{}'", self.target_name))?;
self.connection_manager
.start_health_checking(
vec![self.target_name.clone()],
&self.connection_management_config,
)
.await;
let bind_addr = format!("{}:{}", self.listen_host, self.listen_port);
let listener = TcpListener::bind(&bind_addr)
.await
.with_context(|| format!("Failed to bind to address '{}'", bind_addr))?;
info!(
bind_addr = %bind_addr,
target_name = %self.target_name,
max_connections = self.connection_semaphore.available_permits(),
"Proxy server listening and ready to accept connections"
);
if let Some(ssh) = &self.target_config.ssh {
if ssh.enabled {
info!(
ssh_host = %ssh.host.as_ref().unwrap_or(&"<not configured>".to_string()),
"SSH tunnel enabled"
);
}
}
let mut connection_counter = 0u64;
loop {
match listener.accept().await {
Ok((client_stream, client_addr)) => {
connection_counter += 1;
info!(
"Connection accepted from {} for target {} (connection #{}, {} permits available)",
client_addr,
self.target_name,
connection_counter,
self.connection_semaphore.available_permits()
);
let target_config = self.target_config.clone();
let target_name = self.target_name.clone();
let connection_manager = Arc::clone(&self.connection_manager);
let semaphore = Arc::clone(&self.connection_semaphore);
tokio::spawn(async move {
let connection_start = Instant::now();
let permit = match semaphore.acquire().await {
Ok(permit) => {
debug!(
client_addr = %client_addr,
remaining_permits = semaphore.available_permits(),
"Connection permit acquired"
);
permit
}
Err(e) => {
error!(
"Failed to acquire connection permit for {}: {}",
client_addr, e
);
return;
}
};
let result = handle_connection_with_health_check(
client_stream,
target_config,
target_name.clone(),
connection_manager,
client_addr,
)
.await;
let connection_duration = connection_start.elapsed();
match result {
Ok(()) => {
debug!(
"Connection from {} completed successfully in {}ms",
client_addr,
connection_duration.as_millis()
);
}
Err(e) => {
error!(
"Connection from {} failed after {}ms: {}",
client_addr,
connection_duration.as_millis(),
e
);
}
}
drop(permit);
log_resource_cleanup("connection_permit", &client_addr.to_string(), true);
});
}
Err(e) => {
error!(
error = %e,
bind_addr = %bind_addr,
"Failed to accept connection, continuing to listen"
);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}
}
async fn handle_connection_with_health_check(
client_stream: TcpStream,
target_config: TargetConfig,
target_name: String,
connection_manager: Arc<ConnectionManager>,
client_addr: std::net::SocketAddr,
) -> Result<()> {
debug!(
"Handling connection from {} to target {}",
client_addr, target_name
);
if !connection_manager.is_target_healthy(&target_name).await {
warn!(
"Rejecting connection from {} - target {} is unhealthy",
client_addr, target_name
);
return Err(anyhow::anyhow!("Target '{}' is unhealthy", target_name));
}
let target_stream = timeout(
Duration::from_secs(30),
create_target_connection(&target_config, client_addr),
)
.await
.with_context(|| format!("Timeout connecting to target for client {}", client_addr))?
.with_context(|| format!("Failed to connect to target for client {}", client_addr))?;
info!(
"Successfully connected client {} to target {}",
client_addr, target_name
);
let (client_read, client_write) = client_stream.into_split();
let (target_read, target_write) = target_stream.into_split();
let client_to_target = forward_data(client_read, target_write, "client->target", client_addr);
let target_to_client = forward_data(target_read, client_write, "target->client", client_addr);
tokio::select! {
result = client_to_target => {
match result {
Ok(()) => {
debug!(
"Client->target forwarding completed for {}",
client_addr
);
}
Err(e) => {
debug!(
"Client->target forwarding ended for {}: {}",
client_addr, e
);
}
}
}
result = target_to_client => {
match result {
Ok(()) => {
debug!(
"Target->client forwarding completed for {}",
client_addr
);
}
Err(e) => {
debug!(
"Target->client forwarding ended for {}: {}",
client_addr, e
);
}
}
}
};
debug!(
"Connection forwarding completed for {} to target {}",
client_addr, target_name
);
Ok(())
}
async fn create_target_connection(
target_config: &TargetConfig,
client_addr: std::net::SocketAddr,
) -> Result<TcpStream> {
if let Some(ssh) = &target_config.ssh {
if ssh.enabled {
warn!(
"SSH tunnel requested for {} but using direct connection for now",
client_addr
);
}
}
let target_addr = format!("{}:{}", target_config.host, target_config.port);
let stream = TcpStream::connect(&target_addr).await.map_err(|e| {
error!(
"Failed to connect to target {} for client {}: {}",
target_addr, client_addr, e
);
e
})?;
debug!(
"Connected to target {} for client {}",
target_addr, client_addr
);
Ok(stream)
}
async fn forward_data<R, W>(
mut reader: R,
mut writer: W,
direction: &str,
client_addr: std::net::SocketAddr,
) -> Result<()>
where
R: AsyncReadExt + Unpin,
W: AsyncWriteExt + Unpin,
{
let mut buffer = vec![0u8; 8192];
loop {
match reader.read(&mut buffer).await {
Ok(0) => {
debug!("EOF reached for {} from {}", direction, client_addr);
break;
}
Ok(n) => {
debug!("Read {} bytes for {} from {}", n, direction, client_addr);
if let Err(e) = writer.write_all(&buffer[..n]).await {
error!(
"Failed to write {} bytes for {} from {}: {}",
n, direction, client_addr, e
);
return Err(anyhow::anyhow!("Write failed: {}", e));
}
if let Err(e) = writer.flush().await {
error!(
"Failed to flush data for {} from {}: {}",
direction, client_addr, e
);
return Err(anyhow::anyhow!("Flush failed: {}", e));
}
}
Err(e) => {
error!(
"Failed to read data for {} from {}: {}",
direction, client_addr, e
);
return Err(anyhow::anyhow!("Read failed: {}", e));
}
}
}
debug!(
"Data forwarding completed for {} from {}",
direction, client_addr
);
Ok(())
}