1use std::any::Any;
2use std::collections::{HashMap, HashSet};
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use tokio_tungstenite::tungstenite::http::Request;
9
10pub use hyperstack_auth::AuthContext;
12pub use hyperstack_auth::AuthErrorCode;
14pub use hyperstack_auth::RetryPolicy;
16pub use hyperstack_auth::{
18 auth_failure_event, auth_success_event, rate_limit_event, AuditEvent, AuditSeverity,
19 ChannelAuditLogger, NoOpAuditLogger, SecurityAuditEvent, SecurityAuditLogger,
20};
21pub use hyperstack_auth::{AuthMetrics, AuthMetricsCollector, AuthMetricsSnapshot};
23pub use hyperstack_auth::{MultiKeyVerifier, MultiKeyVerifierBuilder, RotationKey};
25
26#[derive(Debug, Clone)]
27pub struct ConnectionAuthRequest {
28 pub remote_addr: SocketAddr,
29 pub path: String,
30 pub query: Option<String>,
31 pub headers: HashMap<String, String>,
32 pub origin: Option<String>,
34}
35
36impl ConnectionAuthRequest {
37 pub fn from_http_request<B>(remote_addr: SocketAddr, request: &Request<B>) -> Self {
38 let mut headers = HashMap::new();
39 for (name, value) in request.headers() {
40 if let Ok(value_str) = value.to_str() {
41 headers.insert(name.as_str().to_ascii_lowercase(), value_str.to_string());
42 }
43 }
44
45 let origin = headers.get("origin").cloned();
46
47 Self {
48 remote_addr,
49 path: request.uri().path().to_string(),
50 query: request.uri().query().map(|q| q.to_string()),
51 headers,
52 origin,
53 }
54 }
55
56 pub fn header(&self, name: &str) -> Option<&str> {
57 self.headers
58 .get(&name.to_ascii_lowercase())
59 .map(String::as_str)
60 }
61
62 pub fn bearer_token(&self) -> Option<&str> {
63 let value = self.header("authorization")?;
64 let (scheme, token) = value.split_once(' ')?;
65 if scheme.eq_ignore_ascii_case("bearer") {
66 Some(token)
67 } else {
68 None
69 }
70 }
71
72 pub fn query_param(&self, key: &str) -> Option<&str> {
73 let query = self.query.as_deref()?;
74 query
75 .split('&')
76 .filter_map(|pair| pair.split_once('='))
77 .find_map(|(k, v)| if k == key { Some(v) } else { None })
78 }
79}
80
81#[derive(Debug, Clone, Default)]
83pub struct AuthErrorDetails {
84 pub field: Option<String>,
86 pub context: Option<String>,
88 pub suggested_action: Option<String>,
90 pub docs_url: Option<String>,
92}
93
94#[derive(Debug, Clone)]
96pub struct AuthDeny {
97 pub reason: String,
98 pub code: AuthErrorCode,
99 pub details: AuthErrorDetails,
101 pub retry_policy: RetryPolicy,
103 pub http_status: u16,
105 pub reset_at: Option<std::time::SystemTime>,
107}
108
109impl AuthDeny {
110 pub fn new(code: AuthErrorCode, reason: impl Into<String>) -> Self {
112 Self {
113 reason: reason.into(),
114 code,
115 details: AuthErrorDetails::default(),
116 retry_policy: code.default_retry_policy(),
117 http_status: code.http_status(),
118 reset_at: None,
119 }
120 }
121
122 pub fn token_missing() -> Self {
124 Self::new(
125 AuthErrorCode::TokenMissing,
126 "Missing session token (expected Authorization: Bearer <token> or query token)",
127 )
128 .with_suggested_action(
129 "Provide a valid session token in the Authorization header or as a query parameter",
130 )
131 }
132
133 pub fn from_verify_error(err: hyperstack_auth::VerifyError) -> Self {
135 let code = AuthErrorCode::from(&err);
136 Self::new(code, format!("Token verification failed: {}", err))
137 }
138
139 pub fn with_details(mut self, details: AuthErrorDetails) -> Self {
141 self.details = details;
142 self
143 }
144
145 pub fn with_field(mut self, field: impl Into<String>) -> Self {
147 self.details.field = Some(field.into());
148 self
149 }
150
151 pub fn with_context(mut self, context: impl Into<String>) -> Self {
153 self.details.context = Some(context.into());
154 self
155 }
156
157 pub fn with_suggested_action(mut self, action: impl Into<String>) -> Self {
159 self.details.suggested_action = Some(action.into());
160 self
161 }
162
163 pub fn with_docs_url(mut self, url: impl Into<String>) -> Self {
165 self.details.docs_url = Some(url.into());
166 self
167 }
168
169 pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
171 self.retry_policy = policy;
172 self
173 }
174
175 pub fn with_reset_at(mut self, reset_at: std::time::SystemTime) -> Self {
177 self.reset_at = Some(reset_at);
178 self
179 }
180
181 pub fn rate_limited(retry_after: Duration, limit_type: &str) -> Self {
183 let reset_at = std::time::SystemTime::now() + retry_after;
184 Self::new(
185 AuthErrorCode::RateLimitExceeded,
186 format!(
187 "Rate limit exceeded for {}. Please retry after {:?}.",
188 limit_type, retry_after
189 ),
190 )
191 .with_retry_policy(RetryPolicy::RetryAfter(retry_after))
192 .with_reset_at(reset_at)
193 .with_suggested_action(format!(
194 "Wait {:?} before retrying the request",
195 retry_after
196 ))
197 }
198
199 pub fn connection_limit_exceeded(limit_type: &str, current: usize, max: usize) -> Self {
201 Self::new(
202 AuthErrorCode::ConnectionLimitExceeded,
203 format!(
204 "Connection limit exceeded: {} has {} of {} allowed connections",
205 limit_type, current, max
206 ),
207 )
208 .with_suggested_action(
209 "Disconnect existing connections or wait for other connections to close",
210 )
211 }
212
213 pub fn to_error_response(&self) -> ErrorResponse {
215 ErrorResponse {
216 error: self.code.as_str().to_string(),
217 message: self.reason.clone(),
218 code: self.code.to_string(),
219 retryable: matches!(
220 self.retry_policy,
221 RetryPolicy::RetryImmediately
222 | RetryPolicy::RetryAfter(_)
223 | RetryPolicy::RetryWithBackoff { .. }
224 | RetryPolicy::RetryWithFreshToken
225 ),
226 retry_after: match self.retry_policy {
227 RetryPolicy::RetryAfter(d) => Some(d.as_secs()),
228 _ => None,
229 },
230 suggested_action: self.details.suggested_action.clone(),
231 docs_url: self.details.docs_url.clone(),
232 }
233 }
234}
235
236#[derive(Debug, Clone, serde::Serialize)]
238pub struct ErrorResponse {
239 pub error: String,
240 pub message: String,
241 pub code: String,
242 pub retryable: bool,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 pub retry_after: Option<u64>,
245 #[serde(skip_serializing_if = "Option::is_none")]
246 pub suggested_action: Option<String>,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub docs_url: Option<String>,
249}
250
251#[derive(Debug, Clone)]
253pub enum AuthDecision {
254 Allow(AuthContext),
256 Deny(AuthDeny),
258}
259
260impl AuthDecision {
261 pub fn is_allowed(&self) -> bool {
263 matches!(self, AuthDecision::Allow(_))
264 }
265
266 pub fn auth_context(&self) -> Option<&AuthContext> {
268 match self {
269 AuthDecision::Allow(ctx) => Some(ctx),
270 AuthDecision::Deny(_) => None,
271 }
272 }
273}
274
275#[async_trait]
276pub trait WebSocketAuthPlugin: Send + Sync + Any {
277 async fn authorize(&self, request: &ConnectionAuthRequest) -> AuthDecision;
278
279 fn as_any(&self) -> &dyn Any;
280
281 fn audit_logger(&self) -> Option<&dyn SecurityAuditLogger> {
283 None
284 }
285
286 async fn log_audit(&self, event: SecurityAuditEvent) {
288 if let Some(logger) = self.audit_logger() {
289 logger.log(event).await;
290 }
291 }
292
293 fn auth_metrics(&self) -> Option<&AuthMetrics> {
295 None
296 }
297}
298
299pub struct AllowAllAuthPlugin;
304
305#[async_trait]
306impl WebSocketAuthPlugin for AllowAllAuthPlugin {
307 async fn authorize(&self, _request: &ConnectionAuthRequest) -> AuthDecision {
308 let context = AuthContext {
310 subject: "anonymous".to_string(),
311 issuer: "allow-all".to_string(),
312 key_class: hyperstack_auth::KeyClass::Secret,
313 metering_key: "dev".to_string(),
314 deployment_id: None,
315 expires_at: u64::MAX, scope: "read write".to_string(),
317 limits: Default::default(),
318 plan: None,
319 origin: None,
320 client_ip: None,
321 jti: uuid::Uuid::new_v4().to_string(),
322 };
323 AuthDecision::Allow(context)
324 }
325
326 fn as_any(&self) -> &dyn Any {
327 self
328 }
329}
330
331#[derive(Debug, Clone)]
332pub struct StaticTokenAuthPlugin {
333 tokens: HashSet<String>,
334 query_param_name: String,
335}
336
337impl StaticTokenAuthPlugin {
338 pub fn new(tokens: impl IntoIterator<Item = String>) -> Self {
339 Self {
340 tokens: tokens.into_iter().collect(),
341 query_param_name: "token".to_string(),
342 }
343 }
344
345 pub fn with_query_param_name(mut self, query_param_name: impl Into<String>) -> Self {
346 self.query_param_name = query_param_name.into();
347 self
348 }
349
350 fn extract_token<'a>(&self, request: &'a ConnectionAuthRequest) -> Option<&'a str> {
351 request
352 .bearer_token()
353 .or_else(|| request.query_param(&self.query_param_name))
354 }
355}
356
357#[async_trait]
358impl WebSocketAuthPlugin for StaticTokenAuthPlugin {
359 async fn authorize(&self, request: &ConnectionAuthRequest) -> AuthDecision {
360 let token = match self.extract_token(request) {
361 Some(token) => token,
362 None => {
363 return AuthDecision::Deny(AuthDeny::token_missing());
364 }
365 };
366
367 if self.tokens.contains(token) {
368 let context = AuthContext {
370 subject: format!("static:{}", &token[..token.len().min(8)]),
371 issuer: "static-token".to_string(),
372 key_class: hyperstack_auth::KeyClass::Secret,
373 metering_key: token.to_string(),
374 deployment_id: None,
375 expires_at: u64::MAX, scope: "read".to_string(),
377 limits: Default::default(),
378 plan: None,
379 origin: request.origin.clone(),
380 client_ip: None,
381 jti: uuid::Uuid::new_v4().to_string(),
382 };
383 AuthDecision::Allow(context)
384 } else {
385 AuthDecision::Deny(AuthDeny::new(
386 AuthErrorCode::InvalidStaticToken,
387 "Invalid auth token",
388 ))
389 }
390 }
391
392 fn as_any(&self) -> &dyn Any {
393 self
394 }
395}
396
397enum SignedSessionVerifier {
404 Static(hyperstack_auth::TokenVerifier),
405 CachedJwks(hyperstack_auth::AsyncVerifier),
406 MultiKey(hyperstack_auth::MultiKeyVerifier),
407}
408
409pub struct SignedSessionAuthPlugin {
410 verifier: SignedSessionVerifier,
411 query_param_name: String,
412 require_origin: bool,
413 audit_logger: Option<Arc<dyn SecurityAuditLogger>>,
414 metrics: Option<Arc<AuthMetrics>>,
415}
416
417impl SignedSessionAuthPlugin {
418 pub fn new(verifier: hyperstack_auth::TokenVerifier) -> Self {
420 Self {
421 verifier: SignedSessionVerifier::Static(verifier),
422 query_param_name: "hs_token".to_string(),
423 require_origin: false,
424 audit_logger: None,
425 metrics: None,
426 }
427 }
428
429 pub fn new_with_async_verifier(verifier: hyperstack_auth::AsyncVerifier) -> Self {
431 Self {
432 verifier: SignedSessionVerifier::CachedJwks(verifier),
433 query_param_name: "hs_token".to_string(),
434 require_origin: false,
435 audit_logger: None,
436 metrics: None,
437 }
438 }
439
440 pub fn new_with_multi_key_verifier(verifier: hyperstack_auth::MultiKeyVerifier) -> Self {
442 Self {
443 verifier: SignedSessionVerifier::MultiKey(verifier),
444 query_param_name: "hs_token".to_string(),
445 require_origin: false,
446 audit_logger: None,
447 metrics: None,
448 }
449 }
450
451 pub fn with_query_param_name(mut self, name: impl Into<String>) -> Self {
453 self.query_param_name = name.into();
454 self
455 }
456
457 pub fn with_origin_validation(mut self) -> Self {
459 self.require_origin = true;
460 self
461 }
462
463 pub fn with_audit_logger(mut self, logger: Arc<dyn SecurityAuditLogger>) -> Self {
465 self.audit_logger = Some(logger);
466 self
467 }
468
469 pub fn with_metrics(mut self, metrics: Arc<AuthMetrics>) -> Self {
471 self.metrics = Some(metrics);
472 self
473 }
474
475 pub fn metrics_snapshot(&self) -> Option<AuthMetricsSnapshot> {
477 self.metrics.as_ref().map(|m| m.snapshot())
478 }
479
480 fn extract_token<'a>(&self, request: &'a ConnectionAuthRequest) -> Option<&'a str> {
481 request
482 .bearer_token()
483 .or_else(|| request.query_param(&self.query_param_name))
484 }
485
486 pub async fn verify_refresh_token(&self, token: &str) -> Result<AuthContext, AuthDeny> {
492 let result = match &self.verifier {
493 SignedSessionVerifier::Static(verifier) => verifier.verify(token, None, None),
494 SignedSessionVerifier::CachedJwks(verifier) => {
495 verifier.verify_with_cache(token, None, None).await
496 }
497 SignedSessionVerifier::MultiKey(verifier) => verifier.verify(token, None, None).await,
498 };
499
500 match result {
501 Ok(context) => Ok(context),
502 Err(e) => Err(AuthDeny::from_verify_error(e)),
503 }
504 }
505}
506
507#[async_trait]
508impl WebSocketAuthPlugin for SignedSessionAuthPlugin {
509 async fn authorize(&self, request: &ConnectionAuthRequest) -> AuthDecision {
510 let token = match self.extract_token(request) {
511 Some(token) => token,
512 None => {
513 return AuthDecision::Deny(AuthDeny::token_missing());
514 }
515 };
516
517 let expected_origin = request.origin.as_deref();
518
519 let expected_client_ip = None; let result = match &self.verifier {
522 SignedSessionVerifier::Static(verifier) => {
523 verifier.verify(token, expected_origin, expected_client_ip)
524 }
525 SignedSessionVerifier::CachedJwks(verifier) => {
526 verifier
527 .verify_with_cache(token, expected_origin, expected_client_ip)
528 .await
529 }
530 SignedSessionVerifier::MultiKey(verifier) => {
531 verifier
532 .verify(token, expected_origin, expected_client_ip)
533 .await
534 }
535 };
536
537 match result {
538 Ok(context) => {
539 let event = auth_success_event(&context.subject)
541 .with_client_ip(request.remote_addr)
542 .with_path(&request.path);
543 if let Some(origin) = &request.origin {
544 let event = event.with_origin(origin.clone());
545 self.log_audit(event).await;
546 } else {
547 self.log_audit(event).await;
548 }
549 AuthDecision::Allow(context)
550 }
551 Err(e) => {
552 let deny = AuthDeny::from_verify_error(e);
553 let event = auth_failure_event(&deny.code, &deny.reason)
555 .with_client_ip(request.remote_addr)
556 .with_path(&request.path);
557 let event = if let Some(origin) = &request.origin {
558 event.with_origin(origin.clone())
559 } else {
560 event
561 };
562 self.log_audit(event).await;
563 AuthDecision::Deny(deny)
564 }
565 }
566 }
567
568 fn as_any(&self) -> &dyn Any {
569 self
570 }
571
572 fn audit_logger(&self) -> Option<&dyn SecurityAuditLogger> {
573 self.audit_logger.as_ref().map(|l| l.as_ref())
574 }
575
576 fn auth_metrics(&self) -> Option<&AuthMetrics> {
577 self.metrics.as_ref().map(|m| m.as_ref())
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 #[test]
586 fn extracts_bearer_and_query_tokens() {
587 let request = Request::builder()
588 .uri("/ws?token=query-token")
589 .header("Authorization", "Bearer header-token")
590 .body(())
591 .expect("request should build");
592
593 let auth_request = ConnectionAuthRequest::from_http_request(
594 "127.0.0.1:8877".parse().expect("socket addr should parse"),
595 &request,
596 );
597
598 assert_eq!(auth_request.bearer_token(), Some("header-token"));
599 assert_eq!(auth_request.query_param("token"), Some("query-token"));
600 }
601
602 #[tokio::test]
603 async fn static_token_plugin_allows_matching_token() {
604 let plugin = StaticTokenAuthPlugin::new(["secret".to_string()]);
605 let request = Request::builder()
606 .uri("/ws?token=secret")
607 .body(())
608 .expect("request should build");
609 let auth_request = ConnectionAuthRequest::from_http_request(
610 "127.0.0.1:8877".parse().expect("socket addr should parse"),
611 &request,
612 );
613
614 let decision = plugin.authorize(&auth_request).await;
615 assert!(decision.is_allowed());
616 assert!(decision.auth_context().is_some());
617 }
618
619 #[tokio::test]
620 async fn static_token_plugin_denies_missing_token() {
621 let plugin = StaticTokenAuthPlugin::new(["secret".to_string()]);
622 let request = Request::builder()
623 .uri("/ws")
624 .body(())
625 .expect("request should build");
626 let auth_request = ConnectionAuthRequest::from_http_request(
627 "127.0.0.1:8877".parse().expect("socket addr should parse"),
628 &request,
629 );
630
631 let decision = plugin.authorize(&auth_request).await;
632 assert!(!decision.is_allowed());
633 }
634
635 #[tokio::test]
636 async fn allow_all_plugin_allows_with_context() {
637 let plugin = AllowAllAuthPlugin;
638 let request = Request::builder()
639 .uri("/ws")
640 .body(())
641 .expect("request should build");
642 let auth_request = ConnectionAuthRequest::from_http_request(
643 "127.0.0.1:8877".parse().expect("socket addr should parse"),
644 &request,
645 );
646
647 let decision = plugin.authorize(&auth_request).await;
648 assert!(decision.is_allowed());
649 let ctx = decision.auth_context().unwrap();
650 assert_eq!(ctx.subject, "anonymous");
651 }
652
653 #[tokio::test]
656 async fn signed_session_plugin_denies_missing_token() {
657 use hyperstack_auth::TokenSigner;
658
659 let signing_key = hyperstack_auth::SigningKey::generate();
660 let verifying_key = signing_key.verifying_key();
661 let verifier =
662 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
663 let plugin = SignedSessionAuthPlugin::new(verifier);
664
665 let request = Request::builder()
666 .uri("/ws")
667 .body(())
668 .expect("request should build");
669 let auth_request = ConnectionAuthRequest::from_http_request(
670 "127.0.0.1:8877".parse().expect("socket addr should parse"),
671 &request,
672 );
673
674 let decision = plugin.authorize(&auth_request).await;
675 assert!(!decision.is_allowed());
676
677 if let AuthDecision::Deny(deny) = decision {
678 assert_eq!(deny.code, AuthErrorCode::TokenMissing);
679 } else {
680 panic!("Expected Deny decision");
681 }
682 }
683
684 #[tokio::test]
685 async fn signed_session_plugin_denies_expired_token() {
686 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
687 use std::time::{SystemTime, UNIX_EPOCH};
688
689 let signing_key = hyperstack_auth::SigningKey::generate();
690 let verifying_key = signing_key.verifying_key();
691 let signer = TokenSigner::new(signing_key, "test-issuer");
692 let verifier =
693 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
694 let plugin = SignedSessionAuthPlugin::new(verifier);
695
696 let now = SystemTime::now()
698 .duration_since(UNIX_EPOCH)
699 .unwrap()
700 .as_secs();
701 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
702 .with_scope("read")
703 .with_key_class(KeyClass::Secret)
704 .build();
705
706 let mut expired_claims = claims;
708 expired_claims.exp = now - 3600; expired_claims.iat = now - 7200; expired_claims.nbf = now - 7200;
711
712 let token = signer.sign(expired_claims).unwrap();
713
714 let request = Request::builder()
715 .uri(format!("/ws?hs_token={}", token))
716 .body(())
717 .expect("request should build");
718 let auth_request = ConnectionAuthRequest::from_http_request(
719 "127.0.0.1:8877".parse().expect("socket addr should parse"),
720 &request,
721 );
722
723 let decision = plugin.authorize(&auth_request).await;
724 assert!(!decision.is_allowed());
725
726 if let AuthDecision::Deny(deny) = decision {
727 assert_eq!(deny.code, AuthErrorCode::TokenExpired);
728 } else {
729 panic!("Expected Deny decision for expired token");
730 }
731 }
732
733 #[tokio::test]
734 async fn signed_session_plugin_denies_invalid_signature() {
735 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
736
737 let signing_key = hyperstack_auth::SigningKey::generate();
739 let wrong_key = hyperstack_auth::SigningKey::generate();
740
741 let signer = TokenSigner::new(signing_key, "test-issuer");
743 let wrong_verifying_key = wrong_key.verifying_key();
744 let verifier = hyperstack_auth::TokenVerifier::new(
745 wrong_verifying_key,
746 "test-issuer",
747 "test-audience",
748 );
749 let plugin = SignedSessionAuthPlugin::new(verifier);
750
751 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
752 .with_scope("read")
753 .with_key_class(KeyClass::Secret)
754 .build();
755
756 let token = signer.sign(claims).unwrap();
757
758 let request = Request::builder()
759 .uri(format!("/ws?hs_token={}", token))
760 .body(())
761 .expect("request should build");
762 let auth_request = ConnectionAuthRequest::from_http_request(
763 "127.0.0.1:8877".parse().expect("socket addr should parse"),
764 &request,
765 );
766
767 let decision = plugin.authorize(&auth_request).await;
768 assert!(!decision.is_allowed());
769
770 if let AuthDecision::Deny(deny) = decision {
771 assert_eq!(deny.code, AuthErrorCode::TokenInvalidSignature);
772 } else {
773 panic!("Expected Deny decision for invalid signature");
774 }
775 }
776
777 #[tokio::test]
778 async fn signed_session_plugin_denies_wrong_audience() {
779 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
780
781 let signing_key = hyperstack_auth::SigningKey::generate();
782 let verifying_key = signing_key.verifying_key();
783 let signer = TokenSigner::new(signing_key, "test-issuer");
784
785 let verifier =
787 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
788 let plugin = SignedSessionAuthPlugin::new(verifier);
789
790 let claims = SessionClaims::builder("test-issuer", "test-subject", "wrong-audience")
791 .with_scope("read")
792 .with_key_class(KeyClass::Secret)
793 .build();
794
795 let token = signer.sign(claims).unwrap();
796
797 let request = Request::builder()
798 .uri(format!("/ws?hs_token={}", token))
799 .body(())
800 .expect("request should build");
801 let auth_request = ConnectionAuthRequest::from_http_request(
802 "127.0.0.1:8877".parse().expect("socket addr should parse"),
803 &request,
804 );
805
806 let decision = plugin.authorize(&auth_request).await;
807 assert!(!decision.is_allowed());
808
809 if let AuthDecision::Deny(deny) = decision {
810 assert_eq!(deny.code, AuthErrorCode::TokenInvalidAudience);
811 } else {
812 panic!("Expected Deny decision for wrong audience");
813 }
814 }
815
816 #[tokio::test]
817 async fn signed_session_plugin_denies_origin_mismatch() {
818 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
819
820 let signing_key = hyperstack_auth::SigningKey::generate();
821 let verifying_key = signing_key.verifying_key();
822 let signer = TokenSigner::new(signing_key, "test-issuer");
823
824 let verifier =
826 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
827 .with_origin_validation();
828 let plugin = SignedSessionAuthPlugin::new(verifier).with_origin_validation();
829
830 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
832 .with_scope("read")
833 .with_key_class(KeyClass::Secret)
834 .with_origin("https://allowed.example.com")
835 .build();
836
837 let token = signer.sign(claims).unwrap();
838
839 let request = Request::builder()
841 .uri(format!("/ws?hs_token={}", token))
842 .header("Origin", "https://evil.example.com")
843 .body(())
844 .expect("request should build");
845 let auth_request = ConnectionAuthRequest::from_http_request(
846 "127.0.0.1:8877".parse().expect("socket addr should parse"),
847 &request,
848 );
849
850 let decision = plugin.authorize(&auth_request).await;
851 assert!(!decision.is_allowed());
852
853 if let AuthDecision::Deny(deny) = decision {
854 assert_eq!(deny.code, AuthErrorCode::OriginMismatch);
855 } else {
856 panic!("Expected Deny decision for origin mismatch");
857 }
858 }
859
860 #[tokio::test]
861 async fn signed_session_plugin_allows_valid_token() {
862 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
863
864 let signing_key = hyperstack_auth::SigningKey::generate();
865 let verifying_key = signing_key.verifying_key();
866 let signer = TokenSigner::new(signing_key, "test-issuer");
867 let verifier =
868 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
869 let plugin = SignedSessionAuthPlugin::new(verifier);
870
871 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
872 .with_scope("read")
873 .with_key_class(KeyClass::Secret)
874 .with_metering_key("meter-123")
875 .build();
876
877 let token = signer.sign(claims).unwrap();
878
879 let request = Request::builder()
880 .uri(format!("/ws?hs_token={}", token))
881 .body(())
882 .expect("request should build");
883 let auth_request = ConnectionAuthRequest::from_http_request(
884 "127.0.0.1:8877".parse().expect("socket addr should parse"),
885 &request,
886 );
887
888 let decision = plugin.authorize(&auth_request).await;
889 assert!(decision.is_allowed());
890
891 if let AuthDecision::Allow(ctx) = decision {
892 assert_eq!(ctx.subject, "test-subject");
893 assert_eq!(ctx.metering_key, "meter-123");
894 assert_eq!(ctx.key_class, KeyClass::Secret);
895 } else {
896 panic!("Expected Allow decision");
897 }
898 }
899
900 #[tokio::test]
901 async fn signed_session_plugin_allows_with_matching_origin() {
902 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
903
904 let signing_key = hyperstack_auth::SigningKey::generate();
905 let verifying_key = signing_key.verifying_key();
906 let signer = TokenSigner::new(signing_key, "test-issuer");
907
908 let verifier =
909 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
910 .with_origin_validation();
911 let plugin = SignedSessionAuthPlugin::new(verifier).with_origin_validation();
912
913 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
914 .with_scope("read")
915 .with_key_class(KeyClass::Secret)
916 .with_origin("https://trusted.example.com")
917 .build();
918
919 let token = signer.sign(claims).unwrap();
920
921 let request = Request::builder()
922 .uri(format!("/ws?hs_token={}", token))
923 .header("Origin", "https://trusted.example.com")
924 .body(())
925 .expect("request should build");
926 let auth_request = ConnectionAuthRequest::from_http_request(
927 "127.0.0.1:8877".parse().expect("socket addr should parse"),
928 &request,
929 );
930
931 let decision = plugin.authorize(&auth_request).await;
932 assert!(decision.is_allowed());
933
934 if let AuthDecision::Allow(ctx) = decision {
935 assert_eq!(ctx.origin, Some("https://trusted.example.com".to_string()));
936 } else {
937 panic!("Expected Allow decision");
938 }
939 }
940
941 #[tokio::test]
942 async fn signed_session_plugin_allows_token_with_origin_when_no_origin_provided_and_not_required(
943 ) {
944 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
949
950 let signing_key = hyperstack_auth::SigningKey::generate();
951 let verifying_key = signing_key.verifying_key();
952 let signer = TokenSigner::new(signing_key, "test-issuer");
953
954 let verifier =
956 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
957 let plugin = SignedSessionAuthPlugin::new(verifier);
958
959 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
960 .with_scope("read")
961 .with_key_class(KeyClass::Publishable)
962 .with_origin("https://example.com") .build();
964
965 let token = signer.sign(claims).unwrap();
966
967 let request = Request::builder()
969 .uri(format!("/ws?hs_token={}", token))
970 .body(())
971 .expect("request should build");
972 let auth_request = ConnectionAuthRequest::from_http_request(
973 "127.0.0.1:8877".parse().expect("socket addr should parse"),
974 &request,
975 );
976
977 let decision = plugin.authorize(&auth_request).await;
979 assert!(decision.is_allowed(), "Expected Allow decision for non-browser client without Origin");
980
981 if let AuthDecision::Allow(ctx) = decision {
982 assert_eq!(ctx.origin, Some("https://example.com".to_string()));
983 } else {
984 panic!("Expected Allow decision");
985 }
986 }
987
988 #[tokio::test]
989 async fn signed_session_plugin_validates_origin_when_provided_even_when_not_required() {
990 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
993
994 let signing_key = hyperstack_auth::SigningKey::generate();
995 let verifying_key = signing_key.verifying_key();
996 let signer = TokenSigner::new(signing_key, "test-issuer");
997
998 let verifier =
1000 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
1001 let plugin = SignedSessionAuthPlugin::new(verifier);
1002
1003 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
1004 .with_scope("read")
1005 .with_key_class(KeyClass::Publishable)
1006 .with_origin("https://allowed.example.com")
1007 .build();
1008
1009 let token = signer.sign(claims).unwrap();
1010
1011 let request = Request::builder()
1013 .uri(format!("/ws?hs_token={}", token))
1014 .header("Origin", "https://allowed.example.com")
1015 .body(())
1016 .expect("request should build");
1017 let auth_request = ConnectionAuthRequest::from_http_request(
1018 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1019 &request,
1020 );
1021
1022 let decision = plugin.authorize(&auth_request).await;
1023 assert!(decision.is_allowed());
1024
1025 let request = Request::builder()
1027 .uri(format!("/ws?hs_token={}", token))
1028 .header("Origin", "https://evil.example.com")
1029 .body(())
1030 .expect("request should build");
1031 let auth_request = ConnectionAuthRequest::from_http_request(
1032 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1033 &request,
1034 );
1035
1036 let decision = plugin.authorize(&auth_request).await;
1037 assert!(!decision.is_allowed());
1038
1039 if let AuthDecision::Deny(deny) = decision {
1040 assert_eq!(deny.code, AuthErrorCode::OriginMismatch);
1041 } else {
1042 panic!("Expected Deny decision for origin mismatch");
1043 }
1044 }
1045
1046 #[test]
1048 fn auth_error_code_should_retry_logic() {
1049 assert!(AuthErrorCode::RateLimitExceeded.should_retry());
1050 assert!(AuthErrorCode::InternalError.should_retry());
1051 assert!(!AuthErrorCode::TokenExpired.should_retry());
1052 assert!(!AuthErrorCode::TokenInvalidSignature.should_retry());
1053 assert!(!AuthErrorCode::TokenMissing.should_retry());
1054 }
1055
1056 #[test]
1057 fn auth_error_code_should_refresh_token_logic() {
1058 assert!(AuthErrorCode::TokenExpired.should_refresh_token());
1059 assert!(AuthErrorCode::TokenInvalidSignature.should_refresh_token());
1060 assert!(AuthErrorCode::TokenInvalidFormat.should_refresh_token());
1061 assert!(AuthErrorCode::TokenInvalidIssuer.should_refresh_token());
1062 assert!(AuthErrorCode::TokenInvalidAudience.should_refresh_token());
1063 assert!(AuthErrorCode::TokenKeyNotFound.should_refresh_token());
1064 assert!(!AuthErrorCode::TokenMissing.should_refresh_token());
1065 assert!(!AuthErrorCode::RateLimitExceeded.should_refresh_token());
1066 assert!(!AuthErrorCode::ConnectionLimitExceeded.should_refresh_token());
1067 }
1068
1069 #[test]
1070 fn auth_error_code_string_representation() {
1071 assert_eq!(AuthErrorCode::TokenMissing.as_str(), "token-missing");
1072 assert_eq!(AuthErrorCode::TokenExpired.as_str(), "token-expired");
1073 assert_eq!(
1074 AuthErrorCode::TokenInvalidSignature.as_str(),
1075 "token-invalid-signature"
1076 );
1077 assert_eq!(
1078 AuthErrorCode::RateLimitExceeded.as_str(),
1079 "rate-limit-exceeded"
1080 );
1081 assert_eq!(
1082 AuthErrorCode::ConnectionLimitExceeded.as_str(),
1083 "connection-limit-exceeded"
1084 );
1085 }
1086
1087 #[test]
1089 fn auth_deny_token_missing_factory() {
1090 let deny = AuthDeny::token_missing();
1091 assert_eq!(deny.code, AuthErrorCode::TokenMissing);
1092 assert!(deny.reason.contains("Missing session token"));
1093 }
1094
1095 #[test]
1096 fn auth_deny_from_verify_error_mapping() {
1097 use hyperstack_auth::VerifyError;
1098
1099 let test_cases = vec![
1100 (VerifyError::Expired, AuthErrorCode::TokenExpired),
1101 (
1102 VerifyError::InvalidSignature,
1103 AuthErrorCode::TokenInvalidSignature,
1104 ),
1105 (
1106 VerifyError::InvalidIssuer,
1107 AuthErrorCode::TokenInvalidIssuer,
1108 ),
1109 (
1110 VerifyError::InvalidAudience,
1111 AuthErrorCode::TokenInvalidAudience,
1112 ),
1113 (
1114 VerifyError::KeyNotFound("kid123".to_string()),
1115 AuthErrorCode::TokenKeyNotFound,
1116 ),
1117 (
1118 VerifyError::OriginMismatch {
1119 expected: "a".to_string(),
1120 actual: "b".to_string(),
1121 },
1122 AuthErrorCode::OriginMismatch,
1123 ),
1124 ];
1125
1126 for (err, expected_code) in test_cases {
1127 let deny = AuthDeny::from_verify_error(err);
1128 assert_eq!(deny.code, expected_code);
1129 }
1130 }
1131
1132 #[tokio::test]
1134 async fn signed_session_plugin_handles_multiple_failure_reasons() {
1135 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
1136
1137 let signing_key = hyperstack_auth::SigningKey::generate();
1138 let verifying_key = signing_key.verifying_key();
1139 let signer = TokenSigner::new(signing_key, "test-issuer");
1140 let verifier =
1141 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
1142 .with_origin_validation();
1143 let plugin = SignedSessionAuthPlugin::new(verifier).with_origin_validation();
1144
1145 let request = Request::builder()
1147 .uri("/ws")
1148 .body(())
1149 .expect("request should build");
1150 let auth_request = ConnectionAuthRequest::from_http_request(
1151 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1152 &request,
1153 );
1154 let decision = plugin.authorize(&auth_request).await;
1155 assert!(!decision.is_allowed());
1156 match decision {
1157 AuthDecision::Deny(deny) => assert_eq!(deny.code, AuthErrorCode::TokenMissing),
1158 _ => panic!("Expected Deny decision"),
1159 }
1160
1161 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
1163 .with_scope("read")
1164 .with_key_class(KeyClass::Secret)
1165 .with_origin("https://allowed.example.com")
1166 .build();
1167 let token = signer.sign(claims).unwrap();
1168
1169 let request = Request::builder()
1170 .uri(format!("/ws?hs_token={}", token))
1171 .header("Origin", "https://evil.example.com")
1172 .body(())
1173 .expect("request should build");
1174 let auth_request = ConnectionAuthRequest::from_http_request(
1175 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1176 &request,
1177 );
1178 let decision = plugin.authorize(&auth_request).await;
1179 assert!(!decision.is_allowed());
1180 match decision {
1181 AuthDecision::Deny(deny) => assert_eq!(deny.code, AuthErrorCode::OriginMismatch),
1182 _ => panic!("Expected Deny decision for origin mismatch"),
1183 }
1184
1185 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
1187 .with_scope("read")
1188 .with_key_class(KeyClass::Secret)
1189 .with_origin("https://allowed.example.com")
1190 .build();
1191 let token = signer.sign(claims).unwrap();
1192
1193 let request = Request::builder()
1194 .uri(format!("/ws?hs_token={}", token))
1195 .header("Origin", "https://allowed.example.com")
1196 .body(())
1197 .expect("request should build");
1198 let auth_request = ConnectionAuthRequest::from_http_request(
1199 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1200 &request,
1201 );
1202 let decision = plugin.authorize(&auth_request).await;
1203 assert!(decision.is_allowed());
1204 }
1205
1206 #[tokio::test]
1208 async fn auth_deney_with_rate_limit_code() {
1209 let deny = AuthDeny::new(
1210 AuthErrorCode::RateLimitExceeded,
1211 "Too many requests from this IP",
1212 );
1213 assert_eq!(deny.code, AuthErrorCode::RateLimitExceeded);
1214 assert!(deny.code.should_retry());
1215 assert!(!deny.code.should_refresh_token());
1216 }
1217
1218 #[tokio::test]
1220 async fn auth_deny_with_connection_limit_code() {
1221 let deny = AuthDeny::new(
1222 AuthErrorCode::ConnectionLimitExceeded,
1223 "Maximum connections exceeded for subject user-123",
1224 );
1225 assert_eq!(deny.code, AuthErrorCode::ConnectionLimitExceeded);
1226 assert!(!deny.code.should_retry());
1227 assert!(!deny.code.should_refresh_token());
1228 }
1229
1230 #[test]
1232 fn token_extraction_priority() {
1233 let request = Request::builder()
1235 .uri("/ws?hs_token=query-value")
1236 .header("Authorization", "Bearer header-value")
1237 .body(())
1238 .expect("request should build");
1239 let auth_request = ConnectionAuthRequest::from_http_request(
1240 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1241 &request,
1242 );
1243
1244 assert_eq!(auth_request.bearer_token(), Some("header-value"));
1246 assert_eq!(auth_request.query_param("hs_token"), Some("query-value"));
1248 }
1249
1250 #[test]
1252 fn malformed_authorization_header() {
1253 let test_cases = vec![
1254 ("Basic dXNlcjpwYXNz", None), ("Bearer", None), ("", None), ("Bearer token extra", Some("token extra")), ];
1259
1260 for (header_value, expected) in test_cases {
1261 let request = Request::builder()
1262 .uri("/ws")
1263 .header("Authorization", header_value)
1264 .body(())
1265 .expect("request should build");
1266 let auth_request = ConnectionAuthRequest::from_http_request(
1267 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1268 &request,
1269 );
1270 assert_eq!(
1271 auth_request.bearer_token(),
1272 expected,
1273 "Failed for header: {}",
1274 header_value
1275 );
1276 }
1277 }
1278
1279 #[test]
1285 fn auth_deny_error_response_structure() {
1286 let deny = AuthDeny::new(AuthErrorCode::TokenExpired, "Token has expired")
1287 .with_field("exp")
1288 .with_context("Token expired 5 minutes ago")
1289 .with_suggested_action("Refresh your authentication token")
1290 .with_docs_url("https://docs.usehyperstack.com/auth/errors#token-expired");
1291
1292 let response = deny.to_error_response();
1293
1294 assert_eq!(response.code, "token-expired");
1295 assert_eq!(response.message, "Token has expired");
1296 assert_eq!(response.error, "token-expired");
1297 assert!(response.retryable);
1298 assert_eq!(
1299 response.suggested_action,
1300 Some("Refresh your authentication token".to_string())
1301 );
1302 assert_eq!(
1303 response.docs_url,
1304 Some("https://docs.usehyperstack.com/auth/errors#token-expired".to_string())
1305 );
1306 }
1307
1308 #[test]
1309 fn auth_deny_rate_limited_response() {
1310 use std::time::Duration;
1311
1312 let deny = AuthDeny::rate_limited(Duration::from_secs(30), "websocket connections");
1313 let response = deny.to_error_response();
1314
1315 assert_eq!(response.code, "rate-limit-exceeded");
1316 assert!(response.message.contains("30s"));
1317 assert!(response.retryable);
1318 assert_eq!(response.retry_after, Some(30));
1319 }
1320
1321 #[test]
1322 fn auth_deny_connection_limit_response() {
1323 let deny = AuthDeny::connection_limit_exceeded("user-123", 5, 5);
1324 let response = deny.to_error_response();
1325
1326 assert_eq!(response.code, "connection-limit-exceeded");
1327 assert!(response.message.contains("user-123"));
1328 assert!(response.message.contains("5 of 5"));
1329 assert!(response.retryable); }
1331
1332 #[test]
1333 fn retry_policy_immediate() {
1334 let deny = AuthDeny::new(AuthErrorCode::InternalError, "Transient error")
1335 .with_retry_policy(RetryPolicy::RetryImmediately);
1336
1337 assert_eq!(deny.retry_policy, RetryPolicy::RetryImmediately);
1338 }
1339
1340 #[test]
1341 fn retry_policy_with_backoff() {
1342 use std::time::Duration;
1343
1344 let deny = AuthDeny::new(AuthErrorCode::RateLimitExceeded, "Too many requests")
1345 .with_retry_policy(RetryPolicy::RetryWithBackoff {
1346 initial: Duration::from_secs(1),
1347 max: Duration::from_secs(60),
1348 });
1349
1350 match deny.retry_policy {
1351 RetryPolicy::RetryWithBackoff { initial, max } => {
1352 assert_eq!(initial, Duration::from_secs(1));
1353 assert_eq!(max, Duration::from_secs(60));
1354 }
1355 _ => panic!("Expected RetryWithBackoff"),
1356 }
1357 }
1358
1359 #[test]
1360 fn auth_error_code_http_status_mapping() {
1361 assert_eq!(AuthErrorCode::TokenMissing.http_status(), 401);
1362 assert_eq!(AuthErrorCode::TokenExpired.http_status(), 401);
1363 assert_eq!(AuthErrorCode::TokenInvalidSignature.http_status(), 401);
1364 assert_eq!(AuthErrorCode::OriginMismatch.http_status(), 403);
1365 assert_eq!(AuthErrorCode::RateLimitExceeded.http_status(), 429);
1366 assert_eq!(AuthErrorCode::ConnectionLimitExceeded.http_status(), 429);
1367 assert_eq!(AuthErrorCode::InternalError.http_status(), 500);
1368 }
1369
1370 #[test]
1371 fn auth_error_code_default_retry_policies() {
1372 use std::time::Duration;
1373
1374 assert!(matches!(
1376 AuthErrorCode::TokenExpired.default_retry_policy(),
1377 RetryPolicy::RetryWithFreshToken
1378 ));
1379 assert!(matches!(
1380 AuthErrorCode::TokenInvalidSignature.default_retry_policy(),
1381 RetryPolicy::RetryWithFreshToken
1382 ));
1383
1384 assert!(matches!(
1386 AuthErrorCode::RateLimitExceeded.default_retry_policy(),
1387 RetryPolicy::RetryWithBackoff { .. }
1388 ));
1389 assert!(matches!(
1390 AuthErrorCode::InternalError.default_retry_policy(),
1391 RetryPolicy::RetryWithBackoff { .. }
1392 ));
1393
1394 assert!(matches!(
1396 AuthErrorCode::TokenMissing.default_retry_policy(),
1397 RetryPolicy::NoRetry
1398 ));
1399 assert!(matches!(
1400 AuthErrorCode::OriginMismatch.default_retry_policy(),
1401 RetryPolicy::NoRetry
1402 ));
1403 }
1404
1405 #[tokio::test]
1408 async fn handshake_rejects_missing_token_with_proper_error() {
1409 use tokio_tungstenite::tungstenite::http::StatusCode;
1410
1411 let plugin = AllowAllAuthPlugin;
1412
1413 let request = Request::builder()
1415 .uri("/ws")
1416 .body(())
1417 .expect("request should build");
1418
1419 let auth_request = ConnectionAuthRequest::from_http_request(
1420 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1421 &request,
1422 );
1423
1424 let static_plugin = StaticTokenAuthPlugin::new(["valid-token".to_string()]);
1427 let decision = static_plugin.authorize(&auth_request).await;
1428
1429 assert!(!decision.is_allowed());
1430
1431 if let AuthDecision::Deny(deny) = decision {
1432 assert_eq!(deny.code, AuthErrorCode::TokenMissing);
1433 assert_eq!(deny.http_status, 401);
1434 assert!(deny.reason.contains("Missing"));
1435 } else {
1436 panic!("Expected Deny decision");
1437 }
1438 }
1439
1440 #[tokio::test]
1441 async fn handshake_rejects_expired_token_with_retry_hint() {
1442 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
1443 use std::time::{SystemTime, UNIX_EPOCH};
1444
1445 let signing_key = hyperstack_auth::SigningKey::generate();
1446 let verifying_key = signing_key.verifying_key();
1447 let signer = TokenSigner::new(signing_key, "test-issuer");
1448
1449 let now = SystemTime::now()
1451 .duration_since(UNIX_EPOCH)
1452 .unwrap()
1453 .as_secs();
1454 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
1455 .with_scope("read")
1456 .with_key_class(KeyClass::Secret)
1457 .build();
1458
1459 let mut expired_claims = claims;
1460 expired_claims.exp = now - 3600;
1461 expired_claims.iat = now - 7200;
1462 expired_claims.nbf = now - 7200;
1463
1464 let token = signer.sign(expired_claims).unwrap();
1465
1466 let verifier =
1468 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
1469 let plugin = SignedSessionAuthPlugin::new(verifier);
1470
1471 let request = Request::builder()
1472 .uri(format!("/ws?hs_token={}", token))
1473 .body(())
1474 .expect("request should build");
1475
1476 let auth_request = ConnectionAuthRequest::from_http_request(
1477 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1478 &request,
1479 );
1480
1481 let decision = plugin.authorize(&auth_request).await;
1482
1483 assert!(!decision.is_allowed());
1484
1485 if let AuthDecision::Deny(deny) = decision {
1486 assert_eq!(deny.code, AuthErrorCode::TokenExpired);
1487 assert_eq!(deny.http_status, 401);
1488 assert!(matches!(
1490 deny.retry_policy,
1491 RetryPolicy::RetryWithFreshToken
1492 ));
1493 } else {
1494 panic!("Expected Deny decision");
1495 }
1496 }
1497
1498 #[tokio::test]
1499 async fn handshake_rejects_invalid_signature_with_retry_hint() {
1500 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
1501
1502 let signing_key = hyperstack_auth::SigningKey::generate();
1504 let wrong_key = hyperstack_auth::SigningKey::generate();
1505
1506 let signer = TokenSigner::new(signing_key, "test-issuer");
1508 let wrong_verifying_key = wrong_key.verifying_key();
1509 let verifier = hyperstack_auth::TokenVerifier::new(
1510 wrong_verifying_key,
1511 "test-issuer",
1512 "test-audience",
1513 );
1514 let plugin = SignedSessionAuthPlugin::new(verifier);
1515
1516 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
1517 .with_scope("read")
1518 .with_key_class(KeyClass::Secret)
1519 .build();
1520
1521 let token = signer.sign(claims).unwrap();
1522
1523 let request = Request::builder()
1524 .uri(format!("/ws?hs_token={}", token))
1525 .body(())
1526 .expect("request should build");
1527
1528 let auth_request = ConnectionAuthRequest::from_http_request(
1529 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1530 &request,
1531 );
1532
1533 let decision = plugin.authorize(&auth_request).await;
1534
1535 assert!(!decision.is_allowed());
1536
1537 if let AuthDecision::Deny(deny) = decision {
1538 assert_eq!(deny.code, AuthErrorCode::TokenInvalidSignature);
1539 assert_eq!(deny.http_status, 401);
1540 assert!(matches!(
1542 deny.retry_policy,
1543 RetryPolicy::RetryWithFreshToken
1544 ));
1545 } else {
1546 panic!("Expected Deny decision");
1547 }
1548 }
1549
1550 #[tokio::test]
1551 async fn handshake_rejects_origin_mismatch_without_retry() {
1552 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
1553
1554 let signing_key = hyperstack_auth::SigningKey::generate();
1555 let verifying_key = signing_key.verifying_key();
1556 let signer = TokenSigner::new(signing_key, "test-issuer");
1557
1558 let verifier =
1559 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience")
1560 .with_origin_validation();
1561 let plugin = SignedSessionAuthPlugin::new(verifier).with_origin_validation();
1562
1563 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
1565 .with_scope("read")
1566 .with_key_class(KeyClass::Secret)
1567 .with_origin("https://allowed.example.com")
1568 .build();
1569
1570 let token = signer.sign(claims).unwrap();
1571
1572 let request = Request::builder()
1574 .uri(format!("/ws?hs_token={}", token))
1575 .header("Origin", "https://evil.example.com")
1576 .body(())
1577 .expect("request should build");
1578
1579 let auth_request = ConnectionAuthRequest::from_http_request(
1580 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1581 &request,
1582 );
1583
1584 let decision = plugin.authorize(&auth_request).await;
1585
1586 assert!(!decision.is_allowed());
1587
1588 if let AuthDecision::Deny(deny) = decision {
1589 assert_eq!(deny.code, AuthErrorCode::OriginMismatch);
1590 assert_eq!(deny.http_status, 403);
1591 assert!(matches!(deny.retry_policy, RetryPolicy::NoRetry));
1593 } else {
1594 panic!("Expected Deny decision");
1595 }
1596 }
1597
1598 #[test]
1600 fn auth_deny_to_http_response() {
1601 let deny = AuthDeny::new(AuthErrorCode::RateLimitExceeded, "Too many requests")
1602 .with_suggested_action("Wait before retrying")
1603 .with_retry_policy(RetryPolicy::RetryAfter(Duration::from_secs(30)));
1604
1605 let response = deny.to_error_response();
1606
1607 let json = serde_json::to_string(&response).expect("Should serialize");
1609 assert!(json.contains("rate-limit-exceeded"));
1610 assert!(json.contains("Too many requests"));
1611 assert!(json.contains("Wait before retrying"));
1612 assert!(json.contains("\"retryable\":true"));
1613 assert!(json.contains("\"retry_after\":30"));
1614 }
1615
1616 #[tokio::test]
1618 async fn comprehensive_auth_error_scenarios() {
1619 use hyperstack_auth::{KeyClass, SessionClaims, TokenSigner};
1620
1621 let signing_key = hyperstack_auth::SigningKey::generate();
1622 let verifying_key = signing_key.verifying_key();
1623 let signer = TokenSigner::new(signing_key, "test-issuer");
1624 let verifier =
1625 hyperstack_auth::TokenVerifier::new(verifying_key, "test-issuer", "test-audience");
1626 let plugin = SignedSessionAuthPlugin::new(verifier);
1627
1628 let test_cases = vec![
1629 ("missing_token", None, AuthErrorCode::TokenMissing),
1630 (
1631 "invalid_format",
1632 Some("not-a-valid-token"),
1633 AuthErrorCode::TokenInvalidFormat,
1634 ),
1635 ];
1636
1637 for (name, token, expected_code) in test_cases {
1638 let uri = token.map_or_else(|| "/ws".to_string(), |t| format!("/ws?hs_token={}", t));
1639
1640 let request = Request::builder()
1641 .uri(&uri)
1642 .body(())
1643 .expect("request should build");
1644
1645 let auth_request = ConnectionAuthRequest::from_http_request(
1646 "127.0.0.1:8877".parse().expect("socket addr should parse"),
1647 &request,
1648 );
1649
1650 let decision = plugin.authorize(&auth_request).await;
1651
1652 assert!(!decision.is_allowed(), "{}: should deny", name);
1653
1654 if let AuthDecision::Deny(deny) = decision {
1655 assert_eq!(deny.code, expected_code, "{}: wrong error code", name);
1656 } else {
1657 panic!("{}: Expected Deny decision", name);
1658 }
1659 }
1660 }
1661}