1use tokio::sync::broadcast::channel as unbounded;
4use tokio::sync::broadcast::{Receiver, Sender};
5
6use futures::stream::FuturesUnordered;
7use futures_util::sink::SinkExt;
8use futures_util::stream::SplitSink;
9use log::{debug, info, trace};
10use std::boxed::Box;
11use std::collections::HashMap;
12use std::fmt;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::RwLock;
16use tokio_stream::StreamExt;
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::WebSocketStream;
19
20use crate::rooms::{
21 get_sockets_for_room, join_channel_to_room, remove_socket_from_room, ChannelPair,
22};
23use crate::socketio_message::SocketIOMessage;
24
25pub type SocketIOHandler =
26 fn(SocketIOSocket, String) -> Pin<Box<dyn Future<Output = Result<(), ()>> + Send>>;
27
28pub const SOCKETIO_PING: &str = "2";
29pub const SOCKETIO_PONG: &str = "3";
30pub const SOCKETIO_EVENT_OPEN: &str = "40"; pub const SOCKETIO_EVENT_MESSAGE: &str = "42"; lazy_static! {
34 static ref ADAPTER: RwLock<Option<Box<dyn SocketIOAdapter>>> = RwLock::new(None);
35}
36
37pub async fn broadcast(room_id: &str, event: &str, message: &str) {
41 if let Some(adapter) = &*ADAPTER.read().unwrap() {
43 adapter.incoming(
44 room_id,
45 &SocketIOMessage::SendMessage(event.to_string(), message.to_string()),
46 );
47 }
48
49 match get_sockets_for_room(room_id) {
50 Some(channels) => {
51 for channel in &*channels {
52 channel.send(InternalMessage::IO(SocketIOMessage::SendMessage(
53 event.to_string(),
54 message.to_string(),
55 )));
56 debug!(
57 "Found socketid {} in room {}, sending message = {}",
58 channel.sid(),
59 room_id,
60 message
61 );
62 }
63 }
64 None => {
65 trace!(
66 "Found no socketid in room {}, not sending message = {}",
67 room_id,
68 message
69 );
70 }
71 }
72}
73
74pub async fn broadcast_binary(room_id: &str, event: &str, message: Vec<u8>) {
78 if let Some(adapter) = &*ADAPTER.read().unwrap() {
80 adapter.incoming(
81 room_id,
82 &SocketIOMessage::SendBinaryMessage(event.to_string(), message.clone()),
83 );
84 }
85
86 match get_sockets_for_room(room_id) {
87 Some(channels) => {
88 for channel in &*channels {
89 channel.send(InternalMessage::IO(SocketIOMessage::SendBinaryMessage(
90 event.to_string(),
91 message.clone(),
92 )));
93 debug!(
94 "Found socketid {} in room {}, sending message = {:?}",
95 channel.sid(),
96 room_id,
97 message
98 );
99 }
100 }
101 None => {
102 trace!(
103 "Found no socketid in room {}, not sending message = {:?}",
104 room_id,
105 message
106 );
107 }
108 }
109}
110
111pub fn adapter(new_adapter: impl SocketIOAdapter + 'static) {
112 let mut adapter = ADAPTER.write().unwrap();
113 adapter.replace(Box::new(new_adapter));
114}
115
116pub fn parse_raw_message(payload: &str) -> (String, String) {
117 let message = &payload[2..];
118 let leading_bracket = message
119 .find('[')
120 .unwrap_or_else(|| panic!("Found a message with no leading bracket: '{}'", message));
121 let event_split = message.find(',').unwrap_or_else(|| {
122 panic!(
123 "Received a message without a comma separator: '{}'",
124 message
125 )
126 });
127
128 let event = &message[leading_bracket + 2..event_split - 1];
129 let mut content = &message[event_split + 1..message.len() - 1];
130
131 if &content[0..1] == "\"" {
132 content = &content[1..content.len() - 1];
133 }
134
135 (event.to_string(), content.to_string())
136}
137
138pub trait SocketIOAdapter: Send + Sync {
139 fn incoming(&self, room_id: &str, message: &SocketIOMessage);
140 fn outgoing(&self, room_id: &str, message: &SocketIOMessage);
141}
142
143#[derive(Clone, Debug)]
144pub enum InternalMessage {
145 IO(SocketIOMessage),
146 WS(WSSocketMessage),
147}
148
149#[derive(Clone, Debug)]
150pub enum WSSocketMessage {
151 RawMessage(String),
152 Close,
153 Ping,
154 Pong,
155 WsPing,
156 WsPong,
157}
158
159pub struct SocketIOSocket {
160 id: String,
161 sender: Sender<InternalMessage>,
162 rooms: Vec<String>,
163}
164
165impl Clone for SocketIOSocket {
166 fn clone(&self) -> Self {
167 SocketIOSocket {
168 id: self.id.clone(),
169 sender: self.sender.clone(),
170 rooms: self.rooms.clone(),
171 }
172 }
173}
174
175impl SocketIOSocket {
176 pub fn new(id: String, sender: Sender<InternalMessage>) -> Self {
177 SocketIOSocket {
178 id,
179 sender,
180 rooms: Vec::new(),
181 }
182 }
183 pub fn id(&self) -> &str {
187 &self.id
188 }
189
190 pub fn use_handler(&self, _handler: SocketIOHandler) {
194 unimplemented!("use_handler isn't implemented yet.")
195 }
196
197 pub fn on(&mut self, event: &str, handler: SocketIOHandler) {
201 let _ = self
202 .sender
203 .send(InternalMessage::IO(SocketIOMessage::AddListener(
204 event.to_string(),
205 handler,
206 )));
207 }
208
209 pub async fn join(&mut self, room_id: &str) {
214 let _ = self.sender.send(InternalMessage::IO(SocketIOMessage::Join(
215 room_id.to_string(),
216 )));
217 }
218
219 pub async fn leave(&mut self, room_id: &str) {
225 let _ = self.sender.send(InternalMessage::IO(SocketIOMessage::Leave(
226 room_id.to_string(),
227 )));
228 }
229
230 pub async fn send(&self, event: &str, message: &str) {
234 let _ = self
235 .sender
236 .send(InternalMessage::IO(SocketIOMessage::SendMessage(
237 event.to_string(),
238 message.to_string(),
239 )));
240 }
241
242 pub async fn emit_to(&self, room_id: &str, event: &str, message: &str) {
247 if let Some(adapter) = &*ADAPTER.read().unwrap() {
249 adapter.incoming(
250 room_id,
251 &SocketIOMessage::SendMessage(event.to_string(), message.to_string()),
252 );
253 }
254
255 if let Some(channels) = get_sockets_for_room(room_id) {
256 for channel in &*channels {
257 channel.send(InternalMessage::IO(SocketIOMessage::SendMessage(
258 event.to_string(),
259 message.to_string(),
260 )));
261 }
262 }
263 }
264
265 pub async fn broadcast_to(&self, room_id: &str, event: &str, message: &str) {
270 if let Some(adapter) = &*ADAPTER.read().unwrap() {
272 adapter.incoming(
273 room_id,
274 &SocketIOMessage::SendMessage(event.to_string(), message.to_string()),
275 );
276 }
277
278 if let Some(channels) = get_sockets_for_room(room_id) {
279 for channel in &*channels {
280 if channel.sid() != self.id {
281 channel.send(InternalMessage::IO(SocketIOMessage::SendMessage(
282 event.to_string(),
283 message.to_string(),
284 )));
285 }
286 }
287 }
288 }
289
290 pub fn rooms(&self) -> &Vec<String> {
294 &self.rooms
295 }
296}
297
298impl fmt::Display for InternalMessage {
299 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300 match self {
301 InternalMessage::IO(v) => write!(f, "Message::IO({})", v),
302 InternalMessage::WS(v) => write!(f, "Message::WS({})", v),
303 }
304 }
305}
306
307impl fmt::Display for WSSocketMessage {
308 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309 match self {
310 WSSocketMessage::RawMessage(val) => write!(f, "WSSocketMessage::RawMessage({})", val),
311 WSSocketMessage::Ping => write!(f, "WSSocketMessage::Ping"),
312 WSSocketMessage::Pong => write!(f, "WSSocketMessage::Pong"),
313 WSSocketMessage::WsPing => write!(f, "WSSocketMessage::WsPing"),
314 WSSocketMessage::WsPong => write!(f, "WSSocketMessage::WsPong"),
315 WSSocketMessage::Close => write!(f, "WSSocketMessage::Close"),
316 }
317 }
318}
319
320pub struct SocketIOWrapper {
321 sid: String,
322 message_number: usize,
323 socket: SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>,
324 rooms: Vec<String>,
325 event_handlers: HashMap<String, Vec<SocketIOHandler>>,
326 sender: Sender<InternalMessage>,
327 receiver: Receiver<InternalMessage>,
328}
329
330impl SocketIOWrapper {
331 pub fn new(
332 sid: String,
333 socket: SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>,
334 message_capacity: usize,
335 ) -> Self {
336 let (sender, receiver) = unbounded(message_capacity);
337 SocketIOWrapper {
338 sid,
339 message_number: 0,
340 socket,
341 rooms: Vec::new(),
342 event_handlers: HashMap::new(),
343 sender,
344 receiver,
345 }
346 }
347
348 pub async fn close(mut self) {
349 for room in &self.rooms {
351 remove_socket_from_room(room, &self.sid);
352 debug!(
353 "SocketIOMessage socketid {} closed, leave room {}",
354 self.sid, room
355 );
356 }
357
358 let _res = self.socket.close().await;
359 }
360
361 pub async fn handle(&mut self, payload: String) {
366 if payload == SOCKETIO_PING {
367 let _ = self.sender.send(InternalMessage::WS(WSSocketMessage::Pong));
368 return;
369 }
370
371 if payload == SOCKETIO_PONG {
372 return;
374 }
375
376 match &payload[0..2] {
377 "42" => {
378 if !payload.is_empty() {
379 let (event, message) = parse_raw_message(&payload);
380
381 match self.event_handlers.get(&event) {
383 Some(handlers) => {
384 let unordered_future = FuturesUnordered::new();
386
387 for handler in handlers {
388 unordered_future.push((handler)(
389 SocketIOSocket {
390 id: self.sid.clone(),
391 sender: self.sender.clone(),
392 rooms: self.rooms.clone(),
393 },
394 message.clone(),
395 ));
396 }
397
398 tokio::spawn(async move {
401 let _ = unordered_future.collect::<Result<(), ()>>().await;
402 });
403 }
404 None => {
405 info!("No handler found for message: {:#?}", event);
406 } }
408 }
409 }
410 "41" => {
411 debug!("{}: Socket closed...", self.sid);
412 }
413 "40" => {
414 debug!("{}: Socket opened...", self.sid);
415 }
416 _ => panic!("Attempted to handle a non-message payload: '{}'", payload),
417 }
418 }
419
420 pub async fn listen(mut self) {
421 while let Ok(val) = self.receiver.recv().await {
422 match val {
423 InternalMessage::IO(val) => {
424 match val {
425 SocketIOMessage::SendMessage(event, message) => {
426 self.message_number += 1;
427
428 let message = match &message[0..1] {
429 "{" | "[" => message,
430 _ => format!("\"{}\"", message),
431 };
432
433 let content = format!(
435 "{}{}[\"{}\",{}]",
436 SOCKETIO_EVENT_MESSAGE, self.message_number, event, message
437 );
438
439 let _ = self.socket.send(Message::Text(content)).await;
440 }
441
442 SocketIOMessage::SendBinaryMessage(_event, message) => {
443 let _ = self.socket.send(Message::Binary(message)).await;
444 }
445
446 SocketIOMessage::Join(room_id) => {
447 if !self.rooms.contains(&room_id) {
449 self.rooms.push(room_id.to_string());
450 debug!("SocketIOMessage socketid {} joined room {}. Rooms = {:?}, rooms len = {}", self.sid, room_id, self.rooms, self.rooms.len());
451
452 join_channel_to_room(
454 &room_id,
455 ChannelPair::new(&self.sid, self.sender()),
456 );
457 } else {
458 debug!("SocketIOMessage socketid {} is already in room {}. Not joining.", self.sid, room_id);
459 }
460 }
461
462 SocketIOMessage::Leave(room_id) => {
463 for (i, room) in self.rooms.iter().enumerate() {
464 if room == &room_id {
465 self.rooms.remove(i);
466 debug!("SocketIOMessage socketid {} leaved room {}. Rooms = {:?}, rooms len = {}", self.sid, room_id, self.rooms, self.rooms.len());
467
468 remove_socket_from_room(&room_id, &self.sid);
470 break;
471 }
472 }
473 }
474
475 SocketIOMessage::AddListener(event, handler) => {
476 let mut existing_handlers =
477 self.event_handlers.remove(&event).unwrap_or_default();
478
479 existing_handlers.push(handler);
480
481 self.event_handlers
482 .insert(event.to_string(), existing_handlers);
483 }
484 _ => (),
485 }
486 }
487 InternalMessage::WS(val) => match val {
488 WSSocketMessage::RawMessage(message) => self.handle(message).await,
489 WSSocketMessage::Ping => {
490 let _ = self
491 .socket
492 .send(Message::Text(SOCKETIO_PONG.to_string()))
493 .await;
494 }
495 WSSocketMessage::Pong => {
496 let _ = self
497 .socket
498 .send(Message::Text(SOCKETIO_PING.to_string()))
499 .await;
500 }
501 WSSocketMessage::WsPing => {
502 let _ = self.socket.send(Message::Pong([].to_vec())).await;
503 }
504 WSSocketMessage::WsPong => {
505 let _ = self.socket.send(Message::Ping([].to_vec())).await;
506 }
507
508 WSSocketMessage::Close => {
509 self.close().await;
510 return;
511 }
512 },
513 }
514 }
515 }
516
517 pub fn sender(&self) -> Sender<InternalMessage> {
518 self.sender.clone()
519 }
520}