use rand::Rng;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Canary {
pub token: String,
}
impl Canary {
#[must_use]
pub fn generate() -> Self {
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::thread_rng();
let token: String = (0..16)
.map(|_| CHARSET[rng.gen_range(0..CHARSET.len())] as char)
.collect();
Self { token }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ScanPolicy {
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub max_retries: u32,
pub jitter: bool,
pub fresh_connection: bool,
}
impl Default for ScanPolicy {
fn default() -> Self {
Self {
base_delay_ms: 100,
max_delay_ms: 10_000,
max_retries: 3,
jitter: true,
fresh_connection: true,
}
}
}
impl ScanPolicy {
#[must_use]
pub fn backoff_delay(&self, attempt: u32) -> Duration {
let exp = 1u64 << attempt.min(63);
let ms = self
.base_delay_ms
.saturating_mul(exp)
.min(self.max_delay_ms);
let jitter_ms = if self.jitter {
rand::thread_rng().gen_range(0..=(ms / 4))
} else {
0
};
Duration::from_millis(ms.saturating_add(jitter_ms))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionPolicy {
Fresh,
Reuse,
Multiplex,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
pub failure_threshold: u32,
pub recovery_timeout: Duration,
pub state: CircuitState,
pub consecutive_failures: u32,
pub last_failure: Option<Instant>,
}
impl CircuitBreaker {
#[must_use]
pub fn new(failure_threshold: u32, recovery_timeout_ms: u64) -> Self {
Self {
failure_threshold,
recovery_timeout: Duration::from_millis(recovery_timeout_ms),
state: CircuitState::Closed,
consecutive_failures: 0,
last_failure: None,
}
}
pub fn record_failure(&mut self) {
self.consecutive_failures = self.consecutive_failures.saturating_add(1);
self.last_failure = Some(Instant::now());
if self.consecutive_failures >= self.failure_threshold {
self.state = CircuitState::Open;
}
}
pub fn record_success(&mut self) {
self.consecutive_failures = 0;
self.state = CircuitState::Closed;
}
#[must_use]
pub fn can_proceed(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(last) = self.last_failure
&& last.elapsed() >= self.recovery_timeout
{
self.state = CircuitState::HalfOpen;
return true;
}
false
}
CircuitState::HalfOpen => true,
}
}
}
#[must_use]
pub fn cache_buster() -> String {
let mut rng = rand::thread_rng();
format!("{}", rng.gen_range(0..=u32::MAX))
}
pub fn sanitize_input(input: &str) -> Result<String, SafetyError> {
if input.contains('\r') || input.contains('\n') || input.contains('\0') {
return Err(SafetyError::HeaderInjection);
}
Ok(input.into())
}
pub fn guard_no_crlf(input: &str) -> Result<(), SafetyError> {
if input.contains('\r') || input.contains('\n') || input.contains('\0') {
return Err(SafetyError::HeaderInjection);
}
Ok(())
}
pub fn guard_prefix_len(prefix: &str, max: usize) -> Result<(), SafetyError> {
if prefix.len() > max {
return Err(SafetyError::PrefixTooLong {
len: prefix.len(),
max,
});
}
Ok(())
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum SafetyError {
#[error("input contains CRLF — possible accidental header injection")]
HeaderInjection,
#[error("prefix length {len} exceeds maximum {max}")]
PrefixTooLong { len: usize, max: usize },
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn canary_unique() {
let mut set = HashSet::new();
for _ in 0..100 {
let c = Canary::generate();
assert_eq!(c.token.len(), 16);
assert!(set.insert(c.token));
}
}
#[test]
fn scan_policy_backoff_monotonic() {
let policy = ScanPolicy::default();
let d0 = policy.backoff_delay(0);
let d1 = policy.backoff_delay(1);
let d2 = policy.backoff_delay(2);
assert!(d1 >= d0);
assert!(d2 >= d1);
let d_max = policy.backoff_delay(100);
assert!(d_max <= Duration::from_millis(policy.max_delay_ms + policy.max_delay_ms / 4));
}
#[test]
fn circuit_breaker_cycles() {
let mut cb = CircuitBreaker::new(2, 10);
assert!(cb.can_proceed());
cb.record_failure();
assert!(cb.can_proceed());
cb.record_failure();
assert!(!cb.can_proceed());
assert_eq!(cb.state, CircuitState::Open);
std::thread::sleep(Duration::from_millis(15));
assert!(cb.can_proceed());
assert_eq!(cb.state, CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state, CircuitState::Closed);
}
#[test]
fn sanitize_rejects_crlf() {
assert!(sanitize_input("a\r\nb").is_err());
assert!(sanitize_input("a\nb").is_err());
assert!(sanitize_input("a\rb").is_err());
assert!(sanitize_input("safe").is_ok());
}
#[test]
fn guard_no_crlf_rejects_newlines() {
assert!(guard_no_crlf("a\r\nb").is_err());
assert!(guard_no_crlf("a\nb").is_err());
assert!(guard_no_crlf("a\rb").is_err());
assert!(guard_no_crlf("safe").is_ok());
}
#[test]
fn guard_prefix_len_rejects_overflow() {
assert!(guard_prefix_len(&"A".repeat(100_000), 1000).is_err());
assert!(guard_prefix_len("short", 1000).is_ok());
}
#[test]
fn guard_prefix_len_error_variant() {
let result = guard_prefix_len(&"A".repeat(100_000), 1000);
match result {
Err(SafetyError::PrefixTooLong { len, max }) => {
assert_eq!(len, 100_000);
assert_eq!(max, 1000);
}
other => panic!("expected PrefixTooLong error, got {other:?}"),
}
}
#[test]
fn sanitize_input_error_variant() {
match sanitize_input("a\r\nb") {
Err(SafetyError::HeaderInjection) => {}
other => panic!("expected HeaderInjection error, got {other:?}"),
}
match sanitize_input("a\nb") {
Err(SafetyError::HeaderInjection) => {}
other => panic!("expected HeaderInjection error, got {other:?}"),
}
match sanitize_input("safe") {
Ok(s) => assert_eq!(s, "safe"),
other => panic!("expected Ok, got {other:?}"),
}
}
#[test]
fn guard_no_crlf_error_variant() {
match guard_no_crlf("a\r\nb") {
Err(SafetyError::HeaderInjection) => {}
other => panic!("expected HeaderInjection error, got {other:?}"),
}
match guard_no_crlf("safe") {
Ok(()) => {}
other => panic!("expected Ok, got {other:?}"),
}
}
#[test]
fn cache_buster_changes() {
let a = cache_buster();
let b = cache_buster();
assert!(!a.is_empty());
assert!(!b.is_empty());
assert_ne!(a, b);
}
}