Skip to main content

gst_plugin_webrtc_signalling/server/
mod.rs

1// SPDX-License-Identifier: MPL-2.0
2
3pub use async_tungstenite;
4
5use anyhow::Error;
6use async_tungstenite::tungstenite::{
7    Message as WsMessage, Utf8Bytes,
8    handshake::server::{Callback, NoCallback},
9};
10use futures::channel::mpsc;
11use futures::prelude::*;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex};
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio::task;
18use tracing::{debug, error, info, instrument, trace, warn};
19
20struct Peer {
21    receive_task_handle: task::JoinHandle<()>,
22    send_task_handle: task::JoinHandle<Result<(), Error>>,
23    sender: mpsc::Sender<String>,
24}
25
26struct State {
27    tx: Option<mpsc::Sender<(String, Option<Utf8Bytes>)>>,
28    peers: HashMap<String, Peer>,
29}
30
31#[derive(Clone)]
32pub struct Server {
33    state: Arc<Mutex<State>>,
34}
35
36#[derive(thiserror::Error, Debug)]
37pub enum ServerError {
38    #[error("error during handshake {0}")]
39    Handshake(#[from] async_tungstenite::tungstenite::Error),
40    #[error("error during TLS handshake {0}")]
41    TLSHandshake(#[from] std::io::Error),
42    #[error("timeout during TLS handshake {0}")]
43    TLSHandshakeTimeout(#[from] tokio::time::error::Elapsed),
44}
45
46impl Server {
47    #[instrument(level = "debug", skip(factory))]
48    pub fn spawn<
49        I: for<'a> Deserialize<'a>,
50        O: Serialize + std::fmt::Debug + Send + Sync,
51        Factory: FnOnce(Pin<Box<dyn Stream<Item = (String, Option<I>)> + Send>>) -> St,
52        St: Stream<Item = (String, O)> + Send + Unpin + 'static,
53    >(
54        factory: Factory,
55    ) -> Self {
56        let (tx, rx) = mpsc::channel::<(String, Option<Utf8Bytes>)>(1000);
57        let mut handler = factory(Box::pin(rx.filter_map(|(peer_id, msg)| async move {
58            if let Some(msg) = msg {
59                match serde_json::from_str::<I>(&msg) {
60                    Ok(msg) => Some((peer_id, Some(msg))),
61                    Err(err) => {
62                        warn!("Failed to parse incoming message: {} ({})", err, msg);
63                        None
64                    }
65                }
66            } else {
67                Some((peer_id, None))
68            }
69        })));
70
71        let state = Arc::new(Mutex::new(State {
72            tx: Some(tx),
73            peers: HashMap::new(),
74        }));
75
76        let state_clone = state.clone();
77        task::spawn(async move {
78            while let Some((peer_id, msg)) = handler.next().await {
79                match serde_json::to_string(&msg) {
80                    Ok(msg_str) => {
81                        let sender = {
82                            let mut state = state_clone.lock().unwrap();
83                            if let Some(peer) = state.peers.get_mut(&peer_id) {
84                                Some(peer.sender.clone())
85                            } else {
86                                None
87                            }
88                        };
89
90                        if let Some(mut sender) = sender {
91                            trace!("Sending {}", msg_str);
92                            let _ = sender.send(msg_str).await;
93                        }
94                    }
95                    Err(err) => {
96                        warn!("Failed to serialize outgoing message: {}", err);
97                    }
98                }
99            }
100        });
101
102        Self { state }
103    }
104
105    #[instrument(level = "debug", skip(state))]
106    fn remove_peer(state: Arc<Mutex<State>>, peer_id: &str) {
107        if let Some(mut peer) = state.lock().unwrap().peers.remove(peer_id) {
108            let peer_id = peer_id.to_string();
109            task::spawn(async move {
110                peer.sender.close_channel();
111                if let Err(err) = peer.send_task_handle.await {
112                    trace!(peer_id = %peer_id, "Error while joining send task: {}", err);
113                }
114
115                if let Err(err) = peer.receive_task_handle.await {
116                    trace!(peer_id = %peer_id, "Error while joining receive task: {}", err);
117                }
118            });
119        }
120    }
121
122    #[instrument(level = "debug", skip(self, stream, callback))]
123    pub async fn accept_hdr_async<
124        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
125        C: Callback + Unpin,
126    >(
127        &mut self,
128        stream: S,
129        callback: C,
130    ) -> Result<String, ServerError> {
131        let ws = match async_tungstenite::tokio::accept_hdr_async(stream, callback).await {
132            Ok(ws) => ws,
133            Err(err) => {
134                warn!("Error during the websocket handshake: {}", err);
135                return Err(ServerError::Handshake(err));
136            }
137        };
138
139        let this_id = uuid::Uuid::new_v4().to_string();
140        info!(this_id = %this_id, "New WebSocket connection");
141
142        // 1000 is completely arbitrary, we simply don't want infinite piling
143        // up of messages as with unbounded
144        let (websocket_sender, mut websocket_receiver) = mpsc::channel::<String>(1000);
145
146        let this_id_clone = this_id.clone();
147        let (mut ws_sink, mut ws_stream) = ws.split();
148        let send_task_handle = task::spawn(async move {
149            let mut res = Ok(());
150            loop {
151                match tokio::time::timeout(
152                    std::time::Duration::from_secs(30),
153                    websocket_receiver.next(),
154                )
155                .await
156                {
157                    Ok(Some(msg)) => {
158                        trace!(this_id = %this_id_clone, "sending {}", msg);
159                        res = ws_sink.send(WsMessage::text(msg)).await;
160                    }
161                    Ok(None) => {
162                        break;
163                    }
164                    Err(_) => {
165                        trace!(this_id = %this_id_clone, "timeout, sending ping");
166                        res = ws_sink.send(WsMessage::Ping(Default::default())).await;
167                    }
168                }
169
170                if let Err(ref err) = res {
171                    error!(this_id = %this_id_clone, "Quitting send loop: {err}");
172                    break;
173                }
174            }
175
176            debug!(this_id = %this_id_clone, "Done sending");
177
178            let _ = ws_sink.close(None).await;
179
180            res.map_err(Into::into)
181        });
182
183        let mut tx = self.state.lock().unwrap().tx.clone();
184        let this_id_clone = this_id.clone();
185        let state_clone = self.state.clone();
186        let receive_task_handle = task::spawn(async move {
187            if let Some(tx) = tx.as_mut()
188                && let Err(err) = tx
189                    .send((
190                        this_id_clone.clone(),
191                        Some(
192                            serde_json::json!({
193                                "type": "newPeer",
194                            })
195                            .to_string()
196                            .into(),
197                        ),
198                    ))
199                    .await
200            {
201                warn!(this = %this_id_clone, "Error handling message: {:?}", err);
202            }
203            while let Some(msg) = ws_stream.next().await {
204                info!("Received message {msg:?}");
205                match msg {
206                    Ok(WsMessage::Text(msg)) => {
207                        if let Some(tx) = tx.as_mut()
208                            && let Err(err) = tx.send((this_id_clone.clone(), Some(msg))).await
209                        {
210                            warn!(this = %this_id_clone, "Error handling message: {:?}", err);
211                        }
212                    }
213                    Ok(WsMessage::Close(reason)) => {
214                        info!(this_id = %this_id_clone, "connection closed: {:?}", reason);
215                        break;
216                    }
217                    Ok(WsMessage::Pong(_)) => {
218                        continue;
219                    }
220                    Ok(_) => warn!(this_id = %this_id_clone, "Unsupported message type"),
221                    Err(err) => {
222                        warn!(this_id = %this_id_clone, "recv error: {}", err);
223                        break;
224                    }
225                }
226            }
227
228            if let Some(tx) = tx.as_mut() {
229                let _ = tx.send((this_id_clone.clone(), None)).await;
230            }
231
232            Self::remove_peer(state_clone, &this_id_clone);
233        });
234
235        self.state.lock().unwrap().peers.insert(
236            this_id.clone(),
237            Peer {
238                receive_task_handle,
239                send_task_handle,
240                sender: websocket_sender,
241            },
242        );
243
244        Ok(this_id)
245    }
246
247    #[instrument(level = "debug", skip(self, stream))]
248    pub async fn accept_async<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
249        &mut self,
250        stream: S,
251    ) -> Result<String, ServerError> {
252        self.accept_hdr_async(stream, NoCallback).await
253    }
254}