1use anyhow::anyhow;
2use axum::extract::ws::{CloseFrame, Message, WebSocket};
3use ezrtc::protocol::{SessionId, SignalMessage, Status, UserId};
4use futures_util::{SinkExt, StreamExt};
5use log::{error, info, warn};
6use std::borrow::Cow;
7use std::collections::{HashMap, HashSet};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use tokio::sync::{mpsc, RwLock};
12use tokio::time;
13use tokio_stream::wrappers::UnboundedReceiverStream;
14
15#[derive(Default, Debug)]
16pub struct Session {
17 pub host: Option<UserId>,
18 pub users: HashSet<UserId>,
19}
20
21#[derive(Default, Debug)]
22pub struct Ping {
23 pub online: bool,
24 pub session_id: Option<SessionId>,
25 pub metadata: Option<serde_json::Value>,
26}
27
28pub type Connections = Arc<RwLock<HashMap<UserId, mpsc::UnboundedSender<Message>>>>;
29pub type Sessions = Arc<RwLock<HashMap<SessionId, Session>>>;
30pub type Pings = Arc<Mutex<HashMap<UserId, Arc<Ping>>>>;
31
32static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
33
34pub async fn user_connected(ws: WebSocket, connections: Connections, sessions: Sessions, pings: Pings) {
35 let user_id = UserId::new(NEXT_USER_ID.fetch_add(1, Ordering::Relaxed));
36 info!("new user connected: {:?}", user_id);
37
38 let (mut ws_send, mut ws_recv) = ws.split();
39
40 let (tx, rx) = mpsc::unbounded_channel();
42 let mut rx = UnboundedReceiverStream::new(rx);
43
44 let tx2 = tx.clone();
46 let user_id2 = user_id.clone();
47 let pings2 = pings.clone();
48
49 let mut ping_task = tokio::spawn(async move {
50 let mut interval = time::interval(Duration::from_secs(60));
51
52 loop {
53 interval.tick().await;
54
55 let status = {
56 let pings = pings2.lock().unwrap();
57 pings.get(&user_id2).cloned()
58 };
59
60 if let Some(ping) = status {
61 if ping.online {
62 pings2.lock().unwrap().insert(
63 user_id2.clone(),
64 Arc::new(Ping {
65 online: false,
66 session_id: ping.session_id.clone(),
67 metadata: ping.metadata.clone(),
68 }),
69 );
70 } else {
71 error!("User failed to respond, closing connection: {:?}", user_id2);
72 break;
73 }
74 }
75
76 warn!("Sending ping to user: {:?}", user_id2);
77
78 let response = SignalMessage::KeepAlive(user_id2.clone(), Status::default());
79 let response = serde_json::to_string(&response).unwrap();
80 if let Err(e) = tx2.send(Message::Text(response)) {
81 error!("Websocket ping error: {}", e);
82 break;
83 }
84 }
85 });
86
87 let mut send_task = tokio::spawn(async move {
89 while let Some(message) = rx.next().await {
90 if ws_send.send(message).await.is_err() {
91 break;
92 }
93 }
94
95 match ws_send
96 .send(Message::Close(Some(CloseFrame {
97 code: axum::extract::ws::close_code::NORMAL,
98 reason: Cow::from("Goodbye"),
99 })))
100 .await
101 {
102 Ok(_) => info!("Sent close to {user_id}"),
103 Err(e) => info!("Failed to close: {e}"),
104 }
105 });
106
107 let connections2 = connections.clone();
109 let sessions2 = sessions.clone();
110 let pings2 = pings.clone();
111
112 let mut recv_task = tokio::spawn(async move {
113 while let Some(msg) = ws_recv.next().await {
114 match msg {
115 Ok(msg) => {
116 if let Err(err) = user_message(user_id, msg, &connections2, &sessions2, &pings2).await {
117 error!("error while handling user message: {}", err);
118 }
119 }
120 Err(e) => {
121 error!("Websocket error: {:?} {}", user_id, e);
122 break;
123 }
124 }
125 }
126 });
127
128 connections.write().await.insert(user_id, tx);
129
130 tokio::select! {
132 t1 = (&mut send_task) => {
133 match t1 {
134 Ok(_) => info!("Sender task stopped"),
135 Err(a) => info!("Error sending messages {a:?}")
136 }
137 recv_task.abort();
138 ping_task.abort();
139 },
140 t2 = (&mut recv_task) => {
141 match t2 {
142 Ok(_) => info!("Receiver task stopped"),
143 Err(b) => info!("Error receiving messages {b:?}")
144 }
145 send_task.abort();
146 ping_task.abort();
147 }
148 t3 = (&mut ping_task) => {
149 match t3 {
150 Ok(_) => info!("Ping task stopped"),
151 Err(c) => info!("Error pinging {c:?}")
152 }
153 send_task.abort();
154 recv_task.abort();
155 }
156 }
157
158 error!("User disconnected: {:?}", user_id);
159 pings.lock().unwrap().remove(&user_id);
160 user_disconnected(user_id, &connections, &sessions).await;
161}
162
163async fn user_message(sender_id: UserId, msg: Message, connections: &Connections, sessions: &Sessions, pings: &Pings) -> crate::Result<()> {
164 if let Ok(msg) = msg.to_text() {
165 if msg.is_empty() || msg == "ping" {
166 return Ok(());
168 }
169
170 match serde_json::from_str::<SignalMessage>(msg) {
171 Ok(request) => {
172 info!("message received from user {:?}: {:?}", sender_id, request);
173 match request {
174 SignalMessage::SessionJoin(session_id, is_host) => {
175 let mut sessions_writer = sessions.write().await;
176 let session = sessions_writer.entry(session_id.clone()).or_insert_with(Session::default);
177 let connections_reader = connections.read().await;
178
179 if is_host && session.host.is_none() {
180 session.host = Some(sender_id);
181 for client_id in &session.users {
183 {
184 let host_tx = connections_reader.get(&sender_id).expect("host not in connections");
185 let host_response = SignalMessage::SessionReady(session_id.clone(), *client_id);
186 let host_response = serde_json::to_string(&host_response)?;
187 host_tx.send(Message::Text(host_response)).expect("failed to send SessionReady message to host");
188 }
189 }
190 } else if is_host && session.host.is_some() {
191 warn!("connecting user wants to be a host, but host is already present, closing connection soon");
192
193 let connections2 = connections.clone();
194
195 tokio::task::spawn(async move {
196 let new_host_tx = {
197 let connections_reader2 = connections2.read().await;
198 connections_reader2.get(&sender_id).cloned()
199 };
200
201 tokio::time::sleep(Duration::from_secs(60)).await;
202 if let Some(new_host_tx) = new_host_tx {
203 new_host_tx
204 .send(Message::Close(Some(CloseFrame {
205 code: 3001,
206 reason: "Multiple hosts".into(),
207 })))
208 .expect("failed to send close message to host");
209 }
210 });
211 } else {
212 session.users.insert(sender_id);
214
215 if let Some(host_id) = session.host {
216 let host_tx = connections_reader.get(&host_id).expect("host not in connections");
217 let host_response = SignalMessage::SessionReady(session_id.clone(), sender_id);
218 let host_response = serde_json::to_string(&host_response)?;
219 host_tx.send(Message::Text(host_response)).expect("failed to send SessionReady message to host");
220 }
221 }
222 }
223 SignalMessage::SdpOffer(session_id, recipient_id, offer) => {
225 let response = SignalMessage::SdpOffer(session_id, sender_id, offer);
226 let response = serde_json::to_string(&response)?;
227 let connections_reader = connections.read().await;
228 if let Some(recipient_tx) = connections_reader.get(&recipient_id) {
229 recipient_tx.send(Message::Text(response))?;
230 } else {
231 warn!("tried to send offer to non existing user");
232 }
233 }
234 SignalMessage::SdpAnswer(session_id, recipient_id, answer) => {
236 let response = SignalMessage::SdpAnswer(session_id, sender_id, answer);
237 let response = serde_json::to_string(&response)?;
238 let connections_reader = connections.read().await;
239 if let Some(recipient_tx) = connections_reader.get(&recipient_id) {
240 recipient_tx.send(Message::Text(response))?;
241 } else {
242 warn!("tried to send offer to non existing user");
243 }
244 }
245 SignalMessage::IceCandidate(session_id, recipient_id, candidate) => {
246 let response = SignalMessage::IceCandidate(session_id, sender_id, candidate);
247 let response = serde_json::to_string(&response)?;
248 let connections_reader = connections.read().await;
249 let recipient_tx = connections_reader.get(&recipient_id).ok_or_else(|| anyhow!("no sender for given id"))?;
250
251 recipient_tx.send(Message::Text(response))?;
252 }
253 SignalMessage::KeepAlive(user_id, status) => {
254 if status.is_host.is_some() {
255 warn!("Received ping from user {:?}", status.session_id);
256 pings.lock().unwrap().insert(user_id, Arc::new(Ping { online: true, session_id: status.session_id, metadata: status.metadata }));
257 }
258 }
259 _ => {}
260 }
261 }
262 Err(error) => {
263 error!("An error occurred: {:?} {:?}", error, msg);
264 }
265 }
266 }
267 Ok(())
268}
269
270async fn user_disconnected(user_id: UserId, connections: &Connections, sessions: &Sessions) {
271 connections.write().await.remove(&user_id);
272
273 let mut session_to_delete = None;
274 for (session_id, session) in sessions.write().await.iter_mut() {
275 if session.host == Some(user_id) {
276 session.host = None;
277 } else if session.users.contains(&user_id) {
278 session.users.remove(&user_id);
279 }
280 if session.host.is_none() && session.users.is_empty() {
281 session_to_delete = Some(session_id.clone());
282 break;
283 }
284 }
285 if let Some(session_id) = session_to_delete {
287 sessions.write().await.remove(&session_id);
288 }
289}