websoc_kit/
manager.rs

1use std::{collections::HashMap, sync::Arc};
2
3use axum::extract::ws::{Message, WebSocket};
4use futures::{
5    SinkExt, StreamExt,
6    stream::{SplitSink, SplitStream},
7};
8use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, mpsc::Sender};
9use tracing::{error, info, instrument, warn};
10
11use crate::{
12    connection_id::ConnectionId,
13    error::{WebsocKitError, WebsocKitResult},
14    message::WebsocKitMessage,
15    subscription::Subscription,
16};
17
18#[expect(clippy::module_name_repetitions)]
19pub struct WebsocKitManager {
20    connections: RwLock<HashMap<ConnectionId, RwLock<SplitSink<WebSocket, Message>>>>,
21    subscriptions: RwLock<HashMap<ConnectionId, HashMap<Subscription, usize>>>,
22    payload_listener_tx: Sender<WebsocKitMessage>,
23}
24
25impl WebsocKitManager {
26    #[must_use]
27    pub fn new(payload_listener_tx: Sender<WebsocKitMessage>) -> Self {
28        Self {
29            connections: RwLock::new(HashMap::new()),
30            subscriptions: RwLock::new(HashMap::new()),
31            payload_listener_tx,
32        }
33    }
34
35    /// Splits the given Websocket connection into a sender/receiver, tracks the connection via an ID,
36    /// listens for incoming socket packets, cleans up connection data once connection is lost.
37    #[instrument(skip_all)]
38    pub async fn handle_new_websocket(
39        self: &Arc<Self>,
40        socket: WebSocket,
41    ) -> WebsocKitResult<ConnectionId> {
42        // split the websocket connection into sender/receiver
43        let (websocket_sender, websocket_listener): (
44            SplitSink<WebSocket, Message>,
45            SplitStream<WebSocket>,
46        ) = socket.split();
47
48        // store the new websocket connection
49        let connection_id: ConnectionId = ConnectionId::new();
50        self.connections
51            .write()
52            .await
53            .insert(connection_id, RwLock::new(websocket_sender));
54        info!("websocket connection established: '{connection_id}'");
55
56        // receive packets from socket
57        self.clone()
58            .listen_to_websocket(websocket_listener, connection_id)
59            .await?;
60
61        // websocket cleanup
62        if self
63            .connections
64            .write()
65            .await
66            .remove(&connection_id)
67            .is_none()
68        {
69            error!(
70                "attempted to discard dead Connection, but none existed with the given ID: '{connection_id}'"
71            );
72            // TODO - should I return an error?
73        }
74        info!("websocket connection closed: '{connection_id}'");
75
76        // return connection_id so that the caller can handle their cleanup
77        Ok(connection_id)
78    }
79
80    /// Receive packet from websocket, pass to the listener.
81    #[instrument(skip_all)]
82    async fn listen_to_websocket(
83        self: Arc<Self>,
84        mut socket_receiver: SplitStream<WebSocket>,
85        connection_id: ConnectionId,
86    ) -> WebsocKitResult<()> {
87        // indefinitely listen to the websocket
88        let cloned_self: Arc<Self> = Arc::clone(&self);
89        while let Some(Ok(message)) = socket_receiver.next().await {
90            match message {
91                // valid
92                Message::Binary(binary) => {
93                    // forward the binary payload to the payload listener
94                    if let Err(_send_error) = cloned_self
95                        .payload_listener_tx
96                        .send(WebsocKitMessage {
97                            connection_id,
98                            payload: binary,
99                        })
100                        .await
101                    {
102                        // If sending the payload fails, this means that the receiver has been closed/dropped.
103                        // This means they don't want to receive any more payloads from this connection, so we can break the loop.
104                        break;
105                    }
106                }
107                Message::Close(close) => {
108                    close.map_or_else(
109                        || {
110                            info!("Websocket '{connection_id}' received close frame.");
111                        },
112                        |close_frame| {
113                            info!("Websocket '{connection_id}' received close frame: '{close_frame:?}'.");
114                        },
115                    );
116                    break;
117                }
118
119                // invalid
120                Message::Text(invalid_text_message) => {
121                    // terminate the connection for not sending binary
122                    return Err(WebsocKitError::TextMessagesNotAllowed(
123                        connection_id,
124                        invalid_text_message,
125                    ));
126                }
127                Message::Ping(_ping) => {
128                    // NOP - handled by Axum
129                    info!("Websocket '{connection_id}' received ping.");
130                }
131                Message::Pong(_pong) => {
132                    // NOP - handled by Axum
133                    info!("Websocket '{connection_id}' received pong.");
134                }
135            }
136        }
137        Ok(())
138    }
139
140    /// Send a payload to multiple websocket.
141    #[instrument(skip_all)]
142    pub async fn send_payload(
143        &self,
144        connection_ids: Vec<ConnectionId>,
145        payload: Vec<u8>,
146    ) -> WebsocKitResult<()> {
147        // make sure that at least one websocket session ID was given
148        if connection_ids.is_empty() {
149            warn!("attempted to send payload to zero websockets: {payload:?}");
150            return Ok(());
151        }
152
153        // loop over all the given Connection IDs
154        // TODO - create JoinHandles via tokio::spawn() to parallelize?
155        for connection_id in connection_ids {
156            // retrieve the Connection by ID, if it exists
157            let connections: RwLockReadGuard<
158                HashMap<ConnectionId, RwLock<SplitSink<WebSocket, Message>>>,
159            > = self.connections.read().await;
160            let Some(outgoing_payload_sender) = connections.get(&connection_id) else {
161                return Err(WebsocKitError::ConnectionDoesNotExist(connection_id));
162            };
163
164            // send the outgoing payload
165            match outgoing_payload_sender
166                .write()
167                .await
168                .send(Message::Binary(payload.clone()))
169                .await
170            {
171                Ok(()) => {
172                    info!("sent payload to websocket '{connection_id}': {payload:?}");
173                }
174                Err(error) => {
175                    error!(
176                        "failed to send payload to websocket '{connection_id}': {payload:?} -> Error: {error}"
177                    );
178                    break;
179                }
180            };
181        }
182
183        Ok(())
184    }
185
186    #[instrument(skip_all)]
187    pub async fn send_payload_to_all_connections(&self, payload: Vec<u8>) -> WebsocKitResult<()> {
188        let all_connection_ids: Vec<ConnectionId> =
189            self.connections.read().await.keys().copied().collect();
190        self.send_payload(all_connection_ids, payload).await
191    }
192
193    #[instrument(skip_all)]
194    pub async fn send_payload_to_subscribers(
195        &self,
196        subscription: Subscription,
197        payload: Vec<u8>,
198    ) -> WebsocKitResult<()> {
199        // retrieve all the Connection IDs that are subscribed to the given subscription
200        let mut connection_ids: Vec<ConnectionId> = Vec::new();
201        let subscriptions: RwLockReadGuard<HashMap<ConnectionId, HashMap<Subscription, usize>>> =
202            self.subscriptions.read().await;
203        for (connection_id, subscriptions) in subscriptions.iter() {
204            if subscriptions.contains_key(&subscription) {
205                connection_ids.push(*connection_id);
206            }
207        }
208        info!(
209            "found websockets subscribed to '{subscription}': {connection_ids:?} - sending payload: {payload:?}"
210        );
211
212        // send the payload to all the subscribers
213        self.send_payload(connection_ids, payload).await
214    }
215
216    /// Add a subscription for a websocket connection.
217    #[instrument(skip_all)]
218    pub async fn add_subscription(&self, connection_id: ConnectionId, subscription: Subscription) {
219        // retrieve the subscriptions by Connection ID
220        let mut subscriptions_lock: RwLockWriteGuard<
221            HashMap<ConnectionId, HashMap<Subscription, usize>>,
222        > = self.subscriptions.write().await;
223        let subscriptions: &mut HashMap<Subscription, usize> =
224            subscriptions_lock.entry(connection_id).or_default();
225        let subscription_count: &mut usize = subscriptions.entry(subscription.clone()).or_insert(0);
226
227        // add the subscription
228        *subscription_count += 1;
229        info!("subscribed websocket '{connection_id}' to '{subscription}'");
230    }
231
232    /// Remove a subscription for a websocket connection.
233    #[instrument(skip_all)]
234    pub async fn remove_subscription(
235        &self,
236        connection_id: ConnectionId,
237        subscription: Subscription,
238    ) {
239        // retrieve the subscriptions by Connection ID
240        let mut subscriptions_lock: RwLockWriteGuard<
241            HashMap<ConnectionId, HashMap<Subscription, usize>>,
242        > = self.subscriptions.write().await;
243        let Some(subscriptions) = subscriptions_lock.get_mut(&connection_id) else {
244            error!(
245                "attempted to unsubscribe from '{subscription}', but websocket '{connection_id}' had zero subscriptions at all"
246            );
247            return;
248        };
249
250        // remove the subscription
251        if let Some(subscription_count) = subscriptions.get_mut(&subscription) {
252            *subscription_count -= 1;
253            info!("unsubscribed '{connection_id}' from '{subscription}'");
254
255            // remove the subscription if the count is zero
256            if *subscription_count == 0 {
257                subscriptions.remove(&subscription);
258                info!("deleted subscription '{subscription}' from '{connection_id}'");
259
260                // remove the Connection ID if it has no subscriptions
261                if subscriptions.is_empty() {
262                    subscriptions_lock.remove(&connection_id);
263                    info!("deleted all subscriptions for '{connection_id}'");
264                }
265            }
266        } else {
267            error!(
268                "attempted to unsubscribe from '{subscription}', but websocket '{connection_id}' was not subscribed to it"
269            );
270        }
271    }
272
273    #[instrument(skip_all)]
274    pub async fn remove_all_subscriptions(&self, connection_id: ConnectionId) {
275        // retrieve the subscriptions by Connection ID
276        let mut subscriptions_lock: RwLockWriteGuard<
277            HashMap<ConnectionId, HashMap<Subscription, usize>>,
278        > = self.subscriptions.write().await;
279
280        // remove all subscriptions
281        match subscriptions_lock.remove(&connection_id) {
282            Some(subscriptions) => {
283                info!("unsubscribed '{connection_id}' from all subscriptions: {subscriptions:?}");
284            }
285            _ => {
286                error!(
287                    "attempted to unsubscribe from all subscriptions, but websocket '{connection_id}' had zero subscriptions at all"
288                );
289            }
290        }
291    }
292
293    #[instrument(skip_all)]
294    pub async fn get_subscriptions(
295        &self,
296        connection_id: ConnectionId,
297    ) -> Option<HashMap<Subscription, usize>> {
298        self.subscriptions.read().await.get(&connection_id).cloned()
299    }
300
301    #[instrument(skip_all)]
302    pub async fn get_subscribers(&self, subscription: Subscription) -> Vec<ConnectionId> {
303        let mut connection_ids: Vec<ConnectionId> = Vec::new();
304        let subscriptions: RwLockReadGuard<HashMap<ConnectionId, HashMap<Subscription, usize>>> =
305            self.subscriptions.read().await;
306        for (connection_id, subscriptions) in subscriptions.iter() {
307            if subscriptions.contains_key(&subscription) {
308                connection_ids.push(*connection_id);
309            }
310        }
311        connection_ids
312    }
313}