1use anyhow::anyhow;
2use axum::extract::ws::{CloseFrame, Message, WebSocket};
3use ezrtc::protocol::{SessionId, SignalMessage, Status, UserId};
4use futures_util::{SinkExt, StreamExt};
5use log::{error, info, warn};
6use std::borrow::Cow;
7use std::collections::{HashMap, HashSet};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::sync::{mpsc, RwLock};
12use tokio::time;
13use tokio_stream::wrappers::UnboundedReceiverStream;
14
15#[derive(Default, Debug)]
16pub struct Session {
17 pub host: Option<UserId>,
18 pub users: HashSet<UserId>,
19}
20
21#[derive(Debug)]
22pub struct Ping {
23 pub last_seen: Instant,
24 pub session_id: Option<SessionId>,
25 pub metadata: Option<serde_json::Value>,
26 pub version: Option<String>,
27}
28
29pub type Connections = Arc<RwLock<HashMap<UserId, mpsc::UnboundedSender<Message>>>>;
30pub type Sessions = Arc<RwLock<HashMap<SessionId, Session>>>;
31pub type Pings = Arc<Mutex<HashMap<UserId, Arc<Ping>>>>;
32
33static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
34
35pub async fn user_connected(ws: WebSocket, connections: Connections, sessions: Sessions, pings: Pings) {
36 let user_id = UserId::new(NEXT_USER_ID.fetch_add(1, Ordering::Relaxed));
37 info!("new user connected: {:?}", user_id);
38
39 let (mut ws_send, mut ws_recv) = ws.split();
40
41 let (tx, rx) = mpsc::unbounded_channel();
43 let mut rx = UnboundedReceiverStream::new(rx);
44
45 let tx2 = tx.clone();
47 let user_id2 = user_id.clone();
48 let pings2 = pings.clone();
49
50 let mut ping_task = tokio::spawn(async move {
51 let mut interval = time::interval(Duration::from_secs(50));
52
53 loop {
54 interval.tick().await;
55
56 let status = {
57 let pings = pings2.lock().unwrap();
58 pings.get(&user_id2).cloned()
59 };
60
61 if let Some(ping) = status {
62 let elapsed = ping.last_seen.elapsed();
63 if elapsed > Duration::from_secs(60) {
64 error!("User failed to respond, closing connection: {:?}", user_id2);
65 break;
66 }
67 }
68
69 warn!("Sending ping to user: {:?}", user_id2);
70
71 let response = SignalMessage::KeepAlive(user_id2.clone(), Status::default());
72 let response = serde_json::to_string(&response).unwrap();
73 if let Err(e) = tx2.send(Message::Text(response)) {
74 error!("Websocket ping error: {}", e);
75 break;
76 }
77 }
78 });
79
80 let mut send_task = tokio::spawn(async move {
82 while let Some(message) = rx.next().await {
83 if ws_send.send(message).await.is_err() {
84 break;
85 }
86 }
87
88 match ws_send
89 .send(Message::Close(Some(CloseFrame {
90 code: axum::extract::ws::close_code::NORMAL,
91 reason: Cow::from("Goodbye"),
92 })))
93 .await
94 {
95 Ok(_) => info!("Sent close to {user_id}"),
96 Err(e) => info!("Failed to close: {e}"),
97 }
98 });
99
100 let connections2 = connections.clone();
102 let sessions2 = sessions.clone();
103 let pings2 = pings.clone();
104
105 let mut recv_task = tokio::spawn(async move {
106 while let Some(msg) = ws_recv.next().await {
107 match msg {
108 Ok(msg) => {
109 if let Err(err) = user_message(user_id, msg, &connections2, &sessions2, &pings2).await {
110 error!("error while handling user message: {}", err);
111 }
112 }
113 Err(e) => {
114 error!("Websocket error: {:?} {}", user_id, e);
115 break;
116 }
117 }
118 }
119 });
120
121 connections.write().await.insert(user_id, tx);
122
123 tokio::select! {
125 t1 = (&mut send_task) => {
126 match t1 {
127 Ok(_) => info!("Sender task stopped"),
128 Err(a) => info!("Error sending messages {a:?}")
129 }
130 recv_task.abort();
131 ping_task.abort();
132 },
133 t2 = (&mut recv_task) => {
134 match t2 {
135 Ok(_) => info!("Receiver task stopped"),
136 Err(b) => info!("Error receiving messages {b:?}")
137 }
138 send_task.abort();
139 ping_task.abort();
140 }
141 t3 = (&mut ping_task) => {
142 match t3 {
143 Ok(_) => info!("Ping task stopped"),
144 Err(c) => info!("Error pinging {c:?}")
145 }
146 send_task.abort();
147 recv_task.abort();
148 }
149 }
150
151 error!("User disconnected: {:?}", user_id);
152 pings.lock().unwrap().remove(&user_id);
153 user_disconnected(user_id, &connections, &sessions).await;
154}
155
156async fn user_message(sender_id: UserId, msg: Message, connections: &Connections, sessions: &Sessions, pings: &Pings) -> crate::Result<()> {
157 if let Ok(msg) = msg.to_text() {
158 if msg.is_empty() || msg == "ping" {
159 return Ok(());
161 }
162
163 match serde_json::from_str::<SignalMessage>(msg) {
164 Ok(request) => {
165 info!("message received from user {:?}: {:?}", sender_id, request);
166 match request {
167 SignalMessage::SessionJoin(session_id, is_host) => {
168 let mut sessions_writer = sessions.write().await;
169 let session = sessions_writer.entry(session_id.clone()).or_insert_with(Session::default);
170 let connections_reader = connections.read().await;
171
172 if is_host && session.host.is_none() {
173 session.host = Some(sender_id);
174 for client_id in &session.users {
176 {
177 let host_tx = connections_reader.get(&sender_id).expect("host not in connections");
178 let host_response = SignalMessage::SessionReady(session_id.clone(), *client_id);
179 let host_response = serde_json::to_string(&host_response)?;
180 host_tx.send(Message::Text(host_response)).expect("failed to send SessionReady message to host");
181 }
182 }
183 } else if is_host && session.host.is_some() {
184 warn!("connecting user wants to be a host, but host is already present, closing connection soon");
185
186 let connections2 = connections.clone();
187
188 tokio::task::spawn(async move {
189 let new_host_tx = {
190 let connections_reader2 = connections2.read().await;
191 connections_reader2.get(&sender_id).cloned()
192 };
193
194 tokio::time::sleep(Duration::from_secs(60)).await;
195 if let Some(new_host_tx) = new_host_tx {
196 new_host_tx
197 .send(Message::Close(Some(CloseFrame {
198 code: 3001,
199 reason: "Multiple hosts".into(),
200 })))
201 .expect("failed to send close message to host");
202 }
203 });
204 } else {
205 session.users.insert(sender_id);
207
208 if let Some(host_id) = session.host {
209 let host_tx = connections_reader.get(&host_id).expect("host not in connections");
210 let host_response = SignalMessage::SessionReady(session_id.clone(), sender_id);
211 let host_response = serde_json::to_string(&host_response)?;
212 host_tx.send(Message::Text(host_response)).expect("failed to send SessionReady message to host");
213 }
214 }
215 }
216 SignalMessage::SdpOffer(session_id, recipient_id, offer) => {
218 let response = SignalMessage::SdpOffer(session_id, sender_id, offer);
219 let response = serde_json::to_string(&response)?;
220 let connections_reader = connections.read().await;
221 if let Some(recipient_tx) = connections_reader.get(&recipient_id) {
222 recipient_tx.send(Message::Text(response))?;
223 } else {
224 warn!("tried to send offer to non existing user");
225 }
226 }
227 SignalMessage::SdpAnswer(session_id, recipient_id, answer) => {
229 let response = SignalMessage::SdpAnswer(session_id, sender_id, answer);
230 let response = serde_json::to_string(&response)?;
231 let connections_reader = connections.read().await;
232 if let Some(recipient_tx) = connections_reader.get(&recipient_id) {
233 recipient_tx.send(Message::Text(response))?;
234 } else {
235 warn!("tried to send offer to non existing user");
236 }
237 }
238 SignalMessage::IceCandidate(session_id, recipient_id, candidate) => {
239 let response = SignalMessage::IceCandidate(session_id, sender_id, candidate);
240 let response = serde_json::to_string(&response)?;
241 let connections_reader = connections.read().await;
242 let recipient_tx = connections_reader.get(&recipient_id).ok_or_else(|| anyhow!("no sender for given id"))?;
243
244 recipient_tx.send(Message::Text(response))?;
245 }
246 SignalMessage::KeepAlive(user_id, status) => {
247 if status.is_host.is_some() {
248 warn!("Received ping from user {:?}", status.session_id);
249 pings.lock().unwrap().insert(
250 user_id,
251 Arc::new(Ping {
252 last_seen: Instant::now(),
253 session_id: status.session_id,
254 metadata: status.metadata,
255 version: status.version,
256 }),
257 );
258 }
259 }
260 _ => {}
261 }
262 }
263 Err(error) => {
264 error!("An error occurred: {:?} {:?}", error, msg);
265 }
266 }
267 }
268 Ok(())
269}
270
271async fn user_disconnected(user_id: UserId, connections: &Connections, sessions: &Sessions) {
272 connections.write().await.remove(&user_id);
273
274 let mut session_to_delete = None;
275 for (session_id, session) in sessions.write().await.iter_mut() {
276 if session.host == Some(user_id) {
277 session.host = None;
278 } else if session.users.contains(&user_id) {
279 session.users.remove(&user_id);
280 }
281 if session.host.is_none() && session.users.is_empty() {
282 session_to_delete = Some(session_id.clone());
283 break;
284 }
285 }
286 if let Some(session_id) = session_to_delete {
288 sessions.write().await.remove(&session_id);
289 }
290}