tiny-proxy 0.4.0

A high-performance HTTP reverse proxy server written in Rust with SSE support, connection pooling, and configurable routing
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
use anyhow::Error;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::header;
use hyper::{Request, Response, StatusCode, Uri};
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use std::sync::Arc;
use tokio::time::{timeout, Duration};
use tracing::{error, info};

#[cfg(feature = "logging")]
use tracing::info_span;
#[cfg(feature = "logging")]
use tracing::Instrument;

use crate::config::{extract_hostname, Config, SiteConfig};
#[cfg(feature = "logging")]
use crate::proxy::access_log::AccessLogGuard;
use crate::proxy::access_log::{ensure_request_id, final_request_id};
use crate::proxy::ActionResult;

use crate::proxy::directives::{
    handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy,
    handle_strip_prefix, handle_uri_replace,
};

/// Unified response body type - can handle both streaming (Incoming) and buffered (Full<Bytes>)
/// This allows us to support SSE streaming while maintaining a simple API
type ResponseBody =
    http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;

/// Check if header is hop-by-hop (should not be proxied)
///
/// Hop-by-hop headers are defined in RFC 7230 Section 6.1
/// These headers are meant for a single connection and should NOT be proxied
/// Uses hyper::header constants for optimal performance (no allocations!)
fn is_hop_header(name: &header::HeaderName) -> bool {
    matches!(
        name,
        &header::CONNECTION
            | &header::UPGRADE
            | &header::TE
            | &header::TRAILER
            | &header::PROXY_AUTHENTICATE
            | &header::PROXY_AUTHORIZATION
    )
}

/// Process directives in order, applying modifications and returning final action.
/// Supports recursive handling of handle_path blocks.
///
/// Note: `info!` logs here are correlated with request ID only when the `logging`
/// feature is enabled (via the tracing span set in `proxy()`). Without `logging`,
/// these logs appear without request context.
pub fn process_directives(
    directives: &[crate::config::Directive],
    req: &mut Request<Incoming>,
    current_path: &str,
) -> Result<ActionResult, String> {
    let mut modified_path = current_path.to_string();

    for directive in directives {
        match directive {
            crate::config::Directive::Header { name, value } => {
                if let Err(e) = handle_header(name, value.as_deref(), req) {
                    info!("   Failed to apply header {}: {}", name, e);
                }
            }

            crate::config::Directive::UriReplace { find, replace } => {
                handle_uri_replace(find, replace, &mut modified_path);
            }

            crate::config::Directive::StripPrefix { prefix } => {
                handle_strip_prefix(prefix, &mut modified_path);
            }

            crate::config::Directive::HandlePath {
                pattern,
                directives: nested_directives,
            } => {
                if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
                    info!("   Matched handle_path: {}", pattern);
                    return process_directives(nested_directives, req, &remaining_path);
                }
            }

            crate::config::Directive::Method {
                methods,
                directives: nested_directives,
            } => {
                if handle_method(methods, req) {
                    info!("   Matched method directive");
                    return process_directives(nested_directives, req, &modified_path);
                }
            }

            crate::config::Directive::Redirect { status, url } => {
                return Ok(handle_redirect(status, url));
            }

            crate::config::Directive::Respond { status, body } => {
                return Ok(handle_respond(status, body));
            }

            crate::config::Directive::ReverseProxy {
                to,
                connect_timeout,
                read_timeout,
            } => {
                return Ok(handle_reverse_proxy(
                    to,
                    &modified_path,
                    *connect_timeout,
                    *read_timeout,
                ));
            }
        }
    }

    Err(format!(
        "No action directive (respond or reverse_proxy) found in configuration for path: {}",
        current_path
    ))
}

