wafrift-smuggling 0.2.15

HTTP request smuggling and HTTP/2 frame-level evasion payloads.
Documentation
//! Safety controls to prevent collateral damage during smuggling scans.

use rand::Rng;
use std::time::{Duration, Instant};

/// Per-request poison canary used to distinguish true smuggling responses
/// from coincidental server variance.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Canary {
    pub token: String,
}

impl Canary {
    /// Generate a random 16-byte alphanumeric canary.
    #[must_use]
    pub fn generate() -> Self {
        const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
        let mut rng = rand::thread_rng();
        let token: String = (0..16)
            .map(|_| CHARSET[rng.gen_range(0..CHARSET.len())] as char)
            .collect();
        Self { token }
    }
}

/// Policy that governs safe scanning behavior.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ScanPolicy {
    /// Initial delay between probes (milliseconds).
    pub base_delay_ms: u64,
    /// Maximum delay between probes (milliseconds).
    pub max_delay_ms: u64,
    /// Maximum number of retries for a single probe.
    pub max_retries: u32,
    /// Whether to add random jitter to delays.
    pub jitter: bool,
    /// Whether this probe requires a fresh TCP connection.
    pub fresh_connection: bool,
}

impl Default for ScanPolicy {
    fn default() -> Self {
        Self {
            base_delay_ms: 100,
            max_delay_ms: 10_000,
            max_retries: 3,
            jitter: true,
            fresh_connection: true,
        }
    }
}

impl ScanPolicy {
    /// Compute the backoff delay for a given retry attempt.
    #[must_use]
    pub fn backoff_delay(&self, attempt: u32) -> Duration {
        let exp = 1u64 << attempt.min(63);
        let ms = self
            .base_delay_ms
            .saturating_mul(exp)
            .min(self.max_delay_ms);
        let jitter_ms = if self.jitter {
            rand::thread_rng().gen_range(0..=(ms / 4))
        } else {
            0
        };
        // Pre-fix this used `ms + jitter_ms` which panics in debug /
        // wraps in release when max_delay_ms is near u64::MAX.
        Duration::from_millis(ms.saturating_add(jitter_ms))
    }
}

/// Connection isolation policy for a probe.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionPolicy {
    /// Send on a brand-new connection.
    Fresh,
    /// Reuse an existing connection.
    Reuse,
    /// Multiplex on an HTTP/2 stream.
    Multiplex,
}

/// Circuit-breaker state machine to halt scanning before `DoSing` the target.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
    Closed,
    Open,
    HalfOpen,
}

/// Simple in-memory circuit breaker.
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
    pub failure_threshold: u32,
    pub recovery_timeout: Duration,
    pub state: CircuitState,
    pub consecutive_failures: u32,
    pub last_failure: Option<Instant>,
}

impl CircuitBreaker {
    #[must_use]
    pub fn new(failure_threshold: u32, recovery_timeout_ms: u64) -> Self {
        Self {
            failure_threshold,
            recovery_timeout: Duration::from_millis(recovery_timeout_ms),
            state: CircuitState::Closed,
            consecutive_failures: 0,
            last_failure: None,
        }
    }

    /// Record a probe failure (timeout, 5xx, or connection reset).
    pub fn record_failure(&mut self) {
        // Pre-fix `+= 1` would panic in debug / wrap in release after
        // u32::MAX failures. Saturating add keeps the breaker stuck in
        // Open instead of secretly resetting itself.
        self.consecutive_failures = self.consecutive_failures.saturating_add(1);
        self.last_failure = Some(Instant::now());
        if self.consecutive_failures >= self.failure_threshold {
            self.state = CircuitState::Open;
        }
    }

    /// Record a successful probe.
    pub fn record_success(&mut self) {
        self.consecutive_failures = 0;
        self.state = CircuitState::Closed;
    }

    /// Returns `true` if the circuit allows another probe.
    #[must_use]
    pub fn can_proceed(&mut self) -> bool {
        match self.state {
            CircuitState::Closed => true,
            CircuitState::Open => {
                if let Some(last) = self.last_failure
                    && last.elapsed() >= self.recovery_timeout
                {
                    self.state = CircuitState::HalfOpen;
                    return true;
                }
                false
            }
            CircuitState::HalfOpen => true,
        }
    }
}

/// Generate a cache-busting query parameter token.
#[must_use]
pub fn cache_buster() -> String {
    let mut rng = rand::thread_rng();
    format!("{}", rng.gen_range(0..=u32::MAX))
}

/// Sanitize a user-supplied host/path/prefix to prevent accidental header injection.
///
/// Rejects CR (`\r`), LF (`\n`), and NULL (`\0`). The NULL check is
/// load-bearing: many HTTP/1 stacks truncate header values at the
/// first NUL, turning a benign-looking `Host: example.com\0evil.com`
/// into a smuggled second host. The audit caught this as CRITICAL.
///
/// # Errors
/// Returns an error if the input contains `\r`, `\n`, or `\0`.
pub fn sanitize_input(input: &str) -> Result<String, SafetyError> {
    if input.contains('\r') || input.contains('\n') || input.contains('\0') {
        return Err(SafetyError::HeaderInjection);
    }
    Ok(input.into())
}

