1use crate::connection::{WebSocketConnection, WebSocketError, WebSocketResult};
7use async_trait::async_trait;
8use std::sync::Arc;
9
10pub type AuthResult<T> = Result<T, AuthError>;
12
13#[derive(Debug, thiserror::Error)]
15pub enum AuthError {
16 #[error("Authentication failed: {0}")]
17 AuthenticationFailed(String),
18 #[error("Authorization denied: {0}")]
19 AuthorizationDenied(String),
20 #[error("Invalid credentials")]
21 InvalidCredentials,
22 #[error("Token expired")]
23 TokenExpired,
24 #[error("Missing authentication")]
25 MissingAuthentication,
26}
27
28pub trait AuthUser: Send + Sync + std::fmt::Debug {
30 fn id(&self) -> &str;
32 fn username(&self) -> &str;
34 fn is_authenticated(&self) -> bool;
36 fn has_permission(&self, permission: &str) -> bool;
38}
39
40#[derive(Debug, Clone)]
60pub struct SimpleAuthUser {
61 id: String,
62 username: String,
63 permissions: Vec<String>,
64}
65
66impl SimpleAuthUser {
67 pub fn new(id: String, username: String, permissions: Vec<String>) -> Self {
69 Self {
70 id,
71 username,
72 permissions,
73 }
74 }
75}
76
77impl AuthUser for SimpleAuthUser {
78 fn id(&self) -> &str {
79 &self.id
80 }
81
82 fn username(&self) -> &str {
83 &self.username
84 }
85
86 fn is_authenticated(&self) -> bool {
87 !self.id.is_empty()
88 }
89
90 fn has_permission(&self, permission: &str) -> bool {
91 self.permissions.contains(&permission.to_string())
92 }
93}
94
95#[async_trait]
99pub trait WebSocketAuthenticator: Send + Sync {
100 async fn authenticate(
111 &self,
112 connection: &Arc<WebSocketConnection>,
113 credentials: &str,
114 ) -> AuthResult<Box<dyn AuthUser>>;
115}
116
117pub struct TokenAuthenticator {
144 tokens: std::collections::HashMap<String, SimpleAuthUser>,
145}
146
147impl TokenAuthenticator {
148 pub fn new(tokens: Vec<(String, SimpleAuthUser)>) -> Self {
150 Self {
151 tokens: tokens.into_iter().collect(),
152 }
153 }
154
155 pub fn add_token(&mut self, token: String, user: SimpleAuthUser) {
157 self.tokens.insert(token, user);
158 }
159
160 pub fn remove_token(&mut self, token: &str) -> Option<SimpleAuthUser> {
162 self.tokens.remove(token)
163 }
164}
165
166#[async_trait]
167impl WebSocketAuthenticator for TokenAuthenticator {
168 async fn authenticate(
169 &self,
170 _connection: &Arc<WebSocketConnection>,
171 credentials: &str,
172 ) -> AuthResult<Box<dyn AuthUser>> {
173 self.tokens
174 .get(credentials)
175 .map(|user| Box::new(user.clone()) as Box<dyn AuthUser>)
176 .ok_or(AuthError::InvalidCredentials)
177 }
178}
179
180#[async_trait]
182pub trait AuthorizationPolicy: Send + Sync {
183 async fn authorize(
195 &self,
196 user: &dyn AuthUser,
197 action: &str,
198 resource: Option<&str>,
199 ) -> AuthResult<()>;
200}
201
202pub struct PermissionBasedPolicy {
231 action_permissions: std::collections::HashMap<String, String>,
232}
233
234impl PermissionBasedPolicy {
235 pub fn new(action_permissions: Vec<(String, String)>) -> Self {
237 Self {
238 action_permissions: action_permissions.into_iter().collect(),
239 }
240 }
241
242 pub fn add_permission(&mut self, action: String, permission: String) {
244 self.action_permissions.insert(action, permission);
245 }
246}
247
248#[async_trait]
249impl AuthorizationPolicy for PermissionBasedPolicy {
250 async fn authorize(
251 &self,
252 user: &dyn AuthUser,
253 action: &str,
254 _resource: Option<&str>,
255 ) -> AuthResult<()> {
256 let required_permission = self
257 .action_permissions
258 .get(action)
259 .ok_or_else(|| AuthError::AuthorizationDenied(format!("Unknown action: {}", action)))?;
260
261 if user.has_permission(required_permission) {
262 Ok(())
263 } else {
264 Err(AuthError::AuthorizationDenied(format!(
265 "Missing permission: {}",
266 required_permission
267 )))
268 }
269 }
270}
271
272pub struct AuthenticatedConnection {
294 connection: Arc<WebSocketConnection>,
295 user: Box<dyn AuthUser>,
296}
297
298impl AuthenticatedConnection {
299 pub fn new(connection: Arc<WebSocketConnection>, user: Box<dyn AuthUser>) -> Self {
301 Self { connection, user }
302 }
303
304 pub fn connection(&self) -> &Arc<WebSocketConnection> {
306 &self.connection
307 }
308
309 pub fn user(&self) -> &dyn AuthUser {
311 self.user.as_ref()
312 }
313
314 pub async fn send_with_auth<P: AuthorizationPolicy>(
316 &self,
317 message: crate::connection::Message,
318 policy: &P,
319 ) -> WebSocketResult<()> {
320 policy
321 .authorize(self.user.as_ref(), "send_message", None)
322 .await
323 .map_err(|_| WebSocketError::Protocol("authorization failed".to_string()))?;
324
325 self.connection.send(message).await
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use crate::connection::Message;
333 use tokio::sync::mpsc;
334
335 #[test]
336 fn test_simple_auth_user() {
337 let user = SimpleAuthUser::new(
338 "user_123".to_string(),
339 "alice".to_string(),
340 vec!["read".to_string(), "write".to_string()],
341 );
342
343 assert_eq!(user.id(), "user_123");
344 assert_eq!(user.username(), "alice");
345 assert!(user.is_authenticated());
346 assert!(user.has_permission("read"));
347 assert!(user.has_permission("write"));
348 assert!(!user.has_permission("admin"));
349 }
350
351 #[tokio::test]
352 async fn test_token_authenticator_valid() {
353 let user = SimpleAuthUser::new(
354 "user_1".to_string(),
355 "alice".to_string(),
356 vec!["chat.read".to_string()],
357 );
358
359 let authenticator = TokenAuthenticator::new(vec![("token123".to_string(), user)]);
360
361 let (tx, _rx) = mpsc::unbounded_channel();
362 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
363
364 let auth_user = authenticator.authenticate(&conn, "token123").await.unwrap();
365 assert_eq!(auth_user.username(), "alice");
366 }
367
368 #[tokio::test]
369 async fn test_token_authenticator_invalid() {
370 let authenticator = TokenAuthenticator::new(vec![]);
371
372 let (tx, _rx) = mpsc::unbounded_channel();
373 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
374
375 let result = authenticator.authenticate(&conn, "invalid_token").await;
376 assert!(result.is_err());
377 assert!(matches!(result.unwrap_err(), AuthError::InvalidCredentials));
378 }
379
380 #[tokio::test]
381 async fn test_permission_based_policy_authorized() {
382 let policy = PermissionBasedPolicy::new(vec![(
383 "send_message".to_string(),
384 "chat.write".to_string(),
385 )]);
386
387 let user = SimpleAuthUser::new(
388 "user_1".to_string(),
389 "alice".to_string(),
390 vec!["chat.write".to_string()],
391 );
392
393 let result = policy.authorize(&user, "send_message", None).await;
394 assert!(result.is_ok());
395 }
396
397 #[tokio::test]
398 async fn test_permission_based_policy_denied() {
399 let policy = PermissionBasedPolicy::new(vec![(
400 "delete_message".to_string(),
401 "chat.admin".to_string(),
402 )]);
403
404 let user = SimpleAuthUser::new(
405 "user_1".to_string(),
406 "alice".to_string(),
407 vec!["chat.write".to_string()],
408 );
409
410 let result = policy.authorize(&user, "delete_message", None).await;
411 assert!(result.is_err());
412 assert!(matches!(
413 result.unwrap_err(),
414 AuthError::AuthorizationDenied(_)
415 ));
416 }
417
418 #[tokio::test]
419 async fn test_authenticated_connection_send_with_auth() {
420 let policy = PermissionBasedPolicy::new(vec![(
421 "send_message".to_string(),
422 "chat.write".to_string(),
423 )]);
424
425 let user = SimpleAuthUser::new(
426 "user_1".to_string(),
427 "alice".to_string(),
428 vec!["chat.write".to_string()],
429 );
430
431 let (tx, mut rx) = mpsc::unbounded_channel();
432 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
433 let auth_conn = AuthenticatedConnection::new(conn, Box::new(user));
434
435 let msg = Message::text("Hello".to_string());
436 auth_conn.send_with_auth(msg, &policy).await.unwrap();
437
438 assert!(matches!(rx.try_recv(), Ok(Message::Text { .. })));
439 }
440
441 #[tokio::test]
442 async fn test_authenticated_connection_send_with_auth_denied() {
443 let policy = PermissionBasedPolicy::new(vec![(
444 "send_message".to_string(),
445 "chat.admin".to_string(),
446 )]);
447
448 let user = SimpleAuthUser::new(
449 "user_1".to_string(),
450 "alice".to_string(),
451 vec!["chat.write".to_string()],
452 );
453
454 let (tx, _rx) = mpsc::unbounded_channel();
455 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
456 let auth_conn = AuthenticatedConnection::new(conn, Box::new(user));
457
458 let msg = Message::text("Hello".to_string());
459 let result = auth_conn.send_with_auth(msg, &policy).await;
460
461 assert!(result.is_err());
462 }
463}