actix_security_core/http/security/websocket/
config.rs1use actix_web::{HttpMessage, HttpRequest};
7
8use crate::http::security::User;
9
10use super::error::WebSocketSecurityError;
11use super::extractor::WebSocketUpgrade;
12use super::origin::OriginValidator;
13
14#[derive(Debug, Clone)]
45pub struct WebSocketSecurityConfig {
46 origin_validator: OriginValidator,
48 require_authentication: bool,
50 required_roles: Vec<String>,
52 required_authorities: Vec<String>,
54}
55
56impl Default for WebSocketSecurityConfig {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl WebSocketSecurityConfig {
63 pub fn new() -> Self {
70 Self {
71 origin_validator: OriginValidator::allow_any(),
72 require_authentication: false,
73 required_roles: Vec::new(),
74 required_authorities: Vec::new(),
75 }
76 }
77
78 pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
89 let origins_refs: Vec<&str> = origins.iter().map(|s| s.as_str()).collect();
90 self.origin_validator = OriginValidator::new(&origins_refs);
91 self
92 }
93
94 pub fn origin_validator(mut self, validator: OriginValidator) -> Self {
107 self.origin_validator = validator;
108 self
109 }
110
111 pub fn require_authentication(mut self, require: bool) -> Self {
121 self.require_authentication = require;
122 self
123 }
124
125 pub fn required_roles(mut self, roles: Vec<String>) -> Self {
136 self.required_roles = roles;
137 if !self.required_roles.is_empty() {
138 self.require_authentication = true;
139 }
140 self
141 }
142
143 pub fn required_authorities(mut self, authorities: Vec<String>) -> Self {
154 self.required_authorities = authorities;
155 if !self.required_authorities.is_empty() {
156 self.require_authentication = true;
157 }
158 self
159 }
160
161 pub fn validate_upgrade(
187 &self,
188 req: &HttpRequest,
189 ) -> Result<WebSocketUpgrade, WebSocketSecurityError> {
190 self.origin_validator.validate(req)?;
192
193 let user = req.extensions().get::<User>().cloned();
195
196 if self.require_authentication && user.is_none() {
198 return Err(WebSocketSecurityError::Unauthorized);
199 }
200
201 if !self.required_roles.is_empty() {
203 let roles_refs: Vec<&str> = self.required_roles.iter().map(|s| s.as_str()).collect();
204 if !user.as_ref().is_some_and(|u| u.has_any_role(&roles_refs)) {
205 return Err(WebSocketSecurityError::MissingRole {
206 role: self.required_roles.join(", "),
207 });
208 }
209 }
210
211 if !self.required_authorities.is_empty() {
213 let auth_refs: Vec<&str> = self
214 .required_authorities
215 .iter()
216 .map(|s| s.as_str())
217 .collect();
218 if !user
219 .as_ref()
220 .is_some_and(|u| u.has_any_authority(&auth_refs))
221 {
222 return Err(WebSocketSecurityError::MissingAuthority {
223 authority: self.required_authorities.join(", "),
224 });
225 }
226 }
227
228 let origin = req
230 .headers()
231 .get("origin")
232 .and_then(|h| h.to_str().ok())
233 .map(|s| s.to_string());
234
235 Ok(WebSocketUpgrade::new(user, origin))
236 }
237}
238
239#[derive(Debug, Clone, Default)]
241pub struct WebSocketSecurityConfigBuilder {
242 config: WebSocketSecurityConfig,
243}
244
245impl WebSocketSecurityConfigBuilder {
246 pub fn new() -> Self {
248 Self {
249 config: WebSocketSecurityConfig::new(),
250 }
251 }
252
253 pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
255 self.config = self.config.allowed_origins(origins);
256 self
257 }
258
259 pub fn origin_validator(mut self, validator: OriginValidator) -> Self {
261 self.config = self.config.origin_validator(validator);
262 self
263 }
264
265 pub fn require_authentication(mut self) -> Self {
267 self.config = self.config.require_authentication(true);
268 self
269 }
270
271 pub fn required_roles(mut self, roles: Vec<String>) -> Self {
273 self.config = self.config.required_roles(roles);
274 self
275 }
276
277 pub fn required_authorities(mut self, authorities: Vec<String>) -> Self {
279 self.config = self.config.required_authorities(authorities);
280 self
281 }
282
283 pub fn build(self) -> WebSocketSecurityConfig {
285 self.config
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use actix_web::test::TestRequest;
293
294 fn create_request_with_user(user: User) -> HttpRequest {
295 let req = TestRequest::default()
296 .insert_header(("origin", "https://myapp.com"))
297 .to_http_request();
298 req.extensions_mut().insert(user);
299 req
300 }
301
302 #[test]
303 fn test_default_config_allows_all() {
304 let config = WebSocketSecurityConfig::new();
305 let req = TestRequest::default()
306 .insert_header(("origin", "https://any-origin.com"))
307 .to_http_request();
308
309 assert!(config.validate_upgrade(&req).is_ok());
310 }
311
312 #[test]
313 fn test_origin_validation() {
314 let config =
315 WebSocketSecurityConfig::new().allowed_origins(vec!["https://myapp.com".into()]);
316
317 let req = TestRequest::default()
319 .insert_header(("origin", "https://myapp.com"))
320 .to_http_request();
321 assert!(config.validate_upgrade(&req).is_ok());
322
323 let req = TestRequest::default()
325 .insert_header(("origin", "https://evil.com"))
326 .to_http_request();
327 assert!(config.validate_upgrade(&req).is_err());
328 }
329
330 #[test]
331 fn test_authentication_requirement() {
332 let config = WebSocketSecurityConfig::new()
333 .origin_validator(OriginValidator::allow_any())
334 .require_authentication(true);
335
336 let req = TestRequest::default().to_http_request();
338 assert!(matches!(
339 config.validate_upgrade(&req),
340 Err(WebSocketSecurityError::Unauthorized)
341 ));
342
343 let user = User::new("testuser".into(), "password".into());
345 let req = create_request_with_user(user);
346 assert!(config.validate_upgrade(&req).is_ok());
347 }
348
349 #[test]
350 fn test_role_requirement() {
351 let config = WebSocketSecurityConfig::new()
352 .origin_validator(OriginValidator::allow_any())
353 .required_roles(vec!["ADMIN".into()]);
354
355 let user = User::new("user".into(), "password".into()).roles(&["USER".into()]);
357 let req = create_request_with_user(user);
358 assert!(matches!(
359 config.validate_upgrade(&req),
360 Err(WebSocketSecurityError::MissingRole { .. })
361 ));
362
363 let admin = User::new("admin".into(), "password".into()).roles(&["ADMIN".into()]);
365 let req = create_request_with_user(admin);
366 assert!(config.validate_upgrade(&req).is_ok());
367 }
368
369 #[test]
370 fn test_authority_requirement() {
371 let config = WebSocketSecurityConfig::new()
372 .origin_validator(OriginValidator::allow_any())
373 .required_authorities(vec!["ws:connect".into()]);
374
375 let user = User::new("user".into(), "password".into());
377 let req = create_request_with_user(user);
378 assert!(matches!(
379 config.validate_upgrade(&req),
380 Err(WebSocketSecurityError::MissingAuthority { .. })
381 ));
382
383 let ws_user =
385 User::new("user".into(), "password".into()).authorities(&["ws:connect".into()]);
386 let req = create_request_with_user(ws_user);
387 assert!(config.validate_upgrade(&req).is_ok());
388 }
389
390 #[test]
391 fn test_combined_requirements() {
392 let config = WebSocketSecurityConfig::new()
393 .allowed_origins(vec!["https://myapp.com".into()])
394 .required_roles(vec!["USER".into()])
395 .required_authorities(vec!["ws:connect".into()]);
396
397 let user = User::new("testuser".into(), "password".into())
399 .roles(&["USER".into()])
400 .authorities(&["ws:connect".into()]);
401
402 let req = TestRequest::default()
403 .insert_header(("origin", "https://myapp.com"))
404 .to_http_request();
405 req.extensions_mut().insert(user);
406
407 assert!(config.validate_upgrade(&req).is_ok());
408 }
409
410 #[test]
411 fn test_builder_pattern() {
412 let config = WebSocketSecurityConfigBuilder::new()
413 .allowed_origins(vec!["https://myapp.com".into()])
414 .require_authentication()
415 .required_roles(vec!["USER".into()])
416 .build();
417
418 let user = User::new("user".into(), "password".into()).roles(&["USER".into()]);
419 let req = TestRequest::default()
420 .insert_header(("origin", "https://myapp.com"))
421 .to_http_request();
422 req.extensions_mut().insert(user);
423
424 assert!(config.validate_upgrade(&req).is_ok());
425 }
426}