1use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12pub trait WebSocketEndpointInfo {
19 fn path() -> &'static str;
21 fn name() -> Option<&'static str>;
23}
24
25pub struct WebSocketEndpointMetadata {
29 pub path: &'static str,
31 pub name: &'static str,
33 pub fn_name: &'static str,
35 pub module_path: &'static str,
37}
38
39inventory::collect!(WebSocketEndpointMetadata);
40
41pub fn substitute_ws_params(path: &str, params: &[(&str, &str)]) -> String {
45 let mut result = path.to_string();
46 for (name, value) in params {
47 result = result.replace(&format!("{{{}}}", name), value);
48 }
49 result
50}
51
52pub type RouteResult = Result<(), RouteError>;
56
57#[derive(Debug, thiserror::Error)]
59pub enum RouteError {
60 #[error("Route not found: {0}")]
62 NotFound(String),
63 #[error("Route already exists: {0}")]
65 AlreadyExists(String),
66 #[error("Invalid route pattern: {0}")]
68 InvalidPattern(String),
69}
70
71#[derive(Debug, Clone)]
73pub struct WebSocketRoute {
74 path: String,
75 name: Option<String>,
76 metadata: HashMap<String, String>,
77}
78
79impl WebSocketRoute {
80 pub fn new(path: String, name: Option<String>) -> Self {
82 Self {
83 path,
84 name,
85 metadata: HashMap::new(),
86 }
87 }
88
89 pub fn path(&self) -> &str {
91 &self.path
92 }
93
94 pub fn name(&self) -> Option<&str> {
96 self.name.as_deref()
97 }
98
99 pub fn with_metadata(mut self, key: String, value: String) -> Self {
101 self.metadata.insert(key, value);
102 self
103 }
104
105 pub fn get_metadata(&self, key: &str) -> Option<&String> {
107 self.metadata.get(key)
108 }
109}
110
111#[derive(Clone)]
120pub struct WebSocketRouter {
121 routes: Arc<RwLock<HashMap<String, WebSocketRoute>>>,
122 names: Arc<RwLock<HashMap<String, String>>>,
123 pending_consumers: Vec<WebSocketRoute>,
125 namespace: Option<String>,
127}
128
129impl WebSocketRouter {
130 pub fn new() -> Self {
132 Self {
133 routes: Arc::new(RwLock::new(HashMap::new())),
134 names: Arc::new(RwLock::new(HashMap::new())),
135 pending_consumers: Vec::new(),
136 namespace: None,
137 }
138 }
139
140 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
148 self.namespace = Some(namespace.into());
149 self
150 }
151
152 pub fn namespace(&self) -> Option<&str> {
156 self.namespace.as_deref()
157 }
158
159 pub fn consumer<C, F>(mut self, _f: F) -> Self
164 where
165 F: Fn() -> C,
166 C: WebSocketEndpointInfo + 'static,
167 {
168 self.pending_consumers.push(WebSocketRoute::new(
169 C::path().to_string(),
170 C::name().map(|s| s.to_string()),
171 ));
172 self
173 }
174
175 pub fn find_pending(&self, name: &str) -> Option<&WebSocketRoute> {
177 self.pending_consumers
178 .iter()
179 .find(|r| r.name() == Some(name))
180 }
181
182 pub fn reverse(&self, name: &str, params: &[(&str, &str)]) -> Option<String> {
184 self.pending_consumers
185 .iter()
186 .find(|r| r.name() == Some(name))
187 .map(|r| substitute_ws_params(r.path(), params))
188 }
189
190 pub async fn register_route(&mut self, route: WebSocketRoute) -> RouteResult {
192 let mut routes = self.routes.write().await;
193 if routes.contains_key(&route.path) {
194 return Err(RouteError::AlreadyExists(route.path.clone()));
195 }
196 if let Some(name) = &route.name {
197 let mut names = self.names.write().await;
198 names.insert(name.clone(), route.path.clone());
199 }
200 routes.insert(route.path.clone(), route);
201 Ok(())
202 }
203
204 pub async fn find_route(&self, path: &str) -> Option<WebSocketRoute> {
206 let routes = self.routes.read().await;
207 routes.get(path).cloned()
208 }
209
210 pub async fn find_route_by_name(&self, name: &str) -> Option<WebSocketRoute> {
212 let names = self.names.read().await;
213 if let Some(path) = names.get(name) {
214 let routes = self.routes.read().await;
215 routes.get(path).cloned()
216 } else {
217 None
218 }
219 }
220
221 pub async fn remove_route(&mut self, path: &str) -> RouteResult {
223 let mut routes = self.routes.write().await;
224 let route = routes
225 .remove(path)
226 .ok_or_else(|| RouteError::NotFound(path.to_string()))?;
227 if let Some(name) = &route.name {
228 let mut names = self.names.write().await;
229 names.remove(name);
230 }
231 Ok(())
232 }
233
234 pub async fn all_routes(&self) -> Vec<WebSocketRoute> {
236 let routes = self.routes.read().await;
237 routes.values().cloned().collect()
238 }
239
240 pub async fn has_route(&self, path: &str) -> bool {
242 self.routes.read().await.contains_key(path)
243 }
244
245 pub async fn route_count(&self) -> usize {
247 self.routes.read().await.len()
248 }
249
250 pub async fn clear(&mut self) {
252 self.routes.write().await.clear();
253 self.names.write().await.clear();
254 }
255}
256
257impl Default for WebSocketRouter {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263static GLOBAL_ROUTER: once_cell::sync::Lazy<Arc<RwLock<Option<WebSocketRouter>>>> =
266 once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(None)));
267
268pub async fn register_websocket_router(router: WebSocketRouter) {
270 *GLOBAL_ROUTER.write().await = Some(router);
271}
272
273pub async fn get_websocket_router() -> Option<WebSocketRouter> {
275 GLOBAL_ROUTER.read().await.clone()
276}
277
278pub async fn clear_websocket_router() {
280 *GLOBAL_ROUTER.write().await = None;
281}
282
283pub async fn reverse_websocket_url(router: &WebSocketRouter, name: &str) -> Option<String> {
285 let names = router.names.read().await;
286 if let Some(path) = names.get(name) {
287 let routes = router.routes.read().await;
288 routes.get(path).map(|r| r.path().to_string())
289 } else {
290 router.find_pending(name).map(|r| r.path().to_string())
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use rstest::rstest;
298
299 struct TestConsumer;
300 impl WebSocketEndpointInfo for TestConsumer {
301 fn path() -> &'static str {
302 "/ws/chat/{room_id}/"
303 }
304 fn name() -> Option<&'static str> {
305 Some("chat_ws")
306 }
307 }
308
309 #[rstest]
310 fn test_substitute_no_params() {
311 assert_eq!(substitute_ws_params("/ws/notif/", &[]), "/ws/notif/");
312 }
313
314 #[rstest]
315 fn test_substitute_one_param() {
316 assert_eq!(
317 substitute_ws_params("/ws/chat/{room_id}/", &[("room_id", "42")]),
318 "/ws/chat/42/"
319 );
320 }
321
322 #[rstest]
323 fn test_consumer_builder() {
324 let router = WebSocketRouter::new().consumer(|| TestConsumer);
325 let route = router.find_pending("chat_ws");
326 assert!(route.is_some());
327 assert_eq!(route.unwrap().path(), "/ws/chat/{room_id}/");
328 }
329
330 #[rstest]
331 fn test_with_namespace_stores_value_without_rewriting_paths() {
332 let router = WebSocketRouter::new()
333 .with_namespace("auth")
334 .consumer(|| TestConsumer);
335 assert_eq!(router.namespace(), Some("auth"));
336 assert_eq!(
337 router.find_pending("chat_ws").unwrap().path(),
338 "/ws/chat/{room_id}/"
339 );
340 }
341
342 #[rstest]
343 fn test_reverse() {
344 let router = WebSocketRouter::new().consumer(|| TestConsumer);
345 assert_eq!(
346 router.reverse("chat_ws", &[("room_id", "99")]),
347 Some("/ws/chat/99/".to_string())
348 );
349 assert_eq!(router.reverse("unknown", &[]), None);
350 }
351}