Skip to main content

reinhardt_core/
ws.rs

1//! WebSocket routing primitives shared across reinhardt crates.
2//!
3//! This module provides foundational WebSocket types used by both
4//! `reinhardt-websockets` (connection handling) and `reinhardt-urls`
5//! (`UnifiedRouter::websocket()` builder). Placing them here avoids a
6//! circular dependency between those two crates.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12// ── Endpoint metadata ─────────────────────────────────────────────────────
13
14/// Compile-time WebSocket endpoint metadata.
15///
16/// Implemented on the consumer struct generated by `#[websocket]`.
17/// Parallel to `EndpointInfo` for HTTP views.
18pub trait WebSocketEndpointInfo {
19	/// Returns the URL path pattern for this WebSocket endpoint.
20	fn path() -> &'static str;
21	/// Returns the optional route name for this WebSocket endpoint.
22	fn name() -> Option<&'static str>;
23}
24
25/// Inventory metadata submitted by `#[websocket]` at compile time.
26///
27/// Used by `impl WebSocketUrlResolver for ResolvedUrls` to resolve route names.
28pub struct WebSocketEndpointMetadata {
29	/// URL path pattern (e.g. `"/ws/chat/{room_id}/"`).
30	pub path: &'static str,
31	/// Route name used for URL reversal.
32	pub name: &'static str,
33	/// Handler function name for diagnostics.
34	pub fn_name: &'static str,
35	/// Rust module path of the handler for diagnostics.
36	pub module_path: &'static str,
37}
38
39inventory::collect!(WebSocketEndpointMetadata);
40
41/// Substitute path parameters in a WebSocket URL pattern.
42///
43/// `"/ws/chat/{room_id}/"` + `[("room_id", "42")]` → `"/ws/chat/42/"`
44pub fn substitute_ws_params(path: &str, params: &[(&str, &str)]) -> String {
45	let mut result = path.to_string();
46	for (name, value) in params {
47		result = result.replace(&format!("{{{}}}", name), value);
48	}
49	result
50}
51
52// ── Routing types ─────────────────────────────────────────────────────────
53
54/// Routing result type
55pub type RouteResult = Result<(), RouteError>;
56
57/// Routing errors for WebSocket routes
58#[derive(Debug, thiserror::Error)]
59pub enum RouteError {
60	/// No route registered for the given path.
61	#[error("Route not found: {0}")]
62	NotFound(String),
63	/// A route with the given path is already registered.
64	#[error("Route already exists: {0}")]
65	AlreadyExists(String),
66	/// The provided route pattern is syntactically invalid.
67	#[error("Invalid route pattern: {0}")]
68	InvalidPattern(String),
69}
70
71/// A registered WebSocket route (path + optional name + metadata).
72#[derive(Debug, Clone)]
73pub struct WebSocketRoute {
74	path: String,
75	name: Option<String>,
76	metadata: HashMap<String, String>,
77}
78
79impl WebSocketRoute {
80	/// Creates a new route with the given path and optional name.
81	pub fn new(path: String, name: Option<String>) -> Self {
82		Self {
83			path,
84			name,
85			metadata: HashMap::new(),
86		}
87	}
88
89	/// Returns the URL path pattern for this route.
90	pub fn path(&self) -> &str {
91		&self.path
92	}
93
94	/// Returns the optional name for this route.
95	pub fn name(&self) -> Option<&str> {
96		self.name.as_deref()
97	}
98
99	/// Attaches a key-value metadata entry to this route.
100	pub fn with_metadata(mut self, key: String, value: String) -> Self {
101		self.metadata.insert(key, value);
102		self
103	}
104
105	/// Returns the metadata value for the given key, if present.
106	pub fn get_metadata(&self, key: &str) -> Option<&String> {
107		self.metadata.get(key)
108	}
109}
110
111// ── WebSocketRouter ───────────────────────────────────────────────────────
112
113/// WebSocket router: build-time registration + runtime lookup.
114///
115/// The build-time API (`consumer()`, `reverse()`, `find_pending()`) is used
116/// by `#[url_patterns(mode = ws)]` and `UnifiedRouter::websocket()`.
117/// The async API (`register_route()`, `find_route()`, etc.) is used at
118/// connection-handling time in `reinhardt-websockets`.
119#[derive(Clone)]
120pub struct WebSocketRouter {
121	routes: Arc<RwLock<HashMap<String, WebSocketRoute>>>,
122	names: Arc<RwLock<HashMap<String, String>>>,
123	/// Build-time consumer registrations (added by `consumer()` builder).
124	pending_consumers: Vec<WebSocketRoute>,
125	/// Optional namespace (app label) for this router.
126	namespace: Option<String>,
127}
128
129impl WebSocketRouter {
130	/// Creates a new empty router.
131	pub fn new() -> Self {
132		Self {
133			routes: Arc::new(RwLock::new(HashMap::new())),
134			names: Arc::new(RwLock::new(HashMap::new())),
135			pending_consumers: Vec::new(),
136			namespace: None,
137		}
138	}
139
140	/// Set the namespace for this router.
141	///
142	/// Parallel to `ServerRouter::with_namespace`, emitted by
143	/// `#[url_patterns(mode = ws)]` to pass the `AppLabel::path(...)`.
144	/// WebSocket route paths are absolute today and are not rewritten
145	/// with this namespace; the value is stored for parity with other
146	/// routers and future use. See reinhardt-web#3829.
147	pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
148		self.namespace = Some(namespace.into());
149		self
150	}
151
152	/// Returns the namespace set via [`with_namespace`], if any.
153	///
154	/// [`with_namespace`]: Self::with_namespace
155	pub fn namespace(&self) -> Option<&str> {
156		self.namespace.as_deref()
157	}
158
159	/// Register a WebSocket consumer by its factory function.
160	///
161	/// Parallel to `ServerRouter::endpoint()`. Path and name are derived
162	/// from `C`'s `WebSocketEndpointInfo` impl at compile time.
163	pub fn consumer<C, F>(mut self, _f: F) -> Self
164	where
165		F: Fn() -> C,
166		C: WebSocketEndpointInfo + 'static,
167	{
168		self.pending_consumers.push(WebSocketRoute::new(
169			C::path().to_string(),
170			C::name().map(|s| s.to_string()),
171		));
172		self
173	}
174
175	/// Find a pending consumer route by name.
176	pub fn find_pending(&self, name: &str) -> Option<&WebSocketRoute> {
177		self.pending_consumers
178			.iter()
179			.find(|r| r.name() == Some(name))
180	}
181
182	/// Resolve a WebSocket URL by route name, substituting path parameters.
183	pub fn reverse(&self, name: &str, params: &[(&str, &str)]) -> Option<String> {
184		self.pending_consumers
185			.iter()
186			.find(|r| r.name() == Some(name))
187			.map(|r| substitute_ws_params(r.path(), params))
188	}
189
190	/// Register a route at runtime (async, used by the connection handler).
191	pub async fn register_route(&mut self, route: WebSocketRoute) -> RouteResult {
192		let mut routes = self.routes.write().await;
193		if routes.contains_key(&route.path) {
194			return Err(RouteError::AlreadyExists(route.path.clone()));
195		}
196		if let Some(name) = &route.name {
197			let mut names = self.names.write().await;
198			names.insert(name.clone(), route.path.clone());
199		}
200		routes.insert(route.path.clone(), route);
201		Ok(())
202	}
203
204	/// Looks up a registered route by its exact path.
205	pub async fn find_route(&self, path: &str) -> Option<WebSocketRoute> {
206		let routes = self.routes.read().await;
207		routes.get(path).cloned()
208	}
209
210	/// Looks up a registered route by its name.
211	pub async fn find_route_by_name(&self, name: &str) -> Option<WebSocketRoute> {
212		let names = self.names.read().await;
213		if let Some(path) = names.get(name) {
214			let routes = self.routes.read().await;
215			routes.get(path).cloned()
216		} else {
217			None
218		}
219	}
220
221	/// Removes the registered route for the given path.
222	pub async fn remove_route(&mut self, path: &str) -> RouteResult {
223		let mut routes = self.routes.write().await;
224		let route = routes
225			.remove(path)
226			.ok_or_else(|| RouteError::NotFound(path.to_string()))?;
227		if let Some(name) = &route.name {
228			let mut names = self.names.write().await;
229			names.remove(name);
230		}
231		Ok(())
232	}
233
234	/// Returns all currently registered routes.
235	pub async fn all_routes(&self) -> Vec<WebSocketRoute> {
236		let routes = self.routes.read().await;
237		routes.values().cloned().collect()
238	}
239
240	/// Returns `true` if a route is registered for the given path.
241	pub async fn has_route(&self, path: &str) -> bool {
242		self.routes.read().await.contains_key(path)
243	}
244
245	/// Returns the number of currently registered routes.
246	pub async fn route_count(&self) -> usize {
247		self.routes.read().await.len()
248	}
249
250	/// Removes all registered routes and name mappings.
251	pub async fn clear(&mut self) {
252		self.routes.write().await.clear();
253		self.names.write().await.clear();
254	}
255}
256
257impl Default for WebSocketRouter {
258	fn default() -> Self {
259		Self::new()
260	}
261}
262
263// ── Global registry ───────────────────────────────────────────────────────
264
265static GLOBAL_ROUTER: once_cell::sync::Lazy<Arc<RwLock<Option<WebSocketRouter>>>> =
266	once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(None)));
267
268/// Installs `router` as the process-wide WebSocket router.
269pub async fn register_websocket_router(router: WebSocketRouter) {
270	*GLOBAL_ROUTER.write().await = Some(router);
271}
272
273/// Returns a clone of the current process-wide WebSocket router, if set.
274pub async fn get_websocket_router() -> Option<WebSocketRouter> {
275	GLOBAL_ROUTER.read().await.clone()
276}
277
278/// Clears the process-wide WebSocket router (primarily for tests).
279pub async fn clear_websocket_router() {
280	*GLOBAL_ROUTER.write().await = None;
281}
282
283/// Resolves a registered or pending WebSocket URL by route name.
284pub async fn reverse_websocket_url(router: &WebSocketRouter, name: &str) -> Option<String> {
285	let names = router.names.read().await;
286	if let Some(path) = names.get(name) {
287		let routes = router.routes.read().await;
288		routes.get(path).map(|r| r.path().to_string())
289	} else {
290		router.find_pending(name).map(|r| r.path().to_string())
291	}
292}
293
294#[cfg(test)]
295mod tests {
296	use super::*;
297	use rstest::rstest;
298
299	struct TestConsumer;
300	impl WebSocketEndpointInfo for TestConsumer {
301		fn path() -> &'static str {
302			"/ws/chat/{room_id}/"
303		}
304		fn name() -> Option<&'static str> {
305			Some("chat_ws")
306		}
307	}
308
309	#[rstest]
310	fn test_substitute_no_params() {
311		assert_eq!(substitute_ws_params("/ws/notif/", &[]), "/ws/notif/");
312	}
313
314	#[rstest]
315	fn test_substitute_one_param() {
316		assert_eq!(
317			substitute_ws_params("/ws/chat/{room_id}/", &[("room_id", "42")]),
318			"/ws/chat/42/"
319		);
320	}
321
322	#[rstest]
323	fn test_consumer_builder() {
324		let router = WebSocketRouter::new().consumer(|| TestConsumer);
325		let route = router.find_pending("chat_ws");
326		assert!(route.is_some());
327		assert_eq!(route.unwrap().path(), "/ws/chat/{room_id}/");
328	}
329
330	#[rstest]
331	fn test_with_namespace_stores_value_without_rewriting_paths() {
332		let router = WebSocketRouter::new()
333			.with_namespace("auth")
334			.consumer(|| TestConsumer);
335		assert_eq!(router.namespace(), Some("auth"));
336		assert_eq!(
337			router.find_pending("chat_ws").unwrap().path(),
338			"/ws/chat/{room_id}/"
339		);
340	}
341
342	#[rstest]
343	fn test_reverse() {
344		let router = WebSocketRouter::new().consumer(|| TestConsumer);
345		assert_eq!(
346			router.reverse("chat_ws", &[("room_id", "99")]),
347			Some("/ws/chat/99/".to_string())
348		);
349		assert_eq!(router.reverse("unknown", &[]), None);
350	}
351}