tiny-proxy 0.3.0

A high-performance HTTP reverse proxy server written in Rust with SSE support, connection pooling, and configurable routing
Documentation
use anyhow::Error;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::header;
use hyper::{Request, Response, StatusCode, Uri};
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use std::sync::Arc;
use tokio::time::{timeout, Duration};
use tracing::{error, info};

use crate::config::Config;
use crate::proxy::ActionResult;

use crate::proxy::directives::{
    handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy,
    handle_strip_prefix, handle_uri_replace,
};

/// Unified response body type - can handle both streaming (Incoming) and buffered (Full<Bytes>)
/// This allows us to support SSE streaming while maintaining a simple API
type ResponseBody =
    http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;

/// Check if header is hop-by-hop (should not be proxied)
///
/// Hop-by-hop headers are defined in RFC 7230 Section 6.1
/// These headers are meant for a single connection and should NOT be proxied
/// Uses hyper::header constants for optimal performance (no allocations!)
fn is_hop_header(name: &header::HeaderName) -> bool {
    matches!(
        name,
        &header::CONNECTION
            | &header::UPGRADE
            | &header::TE
            | &header::TRAILER
            | &header::PROXY_AUTHENTICATE
            | &header::PROXY_AUTHORIZATION
    )
}

/// Process directives in order, applying modifications and returning final action
/// Supports recursive handling of handle_path blocks
pub fn process_directives(
    directives: &[crate::config::Directive],
    req: &mut Request<Incoming>,
    current_path: &str,
) -> Result<ActionResult, String> {
    let mut modified_path = current_path.to_string();

    for directive in directives {
        match directive {
            // Apply header modifications using directive handler

            // Apply header modifications using directive handler
            crate::config::Directive::Header { name, value } => {
                if let Err(e) = handle_header(name, value.as_deref(), req) {
                    info!("   Failed to apply header {}: {}", name, e);
                }
            }

            // Apply URI replacements using directive handler
            crate::config::Directive::UriReplace { find, replace } => {
                handle_uri_replace(find, replace, &mut modified_path);
            }

            // Strip prefix from URI path
            crate::config::Directive::StripPrefix { prefix } => {
                handle_strip_prefix(prefix, &mut modified_path);
            }

            // Handle path-based routing recursively
            crate::config::Directive::HandlePath {
                pattern,
                directives: nested_directives,
            } => {
                if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
                    info!("   Matched handle_path: {}", pattern);
                    // Recursively process nested directives with remaining path
                    return process_directives(nested_directives, req, &remaining_path);
                }
            }

            // Method-based directives
            crate::config::Directive::Method {
                methods,
                directives: nested_directives,
            } => {
                if handle_method(methods, req) {
                    info!("   Matched method directive");
                    // Process nested directives with same path
                    return process_directives(nested_directives, req, &modified_path);
                }
            }

            // Redirect - return redirect response with Location header
            crate::config::Directive::Redirect { status, url } => {
                return Ok(handle_redirect(status, url));
            }

            // Direct response - return immediately using directive handler
            crate::config::Directive::Respond { status, body } => {
                return Ok(handle_respond(status, body));
            }

            // Reverse proxy - return action using directive handler
            crate::config::Directive::ReverseProxy {
                to,
                connect_timeout,
                read_timeout,
            } => {
                return Ok(handle_reverse_proxy(
                    to,
                    &modified_path,
                    *connect_timeout,
                    *read_timeout,
                ));
            }
        }
    }

    Err(format!(
        "No action directive (respond or reverse_proxy) found in configuration for path: {}",
        current_path
    ))
}

