Skip to main content

reinhardt_websockets/
auth.rs

1//! WebSocket authentication and authorization
2//!
3//! This module provides authentication and authorization hooks for WebSocket connections,
4//! integrating with Reinhardt's auth system.
5
6use crate::connection::{WebSocketConnection, WebSocketError, WebSocketResult};
7use async_trait::async_trait;
8use std::sync::Arc;
9
10/// Authentication result for WebSocket connections
11pub type AuthResult<T> = Result<T, AuthError>;
12
13/// Authentication errors
14#[derive(Debug, thiserror::Error)]
15pub enum AuthError {
16	#[error("Authentication failed: {0}")]
17	AuthenticationFailed(String),
18	#[error("Authorization denied: {0}")]
19	AuthorizationDenied(String),
20	#[error("Invalid credentials")]
21	InvalidCredentials,
22	#[error("Token expired")]
23	TokenExpired,
24	#[error("Missing authentication")]
25	MissingAuthentication,
26}
27
28/// Authenticated user information
29pub trait AuthUser: Send + Sync + std::fmt::Debug {
30	/// Get user identifier
31	fn id(&self) -> &str;
32	/// Get username
33	fn username(&self) -> &str;
34	/// Check if user is authenticated
35	fn is_authenticated(&self) -> bool;
36	/// Check if user has specific permission
37	fn has_permission(&self, permission: &str) -> bool;
38}
39
40/// Simple user implementation for WebSocket authentication
41///
42/// # Examples
43///
44/// ```
45/// use reinhardt_websockets::auth::{SimpleAuthUser, AuthUser};
46///
47/// let user = SimpleAuthUser::new(
48///     "user_123".to_string(),
49///     "alice".to_string(),
50///     vec!["chat.read".to_string(), "chat.write".to_string()],
51/// );
52///
53/// assert_eq!(user.id(), "user_123");
54/// assert_eq!(user.username(), "alice");
55/// assert!(user.is_authenticated());
56/// assert!(user.has_permission("chat.read"));
57/// assert!(!user.has_permission("admin.access"));
58/// ```
59#[derive(Debug, Clone)]
60pub struct SimpleAuthUser {
61	id: String,
62	username: String,
63	permissions: Vec<String>,
64}
65
66impl SimpleAuthUser {
67	/// Create a new authenticated user
68	pub fn new(id: String, username: String, permissions: Vec<String>) -> Self {
69		Self {
70			id,
71			username,
72			permissions,
73		}
74	}
75}
76
77impl AuthUser for SimpleAuthUser {
78	fn id(&self) -> &str {
79		&self.id
80	}
81
82	fn username(&self) -> &str {
83		&self.username
84	}
85
86	fn is_authenticated(&self) -> bool {
87		!self.id.is_empty()
88	}
89
90	fn has_permission(&self, permission: &str) -> bool {
91		self.permissions.contains(&permission.to_string())
92	}
93}
94
95/// WebSocket authenticator trait
96///
97/// Implementors define how to authenticate WebSocket connections.
98#[async_trait]
99pub trait WebSocketAuthenticator: Send + Sync {
100	/// Authenticate a WebSocket connection
101	///
102	/// # Arguments
103	///
104	/// * `connection` - The WebSocket connection to authenticate
105	/// * `credentials` - Authentication credentials (e.g., token, cookie)
106	///
107	/// # Returns
108	///
109	/// Returns the authenticated user on success, or an error on failure.
110	async fn authenticate(
111		&self,
112		connection: &Arc<WebSocketConnection>,
113		credentials: &str,
114	) -> AuthResult<Box<dyn AuthUser>>;
115}
116
117/// Token-based WebSocket authenticator
118///
119/// # Examples
120///
121/// ```
122/// use reinhardt_websockets::auth::{TokenAuthenticator, WebSocketAuthenticator, SimpleAuthUser};
123/// use reinhardt_websockets::WebSocketConnection;
124/// use tokio::sync::mpsc;
125/// use std::sync::Arc;
126///
127/// # tokio_test::block_on(async {
128/// let authenticator = TokenAuthenticator::new(vec![
129///     ("valid_token".to_string(), SimpleAuthUser::new(
130///         "user_1".to_string(),
131///         "alice".to_string(),
132///         vec!["chat.read".to_string()],
133///     )),
134/// ]);
135///
136/// let (tx, _rx) = mpsc::unbounded_channel();
137/// let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
138///
139/// let user = authenticator.authenticate(&conn, "valid_token").await.unwrap();
140/// assert_eq!(user.username(), "alice");
141/// # });
142/// ```
143pub struct TokenAuthenticator {
144	tokens: std::collections::HashMap<String, SimpleAuthUser>,
145}
146
147impl TokenAuthenticator {
148	/// Create a new token authenticator with predefined tokens
149	pub fn new(tokens: Vec<(String, SimpleAuthUser)>) -> Self {
150		Self {
151			tokens: tokens.into_iter().collect(),
152		}
153	}
154
155	/// Add a token to the authenticator
156	pub fn add_token(&mut self, token: String, user: SimpleAuthUser) {
157		self.tokens.insert(token, user);
158	}
159
160	/// Remove a token from the authenticator
161	pub fn remove_token(&mut self, token: &str) -> Option<SimpleAuthUser> {
162		self.tokens.remove(token)
163	}
164}
165
166#[async_trait]
167impl WebSocketAuthenticator for TokenAuthenticator {
168	async fn authenticate(
169		&self,
170		_connection: &Arc<WebSocketConnection>,
171		credentials: &str,
172	) -> AuthResult<Box<dyn AuthUser>> {
173		self.tokens
174			.get(credentials)
175			.map(|user| Box::new(user.clone()) as Box<dyn AuthUser>)
176			.ok_or(AuthError::InvalidCredentials)
177	}
178}
179
180/// Authorization policy for WebSocket messages
181#[async_trait]
182pub trait AuthorizationPolicy: Send + Sync {
183	/// Check if a user is authorized to perform an action
184	///
185	/// # Arguments
186	///
187	/// * `user` - The authenticated user
188	/// * `action` - The action to authorize (e.g., "send_message", "join_room")
189	/// * `resource` - Optional resource identifier (e.g., room ID)
190	///
191	/// # Returns
192	///
193	/// Returns `Ok(())` if authorized, or an error if denied.
194	async fn authorize(
195		&self,
196		user: &dyn AuthUser,
197		action: &str,
198		resource: Option<&str>,
199	) -> AuthResult<()>;
200}
201
202/// Permission-based authorization policy
203///
204/// # Examples
205///
206/// ```
207/// use reinhardt_websockets::auth::{
208///     PermissionBasedPolicy, AuthorizationPolicy, SimpleAuthUser
209/// };
210///
211/// # tokio_test::block_on(async {
212/// let policy = PermissionBasedPolicy::new(vec![
213///     ("send_message".to_string(), "chat.write".to_string()),
214///     ("delete_message".to_string(), "chat.admin".to_string()),
215/// ]);
216///
217/// let user = SimpleAuthUser::new(
218///     "user_1".to_string(),
219///     "alice".to_string(),
220///     vec!["chat.write".to_string()],
221/// );
222///
223/// // User can send messages
224/// assert!(policy.authorize(&user, "send_message", None).await.is_ok());
225///
226/// // User cannot delete messages (lacks chat.admin permission)
227/// assert!(policy.authorize(&user, "delete_message", None).await.is_err());
228/// # });
229/// ```
230pub struct PermissionBasedPolicy {
231	action_permissions: std::collections::HashMap<String, String>,
232}
233
234impl PermissionBasedPolicy {
235	/// Create a new permission-based policy
236	pub fn new(action_permissions: Vec<(String, String)>) -> Self {
237		Self {
238			action_permissions: action_permissions.into_iter().collect(),
239		}
240	}
241
242	/// Add an action-permission mapping
243	pub fn add_permission(&mut self, action: String, permission: String) {
244		self.action_permissions.insert(action, permission);
245	}
246}
247
248#[async_trait]
249impl AuthorizationPolicy for PermissionBasedPolicy {
250	async fn authorize(
251		&self,
252		user: &dyn AuthUser,
253		action: &str,
254		_resource: Option<&str>,
255	) -> AuthResult<()> {
256		let required_permission = self
257			.action_permissions
258			.get(action)
259			.ok_or_else(|| AuthError::AuthorizationDenied(format!("Unknown action: {}", action)))?;
260
261		if user.has_permission(required_permission) {
262			Ok(())
263		} else {
264			Err(AuthError::AuthorizationDenied(format!(
265				"Missing permission: {}",
266				required_permission
267			)))
268		}
269	}
270}
271
272/// Authenticated WebSocket connection wrapper
273///
274/// # Examples
275///
276/// ```
277/// use reinhardt_websockets::auth::{AuthenticatedConnection, SimpleAuthUser};
278/// use reinhardt_websockets::WebSocketConnection;
279/// use tokio::sync::mpsc;
280/// use std::sync::Arc;
281///
282/// let (tx, _rx) = mpsc::unbounded_channel();
283/// let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
284/// let user = SimpleAuthUser::new(
285///     "user_1".to_string(),
286///     "alice".to_string(),
287///     vec!["chat.read".to_string()],
288/// );
289///
290/// let auth_conn = AuthenticatedConnection::new(conn, Box::new(user));
291/// assert_eq!(auth_conn.user().username(), "alice");
292/// ```
293pub struct AuthenticatedConnection {
294	connection: Arc<WebSocketConnection>,
295	user: Box<dyn AuthUser>,
296}
297
298impl AuthenticatedConnection {
299	/// Create a new authenticated connection
300	pub fn new(connection: Arc<WebSocketConnection>, user: Box<dyn AuthUser>) -> Self {
301		Self { connection, user }
302	}
303
304	/// Get the underlying WebSocket connection
305	pub fn connection(&self) -> &Arc<WebSocketConnection> {
306		&self.connection
307	}
308
309	/// Get the authenticated user
310	pub fn user(&self) -> &dyn AuthUser {
311		self.user.as_ref()
312	}
313
314	/// Send a message with authorization check
315	pub async fn send_with_auth<P: AuthorizationPolicy>(
316		&self,
317		message: crate::connection::Message,
318		policy: &P,
319	) -> WebSocketResult<()> {
320		policy
321			.authorize(self.user.as_ref(), "send_message", None)
322			.await
323			.map_err(|_| WebSocketError::Protocol("authorization failed".to_string()))?;
324
325		self.connection.send(message).await
326	}
327}
328
329#[cfg(test)]
330mod tests {
331	use super::*;
332	use crate::connection::Message;
333	use tokio::sync::mpsc;
334
335	#[test]
336	fn test_simple_auth_user() {
337		let user = SimpleAuthUser::new(
338			"user_123".to_string(),
339			"alice".to_string(),
340			vec!["read".to_string(), "write".to_string()],
341		);
342
343		assert_eq!(user.id(), "user_123");
344		assert_eq!(user.username(), "alice");
345		assert!(user.is_authenticated());
346		assert!(user.has_permission("read"));
347		assert!(user.has_permission("write"));
348		assert!(!user.has_permission("admin"));
349	}
350
351	#[tokio::test]
352	async fn test_token_authenticator_valid() {
353		let user = SimpleAuthUser::new(
354			"user_1".to_string(),
355			"alice".to_string(),
356			vec!["chat.read".to_string()],
357		);
358
359		let authenticator = TokenAuthenticator::new(vec![("token123".to_string(), user)]);
360
361		let (tx, _rx) = mpsc::unbounded_channel();
362		let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
363
364		let auth_user = authenticator.authenticate(&conn, "token123").await.unwrap();
365		assert_eq!(auth_user.username(), "alice");
366	}
367
368	#[tokio::test]
369	async fn test_token_authenticator_invalid() {
370		let authenticator = TokenAuthenticator::new(vec![]);
371
372		let (tx, _rx) = mpsc::unbounded_channel();
373		let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
374
375		let result = authenticator.authenticate(&conn, "invalid_token").await;
376		assert!(result.is_err());
377		assert!(matches!(result.unwrap_err(), AuthError::InvalidCredentials));
378	}
379
380	#[tokio::test]
381	async fn test_permission_based_policy_authorized() {
382		let policy = PermissionBasedPolicy::new(vec![(
383			"send_message".to_string(),
384			"chat.write".to_string(),
385		)]);
386
387		let user = SimpleAuthUser::new(
388			"user_1".to_string(),
389			"alice".to_string(),
390			vec!["chat.write".to_string()],
391		);
392
393		let result = policy.authorize(&user, "send_message", None).await;
394		assert!(result.is_ok());
395	}
396
397	#[tokio::test]
398	async fn test_permission_based_policy_denied() {
399		let policy = PermissionBasedPolicy::new(vec![(
400			"delete_message".to_string(),
401			"chat.admin".to_string(),
402		)]);
403
404		let user = SimpleAuthUser::new(
405			"user_1".to_string(),
406			"alice".to_string(),
407			vec!["chat.write".to_string()],
408		);
409
410		let result = policy.authorize(&user, "delete_message", None).await;
411		assert!(result.is_err());
412		assert!(matches!(
413			result.unwrap_err(),
414			AuthError::AuthorizationDenied(_)
415		));
416	}
417
418	#[tokio::test]
419	async fn test_authenticated_connection_send_with_auth() {
420		let policy = PermissionBasedPolicy::new(vec![(
421			"send_message".to_string(),
422			"chat.write".to_string(),
423		)]);
424
425		let user = SimpleAuthUser::new(
426			"user_1".to_string(),
427			"alice".to_string(),
428			vec!["chat.write".to_string()],
429		);
430
431		let (tx, mut rx) = mpsc::unbounded_channel();
432		let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
433		let auth_conn = AuthenticatedConnection::new(conn, Box::new(user));
434
435		let msg = Message::text("Hello".to_string());
436		auth_conn.send_with_auth(msg, &policy).await.unwrap();
437
438		assert!(matches!(rx.try_recv(), Ok(Message::Text { .. })));
439	}
440
441	#[tokio::test]
442	async fn test_authenticated_connection_send_with_auth_denied() {
443		let policy = PermissionBasedPolicy::new(vec![(
444			"send_message".to_string(),
445			"chat.admin".to_string(),
446		)]);
447
448		let user = SimpleAuthUser::new(
449			"user_1".to_string(),
450			"alice".to_string(),
451			vec!["chat.write".to_string()],
452		);
453
454		let (tx, _rx) = mpsc::unbounded_channel();
455		let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
456		let auth_conn = AuthenticatedConnection::new(conn, Box::new(user));
457
458		let msg = Message::text("Hello".to_string());
459		let result = auth_conn.send_with_auth(msg, &policy).await;
460
461		assert!(result.is_err());
462	}
463}