ezrtc_server/
one_to_many.rs

1use anyhow::anyhow;
2use axum::extract::ws::{CloseFrame, Message, WebSocket};
3use ezrtc::protocol::{SessionId, SignalMessage, Status, UserId};
4use futures_util::{SinkExt, StreamExt};
5use log::{error, info, warn};
6use std::borrow::Cow;
7use std::collections::{HashMap, HashSet};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::sync::{mpsc, RwLock};
12use tokio::time;
13use tokio_stream::wrappers::UnboundedReceiverStream;
14
15#[derive(Default, Debug)]
16pub struct Session {
17    pub host: Option<UserId>,
18    pub users: HashSet<UserId>,
19}
20
21#[derive(Debug)]
22pub struct Ping {
23    pub last_seen: Instant,
24    pub session_id: Option<SessionId>,
25    pub metadata: Option<serde_json::Value>,
26    pub version: Option<String>,
27}
28
29pub type Connections = Arc<RwLock<HashMap<UserId, mpsc::UnboundedSender<Message>>>>;
30pub type Sessions = Arc<RwLock<HashMap<SessionId, Session>>>;
31pub type Pings = Arc<Mutex<HashMap<UserId, Arc<Ping>>>>;
32
33static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
34
35pub async fn user_connected(ws: WebSocket, connections: Connections, sessions: Sessions, pings: Pings) {
36    let user_id = UserId::new(NEXT_USER_ID.fetch_add(1, Ordering::Relaxed));
37    info!("new user connected: {:?}", user_id);
38
39    let (mut ws_send, mut ws_recv) = ws.split();
40
41    // Create a channel for sending and receiving ws messages
42    let (tx, rx) = mpsc::unbounded_channel();
43    let mut rx = UnboundedReceiverStream::new(rx);
44
45    // Ping client every 60 seconds
46    let tx2 = tx.clone();
47    let user_id2 = user_id.clone();
48    let pings2 = pings.clone();
49
50    let mut ping_task = tokio::spawn(async move {
51        let mut interval = time::interval(Duration::from_secs(50));
52
53        loop {
54            interval.tick().await;
55
56            let status = {
57                let pings = pings2.lock().unwrap();
58                pings.get(&user_id2).cloned()
59            };
60
61            if let Some(ping) = status {
62                let elapsed = ping.last_seen.elapsed();
63                if elapsed > Duration::from_secs(60) {
64                    error!("User failed to respond, closing connection: {:?}", user_id2);
65                    break;
66                }
67            }
68
69            warn!("Sending ping to user: {:?}", user_id2);
70
71            let response = SignalMessage::KeepAlive(user_id2.clone(), Status::default());
72            let response = serde_json::to_string(&response).unwrap();
73            if let Err(e) = tx2.send(Message::Text(response)) {
74                error!("Websocket ping error: {}", e);
75                break;
76            }
77        }
78    });
79
80    // Send messages to websocket from channel
81    let mut send_task = tokio::spawn(async move {
82        while let Some(message) = rx.next().await {
83            if ws_send.send(message).await.is_err() {
84                break;
85            }
86        }
87
88        match ws_send
89            .send(Message::Close(Some(CloseFrame {
90                code: axum::extract::ws::close_code::NORMAL,
91                reason: Cow::from("Goodbye"),
92            })))
93            .await
94        {
95            Ok(_) => info!("Sent close to {user_id}"),
96            Err(e) => info!("Failed to close: {e}"),
97        }
98    });
99
100    // Receive messages from websocket
101    let connections2 = connections.clone();
102    let sessions2 = sessions.clone();
103    let pings2 = pings.clone();
104
105    let mut recv_task = tokio::spawn(async move {
106        while let Some(msg) = ws_recv.next().await {
107            match msg {
108                Ok(msg) => {
109                    if let Err(err) = user_message(user_id, msg, &connections2, &sessions2, &pings2).await {
110                        error!("error while handling user message: {}", err);
111                    }
112                }
113                Err(e) => {
114                    error!("Websocket error: {:?} {}", user_id, e);
115                    break;
116                }
117            }
118        }
119    });
120
121    connections.write().await.insert(user_id, tx);
122
123    // Run all tasks and abort if any of them fails
124    tokio::select! {
125        t1 = (&mut send_task) => {
126            match t1 {
127                Ok(_) => info!("Sender task stopped"),
128                Err(a) => info!("Error sending messages {a:?}")
129            }
130            recv_task.abort();
131            ping_task.abort();
132        },
133        t2 = (&mut recv_task) => {
134            match t2 {
135                Ok(_) => info!("Receiver task stopped"),
136                Err(b) => info!("Error receiving messages {b:?}")
137            }
138            send_task.abort();
139            ping_task.abort();
140        }
141        t3 = (&mut ping_task) => {
142            match t3 {
143                Ok(_) => info!("Ping task stopped"),
144                Err(c) => info!("Error pinging {c:?}")
145            }
146            send_task.abort();
147            recv_task.abort();
148        }
149    }
150
151    error!("User disconnected: {:?}", user_id);
152    pings.lock().unwrap().remove(&user_id);
153    user_disconnected(user_id, &connections, &sessions).await;
154}
155
156async fn user_message(sender_id: UserId, msg: Message, connections: &Connections, sessions: &Sessions, pings: &Pings) -> crate::Result<()> {
157    if let Ok(msg) = msg.to_text() {
158        if msg.is_empty() || msg == "ping" {
159            // warn!("empty message from user {:?}", sender_id);
160            return Ok(());
161        }
162
163        match serde_json::from_str::<SignalMessage>(msg) {
164            Ok(request) => {
165                info!("message received from user {:?}: {:?}", sender_id, request);
166                match request {
167                    SignalMessage::SessionJoin(session_id, is_host) => {
168                        let mut sessions_writer = sessions.write().await;
169                        let session = sessions_writer.entry(session_id.clone()).or_insert_with(Session::default);
170                        let connections_reader = connections.read().await;
171
172                        if is_host && session.host.is_none() {
173                            session.host = Some(sender_id);
174                            // start connections with all already present users
175                            for client_id in &session.users {
176                                {
177                                    let host_tx = connections_reader.get(&sender_id).expect("host not in connections");
178                                    let host_response = SignalMessage::SessionReady(session_id.clone(), *client_id);
179                                    let host_response = serde_json::to_string(&host_response)?;
180                                    host_tx.send(Message::Text(host_response)).expect("failed to send SessionReady message to host");
181                                }
182                            }
183                        } else if is_host && session.host.is_some() {
184                            warn!("connecting user wants to be a host, but host is already present, closing connection soon");
185
186                            let connections2 = connections.clone();
187
188                            tokio::task::spawn(async move {
189                                let new_host_tx = {
190                                    let connections_reader2 = connections2.read().await;
191                                    connections_reader2.get(&sender_id).cloned()
192                                };
193
194                                tokio::time::sleep(Duration::from_secs(60)).await;
195                                if let Some(new_host_tx) = new_host_tx {
196                                    new_host_tx
197                                        .send(Message::Close(Some(CloseFrame {
198                                            code: 3001,
199                                            reason: "Multiple hosts".into(),
200                                        })))
201                                        .expect("failed to send close message to host");
202                                }
203                            });
204                        } else {
205                            // connect new user with host
206                            session.users.insert(sender_id);
207
208                            if let Some(host_id) = session.host {
209                                let host_tx = connections_reader.get(&host_id).expect("host not in connections");
210                                let host_response = SignalMessage::SessionReady(session_id.clone(), sender_id);
211                                let host_response = serde_json::to_string(&host_response)?;
212                                host_tx.send(Message::Text(host_response)).expect("failed to send SessionReady message to host");
213                            }
214                        }
215                    }
216                    // pass offer to the other user in session without changing anything
217                    SignalMessage::SdpOffer(session_id, recipient_id, offer) => {
218                        let response = SignalMessage::SdpOffer(session_id, sender_id, offer);
219                        let response = serde_json::to_string(&response)?;
220                        let connections_reader = connections.read().await;
221                        if let Some(recipient_tx) = connections_reader.get(&recipient_id) {
222                            recipient_tx.send(Message::Text(response))?;
223                        } else {
224                            warn!("tried to send offer to non existing user");
225                        }
226                    }
227                    // pass answer to the other user in session without changing anything
228                    SignalMessage::SdpAnswer(session_id, recipient_id, answer) => {
229                        let response = SignalMessage::SdpAnswer(session_id, sender_id, answer);
230                        let response = serde_json::to_string(&response)?;
231                        let connections_reader = connections.read().await;
232                        if let Some(recipient_tx) = connections_reader.get(&recipient_id) {
233                            recipient_tx.send(Message::Text(response))?;
234                        } else {
235                            warn!("tried to send offer to non existing user");
236                        }
237                    }
238                    SignalMessage::IceCandidate(session_id, recipient_id, candidate) => {
239                        let response = SignalMessage::IceCandidate(session_id, sender_id, candidate);
240                        let response = serde_json::to_string(&response)?;
241                        let connections_reader = connections.read().await;
242                        let recipient_tx = connections_reader.get(&recipient_id).ok_or_else(|| anyhow!("no sender for given id"))?;
243
244                        recipient_tx.send(Message::Text(response))?;
245                    }
246                    SignalMessage::KeepAlive(user_id, status) => {
247                        if status.is_host.is_some() {
248                            warn!("Received ping from user {:?}", status.session_id);
249                            pings.lock().unwrap().insert(
250                                user_id,
251                                Arc::new(Ping {
252                                    last_seen: Instant::now(),
253                                    session_id: status.session_id,
254                                    metadata: status.metadata,
255                                    version: status.version,
256                                }),
257                            );
258                        }
259                    }
260                    _ => {}
261                }
262            }
263            Err(error) => {
264                error!("An error occurred: {:?} {:?}", error, msg);
265            }
266        }
267    }
268    Ok(())
269}
270
271async fn user_disconnected(user_id: UserId, connections: &Connections, sessions: &Sessions) {
272    connections.write().await.remove(&user_id);
273
274    let mut session_to_delete = None;
275    for (session_id, session) in sessions.write().await.iter_mut() {
276        if session.host == Some(user_id) {
277            session.host = None;
278        } else if session.users.contains(&user_id) {
279            session.users.remove(&user_id);
280        }
281        if session.host.is_none() && session.users.is_empty() {
282            session_to_delete = Some(session_id.clone());
283            break;
284        }
285    }
286    // remove session if it's empty
287    if let Some(session_id) = session_to_delete {
288        sessions.write().await.remove(&session_id);
289    }
290}