1use dashmap::DashMap;
4use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use crate::tunnel::TunnelConnection;
12
13#[derive(Debug)]
19struct AuthAttempt {
20 failed_count: u32,
22 first_failure: Instant,
24 banned_until: Option<Instant>,
26}
27
28#[derive(Debug, Clone)]
30pub struct AuthRateLimitConfig {
31 pub max_failed_attempts: u32,
33 pub ban_duration: Duration,
35 pub attempt_window: Duration,
37}
38
39impl Default for AuthRateLimitConfig {
40 fn default() -> Self {
41 Self {
42 max_failed_attempts: 5,
43 ban_duration: Duration::from_secs(300), attempt_window: Duration::from_secs(60), }
46 }
47}
48
49pub struct AuthRateLimiter {
51 attempts: DashMap<String, AuthAttempt>,
52 config: AuthRateLimitConfig,
53}
54
55impl AuthRateLimiter {
56 pub fn new(config: AuthRateLimitConfig) -> Self {
57 Self {
58 attempts: DashMap::new(),
59 config,
60 }
61 }
62
63 pub fn is_banned(&self, ip: &str) -> bool {
65 if let Some(attempt) = self.attempts.get(ip) {
66 if let Some(banned_until) = attempt.banned_until {
67 if Instant::now() < banned_until {
68 return true;
69 }
70 }
71 }
72 false
73 }
74
75 pub fn record_failure(&self, ip: &str) {
77 let now = Instant::now();
78
79 self.attempts
80 .entry(ip.to_string())
81 .and_modify(|attempt| {
82 if now.duration_since(attempt.first_failure) > self.config.attempt_window {
84 attempt.failed_count = 1;
85 attempt.first_failure = now;
86 attempt.banned_until = None;
87 } else {
88 attempt.failed_count += 1;
89
90 if attempt.failed_count >= self.config.max_failed_attempts {
92 attempt.banned_until = Some(now + self.config.ban_duration);
93 }
94 }
95 })
96 .or_insert(AuthAttempt {
97 failed_count: 1,
98 first_failure: now,
99 banned_until: None,
100 });
101 }
102
103 pub fn record_success(&self, ip: &str) {
105 self.attempts.remove(ip);
106 }
107
108 pub fn failed_attempts(&self, ip: &str) -> u32 {
110 self.attempts
111 .get(ip)
112 .map(|a| a.failed_count)
113 .unwrap_or(0)
114 }
115
116 pub fn ban_remaining(&self, ip: &str) -> Option<Duration> {
118 self.attempts.get(ip).and_then(|attempt| {
119 attempt.banned_until.and_then(|until| {
120 let now = Instant::now();
121 if now < until {
122 Some(until - now)
123 } else {
124 None
125 }
126 })
127 })
128 }
129}
130
131#[derive(Debug, Serialize, Deserialize)]
137pub struct TunnelClaims {
138 pub sub: String,
140 pub exp: usize,
142 pub iat: usize,
144 pub jti: String,
146 #[serde(skip_serializing_if = "Option::is_none")]
148 pub tunnel_id: Option<String>,
149}
150
151pub struct TokenManager {
153 encoding_key: EncodingKey,
154 decoding_key: DecodingKey,
155 revoked_tokens: DashMap<String, Instant>,
157 token_validity: Duration,
159}
160
161impl TokenManager {
162 pub fn new(secret: &[u8], token_validity: Duration) -> Self {
164 Self {
165 encoding_key: EncodingKey::from_secret(secret),
166 decoding_key: DecodingKey::from_secret(secret),
167 revoked_tokens: DashMap::new(),
168 token_validity,
169 }
170 }
171
172 pub fn generate_token(&self, subject: &str, tunnel_id: Option<String>) -> Result<String, jsonwebtoken::errors::Error> {
174 let now = std::time::SystemTime::now()
175 .duration_since(std::time::UNIX_EPOCH)
176 .unwrap()
177 .as_secs() as usize;
178
179 let claims = TunnelClaims {
180 sub: subject.to_string(),
181 exp: now + self.token_validity.as_secs() as usize,
182 iat: now,
183 jti: nanoid::nanoid!(16),
184 tunnel_id,
185 };
186
187 encode(&Header::default(), &claims, &self.encoding_key)
188 }
189
190 pub fn validate_token(&self, token: &str) -> Result<TunnelClaims, TokenError> {
192 let validation = Validation::default();
193
194 let token_data = decode::<TunnelClaims>(token, &self.decoding_key, &validation)
195 .map_err(|e| match e.kind() {
196 jsonwebtoken::errors::ErrorKind::ExpiredSignature => TokenError::Expired,
197 jsonwebtoken::errors::ErrorKind::InvalidSignature => TokenError::InvalidSignature,
198 _ => TokenError::Invalid(e.to_string()),
199 })?;
200
201 if self.revoked_tokens.contains_key(&token_data.claims.jti) {
203 return Err(TokenError::Revoked);
204 }
205
206 Ok(token_data.claims)
207 }
208
209 pub fn revoke_token(&self, jti: &str) {
211 self.revoked_tokens.insert(jti.to_string(), Instant::now());
212 }
213
214 pub fn cleanup_revocations(&self, max_age: Duration) {
216 let now = Instant::now();
217 self.revoked_tokens.retain(|_, revoked_at| {
218 now.duration_since(*revoked_at) < max_age
219 });
220 }
221}
222
223#[derive(Debug, Clone)]
225pub enum TokenError {
226 Expired,
227 Revoked,
228 InvalidSignature,
229 Invalid(String),
230}
231
232impl std::fmt::Display for TokenError {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 match self {
235 TokenError::Expired => write!(f, "Token has expired"),
236 TokenError::Revoked => write!(f, "Token has been revoked"),
237 TokenError::InvalidSignature => write!(f, "Invalid token signature"),
238 TokenError::Invalid(msg) => write!(f, "Invalid token: {}", msg),
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
249pub struct RelayConfig {
250 pub base_domain: String,
252 pub listen_addr: SocketAddr,
254 pub request_timeout: Duration,
256 pub max_tunnels_per_ip: usize,
258 pub use_https: bool,
260 pub auth_tokens: HashSet<String>,
262 pub require_auth: bool,
264 pub jwt_secret: Option<Vec<u8>>,
266 pub jwt_validity: Duration,
268 pub auth_rate_limit: AuthRateLimitConfig,
270 pub max_tunnel_age: Option<Duration>,
273 pub max_idle_time: Option<Duration>,
276 pub allow_custom_ids: bool,
279}
280
281impl Default for RelayConfig {
282 fn default() -> Self {
283 Self {
284 base_domain: "localhost:3001".to_string(),
285 listen_addr: "127.0.0.1:3001".parse().unwrap(),
286 request_timeout: Duration::from_secs(30),
287 max_tunnels_per_ip: 10,
288 use_https: false,
289 auth_tokens: HashSet::new(),
290 require_auth: false,
291 jwt_secret: None,
292 jwt_validity: Duration::from_secs(3600), auth_rate_limit: AuthRateLimitConfig::default(),
294 max_tunnel_age: None, max_idle_time: None, allow_custom_ids: true, }
298 }
299}
300
301#[derive(Debug)]
307pub enum AuthResult {
308 Success,
310 SuccessWithClaims(TunnelClaims),
312 NotRequired,
314 Banned { remaining: Duration },
316 Invalid(String),
318}
319
320impl AuthResult {
321 pub fn is_success(&self) -> bool {
322 matches!(self, AuthResult::Success | AuthResult::SuccessWithClaims(_) | AuthResult::NotRequired)
323 }
324}
325
326pub struct RelayState {
328 pub tunnels: DashMap<String, Arc<TunnelConnection>>,
330 pub tunnels_per_ip: DashMap<String, usize>,
332 pub config: RelayConfig,
334 pub auth_rate_limiter: AuthRateLimiter,
336 pub token_manager: Option<TokenManager>,
338}
339
340impl RelayState {
341 pub fn new(config: RelayConfig) -> Self {
343 let auth_rate_limiter = AuthRateLimiter::new(config.auth_rate_limit.clone());
344 let token_manager = config.jwt_secret.as_ref().map(|secret| {
345 TokenManager::new(secret, config.jwt_validity)
346 });
347
348 Self {
349 tunnels: DashMap::new(),
350 tunnels_per_ip: DashMap::new(),
351 auth_rate_limiter,
352 token_manager,
353 config,
354 }
355 }
356
357 pub fn validate_token(&self, token: Option<&str>) -> bool {
359 if !self.config.require_auth {
360 return true; }
362
363 match token {
364 Some(t) if !t.is_empty() => {
365 if self.config.auth_tokens.contains(t) {
367 return true;
368 }
369 if let Some(ref tm) = self.token_manager {
371 return tm.validate_token(t).is_ok();
372 }
373 false
374 },
375 _ => false,
376 }
377 }
378
379 pub fn validate_auth(&self, ip: &str, token: Option<&str>) -> AuthResult {
381 if let Some(remaining) = self.auth_rate_limiter.ban_remaining(ip) {
383 return AuthResult::Banned { remaining };
384 }
385
386 if !self.config.require_auth {
388 return AuthResult::NotRequired;
389 }
390
391 match token {
392 Some(t) if !t.is_empty() => {
393 if self.config.auth_tokens.contains(t) {
395 self.auth_rate_limiter.record_success(ip);
396 return AuthResult::Success;
397 }
398
399 if let Some(ref tm) = self.token_manager {
401 match tm.validate_token(t) {
402 Ok(claims) => {
403 self.auth_rate_limiter.record_success(ip);
404 return AuthResult::SuccessWithClaims(claims);
405 }
406 Err(e) => {
407 self.auth_rate_limiter.record_failure(ip);
408 return AuthResult::Invalid(e.to_string());
409 }
410 }
411 }
412
413 self.auth_rate_limiter.record_failure(ip);
415 AuthResult::Invalid("Invalid token".to_string())
416 }
417 _ => {
418 self.auth_rate_limiter.record_failure(ip);
419 AuthResult::Invalid("Missing token".to_string())
420 }
421 }
422 }
423
424 pub fn generate_token(&self, subject: &str, tunnel_id: Option<String>) -> Option<String> {
426 self.token_manager.as_ref().and_then(|tm| {
427 tm.generate_token(subject, tunnel_id).ok()
428 })
429 }
430
431 pub fn revoke_token(&self, jti: &str) {
433 if let Some(ref tm) = self.token_manager {
434 tm.revoke_token(jti);
435 }
436 }
437
438 pub fn can_create_tunnel(&self, ip: &str) -> bool {
440 let count = self.tunnels_per_ip.get(ip).map(|r| *r).unwrap_or(0);
441 count < self.config.max_tunnels_per_ip
442 }
443
444 fn increment_ip_count(&self, ip: &str) {
446 self.tunnels_per_ip
447 .entry(ip.to_string())
448 .and_modify(|c| *c += 1)
449 .or_insert(1);
450 }
451
452 fn decrement_ip_count(&self, ip: &str) {
454 if let Some(mut count) = self.tunnels_per_ip.get_mut(ip) {
455 if *count > 0 {
456 *count -= 1;
457 }
458 if *count == 0 {
459 drop(count);
460 self.tunnels_per_ip.remove(ip);
461 }
462 }
463 }
464
465 pub fn register_tunnel(&self, tunnel: TunnelConnection) -> Arc<TunnelConnection> {
467 let tunnel_id = tunnel.tunnel_id.clone();
468 let source_ip = tunnel.source_ip.clone();
469 let tunnel = Arc::new(tunnel);
470 self.tunnels.insert(tunnel_id, tunnel.clone());
471 self.increment_ip_count(&source_ip);
472 tunnel
473 }
474
475 pub fn remove_tunnel(&self, tunnel_id: &str) -> Option<Arc<TunnelConnection>> {
477 if let Some((_, tunnel)) = self.tunnels.remove(tunnel_id) {
478 self.decrement_ip_count(&tunnel.source_ip);
479 Some(tunnel)
480 } else {
481 None
482 }
483 }
484
485 pub fn get_tunnel(&self, tunnel_id: &str) -> Option<Arc<TunnelConnection>> {
487 self.tunnels.get(tunnel_id).map(|r| r.clone())
488 }
489
490 pub fn tunnel_url(&self, tunnel_id: &str) -> String {
492 let scheme = if self.config.use_https { "https" } else { "http" };
493 format!("{}://{}.{}", scheme, tunnel_id, self.config.base_domain)
494 }
495
496 pub fn tunnel_count_for_ip(&self, ip: &str) -> usize {
498 self.tunnels_per_ip.get(ip).map(|r| *r).unwrap_or(0)
499 }
500
501 pub async fn is_tunnel_expired(&self, tunnel: &crate::tunnel::TunnelConnection) -> bool {
503 let now = chrono::Utc::now();
504
505 if let Some(max_age) = self.config.max_tunnel_age {
507 let age = now.signed_duration_since(tunnel.created_at);
508 if age.to_std().unwrap_or(Duration::ZERO) >= max_age {
509 return true;
510 }
511 }
512
513 if let Some(max_idle) = self.config.max_idle_time {
515 let last_activity = *tunnel.last_activity.read().await;
516 let idle_time = now.signed_duration_since(last_activity);
517 if idle_time.to_std().unwrap_or(Duration::ZERO) >= max_idle {
518 return true;
519 }
520 }
521
522 false
523 }
524
525 pub async fn cleanup_expired_tunnels(&self) -> usize {
528 let mut expired_ids = Vec::new();
529
530 for entry in self.tunnels.iter() {
532 if self.is_tunnel_expired(entry.value()).await {
533 expired_ids.push(entry.key().clone());
534 }
535 }
536
537 let count = expired_ids.len();
539 for tunnel_id in expired_ids {
540 if let Some(tunnel) = self.remove_tunnel(&tunnel_id) {
541 tracing::info!(
542 tunnel_id = %tunnel_id,
543 source_ip = %tunnel.source_ip,
544 age_secs = (chrono::Utc::now() - tunnel.created_at).num_seconds(),
545 "Tunnel expired and removed"
546 );
547 }
548 }
549
550 count
551 }
552
553 pub fn tunnel_count(&self) -> usize {
555 self.tunnels.len()
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use tokio::sync::mpsc;
563
564 fn create_test_config() -> RelayConfig {
565 RelayConfig {
566 base_domain: "test.example.com".to_string(),
567 listen_addr: "127.0.0.1:3001".parse().unwrap(),
568 request_timeout: Duration::from_secs(30),
569 max_tunnels_per_ip: 3,
570 use_https: false,
571 auth_tokens: ["token1".to_string(), "token2".to_string()].into_iter().collect(),
572 require_auth: true,
573 jwt_secret: None,
574 jwt_validity: Duration::from_secs(3600),
575 auth_rate_limit: AuthRateLimitConfig::default(),
576 max_tunnel_age: None,
577 max_idle_time: None,
578 allow_custom_ids: true,
579 }
580 }
581
582 fn create_jwt_config() -> RelayConfig {
583 RelayConfig {
584 jwt_secret: Some(b"test-secret-key-for-jwt-testing".to_vec()),
585 jwt_validity: Duration::from_secs(3600),
586 require_auth: true,
587 auth_tokens: HashSet::new(), ..create_test_config()
589 }
590 }
591
592 fn create_test_tunnel(tunnel_id: &str, source_ip: &str) -> crate::tunnel::TunnelConnection {
593 let (tx, _rx) = mpsc::channel(10);
594 crate::tunnel::TunnelConnection::new(
595 tunnel_id.to_string(),
596 tx,
597 source_ip.to_string(),
598 )
599 }
600
601 #[test]
606 fn test_validate_token_valid() {
607 let state = RelayState::new(create_test_config());
608 assert!(state.validate_token(Some("token1")));
609 assert!(state.validate_token(Some("token2")));
610 }
611
612 #[test]
613 fn test_validate_token_invalid() {
614 let state = RelayState::new(create_test_config());
615 assert!(!state.validate_token(Some("invalid-token")));
616 assert!(!state.validate_token(Some("")));
617 assert!(!state.validate_token(None));
618 }
619
620 #[test]
621 fn test_validate_token_auth_not_required() {
622 let mut config = create_test_config();
623 config.require_auth = false;
624 let state = RelayState::new(config);
625
626 assert!(state.validate_token(None));
628 assert!(state.validate_token(Some("")));
629 assert!(state.validate_token(Some("random")));
630 }
631
632 #[test]
637 fn test_can_create_tunnel_under_limit() {
638 let state = RelayState::new(create_test_config());
639 let ip = "192.168.1.1";
640
641 assert!(state.can_create_tunnel(ip));
642 assert_eq!(state.tunnel_count_for_ip(ip), 0);
643 }
644
645 #[test]
646 fn test_rate_limiting_enforced() {
647 let state = RelayState::new(create_test_config());
648 let ip = "192.168.1.1";
649
650 let t1 = create_test_tunnel("tunnel1", ip);
652 let t2 = create_test_tunnel("tunnel2", ip);
653 let t3 = create_test_tunnel("tunnel3", ip);
654
655 state.register_tunnel(t1);
656 assert_eq!(state.tunnel_count_for_ip(ip), 1);
657 assert!(state.can_create_tunnel(ip));
658
659 state.register_tunnel(t2);
660 assert_eq!(state.tunnel_count_for_ip(ip), 2);
661 assert!(state.can_create_tunnel(ip));
662
663 state.register_tunnel(t3);
664 assert_eq!(state.tunnel_count_for_ip(ip), 3);
665 assert!(!state.can_create_tunnel(ip));
667 }
668
669 #[test]
670 fn test_rate_limiting_per_ip() {
671 let state = RelayState::new(create_test_config());
672 let ip1 = "192.168.1.1";
673 let ip2 = "192.168.1.2";
674
675 for i in 0..3 {
677 let t = create_test_tunnel(&format!("tunnel-ip1-{}", i), ip1);
678 state.register_tunnel(t);
679 }
680
681 assert!(!state.can_create_tunnel(ip1));
683
684 assert!(state.can_create_tunnel(ip2));
686 assert_eq!(state.tunnel_count_for_ip(ip2), 0);
687 }
688
689 #[test]
690 fn test_rate_limiting_released_on_disconnect() {
691 let state = RelayState::new(create_test_config());
692 let ip = "192.168.1.1";
693
694 for i in 0..3 {
696 let t = create_test_tunnel(&format!("tunnel{}", i), ip);
697 state.register_tunnel(t);
698 }
699 assert!(!state.can_create_tunnel(ip));
700
701 state.remove_tunnel("tunnel1");
703 assert_eq!(state.tunnel_count_for_ip(ip), 2);
704
705 assert!(state.can_create_tunnel(ip));
707 }
708
709 #[test]
710 fn test_tunnel_url_generation() {
711 let state = RelayState::new(create_test_config());
712
713 let url = state.tunnel_url("abc123");
714 assert_eq!(url, "http://abc123.test.example.com");
715 }
716
717 #[test]
718 fn test_tunnel_url_with_https() {
719 let mut config = create_test_config();
720 config.use_https = true;
721 let state = RelayState::new(config);
722
723 let url = state.tunnel_url("xyz789");
724 assert_eq!(url, "https://xyz789.test.example.com");
725 }
726
727 #[test]
732 fn test_auth_rate_limiter_tracks_failures() {
733 let limiter = AuthRateLimiter::new(AuthRateLimitConfig {
734 max_failed_attempts: 3,
735 ban_duration: Duration::from_secs(60),
736 attempt_window: Duration::from_secs(30),
737 });
738 let ip = "10.0.0.1";
739
740 assert_eq!(limiter.failed_attempts(ip), 0);
741 assert!(!limiter.is_banned(ip));
742
743 limiter.record_failure(ip);
744 assert_eq!(limiter.failed_attempts(ip), 1);
745
746 limiter.record_failure(ip);
747 assert_eq!(limiter.failed_attempts(ip), 2);
748
749 assert!(!limiter.is_banned(ip));
751 }
752
753 #[test]
754 fn test_auth_rate_limiter_bans_after_max_attempts() {
755 let limiter = AuthRateLimiter::new(AuthRateLimitConfig {
756 max_failed_attempts: 3,
757 ban_duration: Duration::from_secs(60),
758 attempt_window: Duration::from_secs(30),
759 });
760 let ip = "10.0.0.2";
761
762 for _ in 0..3 {
764 limiter.record_failure(ip);
765 }
766
767 assert!(limiter.is_banned(ip));
769 assert!(limiter.ban_remaining(ip).is_some());
770 }
771
772 #[test]
773 fn test_auth_rate_limiter_success_clears_failures() {
774 let limiter = AuthRateLimiter::new(AuthRateLimitConfig::default());
775 let ip = "10.0.0.3";
776
777 limiter.record_failure(ip);
778 limiter.record_failure(ip);
779 assert_eq!(limiter.failed_attempts(ip), 2);
780
781 limiter.record_success(ip);
782 assert_eq!(limiter.failed_attempts(ip), 0);
783 }
784
785 #[test]
786 fn test_validate_auth_with_rate_limiting() {
787 let mut config = create_test_config();
788 config.auth_rate_limit = AuthRateLimitConfig {
789 max_failed_attempts: 2,
790 ban_duration: Duration::from_secs(60),
791 attempt_window: Duration::from_secs(30),
792 };
793 let state = RelayState::new(config);
794 let ip = "10.0.0.4";
795
796 assert!(state.validate_auth(ip, Some("token1")).is_success());
798
799 assert!(!state.validate_auth(ip, Some("bad")).is_success());
801 assert!(!state.validate_auth(ip, Some("bad")).is_success());
802
803 let result = state.validate_auth(ip, Some("token1")); assert!(matches!(result, AuthResult::Banned { .. }));
806 }
807
808 #[test]
813 fn test_jwt_token_generation_and_validation() {
814 let config = create_jwt_config();
815 let state = RelayState::new(config);
816
817 let token = state.generate_token("user123", None);
819 assert!(token.is_some());
820
821 let token = token.unwrap();
822 assert!(!token.is_empty());
823
824 assert!(state.validate_token(Some(&token)));
826 }
827
828 #[test]
829 fn test_jwt_token_with_tunnel_id() {
830 let config = create_jwt_config();
831 let state = RelayState::new(config);
832
833 let token = state.generate_token("user456", Some("my-tunnel".to_string()));
834 assert!(token.is_some());
835
836 let token = token.unwrap();
837 let result = state.validate_auth("10.0.0.5", Some(&token));
838
839 match result {
840 AuthResult::SuccessWithClaims(claims) => {
841 assert_eq!(claims.sub, "user456");
842 assert_eq!(claims.tunnel_id, Some("my-tunnel".to_string()));
843 }
844 _ => panic!("Expected SuccessWithClaims, got {:?}", result),
845 }
846 }
847
848 #[test]
849 fn test_jwt_token_revocation() {
850 let config = create_jwt_config();
851 let state = RelayState::new(config);
852
853 let token = state.generate_token("user789", None).unwrap();
855 let result = state.validate_auth("10.0.0.6", Some(&token));
856
857 let jti = match result {
858 AuthResult::SuccessWithClaims(claims) => claims.jti,
859 _ => panic!("Expected success"),
860 };
861
862 state.revoke_token(&jti);
864
865 let result = state.validate_auth("10.0.0.6", Some(&token));
867 assert!(matches!(result, AuthResult::Invalid(_)));
868 }
869
870 #[test]
871 fn test_jwt_invalid_signature() {
872 let config = create_jwt_config();
873 let state = RelayState::new(config);
874
875 let tampered = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyIiwiZXhwIjo5OTk5OTk5OTk5LCJpYXQiOjE3MDAwMDAwMDAsImp0aSI6InRlc3QifQ.invalid_signature";
877
878 let result = state.validate_auth("10.0.0.7", Some(tampered));
879 assert!(!result.is_success());
880 }
881
882 #[test]
883 fn test_token_manager_cleanup_revocations() {
884 let tm = TokenManager::new(b"secret", Duration::from_secs(3600));
885
886 tm.revoke_token("token1");
888 tm.revoke_token("token2");
889 tm.revoke_token("token3");
890
891 tm.cleanup_revocations(Duration::from_secs(0));
893
894 }
897
898 #[test]
899 fn test_auth_not_required_returns_not_required() {
900 let mut config = create_test_config();
901 config.require_auth = false;
902 let state = RelayState::new(config);
903
904 let result = state.validate_auth("10.0.0.8", None);
905 assert!(matches!(result, AuthResult::NotRequired));
906 }
907
908 #[tokio::test]
913 async fn test_tunnel_not_expired_without_limits() {
914 let config = create_test_config();
915 let state = RelayState::new(config);
916 let tunnel = create_test_tunnel("test1", "192.168.1.1");
917 let tunnel = state.register_tunnel(tunnel);
918
919 assert!(!state.is_tunnel_expired(&tunnel).await);
921 }
922
923 #[tokio::test]
924 async fn test_tunnel_expired_by_age() {
925 let mut config = create_test_config();
926 config.max_tunnel_age = Some(Duration::from_millis(50)); let state = RelayState::new(config);
928
929 let tunnel = create_test_tunnel("test-ttl", "192.168.1.1");
930 let tunnel = state.register_tunnel(tunnel);
931
932 assert!(!state.is_tunnel_expired(&tunnel).await);
934
935 tokio::time::sleep(Duration::from_millis(100)).await;
937
938 assert!(state.is_tunnel_expired(&tunnel).await);
940 }
941
942 #[tokio::test]
943 async fn test_cleanup_expired_tunnels() {
944 let mut config = create_test_config();
945 config.max_tunnel_age = Some(Duration::from_millis(50));
946 let state = RelayState::new(config);
947
948 let t1 = create_test_tunnel("tunnel1", "192.168.1.1");
950 let t2 = create_test_tunnel("tunnel2", "192.168.1.2");
951 state.register_tunnel(t1);
952 state.register_tunnel(t2);
953
954 assert_eq!(state.tunnel_count(), 2);
955
956 tokio::time::sleep(Duration::from_millis(100)).await;
958
959 let removed = state.cleanup_expired_tunnels().await;
961 assert_eq!(removed, 2);
962 assert_eq!(state.tunnel_count(), 0);
963 }
964
965 #[test]
966 fn test_allow_custom_ids_config() {
967 let mut config = create_test_config();
968 assert!(config.allow_custom_ids); config.allow_custom_ids = false;
971 let state = RelayState::new(config);
972 assert!(!state.config.allow_custom_ids);
973 }
974
975 #[test]
976 fn test_tunnel_count() {
977 let config = create_test_config();
978 let state = RelayState::new(config);
979
980 assert_eq!(state.tunnel_count(), 0);
981
982 let t1 = create_test_tunnel("tunnel1", "192.168.1.1");
983 state.register_tunnel(t1);
984 assert_eq!(state.tunnel_count(), 1);
985
986 let t2 = create_test_tunnel("tunnel2", "192.168.1.2");
987 state.register_tunnel(t2);
988 assert_eq!(state.tunnel_count(), 2);
989
990 state.remove_tunnel("tunnel1");
991 assert_eq!(state.tunnel_count(), 1);
992 }
993
994 #[test]
995 fn test_public_relay_config() {
996 let config = RelayConfig {
998 require_auth: false, allow_custom_ids: false, max_tunnel_age: Some(Duration::from_secs(8 * 3600)), max_idle_time: Some(Duration::from_secs(1800)), max_tunnels_per_ip: 3, ..Default::default()
1004 };
1005
1006 let state = RelayState::new(config);
1007
1008 assert!(state.validate_auth("10.0.0.1", None).is_success());
1010
1011 assert!(!state.config.allow_custom_ids);
1013 assert!(state.config.max_tunnel_age.is_some());
1014 assert!(state.config.max_idle_time.is_some());
1015 }
1016}