use std::sync::atomic::{AtomicU64, Ordering};
#[must_use]
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
== 0
}
#[must_use]
pub fn generate_token() -> String {
uuid::Uuid::new_v4().to_string()
}
pub struct RateLimiter {
tokens: AtomicU64,
max_tokens: u64,
last_refill_ms: AtomicU64,
refill_rate_per_sec: u64,
epoch: std::time::Instant,
}
impl RateLimiter {
#[must_use]
pub fn new(max_requests_per_sec: u64) -> Self {
Self {
tokens: AtomicU64::new(max_requests_per_sec),
max_tokens: max_requests_per_sec,
last_refill_ms: AtomicU64::new(0),
refill_rate_per_sec: max_requests_per_sec,
epoch: std::time::Instant::now(),
}
}
pub fn try_acquire(&self) -> bool {
self.refill();
loop {
let current = self.tokens.load(Ordering::Relaxed);
if current == 0 {
return false;
}
if self
.tokens
.compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
return true;
}
}
}
#[must_use]
pub fn max_tokens(&self) -> u64 {
self.max_tokens
}
#[must_use]
pub fn current_tokens(&self) -> u64 {
self.tokens.load(Ordering::Relaxed)
}
fn elapsed_ms(&self) -> u64 {
self.epoch.elapsed().as_millis() as u64
}
fn refill(&self) {
let now = self.elapsed_ms();
let last = self.last_refill_ms.load(Ordering::Relaxed);
let elapsed_ms = now.saturating_sub(last);
if elapsed_ms == 0 {
return;
}
let add = elapsed_ms * self.refill_rate_per_sec / 1000;
if add == 0 {
return;
}
if self
.last_refill_ms
.compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
loop {
let current = self.tokens.load(Ordering::Relaxed);
let new_val = (current + add).min(self.max_tokens);
if self
.tokens
.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
break;
}
}
}
}
}
pub const DEFAULT_RATE_LIMIT: u64 = 1000;
#[must_use]
pub fn is_localhost_host(host: &str) -> bool {
let host_name = if host.starts_with('[') {
host.split(']').next().map_or(host, |s| &s[1..])
} else if host.contains("::") {
host
} else {
host.split(':').next().unwrap_or(host)
};
matches!(host_name, "localhost" | "127.0.0.1" | "::1")
}
#[must_use]
pub fn is_allowed_origin(origin: &str) -> bool {
if origin.starts_with("tauri://") {
return true;
}
let Ok(parsed) = url::Url::parse(origin) else {
return false;
};
parsed.username().is_empty()
&& parsed.password().is_none()
&& matches!(parsed.scheme(), "http" | "https")
&& matches!(
parsed.host_str(),
Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ct_eq_equal() {
assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
}
#[test]
fn ct_eq_different() {
assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
}
#[test]
fn ct_eq_different_lengths() {
assert!(!constant_time_eq(b"short", b"longer-string"));
}
#[test]
fn ct_eq_empty() {
assert!(constant_time_eq(b"", b""));
}
#[test]
fn ct_eq_one_empty() {
assert!(!constant_time_eq(b"", b"notempty"));
assert!(!constant_time_eq(b"notempty", b""));
}
#[test]
fn ct_eq_single_bit_difference() {
assert!(!constant_time_eq(b"A", b"B"));
}
#[test]
fn ct_eq_long_strings() {
let a = "a".repeat(10_000);
let b = "a".repeat(10_000);
assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
}
#[test]
fn ct_eq_all_byte_values() {
for b in 0..=255u8 {
let a = [b];
assert!(constant_time_eq(&a, &a));
if b < 255 {
assert!(!constant_time_eq(&a, &[b + 1]));
}
}
}
#[test]
fn tokens_are_unique() {
let t1 = generate_token();
let t2 = generate_token();
assert_ne!(t1, t2);
assert_eq!(t1.len(), 36);
}
#[test]
fn token_is_valid_uuid() {
let token = generate_token();
assert!(uuid::Uuid::parse_str(&token).is_ok());
}
#[test]
fn token_uniqueness_over_1000() {
let mut set = std::collections::HashSet::new();
for _ in 0..1000 {
assert!(set.insert(generate_token()), "duplicate token");
}
}
#[test]
fn rate_limiter_allows_within_budget() {
let limiter = RateLimiter::new(10);
for _ in 0..10 {
assert!(limiter.try_acquire());
}
}
#[test]
fn rate_limiter_denies_when_exhausted() {
let limiter = RateLimiter::new(5);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[test]
fn rate_limiter_initial_tokens_match_max() {
let limiter = RateLimiter::new(42);
assert_eq!(limiter.current_tokens(), 42);
assert_eq!(limiter.max_tokens(), 42);
}
#[test]
fn rate_limiter_zero_capacity() {
let limiter = RateLimiter::new(0);
assert!(!limiter.try_acquire());
}
#[test]
fn rate_limiter_concurrent() {
let limiter = std::sync::Arc::new(RateLimiter::new(1000));
let mut handles = vec![];
for _ in 0..10 {
let l = limiter.clone();
handles.push(std::thread::spawn(move || {
let mut acquired: u64 = 0;
for _ in 0..200 {
if l.try_acquire() {
acquired += 1;
}
}
acquired
}));
}
let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
assert!(
total >= 1000,
"should dispense at least the initial budget, got {total}"
);
assert!(total <= 1200, "refill overshoot too high, got {total}");
}
#[test]
fn host_allows_localhost() {
assert!(is_localhost_host("localhost"));
assert!(is_localhost_host("localhost:7373"));
}
#[test]
fn host_allows_ipv4() {
assert!(is_localhost_host("127.0.0.1"));
assert!(is_localhost_host("127.0.0.1:7373"));
}
#[test]
fn host_allows_ipv6() {
assert!(is_localhost_host("[::1]"));
assert!(is_localhost_host("[::1]:7373"));
assert!(is_localhost_host("::1"));
}
#[test]
fn host_blocks_evil() {
assert!(!is_localhost_host("evil.com"));
assert!(!is_localhost_host("localhost.evil.com"));
assert!(!is_localhost_host("127.0.0.1.evil.com"));
assert!(!is_localhost_host(""));
}
#[test]
fn origin_allows_localhost_variants() {
assert!(is_allowed_origin("http://localhost"));
assert!(is_allowed_origin("http://localhost:7373"));
assert!(is_allowed_origin("https://localhost"));
assert!(is_allowed_origin("http://127.0.0.1"));
assert!(is_allowed_origin("http://127.0.0.1:8080"));
assert!(is_allowed_origin("http://[::1]"));
assert!(is_allowed_origin("http://[::1]:7373"));
assert!(is_allowed_origin("tauri://localhost"));
assert!(is_allowed_origin("tauri://some-app"));
}
#[test]
fn origin_blocks_smuggling() {
assert!(!is_allowed_origin("http://localhost.evil.com"));
assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
assert!(!is_allowed_origin("http://localhost@evil.com"));
assert!(!is_allowed_origin("http://user:pass@localhost:7373"));
}
#[test]
fn origin_blocks_external() {
assert!(!is_allowed_origin("http://evil.com"));
assert!(!is_allowed_origin("https://attacker.io"));
assert!(!is_allowed_origin("not-a-url"));
assert!(!is_allowed_origin(""));
assert!(!is_allowed_origin("null"));
assert!(!is_allowed_origin("ftp://localhost"));
}
}