/// Process a single request through the proxy
///
/// This implementation ALWAYS streams backend responses (nginx-style):
/// - No buffering of response body
/// - Direct streaming from backend to client
/// - Works for both SSE and regular HTTP
/// - Optimal performance and memory usage
///
/// For direct responses (Respond directive) and errors, buffering is used
/// since these are small and generated by the proxy itself
pub async fn proxy(
    mut req: Request<Incoming>,
    client: Client<HttpsConnector<HttpConnector>, Incoming>,
    config: Arc<Config>,
    remote_addr: std::net::SocketAddr,
    is_tls: bool,
) -> Result<Response<ResponseBody>, Error> {
    // Generate or reuse request ID
    let initial_request_id = ensure_request_id(&mut req);

    // Extract request info before processing
    #[cfg(feature = "logging")]
    let method = req.method().clone().to_string();
    let path = req.uri().path().to_string();
    let host = req
        .headers()
        .get(hyper::header::HOST)
        .and_then(|h| h.to_str().ok())
        .unwrap_or("localhost")
        .to_string();

    #[cfg(feature = "logging")]
    let span = info_span!("request", req_id = %initial_request_id);

    #[allow(unused_variables)]
    let future = async move {
        #[cfg(feature = "logging")]
        let mut log_guard = AccessLogGuard::new(
            initial_request_id.clone(),
            remote_addr,
            method,
            path.clone(),
            host.clone(),
        );

        // Find site configuration by host
        // Browsers send Host: example.com (no port) for default ports,
        // but config keys may be "example.com:443". Try both.
        let site_config = match find_site(&config, &host, is_tls) {
            Some(config) => config,
            None => {
                error!("No configuration found for host: {}", host);
                let (response, _body_len) = error_response_with_id(
                    StatusCode::NOT_FOUND,
                    &format!("No configuration found for host: {}", host),
                    &initial_request_id,
                );
                #[cfg(feature = "logging")]
                {
                    log_guard.set_bytes_sent(_body_len);
                    log_guard.finish(404);
                }
                return Ok(response);
            }
        };

        // Process directives in correct order
        let action_result = match process_directives(&site_config.directives, &mut req, &path) {
            Ok(result) => result,
            Err(e) => {
                error!("Directive processing error: {}", e);
                let final_id = final_request_id(&req, &initial_request_id);
                #[cfg(feature = "logging")]
                {
                    log_guard.set_request_id(final_id.clone());
                    tracing::Span::current().record("req_id", final_id.as_str());
                }
                let (response, _body_len) =
                    error_response_with_id(StatusCode::INTERNAL_SERVER_ERROR, &e, &final_id);
                #[cfg(feature = "logging")]
                {
                    log_guard.set_bytes_sent(_body_len);
                    log_guard.finish(500);
                }
                return Ok(response);
            }
        };

        // Read final request ID (directive may have overwritten X-Request-ID)
        let request_id = final_request_id(&req, &initial_request_id);
        #[cfg(feature = "logging")]
        {
            log_guard.set_request_id(request_id.clone());
            // Update the tracing span with the final req_id
            tracing::Span::current().record("req_id", request_id.as_str());
        }

        // Execute action
        match action_result {
            ActionResult::Redirect { status, url } => {
                let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);

                let boxed: ResponseBody = Full::new(Bytes::new())
                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
                    .boxed();
                let response = Response::builder()
                    .status(status_code)
                    .header("Location", &url)
                    .header("X-Request-ID", &request_id)
                    .body(boxed)?;
                #[cfg(feature = "logging")]
                {
                    log_guard.set_bytes_sent(0);
                    log_guard.finish(status_code.as_u16());
                }
                Ok(response)
            }
            ActionResult::Respond { status, body } => {
                let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
                let _body_len = body.len();

                let boxed: ResponseBody = Full::new(Bytes::from(body))
                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
                    .boxed();
                let response = Response::builder()
                    .status(status_code)
                    .header("X-Request-ID", &request_id)
                    .body(boxed)?;
                #[cfg(feature = "logging")]
                {
                    log_guard.set_bytes_sent(_body_len);
                    log_guard.finish(status_code.as_u16());
                }
                Ok(response)
            }
            ActionResult::ReverseProxy {
                backend_url,
                path_to_send,
                connect_timeout: _,
                read_timeout,
            } => {
                // Add protocol if missing
                let backend_with_proto =
                    if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
                        backend_url
                    } else {
                        format!("http://{}", backend_url)
                    };

                // Use Uri::from_parts() instead of format!() + parse() - faster!
                let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
                parts.path_and_query = Some(path_to_send.parse()?);
                let new_uri = Uri::from_parts(parts)?;

                *req.uri_mut() = new_uri.clone();

                // Save original host for X-Forwarded headers
                let original_host_header = req.headers().get(hyper::header::HOST).cloned();

                // Update Host header for backend
                req.headers_mut().remove(hyper::header::HOST);
                if let Some(authority) = new_uri.authority() {
                    if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>()
                    {
                        req.headers_mut().insert(hyper::header::HOST, host_value);
                    }
                }

                // Add X-Forwarded-* headers for backend visibility
                if let Some(host_value) = original_host_header.clone() {
                    req.headers_mut().insert("X-Forwarded-Host", host_value);
                }

                // X-Forwarded-Proto: based on whether the connection is TLS
                req.headers_mut().insert(
                    "X-Forwarded-Proto",
                    hyper::header::HeaderValue::from_static(if is_tls { "https" } else { "http" }),
                );

                // X-Forwarded-For: real client IP
                if let Ok(ip_value) =
                    hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
                {
                    req.headers_mut().insert("X-Forwarded-For", ip_value);
                }

                // Remove hop-by-hop headers
                req.headers_mut().remove(header::CONNECTION);
                req.headers_mut().remove("accept-encoding");

                // Forward request to backend with configurable timeout (default 30s)
                let backend_timeout = read_timeout.unwrap_or(30);
                match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
                    Ok(Ok(response)) => {
                        let status = response.status();
                        let headers = response.headers().clone();

                        // Stream response body directly (no buffering)
                        let mut builder = Response::builder().status(status);

                        // Copy headers, filtering hop-by-hop
                        for (name, value) in headers.iter() {
                            if !is_hop_header(name) && name != header::CONTENT_LENGTH {
                                builder = builder.header(name, value);
                            }
                        }

                        let (_, incoming_body) = response.into_parts();
                        let boxed: ResponseBody = incoming_body
                            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
                            .boxed();

                        let response = builder.header("X-Request-ID", &request_id).body(boxed)?;
                        #[cfg(feature = "logging")]
                        log_guard.finish(status.as_u16());
                        Ok(response)
                    }
                    Ok(Err(e)) => {
                        error!("Backend connection failed: {:?}", e);
                        if e.is_connect() {
                            error!("   Reason: Connection refused - backend unavailable");
                        } else {
                            error!("   Reason: Other connection error");
                        }

                        let (response, _body_len) = error_response_with_id(
                            StatusCode::BAD_GATEWAY,
                            "Backend service unavailable",
                            &request_id,
                        );
                        #[cfg(feature = "logging")]
                        {
                            log_guard.set_bytes_sent(_body_len);
                            log_guard.finish(502);
                        }
                        Ok(response)
                    }
                    Err(_) => {
                        error!(
                            "Backend request timed out after {} seconds",
                            backend_timeout
                        );

                        let (response, _body_len) = error_response_with_id(
                            StatusCode::GATEWAY_TIMEOUT,
                            "Backend request timed out",
                            &request_id,
                        );
                        #[cfg(feature = "logging")]
                        {
                            log_guard.set_bytes_sent(_body_len);
                            log_guard.finish(504);
                        }
                        Ok(response)
                    }
                }
            }
        }
    };

    #[cfg(feature = "logging")]
    let future = future.instrument(span);

    future.await
}

