tcproxy 0.1.1

A TCP proxy for PostgreSQL connections with SSH tunnel support and runtime target switching
Documentation
use anyhow::{Context, Result};
use std::io;
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};

/// Initialize structured logging with enhanced error handling and configuration
pub fn init_logging(log_level: &str, json_format: bool) -> Result<()> {
    let valid_levels = ["trace", "debug", "info", "warn", "error"];
    if !valid_levels.contains(&log_level.to_lowercase().as_str()) {
        return Err(anyhow::anyhow!(
            "Invalid log level '{}'. Valid levels are: {}",
            log_level,
            valid_levels.join(", ")
        ));
    }

    let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
        let filter_str = format!(
            "tcproxy={},async_ssh2_tokio=warn,deadpool=info,tokio=warn,tracing=warn",
            log_level
        );
        EnvFilter::new(filter_str)
    });

    let registry = tracing_subscriber::registry().with(env_filter);

    if json_format {
        registry
            .with(
                fmt::layer()
                    .json()
                    .with_current_span(true)
                    .with_span_list(true)
                    .with_target(true)
                    .with_thread_ids(true)
                    .with_file(true)
                    .with_line_number(true)
                    .with_writer(io::stderr),
            )
            .try_init()
            .context("Failed to initialize JSON logging")?;
    } else {
        registry
            .with(
                fmt::layer()
                    .with_target(false)
                    .with_thread_ids(true)
                    .with_file(true)
                    .with_line_number(true)
                    .with_ansi(atty::is(atty::Stream::Stderr))
                    .with_writer(io::stderr)
                    .compact(),
            )
            .try_init()
            .context("Failed to initialize console logging")?;
    }

    tracing::info!(
        log_level = log_level,
        json_format = json_format,
        "Logging initialized successfully"
    );

    Ok(())
}

/// Log an error with full context chain
pub fn log_error_with_context(error: &anyhow::Error, message: &str) {
    tracing::error!(
        error = %error,
        error_chain = ?error.chain().collect::<Vec<_>>(),
        "{}", message
    );
}

/// Create a span for SSH operations
pub fn ssh_span(target: &str, ssh_host: &str) -> tracing::Span {
    tracing::info_span!(
        "ssh_operation",
        target = target,
        ssh_host = ssh_host,
        operation_id = %uuid::Uuid::new_v4()
    )
}

/// Log resource cleanup operations
pub fn log_resource_cleanup(resource_type: &str, resource_id: &str, success: bool) {
    if success {
        tracing::debug!(
            resource_type = resource_type,
            resource_id = resource_id,
            "Resource cleaned up successfully"
        );
    } else {
        tracing::warn!(
            resource_type = resource_type,
            resource_id = resource_id,
            "Failed to clean up resource"
        );
    }
}

/// Enhanced logging macros with structured fields
#[macro_export]
macro_rules! log_error {
    ($($arg:tt)*) => {
        tracing::error!($($arg)*)
    };
}

#[macro_export]
macro_rules! log_warn {
    ($($arg:tt)*) => {
        tracing::warn!($($arg)*)
    };
}

#[macro_export]
macro_rules! log_info {
    ($($arg:tt)*) => {
        tracing::info!($($arg)*)
    };
}

#[macro_export]
macro_rules! log_debug {
    ($($arg:tt)*) => {
        tracing::debug!($($arg)*)
    };
}

#[macro_export]
macro_rules! log_trace {
    ($($arg:tt)*) => {
        tracing::trace!($($arg)*)
    };
}

/// Macro for logging connection events with structured data
#[macro_export]
macro_rules! log_connection_event {
    ($level:ident, $client_addr:expr, $target:expr, $event:expr, $($field:ident = $value:expr),* $(,)?) => {
        tracing::$level!(
            client_addr = %$client_addr,
            target = $target,
            event = $event,
            $($field = $value,)*
        );
    };
}

/// Macro for logging SSH events with structured data
#[macro_export]
macro_rules! log_ssh_event {
    ($level:ident, $target:expr, $ssh_host:expr, $event:expr, $($field:ident = $value:expr),* $(,)?) => {
        tracing::$level!(
            target = $target,
            ssh_host = $ssh_host,
            event = $event,
            $($field = $value,)*
        );
    };
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_init_logging_valid_levels() {
        let valid_levels = ["trace", "debug", "info", "warn", "error"];

        for level in valid_levels {
            // Note: We can't actually test the initialization in unit tests
            // because tracing can only be initialized once per process
            // This test just validates the level validation logic
            let result = validate_log_level(level);
            assert!(result.is_ok(), "Level '{}' should be valid", level);
        }
    }

    #[test]
    fn test_init_logging_invalid_level() {
        let result = validate_log_level("invalid");
        assert!(result.is_err());

        let error_msg = result.unwrap_err().to_string();
        assert!(error_msg.contains("Invalid log level"));
        assert!(error_msg.contains("invalid"));
    }

    #[test]
    fn test_ssh_span_creation() {
        // Test that ssh_span function doesn't panic and creates a span
        let _span = ssh_span("test-target", "test-host");
        // The function should execute without panicking
        // Span behavior depends on tracing subscriber initialization
        // which is complex to test in unit tests
    }

    // Helper function to validate log levels without initializing tracing
    fn validate_log_level(log_level: &str) -> Result<()> {
        let valid_levels = ["trace", "debug", "info", "warn", "error"];
        if !valid_levels.contains(&log_level.to_lowercase().as_str()) {
            return Err(anyhow::anyhow!(
                "Invalid log level '{}'. Valid levels are: {}",
                log_level,
                valid_levels.join(", ")
            ));
        }
        Ok(())
    }
}