gst_plugin_webrtc_signalling/server/
mod.rs1pub use async_tungstenite;
4
5use anyhow::Error;
6use async_tungstenite::tungstenite::{
7 Message as WsMessage, Utf8Bytes,
8 handshake::server::{Callback, NoCallback},
9};
10use futures::channel::mpsc;
11use futures::prelude::*;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex};
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio::task;
18use tracing::{debug, error, info, instrument, trace, warn};
19
20struct Peer {
21 receive_task_handle: task::JoinHandle<()>,
22 send_task_handle: task::JoinHandle<Result<(), Error>>,
23 sender: mpsc::Sender<String>,
24}
25
26struct State {
27 tx: Option<mpsc::Sender<(String, Option<Utf8Bytes>)>>,
28 peers: HashMap<String, Peer>,
29}
30
31#[derive(Clone)]
32pub struct Server {
33 state: Arc<Mutex<State>>,
34}
35
36#[derive(thiserror::Error, Debug)]
37pub enum ServerError {
38 #[error("error during handshake {0}")]
39 Handshake(#[from] async_tungstenite::tungstenite::Error),
40 #[error("error during TLS handshake {0}")]
41 TLSHandshake(#[from] std::io::Error),
42 #[error("timeout during TLS handshake {0}")]
43 TLSHandshakeTimeout(#[from] tokio::time::error::Elapsed),
44}
45
46impl Server {
47 #[instrument(level = "debug", skip(factory))]
48 pub fn spawn<
49 I: for<'a> Deserialize<'a>,
50 O: Serialize + std::fmt::Debug + Send + Sync,
51 Factory: FnOnce(Pin<Box<dyn Stream<Item = (String, Option<I>)> + Send>>) -> St,
52 St: Stream<Item = (String, O)> + Send + Unpin + 'static,
53 >(
54 factory: Factory,
55 ) -> Self {
56 let (tx, rx) = mpsc::channel::<(String, Option<Utf8Bytes>)>(1000);
57 let mut handler = factory(Box::pin(rx.filter_map(|(peer_id, msg)| async move {
58 if let Some(msg) = msg {
59 match serde_json::from_str::<I>(&msg) {
60 Ok(msg) => Some((peer_id, Some(msg))),
61 Err(err) => {
62 warn!("Failed to parse incoming message: {} ({})", err, msg);
63 None
64 }
65 }
66 } else {
67 Some((peer_id, None))
68 }
69 })));
70
71 let state = Arc::new(Mutex::new(State {
72 tx: Some(tx),
73 peers: HashMap::new(),
74 }));
75
76 let state_clone = state.clone();
77 task::spawn(async move {
78 while let Some((peer_id, msg)) = handler.next().await {
79 match serde_json::to_string(&msg) {
80 Ok(msg_str) => {
81 let sender = {
82 let mut state = state_clone.lock().unwrap();
83 if let Some(peer) = state.peers.get_mut(&peer_id) {
84 Some(peer.sender.clone())
85 } else {
86 None
87 }
88 };
89
90 if let Some(mut sender) = sender {
91 trace!("Sending {}", msg_str);
92 let _ = sender.send(msg_str).await;
93 }
94 }
95 Err(err) => {
96 warn!("Failed to serialize outgoing message: {}", err);
97 }
98 }
99 }
100 });
101
102 Self { state }
103 }
104
105 #[instrument(level = "debug", skip(state))]
106 fn remove_peer(state: Arc<Mutex<State>>, peer_id: &str) {
107 if let Some(mut peer) = state.lock().unwrap().peers.remove(peer_id) {
108 let peer_id = peer_id.to_string();
109 task::spawn(async move {
110 peer.sender.close_channel();
111 if let Err(err) = peer.send_task_handle.await {
112 trace!(peer_id = %peer_id, "Error while joining send task: {}", err);
113 }
114
115 if let Err(err) = peer.receive_task_handle.await {
116 trace!(peer_id = %peer_id, "Error while joining receive task: {}", err);
117 }
118 });
119 }
120 }
121
122 #[instrument(level = "debug", skip(self, stream, callback))]
123 pub async fn accept_hdr_async<
124 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
125 C: Callback + Unpin,
126 >(
127 &mut self,
128 stream: S,
129 callback: C,
130 ) -> Result<String, ServerError> {
131 let ws = match async_tungstenite::tokio::accept_hdr_async(stream, callback).await {
132 Ok(ws) => ws,
133 Err(err) => {
134 warn!("Error during the websocket handshake: {}", err);
135 return Err(ServerError::Handshake(err));
136 }
137 };
138
139 let this_id = uuid::Uuid::new_v4().to_string();
140 info!(this_id = %this_id, "New WebSocket connection");
141
142 let (websocket_sender, mut websocket_receiver) = mpsc::channel::<String>(1000);
145
146 let this_id_clone = this_id.clone();
147 let (mut ws_sink, mut ws_stream) = ws.split();
148 let send_task_handle = task::spawn(async move {
149 let mut res = Ok(());
150 loop {
151 match tokio::time::timeout(
152 std::time::Duration::from_secs(30),
153 websocket_receiver.next(),
154 )
155 .await
156 {
157 Ok(Some(msg)) => {
158 trace!(this_id = %this_id_clone, "sending {}", msg);
159 res = ws_sink.send(WsMessage::text(msg)).await;
160 }
161 Ok(None) => {
162 break;
163 }
164 Err(_) => {
165 trace!(this_id = %this_id_clone, "timeout, sending ping");
166 res = ws_sink.send(WsMessage::Ping(Default::default())).await;
167 }
168 }
169
170 if let Err(ref err) = res {
171 error!(this_id = %this_id_clone, "Quitting send loop: {err}");
172 break;
173 }
174 }
175
176 debug!(this_id = %this_id_clone, "Done sending");
177
178 let _ = ws_sink.close(None).await;
179
180 res.map_err(Into::into)
181 });
182
183 let mut tx = self.state.lock().unwrap().tx.clone();
184 let this_id_clone = this_id.clone();
185 let state_clone = self.state.clone();
186 let receive_task_handle = task::spawn(async move {
187 if let Some(tx) = tx.as_mut()
188 && let Err(err) = tx
189 .send((
190 this_id_clone.clone(),
191 Some(
192 serde_json::json!({
193 "type": "newPeer",
194 })
195 .to_string()
196 .into(),
197 ),
198 ))
199 .await
200 {
201 warn!(this = %this_id_clone, "Error handling message: {:?}", err);
202 }
203 while let Some(msg) = ws_stream.next().await {
204 info!("Received message {msg:?}");
205 match msg {
206 Ok(WsMessage::Text(msg)) => {
207 if let Some(tx) = tx.as_mut()
208 && let Err(err) = tx.send((this_id_clone.clone(), Some(msg))).await
209 {
210 warn!(this = %this_id_clone, "Error handling message: {:?}", err);
211 }
212 }
213 Ok(WsMessage::Close(reason)) => {
214 info!(this_id = %this_id_clone, "connection closed: {:?}", reason);
215 break;
216 }
217 Ok(WsMessage::Pong(_)) => {
218 continue;
219 }
220 Ok(_) => warn!(this_id = %this_id_clone, "Unsupported message type"),
221 Err(err) => {
222 warn!(this_id = %this_id_clone, "recv error: {}", err);
223 break;
224 }
225 }
226 }
227
228 if let Some(tx) = tx.as_mut() {
229 let _ = tx.send((this_id_clone.clone(), None)).await;
230 }
231
232 Self::remove_peer(state_clone, &this_id_clone);
233 });
234
235 self.state.lock().unwrap().peers.insert(
236 this_id.clone(),
237 Peer {
238 receive_task_handle,
239 send_task_handle,
240 sender: websocket_sender,
241 },
242 );
243
244 Ok(this_id)
245 }
246
247 #[instrument(level = "debug", skip(self, stream))]
248 pub async fn accept_async<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
249 &mut self,
250 stream: S,
251 ) -> Result<String, ServerError> {
252 self.accept_hdr_async(stream, NoCallback).await
253 }
254}