1use std::{collections::HashMap, sync::Arc};
2
3use axum::extract::ws::{Message, WebSocket};
4use futures::{
5 SinkExt, StreamExt,
6 stream::{SplitSink, SplitStream},
7};
8use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, mpsc::Sender};
9use tracing::{error, info, instrument, warn};
10
11use crate::{
12 connection_id::ConnectionId,
13 error::{WebsocKitError, WebsocKitResult},
14 message::WebsocKitMessage,
15 subscription::Subscription,
16};
17
18#[expect(clippy::module_name_repetitions)]
19pub struct WebsocKitManager {
20 connections: RwLock<HashMap<ConnectionId, RwLock<SplitSink<WebSocket, Message>>>>,
21 subscriptions: RwLock<HashMap<ConnectionId, HashMap<Subscription, usize>>>,
22 payload_listener_tx: Sender<WebsocKitMessage>,
23}
24
25impl WebsocKitManager {
26 #[must_use]
27 pub fn new(payload_listener_tx: Sender<WebsocKitMessage>) -> Self {
28 Self {
29 connections: RwLock::new(HashMap::new()),
30 subscriptions: RwLock::new(HashMap::new()),
31 payload_listener_tx,
32 }
33 }
34
35 #[instrument(skip_all)]
38 pub async fn handle_new_websocket(
39 self: &Arc<Self>,
40 socket: WebSocket,
41 ) -> WebsocKitResult<ConnectionId> {
42 let (websocket_sender, websocket_listener): (
44 SplitSink<WebSocket, Message>,
45 SplitStream<WebSocket>,
46 ) = socket.split();
47
48 let connection_id: ConnectionId = ConnectionId::new();
50 self.connections
51 .write()
52 .await
53 .insert(connection_id, RwLock::new(websocket_sender));
54 info!("websocket connection established: '{connection_id}'");
55
56 self.clone()
58 .listen_to_websocket(websocket_listener, connection_id)
59 .await?;
60
61 if self
63 .connections
64 .write()
65 .await
66 .remove(&connection_id)
67 .is_none()
68 {
69 error!(
70 "attempted to discard dead Connection, but none existed with the given ID: '{connection_id}'"
71 );
72 }
74 info!("websocket connection closed: '{connection_id}'");
75
76 Ok(connection_id)
78 }
79
80 #[instrument(skip_all)]
82 async fn listen_to_websocket(
83 self: Arc<Self>,
84 mut socket_receiver: SplitStream<WebSocket>,
85 connection_id: ConnectionId,
86 ) -> WebsocKitResult<()> {
87 let cloned_self: Arc<Self> = Arc::clone(&self);
89 while let Some(Ok(message)) = socket_receiver.next().await {
90 match message {
91 Message::Binary(binary) => {
93 if let Err(_send_error) = cloned_self
95 .payload_listener_tx
96 .send(WebsocKitMessage {
97 connection_id,
98 payload: binary,
99 })
100 .await
101 {
102 break;
105 }
106 }
107 Message::Close(close) => {
108 close.map_or_else(
109 || {
110 info!("Websocket '{connection_id}' received close frame.");
111 },
112 |close_frame| {
113 info!("Websocket '{connection_id}' received close frame: '{close_frame:?}'.");
114 },
115 );
116 break;
117 }
118
119 Message::Text(invalid_text_message) => {
121 return Err(WebsocKitError::TextMessagesNotAllowed(
123 connection_id,
124 invalid_text_message,
125 ));
126 }
127 Message::Ping(_ping) => {
128 info!("Websocket '{connection_id}' received ping.");
130 }
131 Message::Pong(_pong) => {
132 info!("Websocket '{connection_id}' received pong.");
134 }
135 }
136 }
137 Ok(())
138 }
139
140 #[instrument(skip_all)]
142 pub async fn send_payload(
143 &self,
144 connection_ids: Vec<ConnectionId>,
145 payload: Vec<u8>,
146 ) -> WebsocKitResult<()> {
147 if connection_ids.is_empty() {
149 warn!("attempted to send payload to zero websockets: {payload:?}");
150 return Ok(());
151 }
152
153 for connection_id in connection_ids {
156 let connections: RwLockReadGuard<
158 HashMap<ConnectionId, RwLock<SplitSink<WebSocket, Message>>>,
159 > = self.connections.read().await;
160 let Some(outgoing_payload_sender) = connections.get(&connection_id) else {
161 return Err(WebsocKitError::ConnectionDoesNotExist(connection_id));
162 };
163
164 match outgoing_payload_sender
166 .write()
167 .await
168 .send(Message::Binary(payload.clone()))
169 .await
170 {
171 Ok(()) => {
172 info!("sent payload to websocket '{connection_id}': {payload:?}");
173 }
174 Err(error) => {
175 error!(
176 "failed to send payload to websocket '{connection_id}': {payload:?} -> Error: {error}"
177 );
178 break;
179 }
180 };
181 }
182
183 Ok(())
184 }
185
186 #[instrument(skip_all)]
187 pub async fn send_payload_to_all_connections(&self, payload: Vec<u8>) -> WebsocKitResult<()> {
188 let all_connection_ids: Vec<ConnectionId> =
189 self.connections.read().await.keys().copied().collect();
190 self.send_payload(all_connection_ids, payload).await
191 }
192
193 #[instrument(skip_all)]
194 pub async fn send_payload_to_subscribers(
195 &self,
196 subscription: Subscription,
197 payload: Vec<u8>,
198 ) -> WebsocKitResult<()> {
199 let mut connection_ids: Vec<ConnectionId> = Vec::new();
201 let subscriptions: RwLockReadGuard<HashMap<ConnectionId, HashMap<Subscription, usize>>> =
202 self.subscriptions.read().await;
203 for (connection_id, subscriptions) in subscriptions.iter() {
204 if subscriptions.contains_key(&subscription) {
205 connection_ids.push(*connection_id);
206 }
207 }
208 info!(
209 "found websockets subscribed to '{subscription}': {connection_ids:?} - sending payload: {payload:?}"
210 );
211
212 self.send_payload(connection_ids, payload).await
214 }
215
216 #[instrument(skip_all)]
218 pub async fn add_subscription(&self, connection_id: ConnectionId, subscription: Subscription) {
219 let mut subscriptions_lock: RwLockWriteGuard<
221 HashMap<ConnectionId, HashMap<Subscription, usize>>,
222 > = self.subscriptions.write().await;
223 let subscriptions: &mut HashMap<Subscription, usize> =
224 subscriptions_lock.entry(connection_id).or_default();
225 let subscription_count: &mut usize = subscriptions.entry(subscription.clone()).or_insert(0);
226
227 *subscription_count += 1;
229 info!("subscribed websocket '{connection_id}' to '{subscription}'");
230 }
231
232 #[instrument(skip_all)]
234 pub async fn remove_subscription(
235 &self,
236 connection_id: ConnectionId,
237 subscription: Subscription,
238 ) {
239 let mut subscriptions_lock: RwLockWriteGuard<
241 HashMap<ConnectionId, HashMap<Subscription, usize>>,
242 > = self.subscriptions.write().await;
243 let Some(subscriptions) = subscriptions_lock.get_mut(&connection_id) else {
244 error!(
245 "attempted to unsubscribe from '{subscription}', but websocket '{connection_id}' had zero subscriptions at all"
246 );
247 return;
248 };
249
250 if let Some(subscription_count) = subscriptions.get_mut(&subscription) {
252 *subscription_count -= 1;
253 info!("unsubscribed '{connection_id}' from '{subscription}'");
254
255 if *subscription_count == 0 {
257 subscriptions.remove(&subscription);
258 info!("deleted subscription '{subscription}' from '{connection_id}'");
259
260 if subscriptions.is_empty() {
262 subscriptions_lock.remove(&connection_id);
263 info!("deleted all subscriptions for '{connection_id}'");
264 }
265 }
266 } else {
267 error!(
268 "attempted to unsubscribe from '{subscription}', but websocket '{connection_id}' was not subscribed to it"
269 );
270 }
271 }
272
273 #[instrument(skip_all)]
274 pub async fn remove_all_subscriptions(&self, connection_id: ConnectionId) {
275 let mut subscriptions_lock: RwLockWriteGuard<
277 HashMap<ConnectionId, HashMap<Subscription, usize>>,
278 > = self.subscriptions.write().await;
279
280 match subscriptions_lock.remove(&connection_id) {
282 Some(subscriptions) => {
283 info!("unsubscribed '{connection_id}' from all subscriptions: {subscriptions:?}");
284 }
285 _ => {
286 error!(
287 "attempted to unsubscribe from all subscriptions, but websocket '{connection_id}' had zero subscriptions at all"
288 );
289 }
290 }
291 }
292
293 #[instrument(skip_all)]
294 pub async fn get_subscriptions(
295 &self,
296 connection_id: ConnectionId,
297 ) -> Option<HashMap<Subscription, usize>> {
298 self.subscriptions.read().await.get(&connection_id).cloned()
299 }
300
301 #[instrument(skip_all)]
302 pub async fn get_subscribers(&self, subscription: Subscription) -> Vec<ConnectionId> {
303 let mut connection_ids: Vec<ConnectionId> = Vec::new();
304 let subscriptions: RwLockReadGuard<HashMap<ConnectionId, HashMap<Subscription, usize>>> =
305 self.subscriptions.read().await;
306 for (connection_id, subscriptions) in subscriptions.iter() {
307 if subscriptions.contains_key(&subscription) {
308 connection_ids.push(*connection_id);
309 }
310 }
311 connection_ids
312 }
313}