Skip to main content

actix_security_core/http/security/websocket/
config.rs

1//! WebSocket security configuration.
2//!
3//! Provides a unified configuration for WebSocket security including
4//! authentication requirements and origin validation.
5
6use actix_web::{HttpMessage, HttpRequest};
7
8use crate::http::security::User;
9
10use super::error::WebSocketSecurityError;
11use super::extractor::WebSocketUpgrade;
12use super::origin::OriginValidator;
13
14/// Configuration for WebSocket security.
15///
16/// This provides a unified way to configure security for WebSocket endpoints,
17/// combining authentication requirements and origin validation.
18///
19/// # Spring Security Equivalent
20/// `WebSocketSecurityConfigurer` / `AbstractSecurityWebSocketMessageBrokerConfigurer`
21///
22/// # Example
23///
24/// ```ignore
25/// use actix_security::http::security::websocket::WebSocketSecurityConfig;
26///
27/// // Create configuration
28/// let ws_config = WebSocketSecurityConfig::new()
29///     .allowed_origins(vec!["https://myapp.com".into()])
30///     .require_authentication(true)
31///     .required_roles(vec!["USER".into()]);
32///
33/// // Use in handler
34/// #[get("/ws")]
35/// async fn ws_handler(
36///     req: HttpRequest,
37///     stream: web::Payload,
38///     config: web::Data<WebSocketSecurityConfig>,
39/// ) -> Result<HttpResponse, actix_web::Error> {
40///     let upgrade = config.validate_upgrade(&req)?;
41///     // ... upgrade to WebSocket
42/// }
43/// ```
44#[derive(Debug, Clone)]
45pub struct WebSocketSecurityConfig {
46    /// Origin validator
47    origin_validator: OriginValidator,
48    /// Require authentication for WebSocket connections
49    require_authentication: bool,
50    /// Required roles (any of these)
51    required_roles: Vec<String>,
52    /// Required authorities (any of these)
53    required_authorities: Vec<String>,
54}
55
56impl Default for WebSocketSecurityConfig {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl WebSocketSecurityConfig {
63    /// Creates a new WebSocket security configuration with default settings.
64    ///
65    /// Default settings:
66    /// - No origin validation (allow any)
67    /// - Authentication not required
68    /// - No role/authority requirements
69    pub fn new() -> Self {
70        Self {
71            origin_validator: OriginValidator::allow_any(),
72            require_authentication: false,
73            required_roles: Vec::new(),
74            required_authorities: Vec::new(),
75        }
76    }
77
78    /// Sets the allowed origins for WebSocket connections.
79    ///
80    /// # Arguments
81    /// * `origins` - List of allowed origin URLs
82    ///
83    /// # Example
84    /// ```ignore
85    /// let config = WebSocketSecurityConfig::new()
86    ///     .allowed_origins(vec!["https://myapp.com".into()]);
87    /// ```
88    pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
89        let origins_refs: Vec<&str> = origins.iter().map(|s| s.as_str()).collect();
90        self.origin_validator = OriginValidator::new(&origins_refs);
91        self
92    }
93
94    /// Sets a custom origin validator.
95    ///
96    /// # Example
97    /// ```ignore
98    /// let validator = OriginValidator::builder()
99    ///     .allow("https://myapp.com")
100    ///     .allow_localhost_in_dev(true)
101    ///     .build();
102    ///
103    /// let config = WebSocketSecurityConfig::new()
104    ///     .origin_validator(validator);
105    /// ```
106    pub fn origin_validator(mut self, validator: OriginValidator) -> Self {
107        self.origin_validator = validator;
108        self
109    }
110
111    /// Requires authentication for WebSocket connections.
112    ///
113    /// When enabled, unauthenticated WebSocket upgrade requests will be rejected.
114    ///
115    /// # Example
116    /// ```ignore
117    /// let config = WebSocketSecurityConfig::new()
118    ///     .require_authentication(true);
119    /// ```
120    pub fn require_authentication(mut self, require: bool) -> Self {
121        self.require_authentication = require;
122        self
123    }
124
125    /// Sets required roles for WebSocket connections.
126    ///
127    /// Users must have at least one of the specified roles.
128    /// Automatically enables authentication requirement.
129    ///
130    /// # Example
131    /// ```ignore
132    /// let config = WebSocketSecurityConfig::new()
133    ///     .required_roles(vec!["USER".into(), "ADMIN".into()]);
134    /// ```
135    pub fn required_roles(mut self, roles: Vec<String>) -> Self {
136        self.required_roles = roles;
137        if !self.required_roles.is_empty() {
138            self.require_authentication = true;
139        }
140        self
141    }
142
143    /// Sets required authorities for WebSocket connections.
144    ///
145    /// Users must have at least one of the specified authorities.
146    /// Automatically enables authentication requirement.
147    ///
148    /// # Example
149    /// ```ignore
150    /// let config = WebSocketSecurityConfig::new()
151    ///     .required_authorities(vec!["ws:connect".into()]);
152    /// ```
153    pub fn required_authorities(mut self, authorities: Vec<String>) -> Self {
154        self.required_authorities = authorities;
155        if !self.required_authorities.is_empty() {
156            self.require_authentication = true;
157        }
158        self
159    }
160
161    /// Validates a WebSocket upgrade request.
162    ///
163    /// This method performs all configured security checks:
164    /// 1. Origin validation (CSWSH prevention)
165    /// 2. Authentication check (if required)
166    /// 3. Role check (if configured)
167    /// 4. Authority check (if configured)
168    ///
169    /// # Returns
170    /// - `Ok(WebSocketUpgrade)` - Validation passed, safe to upgrade
171    /// - `Err(WebSocketSecurityError)` - Validation failed
172    ///
173    /// # Example
174    /// ```ignore
175    /// let config = WebSocketSecurityConfig::new()
176    ///     .allowed_origins(vec!["https://myapp.com".into()])
177    ///     .require_authentication(true);
178    ///
179    /// #[get("/ws")]
180    /// async fn ws_handler(req: HttpRequest, stream: web::Payload) -> Result<HttpResponse, Error> {
181    ///     let upgrade = config.validate_upgrade(&req)?;
182    ///     let user = upgrade.into_user().unwrap();
183    ///     // ... upgrade to WebSocket
184    /// }
185    /// ```
186    pub fn validate_upgrade(
187        &self,
188        req: &HttpRequest,
189    ) -> Result<WebSocketUpgrade, WebSocketSecurityError> {
190        // 1. Validate origin
191        self.origin_validator.validate(req)?;
192
193        // 2. Get user from request extensions
194        let user = req.extensions().get::<User>().cloned();
195
196        // 3. Check authentication requirement
197        if self.require_authentication && user.is_none() {
198            return Err(WebSocketSecurityError::Unauthorized);
199        }
200
201        // 4. Check role requirement
202        if !self.required_roles.is_empty() {
203            let roles_refs: Vec<&str> = self.required_roles.iter().map(|s| s.as_str()).collect();
204            if !user.as_ref().is_some_and(|u| u.has_any_role(&roles_refs)) {
205                return Err(WebSocketSecurityError::MissingRole {
206                    role: self.required_roles.join(", "),
207                });
208            }
209        }
210
211        // 5. Check authority requirement
212        if !self.required_authorities.is_empty() {
213            let auth_refs: Vec<&str> = self
214                .required_authorities
215                .iter()
216                .map(|s| s.as_str())
217                .collect();
218            if !user
219                .as_ref()
220                .is_some_and(|u| u.has_any_authority(&auth_refs))
221            {
222                return Err(WebSocketSecurityError::MissingAuthority {
223                    authority: self.required_authorities.join(", "),
224                });
225            }
226        }
227
228        // 6. Extract origin for logging/debugging
229        let origin = req
230            .headers()
231            .get("origin")
232            .and_then(|h| h.to_str().ok())
233            .map(|s| s.to_string());
234
235        Ok(WebSocketUpgrade::new(user, origin))
236    }
237}
238
239/// Builder for more complex WebSocket security configurations.
240#[derive(Debug, Clone, Default)]
241pub struct WebSocketSecurityConfigBuilder {
242    config: WebSocketSecurityConfig,
243}
244
245impl WebSocketSecurityConfigBuilder {
246    /// Creates a new builder.
247    pub fn new() -> Self {
248        Self {
249            config: WebSocketSecurityConfig::new(),
250        }
251    }
252
253    /// Sets allowed origins.
254    pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
255        self.config = self.config.allowed_origins(origins);
256        self
257    }
258
259    /// Sets a custom origin validator.
260    pub fn origin_validator(mut self, validator: OriginValidator) -> Self {
261        self.config = self.config.origin_validator(validator);
262        self
263    }
264
265    /// Requires authentication.
266    pub fn require_authentication(mut self) -> Self {
267        self.config = self.config.require_authentication(true);
268        self
269    }
270
271    /// Sets required roles.
272    pub fn required_roles(mut self, roles: Vec<String>) -> Self {
273        self.config = self.config.required_roles(roles);
274        self
275    }
276
277    /// Sets required authorities.
278    pub fn required_authorities(mut self, authorities: Vec<String>) -> Self {
279        self.config = self.config.required_authorities(authorities);
280        self
281    }
282
283    /// Builds the configuration.
284    pub fn build(self) -> WebSocketSecurityConfig {
285        self.config
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use actix_web::test::TestRequest;
293
294    fn create_request_with_user(user: User) -> HttpRequest {
295        let req = TestRequest::default()
296            .insert_header(("origin", "https://myapp.com"))
297            .to_http_request();
298        req.extensions_mut().insert(user);
299        req
300    }
301
302    #[test]
303    fn test_default_config_allows_all() {
304        let config = WebSocketSecurityConfig::new();
305        let req = TestRequest::default()
306            .insert_header(("origin", "https://any-origin.com"))
307            .to_http_request();
308
309        assert!(config.validate_upgrade(&req).is_ok());
310    }
311
312    #[test]
313    fn test_origin_validation() {
314        let config =
315            WebSocketSecurityConfig::new().allowed_origins(vec!["https://myapp.com".into()]);
316
317        // Valid origin
318        let req = TestRequest::default()
319            .insert_header(("origin", "https://myapp.com"))
320            .to_http_request();
321        assert!(config.validate_upgrade(&req).is_ok());
322
323        // Invalid origin
324        let req = TestRequest::default()
325            .insert_header(("origin", "https://evil.com"))
326            .to_http_request();
327        assert!(config.validate_upgrade(&req).is_err());
328    }
329
330    #[test]
331    fn test_authentication_requirement() {
332        let config = WebSocketSecurityConfig::new()
333            .origin_validator(OriginValidator::allow_any())
334            .require_authentication(true);
335
336        // Without user
337        let req = TestRequest::default().to_http_request();
338        assert!(matches!(
339            config.validate_upgrade(&req),
340            Err(WebSocketSecurityError::Unauthorized)
341        ));
342
343        // With user
344        let user = User::new("testuser".into(), "password".into());
345        let req = create_request_with_user(user);
346        assert!(config.validate_upgrade(&req).is_ok());
347    }
348
349    #[test]
350    fn test_role_requirement() {
351        let config = WebSocketSecurityConfig::new()
352            .origin_validator(OriginValidator::allow_any())
353            .required_roles(vec!["ADMIN".into()]);
354
355        // User without required role
356        let user = User::new("user".into(), "password".into()).roles(&["USER".into()]);
357        let req = create_request_with_user(user);
358        assert!(matches!(
359            config.validate_upgrade(&req),
360            Err(WebSocketSecurityError::MissingRole { .. })
361        ));
362
363        // User with required role
364        let admin = User::new("admin".into(), "password".into()).roles(&["ADMIN".into()]);
365        let req = create_request_with_user(admin);
366        assert!(config.validate_upgrade(&req).is_ok());
367    }
368
369    #[test]
370    fn test_authority_requirement() {
371        let config = WebSocketSecurityConfig::new()
372            .origin_validator(OriginValidator::allow_any())
373            .required_authorities(vec!["ws:connect".into()]);
374
375        // User without required authority
376        let user = User::new("user".into(), "password".into());
377        let req = create_request_with_user(user);
378        assert!(matches!(
379            config.validate_upgrade(&req),
380            Err(WebSocketSecurityError::MissingAuthority { .. })
381        ));
382
383        // User with required authority
384        let ws_user =
385            User::new("user".into(), "password".into()).authorities(&["ws:connect".into()]);
386        let req = create_request_with_user(ws_user);
387        assert!(config.validate_upgrade(&req).is_ok());
388    }
389
390    #[test]
391    fn test_combined_requirements() {
392        let config = WebSocketSecurityConfig::new()
393            .allowed_origins(vec!["https://myapp.com".into()])
394            .required_roles(vec!["USER".into()])
395            .required_authorities(vec!["ws:connect".into()]);
396
397        // User with all requirements met
398        let user = User::new("testuser".into(), "password".into())
399            .roles(&["USER".into()])
400            .authorities(&["ws:connect".into()]);
401
402        let req = TestRequest::default()
403            .insert_header(("origin", "https://myapp.com"))
404            .to_http_request();
405        req.extensions_mut().insert(user);
406
407        assert!(config.validate_upgrade(&req).is_ok());
408    }
409
410    #[test]
411    fn test_builder_pattern() {
412        let config = WebSocketSecurityConfigBuilder::new()
413            .allowed_origins(vec!["https://myapp.com".into()])
414            .require_authentication()
415            .required_roles(vec!["USER".into()])
416            .build();
417
418        let user = User::new("user".into(), "password".into()).roles(&["USER".into()]);
419        let req = TestRequest::default()
420            .insert_header(("origin", "https://myapp.com"))
421            .to_http_request();
422        req.extensions_mut().insert(user);
423
424        assert!(config.validate_upgrade(&req).is_ok());
425    }
426}