Skip to main content

codive_relay/
state.rs

1//! Relay server shared state
2
3use dashmap::DashMap;
4use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use crate::tunnel::TunnelConnection;
12
13// ============================================================================
14// Auth Rate Limiting
15// ============================================================================
16
17/// Tracks failed authentication attempts per IP for rate limiting
18#[derive(Debug)]
19struct AuthAttempt {
20    /// Number of consecutive failed attempts
21    failed_count: u32,
22    /// When the first failure in this window occurred
23    first_failure: Instant,
24    /// When the ban expires (if banned)
25    banned_until: Option<Instant>,
26}
27
28/// Configuration for auth rate limiting
29#[derive(Debug, Clone)]
30pub struct AuthRateLimitConfig {
31    /// Maximum failed attempts before temporary ban
32    pub max_failed_attempts: u32,
33    /// Duration of temporary ban
34    pub ban_duration: Duration,
35    /// Window for counting failed attempts (resets after this time without failures)
36    pub attempt_window: Duration,
37}
38
39impl Default for AuthRateLimitConfig {
40    fn default() -> Self {
41        Self {
42            max_failed_attempts: 5,
43            ban_duration: Duration::from_secs(300), // 5 minutes
44            attempt_window: Duration::from_secs(60), // 1 minute
45        }
46    }
47}
48
49/// Rate limiter for authentication attempts
50pub struct AuthRateLimiter {
51    attempts: DashMap<String, AuthAttempt>,
52    config: AuthRateLimitConfig,
53}
54
55impl AuthRateLimiter {
56    pub fn new(config: AuthRateLimitConfig) -> Self {
57        Self {
58            attempts: DashMap::new(),
59            config,
60        }
61    }
62
63    /// Check if an IP is currently banned
64    pub fn is_banned(&self, ip: &str) -> bool {
65        if let Some(attempt) = self.attempts.get(ip) {
66            if let Some(banned_until) = attempt.banned_until {
67                if Instant::now() < banned_until {
68                    return true;
69                }
70            }
71        }
72        false
73    }
74
75    /// Record a failed authentication attempt
76    pub fn record_failure(&self, ip: &str) {
77        let now = Instant::now();
78
79        self.attempts
80            .entry(ip.to_string())
81            .and_modify(|attempt| {
82                // Reset if window expired
83                if now.duration_since(attempt.first_failure) > self.config.attempt_window {
84                    attempt.failed_count = 1;
85                    attempt.first_failure = now;
86                    attempt.banned_until = None;
87                } else {
88                    attempt.failed_count += 1;
89
90                    // Ban if exceeded max attempts
91                    if attempt.failed_count >= self.config.max_failed_attempts {
92                        attempt.banned_until = Some(now + self.config.ban_duration);
93                    }
94                }
95            })
96            .or_insert(AuthAttempt {
97                failed_count: 1,
98                first_failure: now,
99                banned_until: None,
100            });
101    }
102
103    /// Record a successful authentication (clears failure count)
104    pub fn record_success(&self, ip: &str) {
105        self.attempts.remove(ip);
106    }
107
108    /// Get the number of failed attempts for an IP
109    pub fn failed_attempts(&self, ip: &str) -> u32 {
110        self.attempts
111            .get(ip)
112            .map(|a| a.failed_count)
113            .unwrap_or(0)
114    }
115
116    /// Get time remaining on ban (if banned)
117    pub fn ban_remaining(&self, ip: &str) -> Option<Duration> {
118        self.attempts.get(ip).and_then(|attempt| {
119            attempt.banned_until.and_then(|until| {
120                let now = Instant::now();
121                if now < until {
122                    Some(until - now)
123                } else {
124                    None
125                }
126            })
127        })
128    }
129}
130
131// ============================================================================
132// JWT Token Management
133// ============================================================================
134
135/// JWT claims for tunnel authentication
136#[derive(Debug, Serialize, Deserialize)]
137pub struct TunnelClaims {
138    /// Subject (user/client identifier)
139    pub sub: String,
140    /// Expiration time (Unix timestamp)
141    pub exp: usize,
142    /// Issued at (Unix timestamp)
143    pub iat: usize,
144    /// Token ID (for revocation)
145    pub jti: String,
146    /// Optional: specific tunnel ID this token can use
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub tunnel_id: Option<String>,
149}
150
151/// Manages JWT token generation and validation
152pub struct TokenManager {
153    encoding_key: EncodingKey,
154    decoding_key: DecodingKey,
155    /// Revoked token IDs
156    revoked_tokens: DashMap<String, Instant>,
157    /// Default token validity duration
158    token_validity: Duration,
159}
160
161impl TokenManager {
162    /// Create a new token manager with the given secret
163    pub fn new(secret: &[u8], token_validity: Duration) -> Self {
164        Self {
165            encoding_key: EncodingKey::from_secret(secret),
166            decoding_key: DecodingKey::from_secret(secret),
167            revoked_tokens: DashMap::new(),
168            token_validity,
169        }
170    }
171
172    /// Generate a new JWT token
173    pub fn generate_token(&self, subject: &str, tunnel_id: Option<String>) -> Result<String, jsonwebtoken::errors::Error> {
174        let now = std::time::SystemTime::now()
175            .duration_since(std::time::UNIX_EPOCH)
176            .unwrap()
177            .as_secs() as usize;
178
179        let claims = TunnelClaims {
180            sub: subject.to_string(),
181            exp: now + self.token_validity.as_secs() as usize,
182            iat: now,
183            jti: nanoid::nanoid!(16),
184            tunnel_id,
185        };
186
187        encode(&Header::default(), &claims, &self.encoding_key)
188    }
189
190    /// Validate a JWT token
191    pub fn validate_token(&self, token: &str) -> Result<TunnelClaims, TokenError> {
192        let validation = Validation::default();
193
194        let token_data = decode::<TunnelClaims>(token, &self.decoding_key, &validation)
195            .map_err(|e| match e.kind() {
196                jsonwebtoken::errors::ErrorKind::ExpiredSignature => TokenError::Expired,
197                jsonwebtoken::errors::ErrorKind::InvalidSignature => TokenError::InvalidSignature,
198                _ => TokenError::Invalid(e.to_string()),
199            })?;
200
201        // Check if token is revoked
202        if self.revoked_tokens.contains_key(&token_data.claims.jti) {
203            return Err(TokenError::Revoked);
204        }
205
206        Ok(token_data.claims)
207    }
208
209    /// Revoke a token by its ID
210    pub fn revoke_token(&self, jti: &str) {
211        self.revoked_tokens.insert(jti.to_string(), Instant::now());
212    }
213
214    /// Clean up expired revocations (call periodically)
215    pub fn cleanup_revocations(&self, max_age: Duration) {
216        let now = Instant::now();
217        self.revoked_tokens.retain(|_, revoked_at| {
218            now.duration_since(*revoked_at) < max_age
219        });
220    }
221}
222
223/// Token validation errors
224#[derive(Debug, Clone)]
225pub enum TokenError {
226    Expired,
227    Revoked,
228    InvalidSignature,
229    Invalid(String),
230}
231
232impl std::fmt::Display for TokenError {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        match self {
235            TokenError::Expired => write!(f, "Token has expired"),
236            TokenError::Revoked => write!(f, "Token has been revoked"),
237            TokenError::InvalidSignature => write!(f, "Invalid token signature"),
238            TokenError::Invalid(msg) => write!(f, "Invalid token: {}", msg),
239        }
240    }
241}
242
243// ============================================================================
244// Relay Configuration
245// ============================================================================
246
247/// Relay server configuration
248#[derive(Debug, Clone)]
249pub struct RelayConfig {
250    /// Base domain for tunnels (e.g., "relay.example.com")
251    pub base_domain: String,
252    /// Listen address
253    pub listen_addr: SocketAddr,
254    /// Request timeout
255    pub request_timeout: Duration,
256    /// Maximum tunnels per IP (rate limiting)
257    pub max_tunnels_per_ip: usize,
258    /// Whether to use HTTPS for tunnel URLs
259    pub use_https: bool,
260    /// Valid API tokens for authentication (empty = no auth required)
261    pub auth_tokens: HashSet<String>,
262    /// Whether authentication is required
263    pub require_auth: bool,
264    /// JWT secret for token-based auth (None = use simple token matching)
265    pub jwt_secret: Option<Vec<u8>>,
266    /// JWT token validity duration
267    pub jwt_validity: Duration,
268    /// Auth rate limiting configuration
269    pub auth_rate_limit: AuthRateLimitConfig,
270    /// Maximum tunnel age (TTL) - tunnels older than this are closed
271    /// None = no limit (tunnels live until disconnected)
272    pub max_tunnel_age: Option<Duration>,
273    /// Maximum idle time before tunnel is closed
274    /// None = no idle timeout
275    pub max_idle_time: Option<Duration>,
276    /// Allow custom tunnel IDs (if false, only random IDs are allowed)
277    /// Set to false for public relays to prevent subdomain squatting
278    pub allow_custom_ids: bool,
279}
280
281impl Default for RelayConfig {
282    fn default() -> Self {
283        Self {
284            base_domain: "localhost:3001".to_string(),
285            listen_addr: "127.0.0.1:3001".parse().unwrap(),
286            request_timeout: Duration::from_secs(30),
287            max_tunnels_per_ip: 10,
288            use_https: false,
289            auth_tokens: HashSet::new(),
290            require_auth: false,
291            jwt_secret: None,
292            jwt_validity: Duration::from_secs(3600), // 1 hour default
293            auth_rate_limit: AuthRateLimitConfig::default(),
294            max_tunnel_age: None,        // No TTL by default (for self-hosted)
295            max_idle_time: None,         // No idle timeout by default
296            allow_custom_ids: true,      // Allow custom IDs by default (for self-hosted)
297        }
298    }
299}
300
301// ============================================================================
302// Relay State
303// ============================================================================
304
305/// Result of token validation
306#[derive(Debug)]
307pub enum AuthResult {
308    /// Authentication successful
309    Success,
310    /// Authentication successful with JWT claims
311    SuccessWithClaims(TunnelClaims),
312    /// No authentication required
313    NotRequired,
314    /// IP is temporarily banned
315    Banned { remaining: Duration },
316    /// Token is invalid
317    Invalid(String),
318}
319
320impl AuthResult {
321    pub fn is_success(&self) -> bool {
322        matches!(self, AuthResult::Success | AuthResult::SuccessWithClaims(_) | AuthResult::NotRequired)
323    }
324}
325
326/// Shared state for the relay server
327pub struct RelayState {
328    /// Active tunnels indexed by tunnel_id
329    pub tunnels: DashMap<String, Arc<TunnelConnection>>,
330    /// Tunnel count per IP address (for rate limiting)
331    pub tunnels_per_ip: DashMap<String, usize>,
332    /// Configuration
333    pub config: RelayConfig,
334    /// Auth rate limiter (tracks failed attempts)
335    pub auth_rate_limiter: AuthRateLimiter,
336    /// JWT token manager (optional, for JWT-based auth)
337    pub token_manager: Option<TokenManager>,
338}
339
340impl RelayState {
341    /// Create a new relay state with the given configuration
342    pub fn new(config: RelayConfig) -> Self {
343        let auth_rate_limiter = AuthRateLimiter::new(config.auth_rate_limit.clone());
344        let token_manager = config.jwt_secret.as_ref().map(|secret| {
345            TokenManager::new(secret, config.jwt_validity)
346        });
347
348        Self {
349            tunnels: DashMap::new(),
350            tunnels_per_ip: DashMap::new(),
351            auth_rate_limiter,
352            token_manager,
353            config,
354        }
355    }
356
357    /// Validate an authentication token (simple check, backward compatible)
358    pub fn validate_token(&self, token: Option<&str>) -> bool {
359        if !self.config.require_auth {
360            return true; // Auth not required
361        }
362
363        match token {
364            Some(t) if !t.is_empty() => {
365                // Try simple token match first
366                if self.config.auth_tokens.contains(t) {
367                    return true;
368                }
369                // Try JWT if token manager is configured
370                if let Some(ref tm) = self.token_manager {
371                    return tm.validate_token(t).is_ok();
372                }
373                false
374            },
375            _ => false,
376        }
377    }
378
379    /// Validate authentication with rate limiting and detailed result
380    pub fn validate_auth(&self, ip: &str, token: Option<&str>) -> AuthResult {
381        // Check if IP is banned
382        if let Some(remaining) = self.auth_rate_limiter.ban_remaining(ip) {
383            return AuthResult::Banned { remaining };
384        }
385
386        // Auth not required
387        if !self.config.require_auth {
388            return AuthResult::NotRequired;
389        }
390
391        match token {
392            Some(t) if !t.is_empty() => {
393                // Try simple token match first
394                if self.config.auth_tokens.contains(t) {
395                    self.auth_rate_limiter.record_success(ip);
396                    return AuthResult::Success;
397                }
398
399                // Try JWT if token manager is configured
400                if let Some(ref tm) = self.token_manager {
401                    match tm.validate_token(t) {
402                        Ok(claims) => {
403                            self.auth_rate_limiter.record_success(ip);
404                            return AuthResult::SuccessWithClaims(claims);
405                        }
406                        Err(e) => {
407                            self.auth_rate_limiter.record_failure(ip);
408                            return AuthResult::Invalid(e.to_string());
409                        }
410                    }
411                }
412
413                // Token doesn't match any known tokens
414                self.auth_rate_limiter.record_failure(ip);
415                AuthResult::Invalid("Invalid token".to_string())
416            }
417            _ => {
418                self.auth_rate_limiter.record_failure(ip);
419                AuthResult::Invalid("Missing token".to_string())
420            }
421        }
422    }
423
424    /// Generate a JWT token (if JWT is configured)
425    pub fn generate_token(&self, subject: &str, tunnel_id: Option<String>) -> Option<String> {
426        self.token_manager.as_ref().and_then(|tm| {
427            tm.generate_token(subject, tunnel_id).ok()
428        })
429    }
430
431    /// Revoke a JWT token by its ID
432    pub fn revoke_token(&self, jti: &str) {
433        if let Some(ref tm) = self.token_manager {
434            tm.revoke_token(jti);
435        }
436    }
437
438    /// Check if an IP can create more tunnels (rate limiting)
439    pub fn can_create_tunnel(&self, ip: &str) -> bool {
440        let count = self.tunnels_per_ip.get(ip).map(|r| *r).unwrap_or(0);
441        count < self.config.max_tunnels_per_ip
442    }
443
444    /// Increment tunnel count for an IP
445    fn increment_ip_count(&self, ip: &str) {
446        self.tunnels_per_ip
447            .entry(ip.to_string())
448            .and_modify(|c| *c += 1)
449            .or_insert(1);
450    }
451
452    /// Decrement tunnel count for an IP
453    fn decrement_ip_count(&self, ip: &str) {
454        if let Some(mut count) = self.tunnels_per_ip.get_mut(ip) {
455            if *count > 0 {
456                *count -= 1;
457            }
458            if *count == 0 {
459                drop(count);
460                self.tunnels_per_ip.remove(ip);
461            }
462        }
463    }
464
465    /// Register a new tunnel
466    pub fn register_tunnel(&self, tunnel: TunnelConnection) -> Arc<TunnelConnection> {
467        let tunnel_id = tunnel.tunnel_id.clone();
468        let source_ip = tunnel.source_ip.clone();
469        let tunnel = Arc::new(tunnel);
470        self.tunnels.insert(tunnel_id, tunnel.clone());
471        self.increment_ip_count(&source_ip);
472        tunnel
473    }
474
475    /// Remove a tunnel by ID
476    pub fn remove_tunnel(&self, tunnel_id: &str) -> Option<Arc<TunnelConnection>> {
477        if let Some((_, tunnel)) = self.tunnels.remove(tunnel_id) {
478            self.decrement_ip_count(&tunnel.source_ip);
479            Some(tunnel)
480        } else {
481            None
482        }
483    }
484
485    /// Get a tunnel by ID
486    pub fn get_tunnel(&self, tunnel_id: &str) -> Option<Arc<TunnelConnection>> {
487        self.tunnels.get(tunnel_id).map(|r| r.clone())
488    }
489
490    /// Get the URL for a tunnel
491    pub fn tunnel_url(&self, tunnel_id: &str) -> String {
492        let scheme = if self.config.use_https { "https" } else { "http" };
493        format!("{}://{}.{}", scheme, tunnel_id, self.config.base_domain)
494    }
495
496    /// Get the number of active tunnels for an IP
497    pub fn tunnel_count_for_ip(&self, ip: &str) -> usize {
498        self.tunnels_per_ip.get(ip).map(|r| *r).unwrap_or(0)
499    }
500
501    /// Check if a tunnel has expired (by age or idle time)
502    pub async fn is_tunnel_expired(&self, tunnel: &crate::tunnel::TunnelConnection) -> bool {
503        let now = chrono::Utc::now();
504
505        // Check max age (TTL)
506        if let Some(max_age) = self.config.max_tunnel_age {
507            let age = now.signed_duration_since(tunnel.created_at);
508            if age.to_std().unwrap_or(Duration::ZERO) >= max_age {
509                return true;
510            }
511        }
512
513        // Check idle time
514        if let Some(max_idle) = self.config.max_idle_time {
515            let last_activity = *tunnel.last_activity.read().await;
516            let idle_time = now.signed_duration_since(last_activity);
517            if idle_time.to_std().unwrap_or(Duration::ZERO) >= max_idle {
518                return true;
519            }
520        }
521
522        false
523    }
524
525    /// Clean up expired tunnels
526    /// Returns the number of tunnels removed
527    pub async fn cleanup_expired_tunnels(&self) -> usize {
528        let mut expired_ids = Vec::new();
529
530        // Find expired tunnels
531        for entry in self.tunnels.iter() {
532            if self.is_tunnel_expired(entry.value()).await {
533                expired_ids.push(entry.key().clone());
534            }
535        }
536
537        // Remove them
538        let count = expired_ids.len();
539        for tunnel_id in expired_ids {
540            if let Some(tunnel) = self.remove_tunnel(&tunnel_id) {
541                tracing::info!(
542                    tunnel_id = %tunnel_id,
543                    source_ip = %tunnel.source_ip,
544                    age_secs = (chrono::Utc::now() - tunnel.created_at).num_seconds(),
545                    "Tunnel expired and removed"
546                );
547            }
548        }
549
550        count
551    }
552
553    /// Get total number of active tunnels
554    pub fn tunnel_count(&self) -> usize {
555        self.tunnels.len()
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use tokio::sync::mpsc;
563
564    fn create_test_config() -> RelayConfig {
565        RelayConfig {
566            base_domain: "test.example.com".to_string(),
567            listen_addr: "127.0.0.1:3001".parse().unwrap(),
568            request_timeout: Duration::from_secs(30),
569            max_tunnels_per_ip: 3,
570            use_https: false,
571            auth_tokens: ["token1".to_string(), "token2".to_string()].into_iter().collect(),
572            require_auth: true,
573            jwt_secret: None,
574            jwt_validity: Duration::from_secs(3600),
575            auth_rate_limit: AuthRateLimitConfig::default(),
576            max_tunnel_age: None,
577            max_idle_time: None,
578            allow_custom_ids: true,
579        }
580    }
581
582    fn create_jwt_config() -> RelayConfig {
583        RelayConfig {
584            jwt_secret: Some(b"test-secret-key-for-jwt-testing".to_vec()),
585            jwt_validity: Duration::from_secs(3600),
586            require_auth: true,
587            auth_tokens: HashSet::new(), // No simple tokens, only JWT
588            ..create_test_config()
589        }
590    }
591
592    fn create_test_tunnel(tunnel_id: &str, source_ip: &str) -> crate::tunnel::TunnelConnection {
593        let (tx, _rx) = mpsc::channel(10);
594        crate::tunnel::TunnelConnection::new(
595            tunnel_id.to_string(),
596            tx,
597            source_ip.to_string(),
598        )
599    }
600
601    // ============================================================================
602    // Authentication Tests
603    // ============================================================================
604
605    #[test]
606    fn test_validate_token_valid() {
607        let state = RelayState::new(create_test_config());
608        assert!(state.validate_token(Some("token1")));
609        assert!(state.validate_token(Some("token2")));
610    }
611
612    #[test]
613    fn test_validate_token_invalid() {
614        let state = RelayState::new(create_test_config());
615        assert!(!state.validate_token(Some("invalid-token")));
616        assert!(!state.validate_token(Some("")));
617        assert!(!state.validate_token(None));
618    }
619
620    #[test]
621    fn test_validate_token_auth_not_required() {
622        let mut config = create_test_config();
623        config.require_auth = false;
624        let state = RelayState::new(config);
625
626        // Any token (or no token) should be valid when auth is not required
627        assert!(state.validate_token(None));
628        assert!(state.validate_token(Some("")));
629        assert!(state.validate_token(Some("random")));
630    }
631
632    // ============================================================================
633    // Rate Limiting Tests
634    // ============================================================================
635
636    #[test]
637    fn test_can_create_tunnel_under_limit() {
638        let state = RelayState::new(create_test_config());
639        let ip = "192.168.1.1";
640
641        assert!(state.can_create_tunnel(ip));
642        assert_eq!(state.tunnel_count_for_ip(ip), 0);
643    }
644
645    #[test]
646    fn test_rate_limiting_enforced() {
647        let state = RelayState::new(create_test_config());
648        let ip = "192.168.1.1";
649
650        // Create tunnels up to the limit (3)
651        let t1 = create_test_tunnel("tunnel1", ip);
652        let t2 = create_test_tunnel("tunnel2", ip);
653        let t3 = create_test_tunnel("tunnel3", ip);
654
655        state.register_tunnel(t1);
656        assert_eq!(state.tunnel_count_for_ip(ip), 1);
657        assert!(state.can_create_tunnel(ip));
658
659        state.register_tunnel(t2);
660        assert_eq!(state.tunnel_count_for_ip(ip), 2);
661        assert!(state.can_create_tunnel(ip));
662
663        state.register_tunnel(t3);
664        assert_eq!(state.tunnel_count_for_ip(ip), 3);
665        // Now at limit - should NOT be able to create more
666        assert!(!state.can_create_tunnel(ip));
667    }
668
669    #[test]
670    fn test_rate_limiting_per_ip() {
671        let state = RelayState::new(create_test_config());
672        let ip1 = "192.168.1.1";
673        let ip2 = "192.168.1.2";
674
675        // Fill up ip1's limit
676        for i in 0..3 {
677            let t = create_test_tunnel(&format!("tunnel-ip1-{}", i), ip1);
678            state.register_tunnel(t);
679        }
680
681        // ip1 should be at limit
682        assert!(!state.can_create_tunnel(ip1));
683
684        // ip2 should still be able to create tunnels
685        assert!(state.can_create_tunnel(ip2));
686        assert_eq!(state.tunnel_count_for_ip(ip2), 0);
687    }
688
689    #[test]
690    fn test_rate_limiting_released_on_disconnect() {
691        let state = RelayState::new(create_test_config());
692        let ip = "192.168.1.1";
693
694        // Fill up the limit
695        for i in 0..3 {
696            let t = create_test_tunnel(&format!("tunnel{}", i), ip);
697            state.register_tunnel(t);
698        }
699        assert!(!state.can_create_tunnel(ip));
700
701        // Remove one tunnel
702        state.remove_tunnel("tunnel1");
703        assert_eq!(state.tunnel_count_for_ip(ip), 2);
704
705        // Should be able to create again
706        assert!(state.can_create_tunnel(ip));
707    }
708
709    #[test]
710    fn test_tunnel_url_generation() {
711        let state = RelayState::new(create_test_config());
712
713        let url = state.tunnel_url("abc123");
714        assert_eq!(url, "http://abc123.test.example.com");
715    }
716
717    #[test]
718    fn test_tunnel_url_with_https() {
719        let mut config = create_test_config();
720        config.use_https = true;
721        let state = RelayState::new(config);
722
723        let url = state.tunnel_url("xyz789");
724        assert_eq!(url, "https://xyz789.test.example.com");
725    }
726
727    // ============================================================================
728    // Auth Rate Limiting Tests
729    // ============================================================================
730
731    #[test]
732    fn test_auth_rate_limiter_tracks_failures() {
733        let limiter = AuthRateLimiter::new(AuthRateLimitConfig {
734            max_failed_attempts: 3,
735            ban_duration: Duration::from_secs(60),
736            attempt_window: Duration::from_secs(30),
737        });
738        let ip = "10.0.0.1";
739
740        assert_eq!(limiter.failed_attempts(ip), 0);
741        assert!(!limiter.is_banned(ip));
742
743        limiter.record_failure(ip);
744        assert_eq!(limiter.failed_attempts(ip), 1);
745
746        limiter.record_failure(ip);
747        assert_eq!(limiter.failed_attempts(ip), 2);
748
749        // Not banned yet
750        assert!(!limiter.is_banned(ip));
751    }
752
753    #[test]
754    fn test_auth_rate_limiter_bans_after_max_attempts() {
755        let limiter = AuthRateLimiter::new(AuthRateLimitConfig {
756            max_failed_attempts: 3,
757            ban_duration: Duration::from_secs(60),
758            attempt_window: Duration::from_secs(30),
759        });
760        let ip = "10.0.0.2";
761
762        // Fail 3 times (the limit)
763        for _ in 0..3 {
764            limiter.record_failure(ip);
765        }
766
767        // Should be banned now
768        assert!(limiter.is_banned(ip));
769        assert!(limiter.ban_remaining(ip).is_some());
770    }
771
772    #[test]
773    fn test_auth_rate_limiter_success_clears_failures() {
774        let limiter = AuthRateLimiter::new(AuthRateLimitConfig::default());
775        let ip = "10.0.0.3";
776
777        limiter.record_failure(ip);
778        limiter.record_failure(ip);
779        assert_eq!(limiter.failed_attempts(ip), 2);
780
781        limiter.record_success(ip);
782        assert_eq!(limiter.failed_attempts(ip), 0);
783    }
784
785    #[test]
786    fn test_validate_auth_with_rate_limiting() {
787        let mut config = create_test_config();
788        config.auth_rate_limit = AuthRateLimitConfig {
789            max_failed_attempts: 2,
790            ban_duration: Duration::from_secs(60),
791            attempt_window: Duration::from_secs(30),
792        };
793        let state = RelayState::new(config);
794        let ip = "10.0.0.4";
795
796        // Valid token should succeed
797        assert!(state.validate_auth(ip, Some("token1")).is_success());
798
799        // Invalid tokens should fail and count
800        assert!(!state.validate_auth(ip, Some("bad")).is_success());
801        assert!(!state.validate_auth(ip, Some("bad")).is_success());
802
803        // Should be banned now
804        let result = state.validate_auth(ip, Some("token1")); // Even valid token
805        assert!(matches!(result, AuthResult::Banned { .. }));
806    }
807
808    // ============================================================================
809    // JWT Token Tests
810    // ============================================================================
811
812    #[test]
813    fn test_jwt_token_generation_and_validation() {
814        let config = create_jwt_config();
815        let state = RelayState::new(config);
816
817        // Generate a token
818        let token = state.generate_token("user123", None);
819        assert!(token.is_some());
820
821        let token = token.unwrap();
822        assert!(!token.is_empty());
823
824        // Validate the token
825        assert!(state.validate_token(Some(&token)));
826    }
827
828    #[test]
829    fn test_jwt_token_with_tunnel_id() {
830        let config = create_jwt_config();
831        let state = RelayState::new(config);
832
833        let token = state.generate_token("user456", Some("my-tunnel".to_string()));
834        assert!(token.is_some());
835
836        let token = token.unwrap();
837        let result = state.validate_auth("10.0.0.5", Some(&token));
838
839        match result {
840            AuthResult::SuccessWithClaims(claims) => {
841                assert_eq!(claims.sub, "user456");
842                assert_eq!(claims.tunnel_id, Some("my-tunnel".to_string()));
843            }
844            _ => panic!("Expected SuccessWithClaims, got {:?}", result),
845        }
846    }
847
848    #[test]
849    fn test_jwt_token_revocation() {
850        let config = create_jwt_config();
851        let state = RelayState::new(config);
852
853        // Generate and validate
854        let token = state.generate_token("user789", None).unwrap();
855        let result = state.validate_auth("10.0.0.6", Some(&token));
856
857        let jti = match result {
858            AuthResult::SuccessWithClaims(claims) => claims.jti,
859            _ => panic!("Expected success"),
860        };
861
862        // Revoke the token
863        state.revoke_token(&jti);
864
865        // Should now fail
866        let result = state.validate_auth("10.0.0.6", Some(&token));
867        assert!(matches!(result, AuthResult::Invalid(_)));
868    }
869
870    #[test]
871    fn test_jwt_invalid_signature() {
872        let config = create_jwt_config();
873        let state = RelayState::new(config);
874
875        // A token signed with a different secret
876        let tampered = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyIiwiZXhwIjo5OTk5OTk5OTk5LCJpYXQiOjE3MDAwMDAwMDAsImp0aSI6InRlc3QifQ.invalid_signature";
877
878        let result = state.validate_auth("10.0.0.7", Some(tampered));
879        assert!(!result.is_success());
880    }
881
882    #[test]
883    fn test_token_manager_cleanup_revocations() {
884        let tm = TokenManager::new(b"secret", Duration::from_secs(3600));
885
886        // Revoke some tokens
887        tm.revoke_token("token1");
888        tm.revoke_token("token2");
889        tm.revoke_token("token3");
890
891        // Cleanup with max age of 0 should remove all
892        tm.cleanup_revocations(Duration::from_secs(0));
893
894        // The cleanup doesn't affect validation (tokens are still in the map until expired)
895        // This is testing the cleanup mechanism exists
896    }
897
898    #[test]
899    fn test_auth_not_required_returns_not_required() {
900        let mut config = create_test_config();
901        config.require_auth = false;
902        let state = RelayState::new(config);
903
904        let result = state.validate_auth("10.0.0.8", None);
905        assert!(matches!(result, AuthResult::NotRequired));
906    }
907
908    // ============================================================================
909    // TTL and Tunnel Limit Tests
910    // ============================================================================
911
912    #[tokio::test]
913    async fn test_tunnel_not_expired_without_limits() {
914        let config = create_test_config();
915        let state = RelayState::new(config);
916        let tunnel = create_test_tunnel("test1", "192.168.1.1");
917        let tunnel = state.register_tunnel(tunnel);
918
919        // Without TTL limits, tunnel should not be expired
920        assert!(!state.is_tunnel_expired(&tunnel).await);
921    }
922
923    #[tokio::test]
924    async fn test_tunnel_expired_by_age() {
925        let mut config = create_test_config();
926        config.max_tunnel_age = Some(Duration::from_millis(50)); // Very short TTL
927        let state = RelayState::new(config);
928
929        let tunnel = create_test_tunnel("test-ttl", "192.168.1.1");
930        let tunnel = state.register_tunnel(tunnel);
931
932        // Should not be expired immediately
933        assert!(!state.is_tunnel_expired(&tunnel).await);
934
935        // Wait for TTL to expire
936        tokio::time::sleep(Duration::from_millis(100)).await;
937
938        // Should be expired now
939        assert!(state.is_tunnel_expired(&tunnel).await);
940    }
941
942    #[tokio::test]
943    async fn test_cleanup_expired_tunnels() {
944        let mut config = create_test_config();
945        config.max_tunnel_age = Some(Duration::from_millis(50));
946        let state = RelayState::new(config);
947
948        // Create some tunnels
949        let t1 = create_test_tunnel("tunnel1", "192.168.1.1");
950        let t2 = create_test_tunnel("tunnel2", "192.168.1.2");
951        state.register_tunnel(t1);
952        state.register_tunnel(t2);
953
954        assert_eq!(state.tunnel_count(), 2);
955
956        // Wait for expiration
957        tokio::time::sleep(Duration::from_millis(100)).await;
958
959        // Cleanup should remove both
960        let removed = state.cleanup_expired_tunnels().await;
961        assert_eq!(removed, 2);
962        assert_eq!(state.tunnel_count(), 0);
963    }
964
965    #[test]
966    fn test_allow_custom_ids_config() {
967        let mut config = create_test_config();
968        assert!(config.allow_custom_ids); // Default is true
969
970        config.allow_custom_ids = false;
971        let state = RelayState::new(config);
972        assert!(!state.config.allow_custom_ids);
973    }
974
975    #[test]
976    fn test_tunnel_count() {
977        let config = create_test_config();
978        let state = RelayState::new(config);
979
980        assert_eq!(state.tunnel_count(), 0);
981
982        let t1 = create_test_tunnel("tunnel1", "192.168.1.1");
983        state.register_tunnel(t1);
984        assert_eq!(state.tunnel_count(), 1);
985
986        let t2 = create_test_tunnel("tunnel2", "192.168.1.2");
987        state.register_tunnel(t2);
988        assert_eq!(state.tunnel_count(), 2);
989
990        state.remove_tunnel("tunnel1");
991        assert_eq!(state.tunnel_count(), 1);
992    }
993
994    #[test]
995    fn test_public_relay_config() {
996        // Configuration for a public relay (zero-friction)
997        let config = RelayConfig {
998            require_auth: false,              // No auth required
999            allow_custom_ids: false,          // Random IDs only
1000            max_tunnel_age: Some(Duration::from_secs(8 * 3600)), // 8 hour TTL
1001            max_idle_time: Some(Duration::from_secs(1800)),      // 30 min idle timeout
1002            max_tunnels_per_ip: 3,            // Limit per IP
1003            ..Default::default()
1004        };
1005
1006        let state = RelayState::new(config);
1007
1008        // Auth should not be required
1009        assert!(state.validate_auth("10.0.0.1", None).is_success());
1010
1011        // Should have the limits configured
1012        assert!(!state.config.allow_custom_ids);
1013        assert!(state.config.max_tunnel_age.is_some());
1014        assert!(state.config.max_idle_time.is_some());
1015    }
1016}