async_datachannel/
lib.rs

1///! Async wrapper for the [`datachannel-rs`] crate.
2///!
3///! [`datachannel-rs`]: https://crates.io/crates/datachannel
4use std::{sync::Arc, task::Poll};
5
6use anyhow::Context;
7pub use datachannel::{
8    ConnectionState, DataChannelInit, IceCandidate, Reliability, RtcConfig, SessionDescription,
9};
10use datachannel::{
11    DataChannelHandler, DataChannelInfo, PeerConnectionHandler, RtcDataChannel, RtcPeerConnection,
12};
13use futures::{
14    channel::mpsc,
15    io::{AsyncRead, AsyncWrite},
16    StreamExt,
17};
18use parking_lot::Mutex;
19#[cfg(feature = "derive")]
20use serde::{Deserialize, Serialize};
21use tokio::task::JoinHandle;
22use tracing::{debug, error};
23
24#[derive(Debug)]
25#[cfg_attr(feature = "derive", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "derive", serde(untagged))]
27/// Messages to be used for external signalling.
28#[allow(clippy::large_enum_variant)]
29pub enum Message {
30    RemoteDescription(SessionDescription),
31    RemoteCandidate(IceCandidate),
32}
33
34struct DataChannel {
35    tx_ready: mpsc::Sender<anyhow::Result<()>>,
36    tx_inbound: mpsc::Sender<anyhow::Result<Vec<u8>>>,
37}
38#[allow(clippy::type_complexity)]
39impl DataChannel {
40    fn new() -> (
41        mpsc::Receiver<anyhow::Result<()>>,
42        mpsc::Receiver<anyhow::Result<Vec<u8>>>,
43        Self,
44    ) {
45        let (tx_ready, rx_ready) = mpsc::channel(1);
46        let (tx_inbound, rx_inbound) = mpsc::channel(128);
47        (
48            rx_ready,
49            rx_inbound,
50            Self {
51                tx_ready,
52                tx_inbound,
53            },
54        )
55    }
56}
57
58impl DataChannelHandler for DataChannel {
59    fn on_open(&mut self) {
60        debug!("on_open");
61        // Signal open
62        let _ = self.tx_ready.try_send(Ok(()));
63    }
64
65    fn on_closed(&mut self) {
66        debug!("on_closed");
67        let _ = self.tx_inbound.try_send(Err(anyhow::anyhow!("Closed")));
68    }
69
70    fn on_error(&mut self, err: &str) {
71        let _ = self
72            .tx_ready
73            .try_send(Err(anyhow::anyhow!(err.to_string())));
74        let _ = self
75            .tx_inbound
76            .try_send(Err(anyhow::anyhow!(err.to_string())));
77    }
78
79    fn on_message(&mut self, msg: &[u8]) {
80        let s = String::from_utf8_lossy(msg);
81        debug!("on_message {}", s);
82        let _ = self.tx_inbound.try_send(Ok(msg.to_vec()));
83    }
84
85    // TODO?
86    fn on_buffered_amount_low(&mut self) {}
87
88    fn on_available(&mut self) {
89        debug!("on_available");
90    }
91}
92
93/// The opened data channel. This struct implements both [`AsyncRead`] and [`AsyncWrite`].
94pub struct DataStream {
95    /// The actual data channel
96    inner: Box<RtcDataChannel<DataChannel>>,
97    /// Receiver for inbound bytes from the data channel
98    rx_inbound: mpsc::Receiver<anyhow::Result<Vec<u8>>>,
99    /// Intermediate buffer of inbound bytes, to be polled by `poll_read`
100    buf_inbound: Vec<u8>,
101    /// Reference to the PeerConnection to keep around
102    peer_con: Option<Arc<Mutex<Box<RtcPeerConnection<ConnInternal>>>>>,
103    /// Reference to the outbound piper
104    handle: Option<JoinHandle<()>>,
105}
106
107impl DataStream {
108    pub fn buffered_amount(&self) -> usize {
109        self.inner.buffered_amount()
110    }
111
112    pub fn available_amount(&self) -> usize {
113        self.inner.available_amount()
114    }
115}
116
117impl AsyncRead for DataStream {
118    fn poll_read(
119        mut self: std::pin::Pin<&mut Self>,
120        cx: &mut std::task::Context<'_>,
121        buf: &mut [u8],
122    ) -> std::task::Poll<std::io::Result<usize>> {
123        if !self.buf_inbound.is_empty() {
124            let space = buf.len();
125            if self.buf_inbound.len() <= space {
126                let len = self.buf_inbound.len();
127                buf[..len].copy_from_slice(&self.buf_inbound[..]);
128                self.buf_inbound.drain(..);
129                Poll::Ready(Ok(len))
130            } else {
131                buf.copy_from_slice(&self.buf_inbound[..space]);
132                self.buf_inbound.drain(..space);
133                Poll::Ready(Ok(space))
134            }
135        } else {
136            match self.as_mut().rx_inbound.poll_next_unpin(cx) {
137                std::task::Poll::Ready(Some(Ok(x))) => {
138                    let space = buf.len();
139                    if x.len() <= space {
140                        buf[..x.len()].copy_from_slice(&x[..]);
141                        Poll::Ready(Ok(x.len()))
142                    } else {
143                        buf.copy_from_slice(&x[..space]);
144                        self.buf_inbound.extend_from_slice(&x[space..]);
145                        Poll::Ready(Ok(space))
146                    }
147                }
148                std::task::Poll::Ready(Some(Err(e))) => Poll::Ready(Err(std::io::Error::new(
149                    std::io::ErrorKind::Other,
150                    e.to_string(),
151                ))),
152                std::task::Poll::Ready(None) => Poll::Ready(Ok(0)),
153                Poll::Pending => Poll::Pending,
154            }
155        }
156    }
157}
158
159impl AsyncWrite for DataStream {
160    fn poll_write(
161        mut self: std::pin::Pin<&mut Self>,
162        _cx: &mut std::task::Context<'_>,
163        buf: &[u8],
164    ) -> std::task::Poll<Result<usize, std::io::Error>> {
165        // TODO: Maybe query the underlying buffer to signal backpressure
166        if let Err(e) = self.as_mut().inner.send(buf) {
167            Poll::Ready(Err(std::io::Error::new(
168                std::io::ErrorKind::Other,
169                e.to_string(),
170            )))
171        } else {
172            Poll::Ready(Ok(buf.len()))
173        }
174    }
175
176    fn poll_flush(
177        self: std::pin::Pin<&mut Self>,
178        _cx: &mut std::task::Context<'_>,
179    ) -> std::task::Poll<Result<(), std::io::Error>> {
180        Poll::Ready(Ok(()))
181    }
182
183    fn poll_close(
184        self: std::pin::Pin<&mut Self>,
185        _cx: &mut std::task::Context<'_>,
186    ) -> std::task::Poll<Result<(), std::io::Error>> {
187        Poll::Ready(Ok(()))
188    }
189}
190
191pub struct PeerConnection {
192    peer_con: Arc<Mutex<Box<RtcPeerConnection<ConnInternal>>>>,
193    rx_incoming: mpsc::Receiver<DataStream>,
194    handle: JoinHandle<()>,
195}
196
197impl PeerConnection {
198    /// Create a new [`PeerConnection`] to be used for either dialing or accepting an inbound
199    /// connection. The channel tuple is used to interface with an external signalling system.
200    pub fn new(
201        config: &RtcConfig,
202        (sig_tx, mut sig_rx): (mpsc::Sender<Message>, mpsc::Receiver<Message>),
203    ) -> anyhow::Result<Self> {
204        let (tx_incoming, rx_incoming) = mpsc::channel(8);
205        let peer_con = Arc::new(Mutex::new(RtcPeerConnection::new(
206            config,
207            ConnInternal {
208                tx_signal: sig_tx,
209                tx_incoming,
210                pending: None,
211            },
212        )?));
213        let pc = peer_con.clone();
214        let handle = tokio::spawn(async move {
215            while let Some(m) = sig_rx.next().await {
216                if let Err(err) = match m {
217                    Message::RemoteDescription(i) => pc.lock().set_remote_description(&i),
218                    Message::RemoteCandidate(i) => pc.lock().add_remote_candidate(&i),
219                } {
220                    error!(?err, "Error interacting with RtcPeerConnection");
221                }
222            }
223        });
224        Ok(Self {
225            peer_con,
226            rx_incoming,
227            handle,
228        })
229    }
230
231    /// Wait for an inbound connection.
232    pub async fn accept(mut self) -> anyhow::Result<DataStream> {
233        let mut s = self.rx_incoming.next().await.context("Tx dropped")?;
234        s.handle = Some(self.handle);
235        s.peer_con = Some(self.peer_con);
236        Ok(s)
237    }
238
239    /// Initiate an outbound dialing.
240    pub async fn dial(self, label: &str) -> anyhow::Result<DataStream> {
241        let (mut ready, rx_inbound, chan) = DataChannel::new();
242        let dc = self.peer_con.lock().create_data_channel(label, chan)?;
243        ready.next().await.context("Tx dropped")??;
244        Ok(DataStream {
245            inner: dc,
246            rx_inbound,
247            buf_inbound: vec![],
248            handle: Some(self.handle),
249            peer_con: Some(self.peer_con),
250        })
251    }
252
253    /// Initiate an outbound dialing with extra options.
254    pub async fn dial_ex(
255        self,
256        label: &str,
257        dc_init: &DataChannelInit,
258    ) -> anyhow::Result<DataStream> {
259        let (mut ready, rx_inbound, chan) = DataChannel::new();
260        let dc = self
261            .peer_con
262            .lock()
263            .create_data_channel_ex(label, chan, dc_init)?;
264        ready.next().await.context("Tx dropped")??;
265        Ok(DataStream {
266            inner: dc,
267            rx_inbound,
268            buf_inbound: vec![],
269            handle: Some(self.handle),
270            peer_con: Some(self.peer_con),
271        })
272    }
273}
274
275struct ConnInternal {
276    tx_incoming: mpsc::Sender<DataStream>,
277    tx_signal: mpsc::Sender<Message>,
278    pending: Option<mpsc::Receiver<anyhow::Result<Vec<u8>>>>,
279}
280
281impl PeerConnectionHandler for ConnInternal {
282    type DCH = DataChannel;
283
284    fn data_channel_handler(&mut self, _info: DataChannelInfo) -> Self::DCH {
285        let (_, rx, dc) = DataChannel::new();
286        self.pending.replace(rx);
287        dc
288    }
289
290    fn on_description(&mut self, sess_desc: SessionDescription) {
291        let _ = self
292            .tx_signal
293            .try_send(Message::RemoteDescription(sess_desc));
294    }
295
296    fn on_candidate(&mut self, cand: IceCandidate) {
297        let _ = self.tx_signal.try_send(Message::RemoteCandidate(cand));
298    }
299
300    fn on_connection_state_change(&mut self, _state: datachannel::ConnectionState) {
301        // TODO
302    }
303
304    fn on_data_channel(&mut self, data_channel: Box<RtcDataChannel<Self::DCH>>) {
305        debug!("new incoming data channel");
306        let _ = self.tx_incoming.try_send(DataStream {
307            inner: data_channel,
308            rx_inbound: self
309                .pending
310                .take()
311                .expect("`data_channel_handler` was just called synchronously in the same thread"),
312            buf_inbound: vec![],
313            handle: None,
314            peer_con: Default::default(),
315        });
316    }
317}