/// Guard that a string contains no CRLF or NUL without copying it.
pub fn guard_no_crlf(input: &str) -> Result<(), SafetyError> {
    if input.contains('\r') || input.contains('\n') || input.contains('\0') {
        return Err(SafetyError::HeaderInjection);
    }
    Ok(())
}

/// Guard against absurdly long prefixes that could exceed proxy buffers.
pub fn guard_prefix_len(prefix: &str, max: usize) -> Result<(), SafetyError> {
    if prefix.len() > max {
        return Err(SafetyError::PrefixTooLong {
            len: prefix.len(),
            max,
        });
    }
    Ok(())
}

/// Safety-related errors.
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum SafetyError {
    #[error("input contains CRLF — possible accidental header injection")]
    HeaderInjection,
    #[error("prefix length {len} exceeds maximum {max}")]
    PrefixTooLong { len: usize, max: usize },
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashSet;

    #[test]
    fn canary_unique() {
        let mut set = HashSet::new();
        for _ in 0..100 {
            let c = Canary::generate();
            assert_eq!(c.token.len(), 16);
            assert!(set.insert(c.token));
        }
    }

    #[test]
    fn scan_policy_backoff_monotonic() {
        let policy = ScanPolicy::default();
        let d0 = policy.backoff_delay(0);
        let d1 = policy.backoff_delay(1);
        let d2 = policy.backoff_delay(2);
        assert!(d1 >= d0);
        assert!(d2 >= d1);
        let d_max = policy.backoff_delay(100);
        assert!(d_max <= Duration::from_millis(policy.max_delay_ms + policy.max_delay_ms / 4));
    }

    #[test]
    fn circuit_breaker_cycles() {
        let mut cb = CircuitBreaker::new(2, 10);
        assert!(cb.can_proceed());
        cb.record_failure();
        assert!(cb.can_proceed());
        cb.record_failure();
        assert!(!cb.can_proceed());
        assert_eq!(cb.state, CircuitState::Open);
        std::thread::sleep(Duration::from_millis(15));
        assert!(cb.can_proceed());
        assert_eq!(cb.state, CircuitState::HalfOpen);
        cb.record_success();
        assert_eq!(cb.state, CircuitState::Closed);
    }

    #[test]
    fn sanitize_rejects_crlf() {
        assert!(sanitize_input("a\r\nb").is_err());
        assert!(sanitize_input("a\nb").is_err());
        assert!(sanitize_input("a\rb").is_err());
        assert!(sanitize_input("safe").is_ok());
    }

    #[test]
    fn guard_no_crlf_rejects_newlines() {
        assert!(guard_no_crlf("a\r\nb").is_err());
        assert!(guard_no_crlf("a\nb").is_err());
        assert!(guard_no_crlf("a\rb").is_err());
        assert!(guard_no_crlf("safe").is_ok());
    }

    #[test]
    fn guard_prefix_len_rejects_overflow() {
        assert!(guard_prefix_len(&"A".repeat(100_000), 1000).is_err());
        assert!(guard_prefix_len("short", 1000).is_ok());
    }

    #[test]
    fn guard_prefix_len_error_variant() {
        let result = guard_prefix_len(&"A".repeat(100_000), 1000);
        match result {
            Err(SafetyError::PrefixTooLong { len, max }) => {
                assert_eq!(len, 100_000);
                assert_eq!(max, 1000);
            }
            other => panic!("expected PrefixTooLong error, got {other:?}"),
        }
    }

    #[test]
    fn sanitize_input_error_variant() {
        match sanitize_input("a\r\nb") {
            Err(SafetyError::HeaderInjection) => {}
            other => panic!("expected HeaderInjection error, got {other:?}"),
        }
        match sanitize_input("a\nb") {
            Err(SafetyError::HeaderInjection) => {}
            other => panic!("expected HeaderInjection error, got {other:?}"),
        }
        match sanitize_input("safe") {
            Ok(s) => assert_eq!(s, "safe"),
            other => panic!("expected Ok, got {other:?}"),
        }
    }

    #[test]
    fn guard_no_crlf_error_variant() {
        match guard_no_crlf("a\r\nb") {
            Err(SafetyError::HeaderInjection) => {}
            other => panic!("expected HeaderInjection error, got {other:?}"),
        }
        match guard_no_crlf("safe") {
            Ok(()) => {}
            other => panic!("expected Ok, got {other:?}"),
        }
    }

    #[test]
    fn cache_buster_changes() {
        let a = cache_buster();
        let b = cache_buster();
        assert!(!a.is_empty());
        assert!(!b.is_empty());
        // Very unlikely to collide
        assert_ne!(a, b);
    }
}