tcproxy 0.1.1

A TCP proxy for PostgreSQL connections with SSH tunnel support and runtime target switching
Documentation
use crate::config::TargetConfig;
use crate::connection_manager::ConnectionManager;
use crate::logging::log_resource_cleanup;
use anyhow::{Context, Result};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Semaphore;
use tokio::time::timeout;
use tracing::{debug, error, info, warn};

pub struct ProxyServer {
    listen_host: String,
    listen_port: u16,
    target_name: String,
    target_config: TargetConfig,
    connection_manager: Arc<ConnectionManager>,
    connection_semaphore: Arc<Semaphore>,
    connection_management_config: crate::config::ConnectionManagementConfig,
}

impl ProxyServer {
    pub fn new(
        listen_host: String,
        listen_port: u16,
        target_name: String,
        target_config: TargetConfig,
        connection_management_config: crate::config::ConnectionManagementConfig,
    ) -> Self {
        let connection_manager = Arc::new(ConnectionManager::new());

        Self {
            listen_host,
            listen_port,
            target_name,
            target_config,
            connection_manager,
            connection_semaphore: Arc::new(Semaphore::new(1000)),
            connection_management_config,
        }
    }

    pub async fn run(&self) -> Result<()> {
        info!(
            target_name = %self.target_name,
            listen_host = %self.listen_host,
            listen_port = self.listen_port,
            "Initializing proxy server"
        );

        self.connection_manager
            .initialize_target(self.target_name.clone(), self.target_config.clone())
            .await
            .with_context(|| format!("Failed to initialize target '{}'", self.target_name))?;

        self.connection_manager
            .start_health_checking(
                vec![self.target_name.clone()],
                &self.connection_management_config,
            )
            .await;

        let bind_addr = format!("{}:{}", self.listen_host, self.listen_port);
        let listener = TcpListener::bind(&bind_addr)
            .await
            .with_context(|| format!("Failed to bind to address '{}'", bind_addr))?;

        info!(
            bind_addr = %bind_addr,
            target_name = %self.target_name,
            max_connections = self.connection_semaphore.available_permits(),
            "Proxy server listening and ready to accept connections"
        );

        if let Some(ssh) = &self.target_config.ssh {
            if ssh.enabled {
                info!(
                    ssh_host = %ssh.host.as_ref().unwrap_or(&"<not configured>".to_string()),
                    "SSH tunnel enabled"
                );
            }
        }

        let mut connection_counter = 0u64;

        loop {
            match listener.accept().await {
                Ok((client_stream, client_addr)) => {
                    connection_counter += 1;

                    info!(
                        "Connection accepted from {} for target {} (connection #{}, {} permits available)",
                        client_addr,
                        self.target_name,
                        connection_counter,
                        self.connection_semaphore.available_permits()
                    );

                    let target_config = self.target_config.clone();
                    let target_name = self.target_name.clone();
                    let connection_manager = Arc::clone(&self.connection_manager);
                    let semaphore = Arc::clone(&self.connection_semaphore);

                    tokio::spawn(async move {
                        let connection_start = Instant::now();

                        let permit = match semaphore.acquire().await {
                            Ok(permit) => {
                                debug!(
                                    client_addr = %client_addr,
                                    remaining_permits = semaphore.available_permits(),
                                    "Connection permit acquired"
                                );
                                permit
                            }
                            Err(e) => {
                                error!(
                                    "Failed to acquire connection permit for {}: {}",
                                    client_addr, e
                                );
                                return;
                            }
                        };

                        let result = handle_connection_with_health_check(
                            client_stream,
                            target_config,
                            target_name.clone(),
                            connection_manager,
                            client_addr,
                        )
                        .await;

                        let connection_duration = connection_start.elapsed();

                        match result {
                            Ok(()) => {
                                debug!(
                                    "Connection from {} completed successfully in {}ms",
                                    client_addr,
                                    connection_duration.as_millis()
                                );
                            }
                            Err(e) => {
                                error!(
                                    "Connection from {} failed after {}ms: {}",
                                    client_addr,
                                    connection_duration.as_millis(),
                                    e
                                );
                            }
                        }

                        drop(permit);
                        log_resource_cleanup("connection_permit", &client_addr.to_string(), true);
                    });
                }
                Err(e) => {
                    error!(
                        error = %e,
                        bind_addr = %bind_addr,
                        "Failed to accept connection, continuing to listen"
                    );

                    tokio::time::sleep(Duration::from_millis(100)).await;
                }
            }
        }
    }
}

