use dashmap::DashMap;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::tunnel::TunnelConnection;
#[derive(Debug)]
struct AuthAttempt {
failed_count: u32,
first_failure: Instant,
banned_until: Option<Instant>,
}
#[derive(Debug, Clone)]
pub struct AuthRateLimitConfig {
pub max_failed_attempts: u32,
pub ban_duration: Duration,
pub attempt_window: Duration,
}
impl Default for AuthRateLimitConfig {
fn default() -> Self {
Self {
max_failed_attempts: 5,
ban_duration: Duration::from_secs(300), attempt_window: Duration::from_secs(60), }
}
}
pub struct AuthRateLimiter {
attempts: DashMap<String, AuthAttempt>,
config: AuthRateLimitConfig,
}
impl AuthRateLimiter {
pub fn new(config: AuthRateLimitConfig) -> Self {
Self {
attempts: DashMap::new(),
config,
}
}
pub fn is_banned(&self, ip: &str) -> bool {
if let Some(attempt) = self.attempts.get(ip) {
if let Some(banned_until) = attempt.banned_until {
if Instant::now() < banned_until {
return true;
}
}
}
false
}
pub fn record_failure(&self, ip: &str) {
let now = Instant::now();
self.attempts
.entry(ip.to_string())
.and_modify(|attempt| {
if now.duration_since(attempt.first_failure) > self.config.attempt_window {
attempt.failed_count = 1;
attempt.first_failure = now;
attempt.banned_until = None;
} else {
attempt.failed_count += 1;
if attempt.failed_count >= self.config.max_failed_attempts {
attempt.banned_until = Some(now + self.config.ban_duration);
}
}
})
.or_insert(AuthAttempt {
failed_count: 1,
first_failure: now,
banned_until: None,
});
}
pub fn record_success(&self, ip: &str) {
self.attempts.remove(ip);
}
pub fn failed_attempts(&self, ip: &str) -> u32 {
self.attempts
.get(ip)
.map(|a| a.failed_count)
.unwrap_or(0)
}
pub fn ban_remaining(&self, ip: &str) -> Option<Duration> {
self.attempts.get(ip).and_then(|attempt| {
attempt.banned_until.and_then(|until| {
let now = Instant::now();
if now < until {
Some(until - now)
} else {
None
}
})
})
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TunnelClaims {
pub sub: String,
pub exp: usize,
pub iat: usize,
pub jti: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tunnel_id: Option<String>,
}
pub struct TokenManager {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
revoked_tokens: DashMap<String, Instant>,
token_validity: Duration,
}
impl TokenManager {
pub fn new(secret: &[u8], token_validity: Duration) -> Self {
Self {
encoding_key: EncodingKey::from_secret(secret),
decoding_key: DecodingKey::from_secret(secret),
revoked_tokens: DashMap::new(),
token_validity,
}
}
pub fn generate_token(&self, subject: &str, tunnel_id: Option<String>) -> Result<String, jsonwebtoken::errors::Error> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as usize;
let claims = TunnelClaims {
sub: subject.to_string(),
exp: now + self.token_validity.as_secs() as usize,
iat: now,
jti: nanoid::nanoid!(16),
tunnel_id,
};
encode(&Header::default(), &claims, &self.encoding_key)
}
pub fn validate_token(&self, token: &str) -> Result<TunnelClaims, TokenError> {
let validation = Validation::default();
let token_data = decode::<TunnelClaims>(token, &self.decoding_key, &validation)
.map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => TokenError::Expired,
jsonwebtoken::errors::ErrorKind::InvalidSignature => TokenError::InvalidSignature,
_ => TokenError::Invalid(e.to_string()),
})?;
if self.revoked_tokens.contains_key(&token_data.claims.jti) {
return Err(TokenError::Revoked);
}
Ok(token_data.claims)
}
pub fn revoke_token(&self, jti: &str) {
self.revoked_tokens.insert(jti.to_string(), Instant::now());
}
pub fn cleanup_revocations(&self, max_age: Duration) {
let now = Instant::now();
self.revoked_tokens.retain(|_, revoked_at| {
now.duration_since(*revoked_at) < max_age
});
}
}
#[derive(Debug, Clone)]
pub enum TokenError {
Expired,
Revoked,
InvalidSignature,
Invalid(String),
}
impl std::fmt::Display for TokenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TokenError::Expired => write!(f, "Token has expired"),
TokenError::Revoked => write!(f, "Token has been revoked"),
TokenError::InvalidSignature => write!(f, "Invalid token signature"),
TokenError::Invalid(msg) => write!(f, "Invalid token: {}", msg),
}
}
}
#[derive(Debug, Clone)]
pub struct RelayConfig {
pub base_domain: String,
pub listen_addr: SocketAddr,
pub request_timeout: Duration,
pub max_tunnels_per_ip: usize,
pub use_https: bool,
pub auth_tokens: HashSet<String>,
pub require_auth: bool,
pub jwt_secret: Option<Vec<u8>>,
pub jwt_validity: Duration,
pub auth_rate_limit: AuthRateLimitConfig,
pub max_tunnel_age: Option<Duration>,
pub max_idle_time: Option<Duration>,
pub allow_custom_ids: bool,
}
impl Default for RelayConfig {
fn default() -> Self {
Self {
base_domain: "localhost:3001".to_string(),
listen_addr: "127.0.0.1:3001".parse().unwrap(),
request_timeout: Duration::from_secs(30),
max_tunnels_per_ip: 10,
use_https: false,
auth_tokens: HashSet::new(),
require_auth: false,
jwt_secret: None,
jwt_validity: Duration::from_secs(3600), auth_rate_limit: AuthRateLimitConfig::default(),
max_tunnel_age: None, max_idle_time: None, allow_custom_ids: true, }
}
}
#[derive(Debug)]
pub enum AuthResult {
Success,
SuccessWithClaims(TunnelClaims),
NotRequired,
Banned { remaining: Duration },
Invalid(String),
}
impl AuthResult {
pub fn is_success(&self) -> bool {
matches!(self, AuthResult::Success | AuthResult::SuccessWithClaims(_) | AuthResult::NotRequired)
}
}
pub struct RelayState {
pub tunnels: DashMap<String, Arc<TunnelConnection>>,
pub tunnels_per_ip: DashMap<String, usize>,
pub config: RelayConfig,
pub auth_rate_limiter: AuthRateLimiter,
pub token_manager: Option<TokenManager>,
}
impl RelayState {
pub fn new(config: RelayConfig) -> Self {
let auth_rate_limiter = AuthRateLimiter::new(config.auth_rate_limit.clone());
let token_manager = config.jwt_secret.as_ref().map(|secret| {
TokenManager::new(secret, config.jwt_validity)
});
Self {
tunnels: DashMap::new(),
tunnels_per_ip: DashMap::new(),
auth_rate_limiter,
token_manager,
config,
}
}
pub fn validate_token(&self, token: Option<&str>) -> bool {
if !self.config.require_auth {
return true; }
match token {
Some(t) if !t.is_empty() => {
if self.config.auth_tokens.contains(t) {
return true;
}
if let Some(ref tm) = self.token_manager {
return tm.validate_token(t).is_ok();
}
false
},
_ => false,
}
}
pub fn validate_auth(&self, ip: &str, token: Option<&str>) -> AuthResult {
if let Some(remaining) = self.auth_rate_limiter.ban_remaining(ip) {
return AuthResult::Banned { remaining };
}
if !self.config.require_auth {
return AuthResult::NotRequired;
}
match token {
Some(t) if !t.is_empty() => {
if self.config.auth_tokens.contains(t) {
self.auth_rate_limiter.record_success(ip);
return AuthResult::Success;
}
if let Some(ref tm) = self.token_manager {
match tm.validate_token(t) {
Ok(claims) => {
self.auth_rate_limiter.record_success(ip);
return AuthResult::SuccessWithClaims(claims);
}
Err(e) => {
self.auth_rate_limiter.record_failure(ip);
return AuthResult::Invalid(e.to_string());
}
}
}
self.auth_rate_limiter.record_failure(ip);
AuthResult::Invalid("Invalid token".to_string())
}
_ => {
self.auth_rate_limiter.record_failure(ip);
AuthResult::Invalid("Missing token".to_string())
}
}
}
pub fn generate_token(&self, subject: &str, tunnel_id: Option<String>) -> Option<String> {
self.token_manager.as_ref().and_then(|tm| {
tm.generate_token(subject, tunnel_id).ok()
})
}
pub fn revoke_token(&self, jti: &str) {
if let Some(ref tm) = self.token_manager {
tm.revoke_token(jti);
}
}
pub fn can_create_tunnel(&self, ip: &str) -> bool {
let count = self.tunnels_per_ip.get(ip).map(|r| *r).unwrap_or(0);
count < self.config.max_tunnels_per_ip
}
fn increment_ip_count(&self, ip: &str) {
self.tunnels_per_ip
.entry(ip.to_string())
.and_modify(|c| *c += 1)
.or_insert(1);
}
fn decrement_ip_count(&self, ip: &str) {
if let Some(mut count) = self.tunnels_per_ip.get_mut(ip) {
if *count > 0 {
*count -= 1;
}
if *count == 0 {
drop(count);
self.tunnels_per_ip.remove(ip);
}
}
}
pub fn register_tunnel(&self, tunnel: TunnelConnection) -> Arc<TunnelConnection> {
let tunnel_id = tunnel.tunnel_id.clone();
let source_ip = tunnel.source_ip.clone();
let tunnel = Arc::new(tunnel);
self.tunnels.insert(tunnel_id, tunnel.clone());
self.increment_ip_count(&source_ip);
tunnel
}
pub fn remove_tunnel(&self, tunnel_id: &str) -> Option<Arc<TunnelConnection>> {
if let Some((_, tunnel)) = self.tunnels.remove(tunnel_id) {
self.decrement_ip_count(&tunnel.source_ip);
Some(tunnel)
} else {
None
}
}
pub fn get_tunnel(&self, tunnel_id: &str) -> Option<Arc<TunnelConnection>> {
self.tunnels.get(tunnel_id).map(|r| r.clone())
}
pub fn tunnel_url(&self, tunnel_id: &str) -> String {
let scheme = if self.config.use_https { "https" } else { "http" };
format!("{}://{}.{}", scheme, tunnel_id, self.config.base_domain)
}
pub fn tunnel_count_for_ip(&self, ip: &str) -> usize {
self.tunnels_per_ip.get(ip).map(|r| *r).unwrap_or(0)
}
pub async fn is_tunnel_expired(&self, tunnel: &crate::tunnel::TunnelConnection) -> bool {
let now = chrono::Utc::now();
if let Some(max_age) = self.config.max_tunnel_age {
let age = now.signed_duration_since(tunnel.created_at);
if age.to_std().unwrap_or(Duration::ZERO) >= max_age {
return true;
}
}
if let Some(max_idle) = self.config.max_idle_time {
let last_activity = *tunnel.last_activity.read().await;
let idle_time = now.signed_duration_since(last_activity);
if idle_time.to_std().unwrap_or(Duration::ZERO) >= max_idle {
return true;
}
}
false
}
pub async fn cleanup_expired_tunnels(&self) -> usize {
let mut expired_ids = Vec::new();
for entry in self.tunnels.iter() {
if self.is_tunnel_expired(entry.value()).await {
expired_ids.push(entry.key().clone());
}
}
let count = expired_ids.len();
for tunnel_id in expired_ids {
if let Some(tunnel) = self.remove_tunnel(&tunnel_id) {
tracing::info!(
tunnel_id = %tunnel_id,
source_ip = %tunnel.source_ip,
age_secs = (chrono::Utc::now() - tunnel.created_at).num_seconds(),
"Tunnel expired and removed"
);
}
}
count
}
pub fn tunnel_count(&self) -> usize {
self.tunnels.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
fn create_test_config() -> RelayConfig {
RelayConfig {
base_domain: "test.example.com".to_string(),
listen_addr: "127.0.0.1:3001".parse().unwrap(),
request_timeout: Duration::from_secs(30),
max_tunnels_per_ip: 3,
use_https: false,
auth_tokens: ["token1".to_string(), "token2".to_string()].into_iter().collect(),
require_auth: true,
jwt_secret: None,
jwt_validity: Duration::from_secs(3600),
auth_rate_limit: AuthRateLimitConfig::default(),
max_tunnel_age: None,
max_idle_time: None,
allow_custom_ids: true,
}
}
fn create_jwt_config() -> RelayConfig {
RelayConfig {
jwt_secret: Some(b"test-secret-key-for-jwt-testing".to_vec()),
jwt_validity: Duration::from_secs(3600),
require_auth: true,
auth_tokens: HashSet::new(), ..create_test_config()
}
}
fn create_test_tunnel(tunnel_id: &str, source_ip: &str) -> crate::tunnel::TunnelConnection {
let (tx, _rx) = mpsc::channel(10);
crate::tunnel::TunnelConnection::new(
tunnel_id.to_string(),
tx,
source_ip.to_string(),
)
}
#[test]
fn test_validate_token_valid() {
let state = RelayState::new(create_test_config());
assert!(state.validate_token(Some("token1")));
assert!(state.validate_token(Some("token2")));
}
#[test]
fn test_validate_token_invalid() {
let state = RelayState::new(create_test_config());
assert!(!state.validate_token(Some("invalid-token")));
assert!(!state.validate_token(Some("")));
assert!(!state.validate_token(None));
}
#[test]
fn test_validate_token_auth_not_required() {
let mut config = create_test_config();
config.require_auth = false;
let state = RelayState::new(config);
assert!(state.validate_token(None));
assert!(state.validate_token(Some("")));
assert!(state.validate_token(Some("random")));
}
#[test]
fn test_can_create_tunnel_under_limit() {
let state = RelayState::new(create_test_config());
let ip = "192.168.1.1";
assert!(state.can_create_tunnel(ip));
assert_eq!(state.tunnel_count_for_ip(ip), 0);
}
#[test]
fn test_rate_limiting_enforced() {
let state = RelayState::new(create_test_config());
let ip = "192.168.1.1";
let t1 = create_test_tunnel("tunnel1", ip);
let t2 = create_test_tunnel("tunnel2", ip);
let t3 = create_test_tunnel("tunnel3", ip);
state.register_tunnel(t1);
assert_eq!(state.tunnel_count_for_ip(ip), 1);
assert!(state.can_create_tunnel(ip));
state.register_tunnel(t2);
assert_eq!(state.tunnel_count_for_ip(ip), 2);
assert!(state.can_create_tunnel(ip));
state.register_tunnel(t3);
assert_eq!(state.tunnel_count_for_ip(ip), 3);
assert!(!state.can_create_tunnel(ip));
}
#[test]
fn test_rate_limiting_per_ip() {
let state = RelayState::new(create_test_config());
let ip1 = "192.168.1.1";
let ip2 = "192.168.1.2";
for i in 0..3 {
let t = create_test_tunnel(&format!("tunnel-ip1-{}", i), ip1);
state.register_tunnel(t);
}
assert!(!state.can_create_tunnel(ip1));
assert!(state.can_create_tunnel(ip2));
assert_eq!(state.tunnel_count_for_ip(ip2), 0);
}
#[test]
fn test_rate_limiting_released_on_disconnect() {
let state = RelayState::new(create_test_config());
let ip = "192.168.1.1";
for i in 0..3 {
let t = create_test_tunnel(&format!("tunnel{}", i), ip);
state.register_tunnel(t);
}
assert!(!state.can_create_tunnel(ip));
state.remove_tunnel("tunnel1");
assert_eq!(state.tunnel_count_for_ip(ip), 2);
assert!(state.can_create_tunnel(ip));
}
#[test]
fn test_tunnel_url_generation() {
let state = RelayState::new(create_test_config());
let url = state.tunnel_url("abc123");
assert_eq!(url, "http://abc123.test.example.com");
}
#[test]
fn test_tunnel_url_with_https() {
let mut config = create_test_config();
config.use_https = true;
let state = RelayState::new(config);
let url = state.tunnel_url("xyz789");
assert_eq!(url, "https://xyz789.test.example.com");
}
#[test]
fn test_auth_rate_limiter_tracks_failures() {
let limiter = AuthRateLimiter::new(AuthRateLimitConfig {
max_failed_attempts: 3,
ban_duration: Duration::from_secs(60),
attempt_window: Duration::from_secs(30),
});
let ip = "10.0.0.1";
assert_eq!(limiter.failed_attempts(ip), 0);
assert!(!limiter.is_banned(ip));
limiter.record_failure(ip);
assert_eq!(limiter.failed_attempts(ip), 1);
limiter.record_failure(ip);
assert_eq!(limiter.failed_attempts(ip), 2);
assert!(!limiter.is_banned(ip));
}
#[test]
fn test_auth_rate_limiter_bans_after_max_attempts() {
let limiter = AuthRateLimiter::new(AuthRateLimitConfig {
max_failed_attempts: 3,
ban_duration: Duration::from_secs(60),
attempt_window: Duration::from_secs(30),
});
let ip = "10.0.0.2";
for _ in 0..3 {
limiter.record_failure(ip);
}
assert!(limiter.is_banned(ip));
assert!(limiter.ban_remaining(ip).is_some());
}
#[test]
fn test_auth_rate_limiter_success_clears_failures() {
let limiter = AuthRateLimiter::new(AuthRateLimitConfig::default());
let ip = "10.0.0.3";
limiter.record_failure(ip);
limiter.record_failure(ip);
assert_eq!(limiter.failed_attempts(ip), 2);
limiter.record_success(ip);
assert_eq!(limiter.failed_attempts(ip), 0);
}
#[test]
fn test_validate_auth_with_rate_limiting() {
let mut config = create_test_config();
config.auth_rate_limit = AuthRateLimitConfig {
max_failed_attempts: 2,
ban_duration: Duration::from_secs(60),
attempt_window: Duration::from_secs(30),
};
let state = RelayState::new(config);
let ip = "10.0.0.4";
assert!(state.validate_auth(ip, Some("token1")).is_success());
assert!(!state.validate_auth(ip, Some("bad")).is_success());
assert!(!state.validate_auth(ip, Some("bad")).is_success());
let result = state.validate_auth(ip, Some("token1")); assert!(matches!(result, AuthResult::Banned { .. }));
}
#[test]
fn test_jwt_token_generation_and_validation() {
let config = create_jwt_config();
let state = RelayState::new(config);
let token = state.generate_token("user123", None);
assert!(token.is_some());
let token = token.unwrap();
assert!(!token.is_empty());
assert!(state.validate_token(Some(&token)));
}
#[test]
fn test_jwt_token_with_tunnel_id() {
let config = create_jwt_config();
let state = RelayState::new(config);
let token = state.generate_token("user456", Some("my-tunnel".to_string()));
assert!(token.is_some());
let token = token.unwrap();
let result = state.validate_auth("10.0.0.5", Some(&token));
match result {
AuthResult::SuccessWithClaims(claims) => {
assert_eq!(claims.sub, "user456");
assert_eq!(claims.tunnel_id, Some("my-tunnel".to_string()));
}
_ => panic!("Expected SuccessWithClaims, got {:?}", result),
}
}
#[test]
fn test_jwt_token_revocation() {
let config = create_jwt_config();
let state = RelayState::new(config);
let token = state.generate_token("user789", None).unwrap();
let result = state.validate_auth("10.0.0.6", Some(&token));
let jti = match result {
AuthResult::SuccessWithClaims(claims) => claims.jti,
_ => panic!("Expected success"),
};
state.revoke_token(&jti);
let result = state.validate_auth("10.0.0.6", Some(&token));
assert!(matches!(result, AuthResult::Invalid(_)));
}
#[test]
fn test_jwt_invalid_signature() {
let config = create_jwt_config();
let state = RelayState::new(config);
let tampered = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyIiwiZXhwIjo5OTk5OTk5OTk5LCJpYXQiOjE3MDAwMDAwMDAsImp0aSI6InRlc3QifQ.invalid_signature";
let result = state.validate_auth("10.0.0.7", Some(tampered));
assert!(!result.is_success());
}
#[test]
fn test_token_manager_cleanup_revocations() {
let tm = TokenManager::new(b"secret", Duration::from_secs(3600));
tm.revoke_token("token1");
tm.revoke_token("token2");
tm.revoke_token("token3");
tm.cleanup_revocations(Duration::from_secs(0));
}
#[test]
fn test_auth_not_required_returns_not_required() {
let mut config = create_test_config();
config.require_auth = false;
let state = RelayState::new(config);
let result = state.validate_auth("10.0.0.8", None);
assert!(matches!(result, AuthResult::NotRequired));
}
#[tokio::test]
async fn test_tunnel_not_expired_without_limits() {
let config = create_test_config();
let state = RelayState::new(config);
let tunnel = create_test_tunnel("test1", "192.168.1.1");
let tunnel = state.register_tunnel(tunnel);
assert!(!state.is_tunnel_expired(&tunnel).await);
}
#[tokio::test]
async fn test_tunnel_expired_by_age() {
let mut config = create_test_config();
config.max_tunnel_age = Some(Duration::from_millis(50)); let state = RelayState::new(config);
let tunnel = create_test_tunnel("test-ttl", "192.168.1.1");
let tunnel = state.register_tunnel(tunnel);
assert!(!state.is_tunnel_expired(&tunnel).await);
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(state.is_tunnel_expired(&tunnel).await);
}
#[tokio::test]
async fn test_cleanup_expired_tunnels() {
let mut config = create_test_config();
config.max_tunnel_age = Some(Duration::from_millis(50));
let state = RelayState::new(config);
let t1 = create_test_tunnel("tunnel1", "192.168.1.1");
let t2 = create_test_tunnel("tunnel2", "192.168.1.2");
state.register_tunnel(t1);
state.register_tunnel(t2);
assert_eq!(state.tunnel_count(), 2);
tokio::time::sleep(Duration::from_millis(100)).await;
let removed = state.cleanup_expired_tunnels().await;
assert_eq!(removed, 2);
assert_eq!(state.tunnel_count(), 0);
}
#[test]
fn test_allow_custom_ids_config() {
let mut config = create_test_config();
assert!(config.allow_custom_ids);
config.allow_custom_ids = false;
let state = RelayState::new(config);
assert!(!state.config.allow_custom_ids);
}
#[test]
fn test_tunnel_count() {
let config = create_test_config();
let state = RelayState::new(config);
assert_eq!(state.tunnel_count(), 0);
let t1 = create_test_tunnel("tunnel1", "192.168.1.1");
state.register_tunnel(t1);
assert_eq!(state.tunnel_count(), 1);
let t2 = create_test_tunnel("tunnel2", "192.168.1.2");
state.register_tunnel(t2);
assert_eq!(state.tunnel_count(), 2);
state.remove_tunnel("tunnel1");
assert_eq!(state.tunnel_count(), 1);
}
#[test]
fn test_public_relay_config() {
let config = RelayConfig {
require_auth: false, allow_custom_ids: false, max_tunnel_age: Some(Duration::from_secs(8 * 3600)), max_idle_time: Some(Duration::from_secs(1800)), max_tunnels_per_ip: 3, ..Default::default()
};
let state = RelayState::new(config);
assert!(state.validate_auth("10.0.0.1", None).is_success());
assert!(!state.config.allow_custom_ids);
assert!(state.config.max_tunnel_age.is_some());
assert!(state.config.max_idle_time.is_some());
}
}