axum_realtime_kit/ws/
types.rs

1// axum-realtime-kit/src/ws/types.rs
2
3//! Internal types used by the `WebsocketService`.
4
5use axum::extract::ws::{Message, WebSocket};
6use dashmap::DashMap;
7use futures_util::stream::SplitSink;
8use std::collections::HashSet;
9use std::fmt;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12use uuid::Uuid;
13
14/// A unique identifier for a single WebSocket connection.
15pub type ConnectionId = Uuid;
16
17/// A topic identifier, used for Redis Pub/Sub channels and subscription mapping.
18pub type Topic = String;
19
20/// A type alias for the WebSocket's writing half (the "Sink"),
21/// protected by a Mutex for safe concurrent access from multiple tasks.
22pub type WsSink = Arc<Mutex<SplitSink<WebSocket, Message>>>;
23
24/// Defines the commands that can be sent to the background Redis listener task.
25/// This allows the main service to dynamically change the listener's subscriptions.
26#[derive(Debug, Clone)]
27pub(crate) enum RedisCommand {
28    /// Command to subscribe the Redis listener to a new topic.
29    Subscribe(Topic),
30    /// Command to unsubscribe the Redis listener from a topic.
31    Unsubscribe(Topic),
32}
33
34/// Holds the shared state for the WebSocket service on a single server instance.
35///
36/// `DashMap` is used for high-performance, concurrent access without `async` locks.
37#[derive(Default)]
38pub(crate) struct WsState {
39    /// Maps a `ConnectionId` to its corresponding `WsSink`.
40    /// This allows for sending messages to a specific client.
41    pub(crate) connections: DashMap<ConnectionId, WsSink>,
42
43    /// Maps a `Topic` to a set of `ConnectionId`s that are subscribed to it.
44    /// This is the core of the pub/sub logic on a single node.
45    pub(crate) subscriptions: DashMap<Topic, HashSet<ConnectionId>>,
46}
47
48impl fmt::Debug for WsState {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("WsState")
51            .field("connections_count", &self.connections.len())
52            .field("subscriptions_count", &self.subscriptions.len())
53            .finish()
54    }
55}