1use std::{sync::Arc, task::Poll};
5
6use anyhow::Context;
7pub use datachannel::{
8 ConnectionState, DataChannelInit, IceCandidate, Reliability, RtcConfig, SessionDescription,
9};
10use datachannel::{
11 DataChannelHandler, DataChannelInfo, PeerConnectionHandler, RtcDataChannel, RtcPeerConnection,
12};
13use futures::{
14 channel::mpsc,
15 io::{AsyncRead, AsyncWrite},
16 StreamExt,
17};
18use parking_lot::Mutex;
19#[cfg(feature = "derive")]
20use serde::{Deserialize, Serialize};
21use tokio::task::JoinHandle;
22use tracing::{debug, error};
23
24#[derive(Debug)]
25#[cfg_attr(feature = "derive", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "derive", serde(untagged))]
27#[allow(clippy::large_enum_variant)]
29pub enum Message {
30 RemoteDescription(SessionDescription),
31 RemoteCandidate(IceCandidate),
32}
33
34struct DataChannel {
35 tx_ready: mpsc::Sender<anyhow::Result<()>>,
36 tx_inbound: mpsc::Sender<anyhow::Result<Vec<u8>>>,
37}
38#[allow(clippy::type_complexity)]
39impl DataChannel {
40 fn new() -> (
41 mpsc::Receiver<anyhow::Result<()>>,
42 mpsc::Receiver<anyhow::Result<Vec<u8>>>,
43 Self,
44 ) {
45 let (tx_ready, rx_ready) = mpsc::channel(1);
46 let (tx_inbound, rx_inbound) = mpsc::channel(128);
47 (
48 rx_ready,
49 rx_inbound,
50 Self {
51 tx_ready,
52 tx_inbound,
53 },
54 )
55 }
56}
57
58impl DataChannelHandler for DataChannel {
59 fn on_open(&mut self) {
60 debug!("on_open");
61 let _ = self.tx_ready.try_send(Ok(()));
63 }
64
65 fn on_closed(&mut self) {
66 debug!("on_closed");
67 let _ = self.tx_inbound.try_send(Err(anyhow::anyhow!("Closed")));
68 }
69
70 fn on_error(&mut self, err: &str) {
71 let _ = self
72 .tx_ready
73 .try_send(Err(anyhow::anyhow!(err.to_string())));
74 let _ = self
75 .tx_inbound
76 .try_send(Err(anyhow::anyhow!(err.to_string())));
77 }
78
79 fn on_message(&mut self, msg: &[u8]) {
80 let s = String::from_utf8_lossy(msg);
81 debug!("on_message {}", s);
82 let _ = self.tx_inbound.try_send(Ok(msg.to_vec()));
83 }
84
85 fn on_buffered_amount_low(&mut self) {}
87
88 fn on_available(&mut self) {
89 debug!("on_available");
90 }
91}
92
93pub struct DataStream {
95 inner: Box<RtcDataChannel<DataChannel>>,
97 rx_inbound: mpsc::Receiver<anyhow::Result<Vec<u8>>>,
99 buf_inbound: Vec<u8>,
101 peer_con: Option<Arc<Mutex<Box<RtcPeerConnection<ConnInternal>>>>>,
103 handle: Option<JoinHandle<()>>,
105}
106
107impl DataStream {
108 pub fn buffered_amount(&self) -> usize {
109 self.inner.buffered_amount()
110 }
111
112 pub fn available_amount(&self) -> usize {
113 self.inner.available_amount()
114 }
115}
116
117impl AsyncRead for DataStream {
118 fn poll_read(
119 mut self: std::pin::Pin<&mut Self>,
120 cx: &mut std::task::Context<'_>,
121 buf: &mut [u8],
122 ) -> std::task::Poll<std::io::Result<usize>> {
123 if !self.buf_inbound.is_empty() {
124 let space = buf.len();
125 if self.buf_inbound.len() <= space {
126 let len = self.buf_inbound.len();
127 buf[..len].copy_from_slice(&self.buf_inbound[..]);
128 self.buf_inbound.drain(..);
129 Poll::Ready(Ok(len))
130 } else {
131 buf.copy_from_slice(&self.buf_inbound[..space]);
132 self.buf_inbound.drain(..space);
133 Poll::Ready(Ok(space))
134 }
135 } else {
136 match self.as_mut().rx_inbound.poll_next_unpin(cx) {
137 std::task::Poll::Ready(Some(Ok(x))) => {
138 let space = buf.len();
139 if x.len() <= space {
140 buf[..x.len()].copy_from_slice(&x[..]);
141 Poll::Ready(Ok(x.len()))
142 } else {
143 buf.copy_from_slice(&x[..space]);
144 self.buf_inbound.extend_from_slice(&x[space..]);
145 Poll::Ready(Ok(space))
146 }
147 }
148 std::task::Poll::Ready(Some(Err(e))) => Poll::Ready(Err(std::io::Error::new(
149 std::io::ErrorKind::Other,
150 e.to_string(),
151 ))),
152 std::task::Poll::Ready(None) => Poll::Ready(Ok(0)),
153 Poll::Pending => Poll::Pending,
154 }
155 }
156 }
157}
158
159impl AsyncWrite for DataStream {
160 fn poll_write(
161 mut self: std::pin::Pin<&mut Self>,
162 _cx: &mut std::task::Context<'_>,
163 buf: &[u8],
164 ) -> std::task::Poll<Result<usize, std::io::Error>> {
165 if let Err(e) = self.as_mut().inner.send(buf) {
167 Poll::Ready(Err(std::io::Error::new(
168 std::io::ErrorKind::Other,
169 e.to_string(),
170 )))
171 } else {
172 Poll::Ready(Ok(buf.len()))
173 }
174 }
175
176 fn poll_flush(
177 self: std::pin::Pin<&mut Self>,
178 _cx: &mut std::task::Context<'_>,
179 ) -> std::task::Poll<Result<(), std::io::Error>> {
180 Poll::Ready(Ok(()))
181 }
182
183 fn poll_close(
184 self: std::pin::Pin<&mut Self>,
185 _cx: &mut std::task::Context<'_>,
186 ) -> std::task::Poll<Result<(), std::io::Error>> {
187 Poll::Ready(Ok(()))
188 }
189}
190
191pub struct PeerConnection {
192 peer_con: Arc<Mutex<Box<RtcPeerConnection<ConnInternal>>>>,
193 rx_incoming: mpsc::Receiver<DataStream>,
194 handle: JoinHandle<()>,
195}
196
197impl PeerConnection {
198 pub fn new(
201 config: &RtcConfig,
202 (sig_tx, mut sig_rx): (mpsc::Sender<Message>, mpsc::Receiver<Message>),
203 ) -> anyhow::Result<Self> {
204 let (tx_incoming, rx_incoming) = mpsc::channel(8);
205 let peer_con = Arc::new(Mutex::new(RtcPeerConnection::new(
206 config,
207 ConnInternal {
208 tx_signal: sig_tx,
209 tx_incoming,
210 pending: None,
211 },
212 )?));
213 let pc = peer_con.clone();
214 let handle = tokio::spawn(async move {
215 while let Some(m) = sig_rx.next().await {
216 if let Err(err) = match m {
217 Message::RemoteDescription(i) => pc.lock().set_remote_description(&i),
218 Message::RemoteCandidate(i) => pc.lock().add_remote_candidate(&i),
219 } {
220 error!(?err, "Error interacting with RtcPeerConnection");
221 }
222 }
223 });
224 Ok(Self {
225 peer_con,
226 rx_incoming,
227 handle,
228 })
229 }
230
231 pub async fn accept(mut self) -> anyhow::Result<DataStream> {
233 let mut s = self.rx_incoming.next().await.context("Tx dropped")?;
234 s.handle = Some(self.handle);
235 s.peer_con = Some(self.peer_con);
236 Ok(s)
237 }
238
239 pub async fn dial(self, label: &str) -> anyhow::Result<DataStream> {
241 let (mut ready, rx_inbound, chan) = DataChannel::new();
242 let dc = self.peer_con.lock().create_data_channel(label, chan)?;
243 ready.next().await.context("Tx dropped")??;
244 Ok(DataStream {
245 inner: dc,
246 rx_inbound,
247 buf_inbound: vec![],
248 handle: Some(self.handle),
249 peer_con: Some(self.peer_con),
250 })
251 }
252
253 pub async fn dial_ex(
255 self,
256 label: &str,
257 dc_init: &DataChannelInit,
258 ) -> anyhow::Result<DataStream> {
259 let (mut ready, rx_inbound, chan) = DataChannel::new();
260 let dc = self
261 .peer_con
262 .lock()
263 .create_data_channel_ex(label, chan, dc_init)?;
264 ready.next().await.context("Tx dropped")??;
265 Ok(DataStream {
266 inner: dc,
267 rx_inbound,
268 buf_inbound: vec![],
269 handle: Some(self.handle),
270 peer_con: Some(self.peer_con),
271 })
272 }
273}
274
275struct ConnInternal {
276 tx_incoming: mpsc::Sender<DataStream>,
277 tx_signal: mpsc::Sender<Message>,
278 pending: Option<mpsc::Receiver<anyhow::Result<Vec<u8>>>>,
279}
280
281impl PeerConnectionHandler for ConnInternal {
282 type DCH = DataChannel;
283
284 fn data_channel_handler(&mut self, _info: DataChannelInfo) -> Self::DCH {
285 let (_, rx, dc) = DataChannel::new();
286 self.pending.replace(rx);
287 dc
288 }
289
290 fn on_description(&mut self, sess_desc: SessionDescription) {
291 let _ = self
292 .tx_signal
293 .try_send(Message::RemoteDescription(sess_desc));
294 }
295
296 fn on_candidate(&mut self, cand: IceCandidate) {
297 let _ = self.tx_signal.try_send(Message::RemoteCandidate(cand));
298 }
299
300 fn on_connection_state_change(&mut self, _state: datachannel::ConnectionState) {
301 }
303
304 fn on_data_channel(&mut self, data_channel: Box<RtcDataChannel<Self::DCH>>) {
305 debug!("new incoming data channel");
306 let _ = self.tx_incoming.try_send(DataStream {
307 inner: data_channel,
308 rx_inbound: self
309 .pending
310 .take()
311 .expect("`data_channel_handler` was just called synchronously in the same thread"),
312 buf_inbound: vec![],
313 handle: None,
314 peer_con: Default::default(),
315 });
316 }
317}