thruster_socketio/
socketio.rs

1// use crossbeam::channel::unbounded;
2// use crossbeam::channel::{Receiver, Sender};
3use tokio::sync::broadcast::channel as unbounded;
4use tokio::sync::broadcast::{Receiver, Sender};
5
6use futures::stream::FuturesUnordered;
7use futures_util::sink::SinkExt;
8use futures_util::stream::SplitSink;
9use log::{debug, info, trace};
10use std::boxed::Box;
11use std::collections::HashMap;
12use std::fmt;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::RwLock;
16use tokio_stream::StreamExt;
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::WebSocketStream;
19
20use crate::rooms::{
21    get_sockets_for_room, join_channel_to_room, remove_socket_from_room, ChannelPair,
22};
23use crate::socketio_message::SocketIOMessage;
24
25pub type SocketIOHandler =
26    fn(SocketIOSocket, String) -> Pin<Box<dyn Future<Output = Result<(), ()>> + Send>>;
27
28pub const SOCKETIO_PING: &str = "2";
29pub const SOCKETIO_PONG: &str = "3";
30pub const SOCKETIO_EVENT_OPEN: &str = "40"; // Message, then open
31pub const SOCKETIO_EVENT_MESSAGE: &str = "42"; // Message, then event
32
33lazy_static! {
34    static ref ADAPTER: RwLock<Option<Box<dyn SocketIOAdapter>>> = RwLock::new(None);
35}
36
37///
38/// Broadcast a message to all clients connected to a room.
39///
40pub async fn broadcast(room_id: &str, event: &str, message: &str) {
41    // Send out via adapter
42    if let Some(adapter) = &*ADAPTER.read().unwrap() {
43        adapter.incoming(
44            room_id,
45            &SocketIOMessage::SendMessage(event.to_string(), message.to_string()),
46        );
47    }
48
49    match get_sockets_for_room(room_id) {
50        Some(channels) => {
51            for channel in &*channels {
52                channel.send(InternalMessage::IO(SocketIOMessage::SendMessage(
53                    event.to_string(),
54                    message.to_string(),
55                )));
56                debug!(
57                    "Found socketid {} in room {}, sending message = {}",
58                    channel.sid(),
59                    room_id,
60                    message
61                );
62            }
63        }
64        None => {
65            trace!(
66                "Found no socketid in room {}, not sending message = {}",
67                room_id,
68                message
69            );
70        }
71    }
72}
73
74///
75/// Broadcast a binary message to all clients connected to a room.
76///
77pub async fn broadcast_binary(room_id: &str, event: &str, message: Vec<u8>) {
78    // Send out via adapter
79    if let Some(adapter) = &*ADAPTER.read().unwrap() {
80        adapter.incoming(
81            room_id,
82            &SocketIOMessage::SendBinaryMessage(event.to_string(), message.clone()),
83        );
84    }
85
86    match get_sockets_for_room(room_id) {
87        Some(channels) => {
88            for channel in &*channels {
89                channel.send(InternalMessage::IO(SocketIOMessage::SendBinaryMessage(
90                    event.to_string(),
91                    message.clone(),
92                )));
93                debug!(
94                    "Found socketid {} in room {}, sending message = {:?}",
95                    channel.sid(),
96                    room_id,
97                    message
98                );
99            }
100        }
101        None => {
102            trace!(
103                "Found no socketid in room {}, not sending message = {:?}",
104                room_id,
105                message
106            );
107        }
108    }
109}
110
111pub fn adapter(new_adapter: impl SocketIOAdapter + 'static) {
112    let mut adapter = ADAPTER.write().unwrap();
113    adapter.replace(Box::new(new_adapter));
114}
115
116pub fn parse_raw_message(payload: &str) -> (String, String) {
117    let message = &payload[2..];
118    let leading_bracket = message
119        .find('[')
120        .unwrap_or_else(|| panic!("Found a message with no leading bracket: '{}'", message));
121    let event_split = message.find(',').unwrap_or_else(|| {
122        panic!(
123            "Received a message without a comma separator: '{}'",
124            message
125        )
126    });
127
128    let event = &message[leading_bracket + 2..event_split - 1];
129    let mut content = &message[event_split + 1..message.len() - 1];
130
131    if &content[0..1] == "\"" {
132        content = &content[1..content.len() - 1];
133    }
134
135    (event.to_string(), content.to_string())
136}
137
138pub trait SocketIOAdapter: Send + Sync {
139    fn incoming(&self, room_id: &str, message: &SocketIOMessage);
140    fn outgoing(&self, room_id: &str, message: &SocketIOMessage);
141}
142
143#[derive(Clone, Debug)]
144pub enum InternalMessage {
145    IO(SocketIOMessage),
146    WS(WSSocketMessage),
147}
148
149#[derive(Clone, Debug)]
150pub enum WSSocketMessage {
151    RawMessage(String),
152    Close,
153    Ping,
154    Pong,
155    WsPing,
156    WsPong,
157}
158
159pub struct SocketIOSocket {
160    id: String,
161    sender: Sender<InternalMessage>,
162    rooms: Vec<String>,
163}
164
165impl Clone for SocketIOSocket {
166    fn clone(&self) -> Self {
167        SocketIOSocket {
168            id: self.id.clone(),
169            sender: self.sender.clone(),
170            rooms: self.rooms.clone(),
171        }
172    }
173}
174
175impl SocketIOSocket {
176    pub fn new(id: String, sender: Sender<InternalMessage>) -> Self {
177        SocketIOSocket {
178            id,
179            sender,
180            rooms: Vec::new(),
181        }
182    }
183    ///
184    /// id returns the id for this particular socket.
185    ///
186    pub fn id(&self) -> &str {
187        &self.id
188    }
189
190    ///
191    /// use_handler isn't implemented yet.
192    ///
193    pub fn use_handler(&self, _handler: SocketIOHandler) {
194        unimplemented!("use_handler isn't implemented yet.")
195    }
196
197    ///
198    /// on adds a listener for a particular event
199    ///
200    pub fn on(&mut self, event: &str, handler: SocketIOHandler) {
201        let _ = self
202            .sender
203            .send(InternalMessage::IO(SocketIOMessage::AddListener(
204                event.to_string(),
205                handler,
206            )));
207    }
208
209    ///
210    /// join joins a socket into a room. This makes every message sent
211    /// by that socket go to the room rather than globally.
212    ///
213    pub async fn join(&mut self, room_id: &str) {
214        let _ = self.sender.send(InternalMessage::IO(SocketIOMessage::Join(
215            room_id.to_string(),
216        )));
217    }
218
219    ///
220    /// leave removes a socket from a room. Note that you cannot remove
221    /// a socket from its default room, i.e. its SID. This will result
222    /// in a noop.
223    ///
224    pub async fn leave(&mut self, room_id: &str) {
225        let _ = self.sender.send(InternalMessage::IO(SocketIOMessage::Leave(
226            room_id.to_string(),
227        )));
228    }
229
230    ///
231    /// send sends a message to this socket
232    ///
233    pub async fn send(&self, event: &str, message: &str) {
234        let _ = self
235            .sender
236            .send(InternalMessage::IO(SocketIOMessage::SendMessage(
237                event.to_string(),
238                message.to_string(),
239            )));
240    }
241
242    ///
243    /// emit_to sends a message to all sockets connected to the given
244    /// room_id, including the sending socket.
245    ///
246    pub async fn emit_to(&self, room_id: &str, event: &str, message: &str) {
247        // Send out via adapter
248        if let Some(adapter) = &*ADAPTER.read().unwrap() {
249            adapter.incoming(
250                room_id,
251                &SocketIOMessage::SendMessage(event.to_string(), message.to_string()),
252            );
253        }
254
255        if let Some(channels) = get_sockets_for_room(room_id) {
256            for channel in &*channels {
257                channel.send(InternalMessage::IO(SocketIOMessage::SendMessage(
258                    event.to_string(),
259                    message.to_string(),
260                )));
261            }
262        }
263    }
264
265    ///
266    /// broadcast_to sends a message to all the sockets connected to
267    /// the given room_id, excluding the sending socket.
268    ///
269    pub async fn broadcast_to(&self, room_id: &str, event: &str, message: &str) {
270        // Send out via adapter
271        if let Some(adapter) = &*ADAPTER.read().unwrap() {
272            adapter.incoming(
273                room_id,
274                &SocketIOMessage::SendMessage(event.to_string(), message.to_string()),
275            );
276        }
277
278        if let Some(channels) = get_sockets_for_room(room_id) {
279            for channel in &*channels {
280                if channel.sid() != self.id {
281                    channel.send(InternalMessage::IO(SocketIOMessage::SendMessage(
282                        event.to_string(),
283                        message.to_string(),
284                    )));
285                }
286            }
287        }
288    }
289
290    ///
291    /// rooms returns all of the rooms this socket is currently in
292    ///
293    pub fn rooms(&self) -> &Vec<String> {
294        &self.rooms
295    }
296}
297
298impl fmt::Display for InternalMessage {
299    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300        match self {
301            InternalMessage::IO(v) => write!(f, "Message::IO({})", v),
302            InternalMessage::WS(v) => write!(f, "Message::WS({})", v),
303        }
304    }
305}
306
307impl fmt::Display for WSSocketMessage {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309        match self {
310            WSSocketMessage::RawMessage(val) => write!(f, "WSSocketMessage::RawMessage({})", val),
311            WSSocketMessage::Ping => write!(f, "WSSocketMessage::Ping"),
312            WSSocketMessage::Pong => write!(f, "WSSocketMessage::Pong"),
313            WSSocketMessage::WsPing => write!(f, "WSSocketMessage::WsPing"),
314            WSSocketMessage::WsPong => write!(f, "WSSocketMessage::WsPong"),
315            WSSocketMessage::Close => write!(f, "WSSocketMessage::Close"),
316        }
317    }
318}
319
320pub struct SocketIOWrapper {
321    sid: String,
322    message_number: usize,
323    socket: SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>,
324    rooms: Vec<String>,
325    event_handlers: HashMap<String, Vec<SocketIOHandler>>,
326    sender: Sender<InternalMessage>,
327    receiver: Receiver<InternalMessage>,
328}
329
330impl SocketIOWrapper {
331    pub fn new(
332        sid: String,
333        socket: SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>,
334        message_capacity: usize,
335    ) -> Self {
336        let (sender, receiver) = unbounded(message_capacity);
337        SocketIOWrapper {
338            sid,
339            message_number: 0,
340            socket,
341            rooms: Vec::new(),
342            event_handlers: HashMap::new(),
343            sender,
344            receiver,
345        }
346    }
347
348    pub async fn close(mut self) {
349        // remove the socket from all joined rooms
350        for room in &self.rooms {
351            remove_socket_from_room(room, &self.sid);
352            debug!(
353                "SocketIOMessage socketid {} closed, leave room {}",
354                self.sid, room
355            );
356        }
357
358        let _res = self.socket.close().await;
359    }
360
361    ///
362    /// Handle an incoming payload. This parses the string into the correct parts and calls
363    /// self.handler on them
364    ///
365    pub async fn handle(&mut self, payload: String) {
366        if payload == SOCKETIO_PING {
367            let _ = self.sender.send(InternalMessage::WS(WSSocketMessage::Pong));
368            return;
369        }
370
371        if payload == SOCKETIO_PONG {
372            // Probably should set a timer and send a ping, but really... eh?
373            return;
374        }
375
376        match &payload[0..2] {
377            "42" => {
378                if !payload.is_empty() {
379                    let (event, message) = parse_raw_message(&payload);
380
381                    // Run handlers
382                    match self.event_handlers.get(&event) {
383                        Some(handlers) => {
384                            // Run with each handler -- should they be async and waited for?
385                            let unordered_future = FuturesUnordered::new();
386
387                            for handler in handlers {
388                                unordered_future.push((handler)(
389                                    SocketIOSocket {
390                                        id: self.sid.clone(),
391                                        sender: self.sender.clone(),
392                                        rooms: self.rooms.clone(),
393                                    },
394                                    message.clone(),
395                                ));
396                            }
397
398                            // Dev note -- this must be spawned in a separate task, otherwise
399                            // it can block the receive loop and queue up too many transactions.
400                            tokio::spawn(async move {
401                                let _ = unordered_future.collect::<Result<(), ()>>().await;
402                            });
403                        }
404                        None => {
405                            info!("No handler found for message: {:#?}", event);
406                        } // Ignore
407                    }
408                }
409            }
410            "41" => {
411                debug!("{}: Socket closed...", self.sid);
412            }
413            "40" => {
414                debug!("{}: Socket opened...", self.sid);
415            }
416            _ => panic!("Attempted to handle a non-message payload: '{}'", payload),
417        }
418    }
419
420    pub async fn listen(mut self) {
421        while let Ok(val) = self.receiver.recv().await {
422            match val {
423                InternalMessage::IO(val) => {
424                    match val {
425                        SocketIOMessage::SendMessage(event, message) => {
426                            self.message_number += 1;
427
428                            let message = match &message[0..1] {
429                                "{" | "[" => message,
430                                _ => format!("\"{}\"", message),
431                            };
432
433                            // TODO(trezm): Payload needs to be quoted if just a string, not if it's json
434                            let content = format!(
435                                "{}{}[\"{}\",{}]",
436                                SOCKETIO_EVENT_MESSAGE, self.message_number, event, message
437                            );
438
439                            let _ = self.socket.send(Message::Text(content)).await;
440                        }
441
442                        SocketIOMessage::SendBinaryMessage(_event, message) => {
443                            let _ = self.socket.send(Message::Binary(message)).await;
444                        }
445
446                        SocketIOMessage::Join(room_id) => {
447                            // check if room_id exist. Don't use return because of the following process such as PING/PONG.
448                            if !self.rooms.contains(&room_id) {
449                                self.rooms.push(room_id.to_string());
450                                debug!("SocketIOMessage socketid {} joined room {}. Rooms = {:?}, rooms len = {}", self.sid, room_id, self.rooms, self.rooms.len());
451
452                                //Call rooms::join_channel_to_room
453                                join_channel_to_room(
454                                    &room_id,
455                                    ChannelPair::new(&self.sid, self.sender()),
456                                );
457                            } else {
458                                debug!("SocketIOMessage socketid {} is already in room {}. Not joining.", self.sid, room_id);
459                            }
460                        }
461
462                        SocketIOMessage::Leave(room_id) => {
463                            for (i, room) in self.rooms.iter().enumerate() {
464                                if room == &room_id {
465                                    self.rooms.remove(i);
466                                    debug!("SocketIOMessage socketid {} leaved room {}. Rooms = {:?}, rooms len = {}", self.sid, room_id, self.rooms, self.rooms.len());
467
468                                    //Call rooms::remove_socket_from_room
469                                    remove_socket_from_room(&room_id, &self.sid);
470                                    break;
471                                }
472                            }
473                        }
474
475                        SocketIOMessage::AddListener(event, handler) => {
476                            let mut existing_handlers =
477                                self.event_handlers.remove(&event).unwrap_or_default();
478
479                            existing_handlers.push(handler);
480
481                            self.event_handlers
482                                .insert(event.to_string(), existing_handlers);
483                        }
484                        _ => (),
485                    }
486                }
487                InternalMessage::WS(val) => match val {
488                    WSSocketMessage::RawMessage(message) => self.handle(message).await,
489                    WSSocketMessage::Ping => {
490                        let _ = self
491                            .socket
492                            .send(Message::Text(SOCKETIO_PONG.to_string()))
493                            .await;
494                    }
495                    WSSocketMessage::Pong => {
496                        let _ = self
497                            .socket
498                            .send(Message::Text(SOCKETIO_PING.to_string()))
499                            .await;
500                    }
501                    WSSocketMessage::WsPing => {
502                        let _ = self.socket.send(Message::Pong([].to_vec())).await;
503                    }
504                    WSSocketMessage::WsPong => {
505                        let _ = self.socket.send(Message::Ping([].to_vec())).await;
506                    }
507
508                    WSSocketMessage::Close => {
509                        self.close().await;
510                        return;
511                    }
512                },
513            }
514        }
515    }
516
517    pub fn sender(&self) -> Sender<InternalMessage> {
518        self.sender.clone()
519    }
520}