Skip to main content

rustapi_ws/
auth.rs

1//! WebSocket authentication support
2//!
3//! This module provides authentication infrastructure for WebSocket connections,
4//! allowing token validation before the WebSocket upgrade completes.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use rustapi_ws::auth::{WsAuthConfig, TokenExtractor, TokenValidator, Claims};
10//! use async_trait::async_trait;
11//!
12//! struct MyTokenValidator;
13//!
14//! #[async_trait]
15//! impl TokenValidator for MyTokenValidator {
16//!     async fn validate(&self, token: &str) -> Result<Claims, AuthError> {
17//!         // Validate JWT or other token format
18//!         Ok(Claims::new("user_123"))
19//!     }
20//! }
21//!
22//! let config = WsAuthConfig::new(Box::new(MyTokenValidator))
23//!     .extractor(TokenExtractor::Header("Authorization".to_string()));
24//! ```
25
26use std::collections::HashMap;
27use std::sync::Arc;
28use thiserror::Error;
29
30/// Error type for WebSocket authentication
31#[derive(Error, Debug, Clone)]
32pub enum AuthError {
33    /// Token is missing from the request
34    #[error("Authentication token missing")]
35    TokenMissing,
36
37    /// Token format is invalid
38    #[error("Invalid token format: {0}")]
39    InvalidFormat(String),
40
41    /// Token has expired
42    #[error("Token has expired")]
43    TokenExpired,
44
45    /// Token signature is invalid
46    #[error("Invalid token signature")]
47    InvalidSignature,
48
49    /// Token validation failed
50    #[error("Token validation failed: {0}")]
51    ValidationFailed(String),
52
53    /// Insufficient permissions
54    #[error("Insufficient permissions: {0}")]
55    InsufficientPermissions(String),
56}
57
58impl AuthError {
59    /// Create a validation failed error
60    pub fn validation_failed(msg: impl Into<String>) -> Self {
61        Self::ValidationFailed(msg.into())
62    }
63
64    /// Create an invalid format error
65    pub fn invalid_format(msg: impl Into<String>) -> Self {
66        Self::InvalidFormat(msg.into())
67    }
68
69    /// Create an insufficient permissions error
70    pub fn insufficient_permissions(msg: impl Into<String>) -> Self {
71        Self::InsufficientPermissions(msg.into())
72    }
73}
74
75/// Claims extracted from a validated token
76///
77/// Contains the user identity and any additional claims from the token.
78#[derive(Debug, Clone)]
79pub struct Claims {
80    /// Subject (user ID)
81    pub sub: String,
82    /// Additional claims as key-value pairs
83    pub extra: HashMap<String, String>,
84}
85
86impl Claims {
87    /// Create new claims with just a subject
88    pub fn new(sub: impl Into<String>) -> Self {
89        Self {
90            sub: sub.into(),
91            extra: HashMap::new(),
92        }
93    }
94
95    /// Create claims with subject and extra data
96    pub fn with_extra(sub: impl Into<String>, extra: HashMap<String, String>) -> Self {
97        Self {
98            sub: sub.into(),
99            extra,
100        }
101    }
102
103    /// Get the subject (user ID)
104    pub fn subject(&self) -> &str {
105        &self.sub
106    }
107
108    /// Get an extra claim by key
109    pub fn get(&self, key: &str) -> Option<&str> {
110        self.extra.get(key).map(|s| s.as_str())
111    }
112
113    /// Add an extra claim
114    pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
115        self.extra.insert(key.into(), value.into());
116    }
117}
118
119/// Specifies where to extract the authentication token from
120#[derive(Debug, Clone)]
121pub enum TokenExtractor {
122    /// Extract from a header (e.g., "Authorization")
123    Header(String),
124    /// Extract from a query parameter (e.g., "token")
125    Query(String),
126    /// Extract from the Sec-WebSocket-Protocol header
127    Protocol,
128}
129
130impl Default for TokenExtractor {
131    fn default() -> Self {
132        Self::Header("Authorization".to_string())
133    }
134}
135
136impl TokenExtractor {
137    /// Create a header extractor
138    pub fn header(name: impl Into<String>) -> Self {
139        Self::Header(name.into())
140    }
141
142    /// Create a query parameter extractor
143    pub fn query(name: impl Into<String>) -> Self {
144        Self::Query(name.into())
145    }
146
147    /// Create a protocol extractor
148    pub fn protocol() -> Self {
149        Self::Protocol
150    }
151
152    /// Extract the token from an HTTP request
153    pub fn extract<B>(&self, req: &http::Request<B>) -> Option<String> {
154        match self {
155            TokenExtractor::Header(name) => {
156                req.headers()
157                    .get(name)
158                    .and_then(|v| v.to_str().ok())
159                    .map(|s| {
160                        // Strip "Bearer " prefix if present
161                        if let Some(token) = s.strip_prefix("Bearer ") {
162                            token.to_string()
163                        } else {
164                            s.to_string()
165                        }
166                    })
167            }
168            TokenExtractor::Query(name) => req.uri().query().and_then(|query| {
169                url::form_urlencoded::parse(query.as_bytes())
170                    .find(|(key, _)| key == name)
171                    .map(|(_, value)| value.into_owned())
172            }),
173            TokenExtractor::Protocol => req
174                .headers()
175                .get("Sec-WebSocket-Protocol")
176                .and_then(|v| v.to_str().ok())
177                .map(|s| s.to_string()),
178        }
179    }
180}
181
182/// Trait for validating authentication tokens
183///
184/// Implement this trait to provide custom token validation logic.
185#[async_trait::async_trait]
186pub trait TokenValidator: Send + Sync {
187    /// Validate a token and return the claims if valid
188    async fn validate(&self, token: &str) -> Result<Claims, AuthError>;
189}
190
191/// Configuration for WebSocket authentication
192#[derive(Clone)]
193pub struct WsAuthConfig {
194    /// Token extractor configuration
195    pub extractor: TokenExtractor,
196    /// Token validator
197    pub validator: Arc<dyn TokenValidator>,
198    /// Whether authentication is required (if false, missing tokens are allowed)
199    pub required: bool,
200}
201
202impl WsAuthConfig {
203    /// Create a new authentication configuration with a validator
204    pub fn new<V: TokenValidator + 'static>(validator: V) -> Self {
205        Self {
206            extractor: TokenExtractor::default(),
207            validator: Arc::new(validator),
208            required: true,
209        }
210    }
211
212    /// Set the token extractor
213    pub fn extractor(mut self, extractor: TokenExtractor) -> Self {
214        self.extractor = extractor;
215        self
216    }
217
218    /// Set whether authentication is required
219    pub fn required(mut self, required: bool) -> Self {
220        self.required = required;
221        self
222    }
223
224    /// Extract and validate a token from a request
225    pub async fn authenticate<B>(
226        &self,
227        req: &http::Request<B>,
228    ) -> Result<Option<Claims>, AuthError> {
229        match self.extractor.extract(req) {
230            Some(token) => {
231                let claims = self.validator.validate(&token).await?;
232                Ok(Some(claims))
233            }
234            None if self.required => Err(AuthError::TokenMissing),
235            None => Ok(None),
236        }
237    }
238}
239
240impl std::fmt::Debug for WsAuthConfig {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        f.debug_struct("WsAuthConfig")
243            .field("extractor", &self.extractor)
244            .field("required", &self.required)
245            .finish()
246    }
247}
248
249/// A simple token validator that accepts any non-empty token
250///
251/// This is useful for testing or when token validation is handled elsewhere.
252pub struct AcceptAllValidator;
253
254#[async_trait::async_trait]
255impl TokenValidator for AcceptAllValidator {
256    async fn validate(&self, token: &str) -> Result<Claims, AuthError> {
257        if token.is_empty() {
258            return Err(AuthError::invalid_format("Token cannot be empty"));
259        }
260        Ok(Claims::new(token))
261    }
262}
263
264/// A token validator that rejects all tokens
265///
266/// This is useful for testing authentication failure scenarios.
267pub struct RejectAllValidator;
268
269#[async_trait::async_trait]
270impl TokenValidator for RejectAllValidator {
271    async fn validate(&self, _token: &str) -> Result<Claims, AuthError> {
272        Err(AuthError::validation_failed("All tokens rejected"))
273    }
274}
275
276/// A token validator that validates against a static list of valid tokens
277pub struct StaticTokenValidator {
278    tokens: HashMap<String, Claims>,
279}
280
281impl StaticTokenValidator {
282    /// Create a new static token validator
283    pub fn new() -> Self {
284        Self {
285            tokens: HashMap::new(),
286        }
287    }
288
289    /// Add a valid token with associated claims
290    pub fn add_token(mut self, token: impl Into<String>, claims: Claims) -> Self {
291        self.tokens.insert(token.into(), claims);
292        self
293    }
294}
295
296impl Default for StaticTokenValidator {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302#[async_trait::async_trait]
303impl TokenValidator for StaticTokenValidator {
304    async fn validate(&self, token: &str) -> Result<Claims, AuthError> {
305        self.tokens
306            .get(token)
307            .cloned()
308            .ok_or_else(|| AuthError::validation_failed("Invalid token"))
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use http::Request;
316
317    #[test]
318    fn test_token_extractor_header() {
319        let extractor = TokenExtractor::header("Authorization");
320
321        let req = Request::builder()
322            .header("Authorization", "Bearer test-token")
323            .body(())
324            .unwrap();
325
326        assert_eq!(extractor.extract(&req), Some("test-token".to_string()));
327    }
328
329    #[test]
330    fn test_token_extractor_header_no_bearer() {
331        let extractor = TokenExtractor::header("X-API-Key");
332
333        let req = Request::builder()
334            .header("X-API-Key", "my-api-key")
335            .body(())
336            .unwrap();
337
338        assert_eq!(extractor.extract(&req), Some("my-api-key".to_string()));
339    }
340
341    #[test]
342    fn test_token_extractor_query() {
343        let extractor = TokenExtractor::query("token");
344
345        let req = Request::builder()
346            .uri("ws://localhost/ws?token=query-token&other=value")
347            .body(())
348            .unwrap();
349
350        assert_eq!(extractor.extract(&req), Some("query-token".to_string()));
351    }
352
353    #[test]
354    fn test_token_extractor_protocol() {
355        let extractor = TokenExtractor::protocol();
356
357        let req = Request::builder()
358            .header("Sec-WebSocket-Protocol", "my-protocol-token")
359            .body(())
360            .unwrap();
361
362        assert_eq!(
363            extractor.extract(&req),
364            Some("my-protocol-token".to_string())
365        );
366    }
367
368    #[test]
369    fn test_token_extractor_missing() {
370        let extractor = TokenExtractor::header("Authorization");
371
372        let req = Request::builder().body(()).unwrap();
373
374        assert_eq!(extractor.extract(&req), None);
375    }
376
377    #[tokio::test]
378    async fn test_accept_all_validator() {
379        let validator = AcceptAllValidator;
380
381        let result = validator.validate("any-token").await;
382        assert!(result.is_ok());
383        assert_eq!(result.unwrap().subject(), "any-token");
384    }
385
386    #[tokio::test]
387    async fn test_accept_all_validator_empty() {
388        let validator = AcceptAllValidator;
389
390        let result = validator.validate("").await;
391        assert!(result.is_err());
392    }
393
394    #[tokio::test]
395    async fn test_reject_all_validator() {
396        let validator = RejectAllValidator;
397
398        let result = validator.validate("any-token").await;
399        assert!(result.is_err());
400    }
401
402    #[tokio::test]
403    async fn test_static_token_validator() {
404        let validator =
405            StaticTokenValidator::new().add_token("valid-token", Claims::new("user-123"));
406
407        let result = validator.validate("valid-token").await;
408        assert!(result.is_ok());
409        assert_eq!(result.unwrap().subject(), "user-123");
410
411        let result = validator.validate("invalid-token").await;
412        assert!(result.is_err());
413    }
414
415    #[tokio::test]
416    async fn test_ws_auth_config_required() {
417        let config = WsAuthConfig::new(AcceptAllValidator)
418            .extractor(TokenExtractor::header("Authorization"))
419            .required(true);
420
421        let req = Request::builder().body(()).unwrap();
422
423        let result = config.authenticate(&req).await;
424        assert!(matches!(result, Err(AuthError::TokenMissing)));
425    }
426
427    #[tokio::test]
428    async fn test_ws_auth_config_optional() {
429        let config = WsAuthConfig::new(AcceptAllValidator)
430            .extractor(TokenExtractor::header("Authorization"))
431            .required(false);
432
433        let req = Request::builder().body(()).unwrap();
434
435        let result = config.authenticate(&req).await;
436        assert!(result.is_ok());
437        assert!(result.unwrap().is_none());
438    }
439
440    #[tokio::test]
441    async fn test_ws_auth_config_with_token() {
442        let config = WsAuthConfig::new(AcceptAllValidator)
443            .extractor(TokenExtractor::header("Authorization"));
444
445        let req = Request::builder()
446            .header("Authorization", "Bearer my-token")
447            .body(())
448            .unwrap();
449
450        let result = config.authenticate(&req).await;
451        assert!(result.is_ok());
452        let claims = result.unwrap().unwrap();
453        assert_eq!(claims.subject(), "my-token");
454    }
455
456    #[test]
457    fn test_claims_extra() {
458        let mut claims = Claims::new("user-123");
459        claims.insert("role", "admin");
460        claims.insert("tenant", "acme");
461
462        assert_eq!(claims.subject(), "user-123");
463        assert_eq!(claims.get("role"), Some("admin"));
464        assert_eq!(claims.get("tenant"), Some("acme"));
465        assert_eq!(claims.get("missing"), None);
466    }
467
468    #[test]
469    fn test_auth_error_display() {
470        let err = AuthError::TokenMissing;
471        assert_eq!(err.to_string(), "Authentication token missing");
472
473        let err = AuthError::validation_failed("custom error");
474        assert_eq!(err.to_string(), "Token validation failed: custom error");
475    }
476
477    #[test]
478    fn test_token_extractor_default() {
479        let extractor = TokenExtractor::default();
480        match extractor {
481            TokenExtractor::Header(name) => assert_eq!(name, "Authorization"),
482            _ => panic!("Expected Header extractor"),
483        }
484    }
485}
486
487/// Property-based tests for WebSocket authentication
488///
489/// **Feature: v1-features-roadmap, Property 10: WebSocket authentication enforcement**
490/// **Validates: Requirements 4.1, 4.3**
491#[cfg(test)]
492mod property_tests {
493    use super::*;
494    use proptest::prelude::*;
495
496    /// Strategy for generating random tokens
497    fn token_strategy() -> impl Strategy<Value = String> {
498        prop::string::string_regex("[a-zA-Z0-9._-]{1,100}").unwrap()
499    }
500
501    /// Strategy for generating random header names
502    fn header_name_strategy() -> impl Strategy<Value = String> {
503        prop::string::string_regex("[A-Za-z][A-Za-z0-9-]{0,30}").unwrap()
504    }
505
506    /// Strategy for generating random query parameter names
507    fn query_param_strategy() -> impl Strategy<Value = String> {
508        prop::string::string_regex("[a-z][a-z0-9_]{0,20}").unwrap()
509    }
510
511    /// Strategy for generating token extractors
512    fn extractor_strategy() -> impl Strategy<Value = TokenExtractor> {
513        prop_oneof![
514            header_name_strategy().prop_map(TokenExtractor::Header),
515            query_param_strategy().prop_map(TokenExtractor::Query),
516            Just(TokenExtractor::Protocol),
517        ]
518    }
519
520    proptest! {
521        /// **Feature: v1-features-roadmap, Property 10: WebSocket authentication enforcement**
522        /// **Validates: Requirements 4.1, 4.3**
523        ///
524        /// For any WebSocket connection attempt with required authentication,
525        /// if no token is provided, authentication SHALL fail with TokenMissing error.
526        #[test]
527        fn prop_auth_required_rejects_missing_token(
528            extractor in extractor_strategy()
529        ) {
530            let rt = tokio::runtime::Runtime::new().unwrap();
531            rt.block_on(async {
532                let config = WsAuthConfig::new(AcceptAllValidator)
533                    .extractor(extractor)
534                    .required(true);
535
536                // Request without any token
537                let req = http::Request::builder()
538                    .uri("ws://localhost/ws")
539                    .body(())
540                    .unwrap();
541
542                let result = config.authenticate(&req).await;
543                prop_assert!(matches!(result, Err(AuthError::TokenMissing)));
544                Ok(())
545            })?;
546        }
547
548        /// **Feature: v1-features-roadmap, Property 10: WebSocket authentication enforcement**
549        /// **Validates: Requirements 4.1, 4.3**
550        ///
551        /// For any WebSocket connection attempt with a valid token,
552        /// authentication SHALL succeed and return claims.
553        #[test]
554        fn prop_auth_accepts_valid_token_in_header(
555            token in token_strategy(),
556            header_name in header_name_strategy()
557        ) {
558            let rt = tokio::runtime::Runtime::new().unwrap();
559            rt.block_on(async {
560                let config = WsAuthConfig::new(AcceptAllValidator)
561                    .extractor(TokenExtractor::Header(header_name.clone()))
562                    .required(true);
563
564                let req = http::Request::builder()
565                    .uri("ws://localhost/ws")
566                    .header(&header_name, format!("Bearer {}", token))
567                    .body(())
568                    .unwrap();
569
570                let result = config.authenticate(&req).await;
571                prop_assert!(result.is_ok());
572                let claims = result.unwrap();
573                prop_assert!(claims.is_some());
574                let claims = claims.unwrap();
575                prop_assert_eq!(claims.subject(), &token);
576                Ok(())
577            })?;
578        }
579
580        /// **Feature: v1-features-roadmap, Property 10: WebSocket authentication enforcement**
581        /// **Validates: Requirements 4.1, 4.3**
582        ///
583        /// For any WebSocket connection attempt with a valid token in query,
584        /// authentication SHALL succeed and return claims.
585        #[test]
586        fn prop_auth_accepts_valid_token_in_query(
587            token in token_strategy(),
588            param_name in query_param_strategy()
589        ) {
590            let rt = tokio::runtime::Runtime::new().unwrap();
591            rt.block_on(async {
592                let config = WsAuthConfig::new(AcceptAllValidator)
593                    .extractor(TokenExtractor::Query(param_name.clone()))
594                    .required(true);
595
596                let uri = format!("ws://localhost/ws?{}={}", param_name, token);
597                let req = http::Request::builder()
598                    .uri(&uri)
599                    .body(())
600                    .unwrap();
601
602                let result = config.authenticate(&req).await;
603                prop_assert!(result.is_ok());
604                let claims = result.unwrap();
605                prop_assert!(claims.is_some());
606                let claims = claims.unwrap();
607                prop_assert_eq!(claims.subject(), &token);
608                Ok(())
609            })?;
610        }
611
612        /// **Feature: v1-features-roadmap, Property 10: WebSocket authentication enforcement**
613        /// **Validates: Requirements 4.1, 4.3**
614        ///
615        /// For any WebSocket connection attempt with an invalid token,
616        /// authentication SHALL fail with validation error.
617        #[test]
618        fn prop_auth_rejects_invalid_token(
619            token in token_strategy()
620        ) {
621            let rt = tokio::runtime::Runtime::new().unwrap();
622            rt.block_on(async {
623                let config = WsAuthConfig::new(RejectAllValidator)
624                    .extractor(TokenExtractor::Header("Authorization".to_string()))
625                    .required(true);
626
627                let req = http::Request::builder()
628                    .uri("ws://localhost/ws")
629                    .header("Authorization", format!("Bearer {}", token))
630                    .body(())
631                    .unwrap();
632
633                let result = config.authenticate(&req).await;
634                prop_assert!(result.is_err());
635                prop_assert!(matches!(result, Err(AuthError::ValidationFailed(_))));
636                Ok(())
637            })?;
638        }
639
640        /// **Feature: v1-features-roadmap, Property 10: WebSocket authentication enforcement**
641        /// **Validates: Requirements 4.1, 4.3**
642        ///
643        /// For any WebSocket connection with optional auth and no token,
644        /// authentication SHALL succeed with None claims.
645        #[test]
646        fn prop_optional_auth_allows_missing_token(
647            extractor in extractor_strategy()
648        ) {
649            let rt = tokio::runtime::Runtime::new().unwrap();
650            rt.block_on(async {
651                let config = WsAuthConfig::new(AcceptAllValidator)
652                    .extractor(extractor)
653                    .required(false);
654
655                let req = http::Request::builder()
656                    .uri("ws://localhost/ws")
657                    .body(())
658                    .unwrap();
659
660                let result = config.authenticate(&req).await;
661                prop_assert!(result.is_ok());
662                prop_assert!(result.unwrap().is_none());
663                Ok(())
664            })?;
665        }
666
667        /// **Feature: v1-features-roadmap, Property 10: WebSocket authentication enforcement**
668        /// **Validates: Requirements 4.1, 4.3**
669        ///
670        /// For any static token validator with known valid tokens,
671        /// only those exact tokens SHALL be accepted.
672        #[test]
673        fn prop_static_validator_only_accepts_known_tokens(
674            valid_token in token_strategy(),
675            test_token in token_strategy(),
676            user_id in "[a-z]{3,10}"
677        ) {
678            let rt = tokio::runtime::Runtime::new().unwrap();
679            rt.block_on(async {
680                let validator = StaticTokenValidator::new()
681                    .add_token(valid_token.clone(), Claims::new(user_id.clone()));
682
683                let result = validator.validate(&test_token).await;
684
685                if test_token == valid_token {
686                    prop_assert!(result.is_ok());
687                    let claims = result.unwrap();
688                    prop_assert_eq!(claims.subject(), &user_id);
689                } else {
690                    prop_assert!(result.is_err());
691                }
692                Ok(())
693            })?;
694        }
695    }
696}