use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub trait WebSocketEndpointInfo {
fn path() -> &'static str;
fn name() -> Option<&'static str>;
}
pub struct WebSocketEndpointMetadata {
pub path: &'static str,
pub name: &'static str,
pub fn_name: &'static str,
pub module_path: &'static str,
}
inventory::collect!(WebSocketEndpointMetadata);
pub fn substitute_ws_params(path: &str, params: &[(&str, &str)]) -> String {
let mut result = path.to_string();
for (name, value) in params {
result = result.replace(&format!("{{{}}}", name), value);
}
result
}
pub type RouteResult = Result<(), RouteError>;
#[derive(Debug, thiserror::Error)]
pub enum RouteError {
#[error("Route not found: {0}")]
NotFound(String),
#[error("Route already exists: {0}")]
AlreadyExists(String),
#[error("Invalid route pattern: {0}")]
InvalidPattern(String),
}
#[derive(Debug, Clone)]
pub struct WebSocketRoute {
path: String,
name: Option<String>,
metadata: HashMap<String, String>,
}
impl WebSocketRoute {
pub fn new(path: String, name: Option<String>) -> Self {
Self {
path,
name,
metadata: HashMap::new(),
}
}
pub fn path(&self) -> &str {
&self.path
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.get(key)
}
}
#[derive(Clone)]
pub struct WebSocketRouter {
routes: Arc<RwLock<HashMap<String, WebSocketRoute>>>,
names: Arc<RwLock<HashMap<String, String>>>,
pending_consumers: Vec<WebSocketRoute>,
namespace: Option<String>,
}
impl WebSocketRouter {
pub fn new() -> Self {
Self {
routes: Arc::new(RwLock::new(HashMap::new())),
names: Arc::new(RwLock::new(HashMap::new())),
pending_consumers: Vec::new(),
namespace: None,
}
}
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = Some(namespace.into());
self
}
pub fn namespace(&self) -> Option<&str> {
self.namespace.as_deref()
}
pub fn consumer<C, F>(mut self, _f: F) -> Self
where
F: Fn() -> C,
C: WebSocketEndpointInfo + 'static,
{
self.pending_consumers.push(WebSocketRoute::new(
C::path().to_string(),
C::name().map(|s| s.to_string()),
));
self
}
pub fn find_pending(&self, name: &str) -> Option<&WebSocketRoute> {
self.pending_consumers
.iter()
.find(|r| r.name() == Some(name))
}
pub fn reverse(&self, name: &str, params: &[(&str, &str)]) -> Option<String> {
self.pending_consumers
.iter()
.find(|r| r.name() == Some(name))
.map(|r| substitute_ws_params(r.path(), params))
}
pub async fn register_route(&mut self, route: WebSocketRoute) -> RouteResult {
let mut routes = self.routes.write().await;
if routes.contains_key(&route.path) {
return Err(RouteError::AlreadyExists(route.path.clone()));
}
if let Some(name) = &route.name {
let mut names = self.names.write().await;
names.insert(name.clone(), route.path.clone());
}
routes.insert(route.path.clone(), route);
Ok(())
}
pub async fn find_route(&self, path: &str) -> Option<WebSocketRoute> {
let routes = self.routes.read().await;
routes.get(path).cloned()
}
pub async fn find_route_by_name(&self, name: &str) -> Option<WebSocketRoute> {
let names = self.names.read().await;
if let Some(path) = names.get(name) {
let routes = self.routes.read().await;
routes.get(path).cloned()
} else {
None
}
}
pub async fn remove_route(&mut self, path: &str) -> RouteResult {
let mut routes = self.routes.write().await;
let route = routes
.remove(path)
.ok_or_else(|| RouteError::NotFound(path.to_string()))?;
if let Some(name) = &route.name {
let mut names = self.names.write().await;
names.remove(name);
}
Ok(())
}
pub async fn all_routes(&self) -> Vec<WebSocketRoute> {
let routes = self.routes.read().await;
routes.values().cloned().collect()
}
pub async fn has_route(&self, path: &str) -> bool {
self.routes.read().await.contains_key(path)
}
pub async fn route_count(&self) -> usize {
self.routes.read().await.len()
}
pub async fn clear(&mut self) {
self.routes.write().await.clear();
self.names.write().await.clear();
}
}
impl Default for WebSocketRouter {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_ROUTER: once_cell::sync::Lazy<Arc<RwLock<Option<WebSocketRouter>>>> =
once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(None)));
pub async fn register_websocket_router(router: WebSocketRouter) {
*GLOBAL_ROUTER.write().await = Some(router);
}
pub async fn get_websocket_router() -> Option<WebSocketRouter> {
GLOBAL_ROUTER.read().await.clone()
}
pub async fn clear_websocket_router() {
*GLOBAL_ROUTER.write().await = None;
}
pub async fn reverse_websocket_url(router: &WebSocketRouter, name: &str) -> Option<String> {
let names = router.names.read().await;
if let Some(path) = names.get(name) {
let routes = router.routes.read().await;
routes.get(path).map(|r| r.path().to_string())
} else {
router.find_pending(name).map(|r| r.path().to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
struct TestConsumer;
impl WebSocketEndpointInfo for TestConsumer {
fn path() -> &'static str {
"/ws/chat/{room_id}/"
}
fn name() -> Option<&'static str> {
Some("chat_ws")
}
}
#[rstest]
fn test_substitute_no_params() {
assert_eq!(substitute_ws_params("/ws/notif/", &[]), "/ws/notif/");
}
#[rstest]
fn test_substitute_one_param() {
assert_eq!(
substitute_ws_params("/ws/chat/{room_id}/", &[("room_id", "42")]),
"/ws/chat/42/"
);
}
#[rstest]
fn test_consumer_builder() {
let router = WebSocketRouter::new().consumer(|| TestConsumer);
let route = router.find_pending("chat_ws");
assert!(route.is_some());
assert_eq!(route.unwrap().path(), "/ws/chat/{room_id}/");
}
#[rstest]
fn test_with_namespace_stores_value_without_rewriting_paths() {
let router = WebSocketRouter::new()
.with_namespace("auth")
.consumer(|| TestConsumer);
assert_eq!(router.namespace(), Some("auth"));
assert_eq!(
router.find_pending("chat_ws").unwrap().path(),
"/ws/chat/{room_id}/"
);
}
#[rstest]
fn test_reverse() {
let router = WebSocketRouter::new().consumer(|| TestConsumer);
assert_eq!(
router.reverse("chat_ws", &[("room_id", "99")]),
Some("/ws/chat/99/".to_string())
);
assert_eq!(router.reverse("unknown", &[]), None);
}
}