async fn handle_connection_with_health_check(
    client_stream: TcpStream,
    target_config: TargetConfig,
    target_name: String,
    connection_manager: Arc<ConnectionManager>,
    client_addr: std::net::SocketAddr,
) -> Result<()> {
    debug!(
        "Handling connection from {} to target {}",
        client_addr, target_name
    );

    if !connection_manager.is_target_healthy(&target_name).await {
        warn!(
            "Rejecting connection from {} - target {} is unhealthy",
            client_addr, target_name
        );
        return Err(anyhow::anyhow!("Target '{}' is unhealthy", target_name));
    }

    let target_stream = timeout(
        Duration::from_secs(30),
        create_target_connection(&target_config, client_addr),
    )
    .await
    .with_context(|| format!("Timeout connecting to target for client {}", client_addr))?
    .with_context(|| format!("Failed to connect to target for client {}", client_addr))?;

    info!(
        "Successfully connected client {} to target {}",
        client_addr, target_name
    );

    let (client_read, client_write) = client_stream.into_split();
    let (target_read, target_write) = target_stream.into_split();

    let client_to_target = forward_data(client_read, target_write, "client->target", client_addr);
    let target_to_client = forward_data(target_read, client_write, "target->client", client_addr);

    tokio::select! {
        result = client_to_target => {
            match result {
                Ok(()) => {
                    debug!(
                        "Client->target forwarding completed for {}",
                        client_addr
                    );
                }
                Err(e) => {
                    debug!(
                        "Client->target forwarding ended for {}: {}",
                        client_addr, e
                    );
                }
            }
        }
        result = target_to_client => {
            match result {
                Ok(()) => {
                    debug!(
                        "Target->client forwarding completed for {}",
                        client_addr
                    );
                }
                Err(e) => {
                    debug!(
                        "Target->client forwarding ended for {}: {}",
                        client_addr, e
                    );
                }
            }
        }
    };

    debug!(
        "Connection forwarding completed for {} to target {}",
        client_addr, target_name
    );

    Ok(())
}

async fn create_target_connection(
    target_config: &TargetConfig,
    client_addr: std::net::SocketAddr,
) -> Result<TcpStream> {
    if let Some(ssh) = &target_config.ssh {
        if ssh.enabled {
            warn!(
                "SSH tunnel requested for {} but using direct connection for now",
                client_addr
            );
        }
    }

    let target_addr = format!("{}:{}", target_config.host, target_config.port);
    let stream = TcpStream::connect(&target_addr).await.map_err(|e| {
        error!(
            "Failed to connect to target {} for client {}: {}",
            target_addr, client_addr, e
        );
        e
    })?;

    debug!(
        "Connected to target {} for client {}",
        target_addr, client_addr
    );
    Ok(stream)
}

async fn forward_data<R, W>(
    mut reader: R,
    mut writer: W,
    direction: &str,
    client_addr: std::net::SocketAddr,
) -> Result<()>
where
    R: AsyncReadExt + Unpin,
    W: AsyncWriteExt + Unpin,
{
    let mut buffer = vec![0u8; 8192];

    loop {
        match reader.read(&mut buffer).await {
            Ok(0) => {
                debug!("EOF reached for {} from {}", direction, client_addr);
                break;
            }
            Ok(n) => {
                debug!("Read {} bytes for {} from {}", n, direction, client_addr);

                if let Err(e) = writer.write_all(&buffer[..n]).await {
                    error!(
                        "Failed to write {} bytes for {} from {}: {}",
                        n, direction, client_addr, e
                    );
                    return Err(anyhow::anyhow!("Write failed: {}", e));
                }

                if let Err(e) = writer.flush().await {
                    error!(
                        "Failed to flush data for {} from {}: {}",
                        direction, client_addr, e
                    );
                    return Err(anyhow::anyhow!("Flush failed: {}", e));
                }
            }
            Err(e) => {
                error!(
                    "Failed to read data for {} from {}: {}",
                    direction, client_addr, e
                );
                return Err(anyhow::anyhow!("Read failed: {}", e));
            }
        }
    }

    debug!(
        "Data forwarding completed for {} from {}",
        direction, client_addr
    );
    Ok(())
}