/// Process a single request through the proxy
///
/// This implementation ALWAYS streams backend responses (nginx-style):
/// - No buffering of response body
/// - Direct streaming from backend to client
/// - Works for both SSE and regular HTTP
/// - Optimal performance and memory usage
///
/// For direct responses (Respond directive) and errors, buffering is used
/// since these are small and generated by the proxy itself
pub async fn proxy(
    mut req: Request<Incoming>,
    client: Client<HttpsConnector<HttpConnector>, Incoming>,
    config: Arc<Config>,
    remote_addr: std::net::SocketAddr,
) -> Result<Response<ResponseBody>, Error> {
    // Get path from URI (using String to avoid borrow conflict with mutable req)
    let path = req.uri().path().to_string();

    // Get host from Host header (includes port, e.g., "localhost:8080")
    let host = req
        .headers()
        .get(hyper::header::HOST)
        .and_then(|h| h.to_str().ok())
        .unwrap_or("localhost");

    // Logging with enabled check to avoid string formatting when disabled
    if tracing::enabled!(tracing::Level::INFO) {
        // Removed info logging from hot path for performance
        // Use DEBUG level if needed for troubleshooting
    }

    // Find site configuration by host (with port!)
    let site_config = match config.sites.get(host) {
        Some(config) => config,
        None => {
            error!("No configuration found for host: {}", host);
            return Ok(error_response(
                StatusCode::NOT_FOUND,
                &format!("No configuration found for host: {}", host),
            ));
        }
    };

    // Process directives in correct order
    let action_result =
        process_directives(&site_config.directives, &mut req, &path).map_err(anyhow::Error::msg)?;

    // Execute action
    match action_result {
        ActionResult::Redirect { status, url } => {
            let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
            let boxed: ResponseBody = Full::new(Bytes::from(url.clone()))
                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
                .boxed();
            Ok(Response::builder()
                .status(status_code)
                .header("Location", &url)
                .body(boxed)?)
        }
        ActionResult::Respond { status, body } => {
            let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
            let boxed: ResponseBody = Full::new(Bytes::from(body))
                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
                .boxed();
            Ok(Response::builder().status(status_code).body(boxed)?)
        }
        ActionResult::ReverseProxy {
            backend_url,
            path_to_send,
            connect_timeout: _,
            read_timeout,
        } => {
            // Add protocol if missing
            let backend_with_proto =
                if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
                    backend_url
                } else {
                    format!("http://{}", backend_url)
                };

            // Use Uri::from_parts() instead of format!() + parse() - faster!
            let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
            parts.path_and_query = Some(path_to_send.parse()?);
            let new_uri = Uri::from_parts(parts)?;

            // Logging with enabled check to avoid string formatting when disabled
            if tracing::enabled!(tracing::Level::INFO) {
                // Removed info logging from hot path for performance
                // Use DEBUG level if needed for troubleshooting
            }

            *req.uri_mut() = new_uri.clone();

            // Save original host for X-Forwarded headers
            // Clone HeaderValue directly - 0 allocations!
            let original_host_header = req.headers().get(hyper::header::HOST).cloned();

            // Update Host header for backend
            req.headers_mut().remove(hyper::header::HOST);
            if let Some(authority) = new_uri.authority() {
                if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>() {
                    req.headers_mut().insert(hyper::header::HOST, host_value);
                }
            }

            // Add X-Forwarded-* headers for backend visibility
            // X-Forwarded-Host: original Host header from client
            if let Some(host_value) = original_host_header.clone() {
                req.headers_mut().insert("X-Forwarded-Host", host_value);
            }

            // X-Forwarded-Proto: scheme from original request (http or https)
            let original_scheme = req.uri().scheme_str().unwrap_or("http");
            // Use from_static for known values - 0 allocations!
            match original_scheme {
                "http" => {
                    req.headers_mut().insert(
                        "X-Forwarded-Proto",
                        hyper::header::HeaderValue::from_static("http"),
                    );
                }
                "https" => {
                    req.headers_mut().insert(
                        "X-Forwarded-Proto",
                        hyper::header::HeaderValue::from_static("https"),
                    );
                }
                _ => {} // ignore unknown schemes
            }

            // X-Forwarded-For: real client IP from TCP connection
            if let Ok(ip_value) =
                hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
            {
                req.headers_mut().insert("X-Forwarded-For", ip_value);
            }

            // Remove hop-by-hop headers from request before sending to backend
            // Connection header must not be proxied (hyper manages connections)
            req.headers_mut().remove(header::CONNECTION);

            // Remove Accept-Encoding to prevent compression
            // Compression breaks streaming and SSE
            req.headers_mut().remove("accept-encoding");

            // Forward request to backend with configurable timeout (default 30s)
            let backend_timeout = read_timeout.unwrap_or(30);
            match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
                Ok(Ok(response)) => {
                    // Successfully received response from backend
                    let status = response.status();
                    let headers = response.headers().clone();

                    // Logging with enabled check to avoid string formatting when disabled
                    if tracing::enabled!(tracing::Level::INFO) {
                        // Removed info logging from hot path for performance
                        // Use DEBUG level if needed for troubleshooting
                    }

                    // Stream response body directly (no buffering)
                    let mut builder = Response::builder().status(status);

                    // Copy all headers from backend, filtering out hop-by-hop headers
                    // Hop-by-hop headers should not be proxied per RFC 7230
                    // Also remove Content-Length to let hyper handle chunked encoding
                    for (name, value) in headers.iter() {
                        if !is_hop_header(name) && name != header::CONTENT_LENGTH {
                            builder = builder.header(name, value);
                        }
                    }

                    // Extract streaming body and convert to BoxBody
                    let (_, incoming_body) = response.into_parts();
                    let boxed: ResponseBody = incoming_body
                        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
                        .boxed();

                    Ok(builder.body(boxed)?)
                }
                Ok(Err(e)) => {
                    // Backend unavailable - return 502 Bad Gateway
                    error!("Backend connection failed: {:?}", e);

                    if e.is_connect() {
                        error!("   Reason: Connection refused - backend unavailable");
                    } else {
                        error!("   Reason: Other connection error");
                    }

                    Ok(error_response(
                        StatusCode::BAD_GATEWAY,
                        "Backend service unavailable",
                    ))
                }
                Err(_) => {
                    // Timeout - return 504 Gateway Timeout
                    error!(
                        "Backend request timed out after {} seconds",
                        backend_timeout
                    );

                    Ok(error_response(
                        StatusCode::GATEWAY_TIMEOUT,
                        "Backend request timed out",
                    ))
                }
            }
        }
    }
}

/// Creates HTTP response with error
fn error_response(status: StatusCode, message: &str) -> Response<ResponseBody> {
    let body = format!(
        r#"<!DOCTYPE html>
        <html>
        <head><title>{} {}</title></head>
        <body>
        <h1>{} {}</h1>
        <p>{}</p>
        <hr>
        <p><em>Rust Proxy Server</em></p>
        </body>
        </html>"#,
        status.as_u16(),
        status.canonical_reason().unwrap_or("Error"),
        status.as_u16(),
        status.canonical_reason().unwrap_or("Error"),
        message
    );

    let full = Full::new(Bytes::from(body));
    let boxed: ResponseBody = full
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
        .boxed();

    Response::builder()
        .status(status)
        .header("Content-Type", "text/html; charset=utf-8")
        .body(boxed)
        .unwrap()
}

/// Match path against pattern (supports wildcard *)
/// Returns Some(remaining_path) if match, None otherwise
pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
    if let Some(prefix) = pattern.strip_suffix("/*") {
        if path.starts_with(prefix) {
            // Remove prefix and return remaining path
            let remaining = path.strip_prefix(prefix).unwrap_or(path);
            Some(remaining.to_string())
        } else {
            None
        }
    } else if pattern == path {
        Some("/".to_string()) // Exact match, send root
    } else {
        None
    }
}