webrtc_p2p/
peer.rs

1use std::{
2  future::Future,
3  pin::Pin,
4  sync::{
5    atomic::{AtomicBool, Ordering},
6    Arc,
7  },
8};
9
10use serde::{Deserialize, Serialize};
11use tokio::{runtime::Handle, sync::Mutex, task::block_in_place};
12use uuid::Uuid;
13use webrtc::{
14  api::API,
15  data_channel::{
16    data_channel_init::RTCDataChannelInit, data_channel_message::DataChannelMessage, RTCDataChannel,
17  },
18  ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit},
19  peer_connection::{
20    configuration::RTCConfiguration,
21    offer_answer_options::{RTCAnswerOptions, RTCOfferOptions},
22    peer_connection_state::RTCPeerConnectionState,
23    sdp::{sdp_type::RTCSdpType, session_description::RTCSessionDescription},
24    RTCPeerConnection,
25  },
26  rtp_transceiver::{rtp_receiver::RTCRtpReceiver, RTCRtpTransceiver},
27  track::track_remote::TrackRemote,
28};
29
30use atomicoption::AtomicOption;
31
32#[derive(Clone, Default)]
33pub struct PeerOptions {
34  pub id: Option<String>,
35  pub max_channel_message_size: Option<usize>,
36  pub data_channel_name: Option<String>,
37  pub event_channel_size: Option<usize>,
38  pub connection_config: Option<RTCConfiguration>,
39  pub offer_config: Option<RTCOfferOptions>,
40  pub answer_config: Option<RTCAnswerOptions>,
41  pub data_channel_config: Option<RTCDataChannelInit>,
42}
43
44#[derive(Debug, Clone, Deserialize, Serialize)]
45#[serde(tag = "type")]
46pub enum SignalMessage {
47  #[serde(rename = "renegotiate")]
48  Renegotiate,
49  #[serde(rename = "candidate")]
50  Candidate { candidate: RTCIceCandidateInit },
51  #[serde(untagged)]
52  SDP(RTCSessionDescription),
53}
54
55pub type OnSignal = Box<
56  dyn (FnMut(SignalMessage) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync,
57>;
58pub type OnData =
59  Box<dyn (FnMut(Vec<u8>) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
60pub type OnConnect =
61  Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
62pub type OnClose =
63  Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
64pub type OnNegotiated =
65  Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
66
67#[derive(Clone)]
68pub struct Peer {
69  inner: Arc<PeerInner>,
70}
71
72unsafe impl Send for Peer {}
73unsafe impl Sync for Peer {}
74
75pub struct PeerInner {
76  id: String,
77  api: Arc<API>,
78  initiator: Arc<AtomicBool>,
79  connection_config: RTCConfiguration,
80  connection: AtomicOption<RTCPeerConnection>,
81  offer_config: Option<RTCOfferOptions>,
82  answer_config: Option<RTCAnswerOptions>,
83  data_channel_name: String,
84  data_channel: AtomicOption<Arc<RTCDataChannel>>,
85  data_channel_config: Option<RTCDataChannelInit>,
86  pending_candidates: Mutex<Vec<RTCIceCandidateInit>>,
87  on_signal: Arc<Mutex<Option<OnSignal>>>,
88  on_data: Arc<Mutex<Option<OnData>>>,
89  on_connect: Arc<Mutex<Option<OnConnect>>>,
90  on_close: Arc<Mutex<Option<OnClose>>>,
91  on_negotiated: Arc<Mutex<Option<OnNegotiated>>>,
92}
93
94impl Peer {
95  pub fn new(api: Arc<API>, options: PeerOptions) -> Self {
96    Self {
97      inner: Arc::new(PeerInner {
98        id: options.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
99        api,
100        initiator: Arc::new(AtomicBool::new(false)),
101        data_channel_name: options
102          .data_channel_name
103          .unwrap_or_else(|| Uuid::new_v4().to_string()),
104        connection: AtomicOption::none(),
105        connection_config: options.connection_config.unwrap_or_default(),
106        offer_config: options.offer_config,
107        answer_config: options.answer_config,
108        data_channel_config: options.data_channel_config,
109        data_channel: AtomicOption::none(),
110        pending_candidates: Mutex::new(Vec::new()),
111        on_signal: Arc::new(Mutex::new(None)),
112        on_data: Arc::new(Mutex::new(None)),
113        on_connect: Arc::new(Mutex::new(None)),
114        on_close: Arc::new(Mutex::new(None)),
115        on_negotiated: Arc::new(Mutex::new(None)),
116      }),
117    }
118  }
119
120  pub fn get_id(&self) -> &str {
121    &self.inner.id
122  }
123
124  pub fn get_data_channel(&self) -> Option<Arc<RTCDataChannel>> {
125    self
126      .inner
127      .data_channel
128      .as_ref(Ordering::Relaxed)
129      .map(Clone::clone)
130  }
131
132  pub fn get_connection(&self) -> Option<&RTCPeerConnection> {
133    self.inner.connection.as_ref(Ordering::Relaxed)
134  }
135
136  pub async fn init(&self) -> Result<(), webrtc::Error> {
137    self.inner.initiator.store(true, Ordering::SeqCst);
138    self.create_peer().await
139  }
140
141  pub fn on_signal(&self, callback: OnSignal) {
142    block_in_place(|| {
143      Handle::current()
144        .block_on(self.inner.on_signal.lock())
145        .replace(callback)
146    });
147  }
148
149  pub fn on_data(&self, callback: OnData) {
150    block_in_place(|| {
151      Handle::current()
152        .block_on(self.inner.on_data.lock())
153        .replace(callback)
154    });
155  }
156
157  pub fn on_connect(&self, callback: OnConnect) {
158    block_in_place(|| {
159      Handle::current()
160        .block_on(self.inner.on_connect.lock())
161        .replace(callback)
162    });
163  }
164
165  pub fn on_close(&self, callback: OnClose) {
166    block_in_place(|| {
167      Handle::current()
168        .block_on(self.inner.on_close.lock())
169        .replace(callback)
170    });
171  }
172
173  pub fn on_negotiated(&self, callback: OnNegotiated) {
174    block_in_place(|| {
175      Handle::current()
176        .block_on(self.inner.on_negotiated.lock())
177        .replace(callback)
178    });
179  }
180
181  async fn create_peer(&self) -> Result<(), webrtc::Error> {
182    let api = self.inner.api.clone();
183    let connection =
184      self
185        .inner
186        .connection
187        .load_or_store_with(Ordering::SeqCst, Ordering::SeqCst, || {
188          block_in_place(|| {
189            Handle::current()
190              .block_on(api.new_peer_connection(self.inner.connection_config.clone()))
191              .expect("failed to create peer connection")
192          })
193        });
194
195    let on_negotiation_needed_peer = self.clone();
196    connection.on_negotiation_needed(Box::new(move || {
197      let pinned_peer = on_negotiation_needed_peer.clone();
198      Box::pin(async move {
199        pinned_peer.on_negotiation_needed().await;
200      })
201    }));
202    let on_peer_connection_state_change_peer = self.clone();
203    connection.on_peer_connection_state_change(Box::new(move |connection_state| {
204      let pinned_peer = on_peer_connection_state_change_peer.clone();
205      Box::pin(async move {
206        pinned_peer
207          .on_peer_connection_state_change(connection_state)
208          .await;
209      })
210    }));
211    let on_ice_candidate_peer = self.clone();
212    connection.on_ice_candidate(Box::new(move |candidate| {
213      let pinned_peer = on_ice_candidate_peer.clone();
214      Box::pin(async move {
215        pinned_peer.on_ice_candidate(candidate).await;
216      })
217    }));
218    let on_track_peer = self.clone();
219    connection.on_track(Box::new(move |track, receiver, transceiver| {
220      let pinned_peer = on_track_peer.clone();
221      Box::pin(async move {
222        pinned_peer.on_track(track, receiver, transceiver).await;
223      })
224    }));
225
226    if self.inner.initiator.load(Ordering::Relaxed) {
227      self.on_data_channel(
228        connection
229          .create_data_channel(
230            &self.inner.data_channel_name,
231            self.inner.data_channel_config.clone(),
232          )
233          .await?,
234      );
235    } else {
236      let peer = self.clone();
237      connection.on_data_channel(Box::new(move |data_channel| {
238        peer.on_data_channel(data_channel);
239        Box::pin(async move {})
240      }));
241    }
242
243    Ok(())
244  }
245
246  pub async fn close(&self) -> Result<(), webrtc::Error> {
247    self.internal_close(true).await
248  }
249
250  async fn on_negotiation_needed(&self) {
251    match self.negotiate().await {
252      Ok(_) => {}
253      Err(error) => {
254        eprintln!("error negotiating: {}", error)
255      }
256    }
257  }
258
259  async fn on_peer_connection_state_change(&self, connection_state: RTCPeerConnectionState) {
260    match connection_state {
261      RTCPeerConnectionState::Closed
262      | RTCPeerConnectionState::Failed
263      | RTCPeerConnectionState::Disconnected => match self.internal_close(true).await {
264        Ok(_) => {}
265        Err(error) => {
266          eprintln!("error peer connection change: {}", error)
267        }
268      },
269      _state => {}
270    }
271  }
272  async fn on_ice_candidate(&self, candidate: Option<RTCIceCandidate>) {
273    if let Some(candidate) = candidate {
274      let candidate = match candidate.to_json() {
275        Ok(candidate) => candidate,
276        Err(error) => {
277          eprintln!("error ice candidate: {}", error);
278          return;
279        }
280      };
281      self
282        .internal_on_signal(SignalMessage::Candidate { candidate })
283        .await;
284    }
285  }
286  async fn on_track(
287    &self,
288    track: Arc<TrackRemote>,
289    receiver: Arc<RTCRtpReceiver>,
290    transceiver: Arc<RTCRtpTransceiver>,
291  ) {
292    println!(
293      "{}: track: {:?} {:?} {:?}",
294      self.get_id(),
295      track,
296      receiver,
297      transceiver
298    );
299  }
300
301  fn on_data_channel(&self, data_channel: Arc<RTCDataChannel>) {
302    let on_open_peer = self.clone();
303    data_channel.on_open(Box::new(move || {
304      let pinned_peer = on_open_peer.clone();
305      Box::pin(async move {
306        pinned_peer.on_data_channel_open().await;
307      })
308    }));
309    let on_message_peer = self.clone();
310    data_channel.on_message(Box::new(move |msg| {
311      let pinned_peer = on_message_peer.clone();
312      Box::pin(async move {
313        pinned_peer.on_data_channel_message(msg).await;
314      })
315    }));
316    let on_error_peer = self.clone();
317    data_channel.on_error(Box::new(move |error| {
318      let pinned_peer = on_error_peer.clone();
319      Box::pin(async move {
320        pinned_peer.on_data_channel_error(error).await;
321      })
322    }));
323    self
324      .inner
325      .data_channel
326      .store(Ordering::Relaxed, data_channel);
327  }
328
329  async fn on_data_channel_open(&self) {
330    self.internal_on_connect().await;
331  }
332
333  async fn on_data_channel_message(&self, msg: DataChannelMessage) {
334    self.internal_on_data(msg.data.to_vec()).await;
335  }
336
337  async fn on_data_channel_error(&self, error: webrtc::Error) {
338    eprintln!("data channel error: {}", error);
339  }
340
341  async fn internal_close(&self, emit: bool) -> Result<(), webrtc::Error> {
342    if let Some(channel) = self.inner.data_channel.take(Ordering::SeqCst) {
343      channel.close().await?;
344    }
345    if let Some(connection) = self.inner.connection.take(Ordering::SeqCst) {
346      connection.close().await?;
347    }
348    if emit {
349      self.internal_on_close().await;
350    }
351    Ok(())
352  }
353
354  pub async fn signal(&self, msg: SignalMessage) -> Result<(), webrtc::Error> {
355    if self.inner.connection.is_none(Ordering::Relaxed) {
356      self.create_peer().await?;
357    }
358
359    match msg {
360      SignalMessage::Renegotiate => self.negotiate().await,
361      SignalMessage::Candidate { candidate } => {
362        if let Some(connection) = self.inner.connection.as_ref(Ordering::Relaxed) {
363          if connection.remote_description().await.is_some() {
364            return connection.add_ice_candidate(candidate).await;
365          }
366        }
367        self.inner.pending_candidates.lock().await.push(candidate);
368        Ok(())
369      }
370      SignalMessage::SDP(sdp) => {
371        if let Some(connection) = self.inner.connection.as_ref(Ordering::Relaxed) {
372          let kind = sdp.sdp_type.clone();
373          connection.set_remote_description(sdp).await?;
374          for pending_candidate in self.inner.pending_candidates.lock().await.drain(..) {
375            connection.add_ice_candidate(pending_candidate).await?;
376          }
377          if kind == RTCSdpType::Offer {
378            self.create_answer().await?;
379          }
380          self.internal_on_negotiated().await;
381          Ok(())
382        } else {
383          Err(webrtc::Error::ErrConnectionClosed)
384        }
385      }
386    }
387  }
388
389  async fn create_offer(&self) -> Result<(), webrtc::Error> {
390    if let Some(connection) = self.inner.connection.as_ref(Ordering::Relaxed) {
391      let offer = connection
392        .create_offer(self.inner.offer_config.clone())
393        .await?;
394      connection.set_local_description(offer.clone()).await?;
395      self.internal_on_signal(SignalMessage::SDP(offer)).await;
396    }
397    Ok(())
398  }
399
400  async fn create_answer(&self) -> Result<(), webrtc::Error> {
401    if let Some(connection) = self.inner.connection.as_ref(Ordering::Relaxed) {
402      let answer = connection
403        .create_answer(self.inner.answer_config.clone())
404        .await?;
405      connection.set_local_description(answer.clone()).await?;
406      self.internal_on_signal(SignalMessage::SDP(answer)).await;
407    }
408    Ok(())
409  }
410
411  async fn negotiate(&self) -> Result<(), webrtc::Error> {
412    if self.inner.initiator.load(Ordering::Relaxed) {
413      return self.create_offer().await;
414    }
415    self.internal_on_signal(SignalMessage::Renegotiate).await;
416    Ok(())
417  }
418
419  async fn internal_on_signal(&self, signal: SignalMessage) {
420    if let Some(on_signal) = self.inner.on_signal.lock().await.as_mut() {
421      on_signal(signal).await;
422    }
423  }
424  async fn internal_on_data(&self, data: Vec<u8>) {
425    if let Some(on_data) = self.inner.on_data.lock().await.as_mut() {
426      on_data(data).await;
427    }
428  }
429  async fn internal_on_connect(&self) {
430    if let Some(on_connect) = self.inner.on_connect.lock().await.as_mut() {
431      on_connect().await;
432    }
433  }
434  async fn internal_on_close(&self) {
435    if let Some(on_close) = self.inner.on_close.lock().await.as_mut() {
436      on_close().await;
437    }
438  }
439  async fn internal_on_negotiated(&self) {
440    if let Some(on_negotiated) = self.inner.on_negotiated.lock().await.as_mut() {
441      on_negotiated().await;
442    }
443  }
444}
445
446#[cfg(test)]
447mod test {
448  use webrtc::{
449    api::{
450      interceptor_registry::register_default_interceptors, media_engine::MediaEngine, APIBuilder,
451    },
452    ice_transport::ice_server::RTCIceServer,
453    interceptor::registry::Registry,
454  };
455
456  use super::*;
457
458  #[tokio::test(flavor = "multi_thread")]
459  async fn basic() -> Result<(), webrtc::Error> {
460    let mut m = MediaEngine::default();
461    let registry = register_default_interceptors(Registry::new(), &mut m)?;
462
463    let api = Arc::new(
464      APIBuilder::new()
465        .with_media_engine(m)
466        .with_interceptor_registry(registry)
467        .build(),
468    );
469
470    let options = PeerOptions {
471      connection_config: Some(RTCConfiguration {
472        ice_servers: vec![RTCIceServer {
473          ..Default::default()
474        }],
475        ..Default::default()
476      }),
477      ..Default::default()
478    };
479
480    let peer1 = Peer::new(
481      api.clone(),
482      PeerOptions {
483        id: Some("peer1".to_string()),
484        ..options.clone()
485      },
486    );
487    let peer2 = Peer::new(
488      api,
489      PeerOptions {
490        id: Some("peer2".to_string()),
491        ..options
492      },
493    );
494
495    let on_signal_peer2 = peer2.clone();
496    peer1.on_signal(Box::new(move |singal| {
497      let pinned_peer2 = on_signal_peer2.clone();
498      Box::pin(async move {
499        pinned_peer2
500          .signal(singal)
501          .await
502          .expect("failed to signal peer2");
503      })
504    }));
505
506    let on_signal_peer1 = peer1.clone();
507    peer2.on_signal(Box::new(move |singal| {
508      let pinned_peer1 = on_signal_peer1.clone();
509      Box::pin(async move {
510        pinned_peer1
511          .signal(singal)
512          .await
513          .expect("failed to signal peer1");
514      })
515    }));
516
517    let (connect_sender, mut connect_receiver) = tokio::sync::mpsc::channel::<()>(1);
518    peer2.on_connect(Box::new(move || {
519      let pinned_connect_sender = connect_sender.clone();
520      Box::pin(async move {
521        pinned_connect_sender
522          .send(())
523          .await
524          .expect("failed to send connect");
525      })
526    }));
527
528    let (message_sender, mut message_receiver) = tokio::sync::mpsc::channel::<Vec<u8>>(1);
529    peer1.on_data(Box::new(move |data| {
530      let pinned_message_sender = message_sender.clone();
531      Box::pin(async move {
532        pinned_message_sender
533          .send(data)
534          .await
535          .expect("failed to send connect");
536      })
537    }));
538
539    peer1.init().await?;
540
541    let _ = connect_receiver.recv().await;
542    if let Some(data_channel) = peer2.get_data_channel() {
543      data_channel.send_text("Hello, world!").await?;
544    }
545
546    let data = message_receiver
547      .recv()
548      .await
549      .expect("failed to receive message from peer2");
550
551    assert_eq!(String::from_utf8_lossy(data.as_ref()), "Hello, world!");
552
553    peer1.close().await?;
554    peer2.close().await?;
555
556    Ok(())
557  }
558}