/// Creates HTTP response with error and X-Request-ID header
///
/// Returns both the response and the body length (for access logging).
fn error_response_with_id(
    status: StatusCode,
    message: &str,
    request_id: &str,
) -> (Response<ResponseBody>, usize) {
    let body = format!(
        r#"<!DOCTYPE html>
        <html>
        <head><title>{} {}</title></head>
        <body>
        <h1>{} {}</h1>
        <p>{}</p>
        <hr>
        <p><em>Rust Proxy Server</em></p>
        </body>
        </html>"#,
        status.as_u16(),
        status.canonical_reason().unwrap_or("Error"),
        status.as_u16(),
        status.canonical_reason().unwrap_or("Error"),
        message
    );

    let body_len = body.len();
    let full = Full::new(Bytes::from(body));
    let boxed: ResponseBody = full
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
        .boxed();

    let mut builder = Response::builder()
        .status(status)
        .header("Content-Type", "text/html; charset=utf-8");

    if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
        builder = builder.header("X-Request-ID", val);
    }

    let response = builder.body(boxed).unwrap_or_else(|_| {
        Response::new(
            Full::new(Bytes::from("Internal Server Error"))
                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
                .boxed(),
        )
    });

    (response, body_len)
}

/// Match path against pattern (supports wildcard *)
/// Returns Some(remaining_path) if match, None otherwise
pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
    if let Some(prefix) = pattern.strip_suffix("/*") {
        if path.starts_with(prefix) {
            let remaining = path.strip_prefix(prefix).unwrap_or(path);
            Some(remaining.to_string())
        } else {
            None
        }
    } else if pattern == path {
        Some("/".to_string())
    } else {
        None
    }
}

