ezrtc_server/
one_to_many.rs

1use anyhow::anyhow;
2use axum::extract::ws::{CloseFrame, Message, WebSocket};
3use ezrtc::protocol::{SessionId, SignalMessage, Status, UserId};
4use futures_util::{SinkExt, StreamExt};
5use log::{error, info, warn};
6use std::borrow::Cow;
7use std::collections::{HashMap, HashSet};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use tokio::sync::{mpsc, RwLock};
12use tokio::time;
13use tokio_stream::wrappers::UnboundedReceiverStream;
14
15#[derive(Default, Debug)]
16pub struct Session {
17    pub host: Option<UserId>,
18    pub users: HashSet<UserId>,
19}
20
21#[derive(Default, Debug)]
22pub struct Ping {
23    pub online: bool,
24    pub session_id: Option<SessionId>,
25    pub metadata: Option<serde_json::Value>,
26}
27
28pub type Connections = Arc<RwLock<HashMap<UserId, mpsc::UnboundedSender<Message>>>>;
29pub type Sessions = Arc<RwLock<HashMap<SessionId, Session>>>;
30pub type Pings = Arc<Mutex<HashMap<UserId, Arc<Ping>>>>;
31
32static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
33
34pub async fn user_connected(ws: WebSocket, connections: Connections, sessions: Sessions, pings: Pings) {
35    let user_id = UserId::new(NEXT_USER_ID.fetch_add(1, Ordering::Relaxed));
36    info!("new user connected: {:?}", user_id);
37
38    let (mut ws_send, mut ws_recv) = ws.split();
39
40    // Create a channel for sending and receiving ws messages
41    let (tx, rx) = mpsc::unbounded_channel();
42    let mut rx = UnboundedReceiverStream::new(rx);
43
44    // Ping client every 60 seconds
45    let tx2 = tx.clone();
46    let user_id2 = user_id.clone();
47    let pings2 = pings.clone();
48
49    let mut ping_task = tokio::spawn(async move {
50        let mut interval = time::interval(Duration::from_secs(60));
51
52        loop {
53            interval.tick().await;
54
55            let status = {
56                let pings = pings2.lock().unwrap();
57                pings.get(&user_id2).cloned()
58            };
59
60            if let Some(ping) = status {
61                if ping.online {
62                    pings2.lock().unwrap().insert(
63                        user_id2.clone(),
64                        Arc::new(Ping {
65                            online: false,
66                            session_id: ping.session_id.clone(),
67                            metadata: ping.metadata.clone(),
68                        }),
69                    );
70                } else {
71                    error!("User failed to respond, closing connection: {:?}", user_id2);
72                    break;
73                }
74            }
75
76            warn!("Sending ping to user: {:?}", user_id2);
77
78            let response = SignalMessage::KeepAlive(user_id2.clone(), Status::default());
79            let response = serde_json::to_string(&response).unwrap();
80            if let Err(e) = tx2.send(Message::Text(response)) {
81                error!("Websocket ping error: {}", e);
82                break;
83            }
84        }
85    });
86
87    // Send messages to websocket from channel
88    let mut send_task = tokio::spawn(async move {
89        while let Some(message) = rx.next().await {
90            if ws_send.send(message).await.is_err() {
91                break;
92            }
93        }
94
95        match ws_send
96            .send(Message::Close(Some(CloseFrame {
97                code: axum::extract::ws::close_code::NORMAL,
98                reason: Cow::from("Goodbye"),
99            })))
100            .await
101        {
102            Ok(_) => info!("Sent close to {user_id}"),
103            Err(e) => info!("Failed to close: {e}"),
104        }
105    });
106
107    // Receive messages from websocket
108    let connections2 = connections.clone();
109    let sessions2 = sessions.clone();
110    let pings2 = pings.clone();
111
112    let mut recv_task = tokio::spawn(async move {
113        while let Some(msg) = ws_recv.next().await {
114            match msg {
115                Ok(msg) => {
116                    if let Err(err) = user_message(user_id, msg, &connections2, &sessions2, &pings2).await {
117                        error!("error while handling user message: {}", err);
118                    }
119                }
120                Err(e) => {
121                    error!("Websocket error: {:?} {}", user_id, e);
122                    break;
123                }
124            }
125        }
126    });
127
128    connections.write().await.insert(user_id, tx);
129
130    // Run all tasks and abort if any of them fails
131    tokio::select! {
132        t1 = (&mut send_task) => {
133            match t1 {
134                Ok(_) => info!("Sender task stopped"),
135                Err(a) => info!("Error sending messages {a:?}")
136            }
137            recv_task.abort();
138            ping_task.abort();
139        },
140        t2 = (&mut recv_task) => {
141            match t2 {
142                Ok(_) => info!("Receiver task stopped"),
143                Err(b) => info!("Error receiving messages {b:?}")
144            }
145            send_task.abort();
146            ping_task.abort();
147        }
148        t3 = (&mut ping_task) => {
149            match t3 {
150                Ok(_) => info!("Ping task stopped"),
151                Err(c) => info!("Error pinging {c:?}")
152            }
153            send_task.abort();
154            recv_task.abort();
155        }
156    }
157
158    error!("User disconnected: {:?}", user_id);
159    pings.lock().unwrap().remove(&user_id);
160    user_disconnected(user_id, &connections, &sessions).await;
161}
162
163async fn user_message(sender_id: UserId, msg: Message, connections: &Connections, sessions: &Sessions, pings: &Pings) -> crate::Result<()> {
164    if let Ok(msg) = msg.to_text() {
165        if msg.is_empty() || msg == "ping" {
166            // warn!("empty message from user {:?}", sender_id);
167            return Ok(());
168        }
169
170        match serde_json::from_str::<SignalMessage>(msg) {
171            Ok(request) => {
172                info!("message received from user {:?}: {:?}", sender_id, request);
173                match request {
174                    SignalMessage::SessionJoin(session_id, is_host) => {
175                        let mut sessions_writer = sessions.write().await;
176                        let session = sessions_writer.entry(session_id.clone()).or_insert_with(Session::default);
177                        let connections_reader = connections.read().await;
178
179                        if is_host && session.host.is_none() {
180                            session.host = Some(sender_id);
181                            // start connections with all already present users
182                            for client_id in &session.users {
183                                {
184                                    let host_tx = connections_reader.get(&sender_id).expect("host not in connections");
185                                    let host_response = SignalMessage::SessionReady(session_id.clone(), *client_id);
186                                    let host_response = serde_json::to_string(&host_response)?;
187                                    host_tx.send(Message::Text(host_response)).expect("failed to send SessionReady message to host");
188                                }
189                            }
190                        } else if is_host && session.host.is_some() {
191                            warn!("connecting user wants to be a host, but host is already present, closing connection soon");
192
193                            let connections2 = connections.clone();
194
195                            tokio::task::spawn(async move {
196                                let new_host_tx = {
197                                    let connections_reader2 = connections2.read().await;
198                                    connections_reader2.get(&sender_id).cloned()
199                                };
200
201                                tokio::time::sleep(Duration::from_secs(60)).await;
202                                if let Some(new_host_tx) = new_host_tx {
203                                    new_host_tx
204                                        .send(Message::Close(Some(CloseFrame {
205                                            code: 3001,
206                                            reason: "Multiple hosts".into(),
207                                        })))
208                                        .expect("failed to send close message to host");
209                                }
210                            });
211                        } else {
212                            // connect new user with host
213                            session.users.insert(sender_id);
214
215                            if let Some(host_id) = session.host {
216                                let host_tx = connections_reader.get(&host_id).expect("host not in connections");
217                                let host_response = SignalMessage::SessionReady(session_id.clone(), sender_id);
218                                let host_response = serde_json::to_string(&host_response)?;
219                                host_tx.send(Message::Text(host_response)).expect("failed to send SessionReady message to host");
220                            }
221                        }
222                    }
223                    // pass offer to the other user in session without changing anything
224                    SignalMessage::SdpOffer(session_id, recipient_id, offer) => {
225                        let response = SignalMessage::SdpOffer(session_id, sender_id, offer);
226                        let response = serde_json::to_string(&response)?;
227                        let connections_reader = connections.read().await;
228                        if let Some(recipient_tx) = connections_reader.get(&recipient_id) {
229                            recipient_tx.send(Message::Text(response))?;
230                        } else {
231                            warn!("tried to send offer to non existing user");
232                        }
233                    }
234                    // pass answer to the other user in session without changing anything
235                    SignalMessage::SdpAnswer(session_id, recipient_id, answer) => {
236                        let response = SignalMessage::SdpAnswer(session_id, sender_id, answer);
237                        let response = serde_json::to_string(&response)?;
238                        let connections_reader = connections.read().await;
239                        if let Some(recipient_tx) = connections_reader.get(&recipient_id) {
240                            recipient_tx.send(Message::Text(response))?;
241                        } else {
242                            warn!("tried to send offer to non existing user");
243                        }
244                    }
245                    SignalMessage::IceCandidate(session_id, recipient_id, candidate) => {
246                        let response = SignalMessage::IceCandidate(session_id, sender_id, candidate);
247                        let response = serde_json::to_string(&response)?;
248                        let connections_reader = connections.read().await;
249                        let recipient_tx = connections_reader.get(&recipient_id).ok_or_else(|| anyhow!("no sender for given id"))?;
250
251                        recipient_tx.send(Message::Text(response))?;
252                    }
253                    SignalMessage::KeepAlive(user_id, status) => {
254                        if status.is_host.is_some() {
255                            warn!("Received ping from user {:?}", status.session_id);
256                            pings.lock().unwrap().insert(user_id, Arc::new(Ping { online: true, session_id: status.session_id, metadata: status.metadata }));
257                        }
258                    }
259                    _ => {}
260                }
261            }
262            Err(error) => {
263                error!("An error occurred: {:?} {:?}", error, msg);
264            }
265        }
266    }
267    Ok(())
268}
269
270async fn user_disconnected(user_id: UserId, connections: &Connections, sessions: &Sessions) {
271    connections.write().await.remove(&user_id);
272
273    let mut session_to_delete = None;
274    for (session_id, session) in sessions.write().await.iter_mut() {
275        if session.host == Some(user_id) {
276            session.host = None;
277        } else if session.users.contains(&user_id) {
278            session.users.remove(&user_id);
279        }
280        if session.host.is_none() && session.users.is_empty() {
281            session_to_delete = Some(session_id.clone());
282            break;
283        }
284    }
285    // remove session if it's empty
286    if let Some(session_id) = session_to_delete {
287        sessions.write().await.remove(&session_id);
288    }
289}