/// Find a site configuration matching the given Host header value.
///
/// Browsers on default ports omit the port from the Host header:
/// - HTTPS on 443 → `Host: example.com` (no `:443`)
/// - HTTP on 80 → `Host: example.com` (no `:80`)
///
/// But config keys include the port: `"example.com:443"`, `"example.com:80"`.
///
/// This function tries multiple lookup strategies:
/// 1. Exact match: `host` as-is
/// 2. If `host` has no port → try `host:<default_port>` based on `is_tls`
/// 3. If `host` has a port → also try just the hostname (in case config has no port)
///
/// # Limitations
///
/// For non-default TLS ports (e.g., 8443), browsers always include the port
/// in the `Host` header (`Host: example.com:8443`), so strategy 1 (exact match)
/// works fine. The fallback in strategy 2 only tries ports 443 (TLS) and 80 (HTTP).
/// This means a non-browser client sending `Host: example.com` without a port to
/// a TLS listener on :8443 will get a 404 — this is a protocol violation by the client.
pub fn find_site<'a>(config: &'a Config, host: &str, is_tls: bool) -> Option<&'a SiteConfig> {
    // 1. Exact match
    if let Some(site) = config.sites.get(host) {
        return Some(site);
    }

    // Determine if host already contains a port
    // IPv6: [::1]:8080 — port is after the last ']' + ':'
    // IPv4/hostname: example.com:443
    let has_port = if host.starts_with('[') {
        // IPv6 — look for port after ']'
        if let Some(bracket_end) = host.find(']') {
            host[bracket_end..].contains(':')
        } else {
            false
        }
    } else {
        host.contains(':')
    };

    if !has_port {
        // 2. Host has no port — try appending the default port
        let default_port = if is_tls { 443 } else { 80 };
        let candidate = format!("{}:{}", host, default_port);
        if let Some(site) = config.sites.get(&candidate) {
            return Some(site);
        }

        // 3. TLS on a non-standard port — match by SNI hostname if unambiguous
        if is_tls {
            let mut matches = config.sites.values().filter(|s| {
                s.tls.is_some() && extract_hostname(&s.address).eq_ignore_ascii_case(host)
            });
            if let Some(site) = matches.next() {
                if matches.next().is_none() {
                    return Some(site);
                }
            }
        }
    } else {
        // 4. Host has a port — try just the hostname (strip port)
        let hostname = if host.starts_with('[') {
            // IPv6 [::1]:port → ::1
            let end = host.find(']').unwrap_or(host.len());
            host[1..end].to_string()
        } else {
            host.rsplit(':').next_back().unwrap_or(host).to_string()
        };
        if let Some(site) = config.sites.get(&hostname) {
            return Some(site);
        }
    }

    None
}

#[cfg(test)]
mod find_site_tests {
    use super::*;
    use std::collections::HashMap;

    fn make_config(sites: Vec<(&str, bool)>) -> Config {
        let mut map = HashMap::new();
        for (addr, has_tls) in sites {
            map.insert(
                addr.to_string(),
                crate::config::SiteConfig {
                    address: addr.to_string(),
                    directives: vec![],
                    tls: if has_tls {
                        Some(crate::config::TlsConfig {
                            cert_path: "/fake/cert.pem".to_string(),
                            key_path: "/fake/key.pem".to_string(),
                        })
                    } else {
                        None
                    },
                },
            );
        }
        Config { sites: map }
    }

    #[test]
    fn test_exact_match() {
        let config = make_config(vec![("example.com:443", true)]);
        assert!(find_site(&config, "example.com:443", true).is_some());
    }

    #[test]
    fn test_tls_host_without_port_finds_443() {
        let config = make_config(vec![("example.com:443", true)]);
        // Browser sends Host: example.com (no :443) on HTTPS
        assert!(
            find_site(&config, "example.com", true).is_some(),
            "Should find example.com:443 when Host has no port and is_tls=true"
        );
    }

    #[test]
    fn test_http_host_without_port_finds_80() {
        let config = make_config(vec![("example.com:80", false)]);
        // Browser sends Host: example.com (no :80) on HTTP
        assert!(
            find_site(&config, "example.com", false).is_some(),
            "Should find example.com:80 when Host has no port and is_tls=false"
        );
    }

    #[test]
    fn test_tls_host_without_port_no_match_on_80() {
        let config = make_config(vec![("example.com:80", false)]);
        // Host: example.com on TLS should NOT match :80
        assert!(
            find_site(&config, "example.com", true).is_none(),
            "TLS on port 443 should not find :80 site"
        );
    }

    #[test]
    fn test_host_with_port_strips_port_fallback() {
        let config = make_config(vec![("example.com", false)]);
        // Config has "example.com" (no port), Host has "example.com:8080"
        assert!(
            find_site(&config, "example.com:8080", false).is_some(),
            "Should strip port from Host and find config without port"
        );
    }

    #[test]
    fn test_tls_host_without_port_finds_non_standard_port() {
        let config = make_config(vec![("alpha.local:8443", true)]);
        assert!(
            find_site(&config, "alpha.local", true).is_some(),
            "Should find alpha.local:8443 when Host has no port on TLS"
        );
    }

    #[test]
    fn test_no_match() {
        let config = make_config(vec![("other.com:443", true)]);
        assert!(find_site(&config, "example.com", true).is_none());
    }
}