1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8};
9
10use serde::{Deserialize, Serialize};
11use tokio::{runtime::Handle, sync::Mutex, task::block_in_place};
12use uuid::Uuid;
13use webrtc::{
14 api::API,
15 data_channel::{
16 data_channel_init::RTCDataChannelInit, data_channel_message::DataChannelMessage, RTCDataChannel,
17 },
18 ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit},
19 peer_connection::{
20 configuration::RTCConfiguration,
21 offer_answer_options::{RTCAnswerOptions, RTCOfferOptions},
22 peer_connection_state::RTCPeerConnectionState,
23 sdp::{sdp_type::RTCSdpType, session_description::RTCSessionDescription},
24 RTCPeerConnection,
25 },
26 rtp_transceiver::{rtp_receiver::RTCRtpReceiver, RTCRtpTransceiver},
27 track::track_remote::TrackRemote,
28};
29
30use atomicoption::AtomicOption;
31
32#[derive(Clone, Default)]
33pub struct PeerOptions {
34 pub id: Option<String>,
35 pub max_channel_message_size: Option<usize>,
36 pub data_channel_name: Option<String>,
37 pub event_channel_size: Option<usize>,
38 pub connection_config: Option<RTCConfiguration>,
39 pub offer_config: Option<RTCOfferOptions>,
40 pub answer_config: Option<RTCAnswerOptions>,
41 pub data_channel_config: Option<RTCDataChannelInit>,
42}
43
44#[derive(Debug, Clone, Deserialize, Serialize)]
45#[serde(tag = "type")]
46pub enum SignalMessage {
47 #[serde(rename = "renegotiate")]
48 Renegotiate,
49 #[serde(rename = "candidate")]
50 Candidate { candidate: RTCIceCandidateInit },
51 #[serde(untagged)]
52 SDP(RTCSessionDescription),
53}
54
55pub type OnSignal = Box<
56 dyn (FnMut(SignalMessage) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync,
57>;
58pub type OnData =
59 Box<dyn (FnMut(Vec<u8>) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
60pub type OnConnect =
61 Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
62pub type OnClose =
63 Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
64pub type OnNegotiated =
65 Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
66
67#[derive(Clone)]
68pub struct Peer {
69 inner: Arc<PeerInner>,
70}
71
72unsafe impl Send for Peer {}
73unsafe impl Sync for Peer {}
74
75pub struct PeerInner {
76 id: String,
77 api: Arc<API>,
78 initiator: Arc<AtomicBool>,
79 connection_config: RTCConfiguration,
80 connection: AtomicOption<RTCPeerConnection>,
81 offer_config: Option<RTCOfferOptions>,
82 answer_config: Option<RTCAnswerOptions>,
83 data_channel_name: String,
84 data_channel: AtomicOption<Arc<RTCDataChannel>>,
85 data_channel_config: Option<RTCDataChannelInit>,
86 pending_candidates: Mutex<Vec<RTCIceCandidateInit>>,
87 on_signal: Arc<Mutex<Option<OnSignal>>>,
88 on_data: Arc<Mutex<Option<OnData>>>,
89 on_connect: Arc<Mutex<Option<OnConnect>>>,
90 on_close: Arc<Mutex<Option<OnClose>>>,
91 on_negotiated: Arc<Mutex<Option<OnNegotiated>>>,
92}
93
94impl Peer {
95 pub fn new(api: Arc<API>, options: PeerOptions) -> Self {
96 Self {
97 inner: Arc::new(PeerInner {
98 id: options.id.unwrap_or_else(|| Uuid::new_v4().to_string()),
99 api,
100 initiator: Arc::new(AtomicBool::new(false)),
101 data_channel_name: options
102 .data_channel_name
103 .unwrap_or_else(|| Uuid::new_v4().to_string()),
104 connection: AtomicOption::none(),
105 connection_config: options.connection_config.unwrap_or_default(),
106 offer_config: options.offer_config,
107 answer_config: options.answer_config,
108 data_channel_config: options.data_channel_config,
109 data_channel: AtomicOption::none(),
110 pending_candidates: Mutex::new(Vec::new()),
111 on_signal: Arc::new(Mutex::new(None)),
112 on_data: Arc::new(Mutex::new(None)),
113 on_connect: Arc::new(Mutex::new(None)),
114 on_close: Arc::new(Mutex::new(None)),
115 on_negotiated: Arc::new(Mutex::new(None)),
116 }),
117 }
118 }
119
120 pub fn get_id(&self) -> &str {
121 &self.inner.id
122 }
123
124 pub fn get_data_channel(&self) -> Option<Arc<RTCDataChannel>> {
125 self
126 .inner
127 .data_channel
128 .as_ref(Ordering::Relaxed)
129 .map(Clone::clone)
130 }
131
132 pub fn get_connection(&self) -> Option<&RTCPeerConnection> {
133 self.inner.connection.as_ref(Ordering::Relaxed)
134 }
135
136 pub async fn init(&self) -> Result<(), webrtc::Error> {
137 self.inner.initiator.store(true, Ordering::SeqCst);
138 self.create_peer().await
139 }
140
141 pub fn on_signal(&self, callback: OnSignal) {
142 block_in_place(|| {
143 Handle::current()
144 .block_on(self.inner.on_signal.lock())
145 .replace(callback)
146 });
147 }
148
149 pub fn on_data(&self, callback: OnData) {
150 block_in_place(|| {
151 Handle::current()
152 .block_on(self.inner.on_data.lock())
153 .replace(callback)
154 });
155 }
156
157 pub fn on_connect(&self, callback: OnConnect) {
158 block_in_place(|| {
159 Handle::current()
160 .block_on(self.inner.on_connect.lock())
161 .replace(callback)
162 });
163 }
164
165 pub fn on_close(&self, callback: OnClose) {
166 block_in_place(|| {
167 Handle::current()
168 .block_on(self.inner.on_close.lock())
169 .replace(callback)
170 });
171 }
172
173 pub fn on_negotiated(&self, callback: OnNegotiated) {
174 block_in_place(|| {
175 Handle::current()
176 .block_on(self.inner.on_negotiated.lock())
177 .replace(callback)
178 });
179 }
180
181 async fn create_peer(&self) -> Result<(), webrtc::Error> {
182 let api = self.inner.api.clone();
183 let connection =
184 self
185 .inner
186 .connection
187 .load_or_store_with(Ordering::SeqCst, Ordering::SeqCst, || {
188 block_in_place(|| {
189 Handle::current()
190 .block_on(api.new_peer_connection(self.inner.connection_config.clone()))
191 .expect("failed to create peer connection")
192 })
193 });
194
195 let on_negotiation_needed_peer = self.clone();
196 connection.on_negotiation_needed(Box::new(move || {
197 let pinned_peer = on_negotiation_needed_peer.clone();
198 Box::pin(async move {
199 pinned_peer.on_negotiation_needed().await;
200 })
201 }));
202 let on_peer_connection_state_change_peer = self.clone();
203 connection.on_peer_connection_state_change(Box::new(move |connection_state| {
204 let pinned_peer = on_peer_connection_state_change_peer.clone();
205 Box::pin(async move {
206 pinned_peer
207 .on_peer_connection_state_change(connection_state)
208 .await;
209 })
210 }));
211 let on_ice_candidate_peer = self.clone();
212 connection.on_ice_candidate(Box::new(move |candidate| {
213 let pinned_peer = on_ice_candidate_peer.clone();
214 Box::pin(async move {
215 pinned_peer.on_ice_candidate(candidate).await;
216 })
217 }));
218 let on_track_peer = self.clone();
219 connection.on_track(Box::new(move |track, receiver, transceiver| {
220 let pinned_peer = on_track_peer.clone();
221 Box::pin(async move {
222 pinned_peer.on_track(track, receiver, transceiver).await;
223 })
224 }));
225
226 if self.inner.initiator.load(Ordering::Relaxed) {
227 self.on_data_channel(
228 connection
229 .create_data_channel(
230 &self.inner.data_channel_name,
231 self.inner.data_channel_config.clone(),
232 )
233 .await?,
234 );
235 } else {
236 let peer = self.clone();
237 connection.on_data_channel(Box::new(move |data_channel| {
238 peer.on_data_channel(data_channel);
239 Box::pin(async move {})
240 }));
241 }
242
243 Ok(())
244 }
245
246 pub async fn close(&self) -> Result<(), webrtc::Error> {
247 self.internal_close(true).await
248 }
249
250 async fn on_negotiation_needed(&self) {
251 match self.negotiate().await {
252 Ok(_) => {}
253 Err(error) => {
254 eprintln!("error negotiating: {}", error)
255 }
256 }
257 }
258
259 async fn on_peer_connection_state_change(&self, connection_state: RTCPeerConnectionState) {
260 match connection_state {
261 RTCPeerConnectionState::Closed
262 | RTCPeerConnectionState::Failed
263 | RTCPeerConnectionState::Disconnected => match self.internal_close(true).await {
264 Ok(_) => {}
265 Err(error) => {
266 eprintln!("error peer connection change: {}", error)
267 }
268 },
269 _state => {}
270 }
271 }
272 async fn on_ice_candidate(&self, candidate: Option<RTCIceCandidate>) {
273 if let Some(candidate) = candidate {
274 let candidate = match candidate.to_json() {
275 Ok(candidate) => candidate,
276 Err(error) => {
277 eprintln!("error ice candidate: {}", error);
278 return;
279 }
280 };
281 self
282 .internal_on_signal(SignalMessage::Candidate { candidate })
283 .await;
284 }
285 }
286 async fn on_track(
287 &self,
288 track: Arc<TrackRemote>,
289 receiver: Arc<RTCRtpReceiver>,
290 transceiver: Arc<RTCRtpTransceiver>,
291 ) {
292 println!(
293 "{}: track: {:?} {:?} {:?}",
294 self.get_id(),
295 track,
296 receiver,
297 transceiver
298 );
299 }
300
301 fn on_data_channel(&self, data_channel: Arc<RTCDataChannel>) {
302 let on_open_peer = self.clone();
303 data_channel.on_open(Box::new(move || {
304 let pinned_peer = on_open_peer.clone();
305 Box::pin(async move {
306 pinned_peer.on_data_channel_open().await;
307 })
308 }));
309 let on_message_peer = self.clone();
310 data_channel.on_message(Box::new(move |msg| {
311 let pinned_peer = on_message_peer.clone();
312 Box::pin(async move {
313 pinned_peer.on_data_channel_message(msg).await;
314 })
315 }));
316 let on_error_peer = self.clone();
317 data_channel.on_error(Box::new(move |error| {
318 let pinned_peer = on_error_peer.clone();
319 Box::pin(async move {
320 pinned_peer.on_data_channel_error(error).await;
321 })
322 }));
323 self
324 .inner
325 .data_channel
326 .store(Ordering::Relaxed, data_channel);
327 }
328
329 async fn on_data_channel_open(&self) {
330 self.internal_on_connect().await;
331 }
332
333 async fn on_data_channel_message(&self, msg: DataChannelMessage) {
334 self.internal_on_data(msg.data.to_vec()).await;
335 }
336
337 async fn on_data_channel_error(&self, error: webrtc::Error) {
338 eprintln!("data channel error: {}", error);
339 }
340
341 async fn internal_close(&self, emit: bool) -> Result<(), webrtc::Error> {
342 if let Some(channel) = self.inner.data_channel.take(Ordering::SeqCst) {
343 channel.close().await?;
344 }
345 if let Some(connection) = self.inner.connection.take(Ordering::SeqCst) {
346 connection.close().await?;
347 }
348 if emit {
349 self.internal_on_close().await;
350 }
351 Ok(())
352 }
353
354 pub async fn signal(&self, msg: SignalMessage) -> Result<(), webrtc::Error> {
355 if self.inner.connection.is_none(Ordering::Relaxed) {
356 self.create_peer().await?;
357 }
358
359 match msg {
360 SignalMessage::Renegotiate => self.negotiate().await,
361 SignalMessage::Candidate { candidate } => {
362 if let Some(connection) = self.inner.connection.as_ref(Ordering::Relaxed) {
363 if connection.remote_description().await.is_some() {
364 return connection.add_ice_candidate(candidate).await;
365 }
366 }
367 self.inner.pending_candidates.lock().await.push(candidate);
368 Ok(())
369 }
370 SignalMessage::SDP(sdp) => {
371 if let Some(connection) = self.inner.connection.as_ref(Ordering::Relaxed) {
372 let kind = sdp.sdp_type.clone();
373 connection.set_remote_description(sdp).await?;
374 for pending_candidate in self.inner.pending_candidates.lock().await.drain(..) {
375 connection.add_ice_candidate(pending_candidate).await?;
376 }
377 if kind == RTCSdpType::Offer {
378 self.create_answer().await?;
379 }
380 self.internal_on_negotiated().await;
381 Ok(())
382 } else {
383 Err(webrtc::Error::ErrConnectionClosed)
384 }
385 }
386 }
387 }
388
389 async fn create_offer(&self) -> Result<(), webrtc::Error> {
390 if let Some(connection) = self.inner.connection.as_ref(Ordering::Relaxed) {
391 let offer = connection
392 .create_offer(self.inner.offer_config.clone())
393 .await?;
394 connection.set_local_description(offer.clone()).await?;
395 self.internal_on_signal(SignalMessage::SDP(offer)).await;
396 }
397 Ok(())
398 }
399
400 async fn create_answer(&self) -> Result<(), webrtc::Error> {
401 if let Some(connection) = self.inner.connection.as_ref(Ordering::Relaxed) {
402 let answer = connection
403 .create_answer(self.inner.answer_config.clone())
404 .await?;
405 connection.set_local_description(answer.clone()).await?;
406 self.internal_on_signal(SignalMessage::SDP(answer)).await;
407 }
408 Ok(())
409 }
410
411 async fn negotiate(&self) -> Result<(), webrtc::Error> {
412 if self.inner.initiator.load(Ordering::Relaxed) {
413 return self.create_offer().await;
414 }
415 self.internal_on_signal(SignalMessage::Renegotiate).await;
416 Ok(())
417 }
418
419 async fn internal_on_signal(&self, signal: SignalMessage) {
420 if let Some(on_signal) = self.inner.on_signal.lock().await.as_mut() {
421 on_signal(signal).await;
422 }
423 }
424 async fn internal_on_data(&self, data: Vec<u8>) {
425 if let Some(on_data) = self.inner.on_data.lock().await.as_mut() {
426 on_data(data).await;
427 }
428 }
429 async fn internal_on_connect(&self) {
430 if let Some(on_connect) = self.inner.on_connect.lock().await.as_mut() {
431 on_connect().await;
432 }
433 }
434 async fn internal_on_close(&self) {
435 if let Some(on_close) = self.inner.on_close.lock().await.as_mut() {
436 on_close().await;
437 }
438 }
439 async fn internal_on_negotiated(&self) {
440 if let Some(on_negotiated) = self.inner.on_negotiated.lock().await.as_mut() {
441 on_negotiated().await;
442 }
443 }
444}
445
446#[cfg(test)]
447mod test {
448 use webrtc::{
449 api::{
450 interceptor_registry::register_default_interceptors, media_engine::MediaEngine, APIBuilder,
451 },
452 ice_transport::ice_server::RTCIceServer,
453 interceptor::registry::Registry,
454 };
455
456 use super::*;
457
458 #[tokio::test(flavor = "multi_thread")]
459 async fn basic() -> Result<(), webrtc::Error> {
460 let mut m = MediaEngine::default();
461 let registry = register_default_interceptors(Registry::new(), &mut m)?;
462
463 let api = Arc::new(
464 APIBuilder::new()
465 .with_media_engine(m)
466 .with_interceptor_registry(registry)
467 .build(),
468 );
469
470 let options = PeerOptions {
471 connection_config: Some(RTCConfiguration {
472 ice_servers: vec![RTCIceServer {
473 ..Default::default()
474 }],
475 ..Default::default()
476 }),
477 ..Default::default()
478 };
479
480 let peer1 = Peer::new(
481 api.clone(),
482 PeerOptions {
483 id: Some("peer1".to_string()),
484 ..options.clone()
485 },
486 );
487 let peer2 = Peer::new(
488 api,
489 PeerOptions {
490 id: Some("peer2".to_string()),
491 ..options
492 },
493 );
494
495 let on_signal_peer2 = peer2.clone();
496 peer1.on_signal(Box::new(move |singal| {
497 let pinned_peer2 = on_signal_peer2.clone();
498 Box::pin(async move {
499 pinned_peer2
500 .signal(singal)
501 .await
502 .expect("failed to signal peer2");
503 })
504 }));
505
506 let on_signal_peer1 = peer1.clone();
507 peer2.on_signal(Box::new(move |singal| {
508 let pinned_peer1 = on_signal_peer1.clone();
509 Box::pin(async move {
510 pinned_peer1
511 .signal(singal)
512 .await
513 .expect("failed to signal peer1");
514 })
515 }));
516
517 let (connect_sender, mut connect_receiver) = tokio::sync::mpsc::channel::<()>(1);
518 peer2.on_connect(Box::new(move || {
519 let pinned_connect_sender = connect_sender.clone();
520 Box::pin(async move {
521 pinned_connect_sender
522 .send(())
523 .await
524 .expect("failed to send connect");
525 })
526 }));
527
528 let (message_sender, mut message_receiver) = tokio::sync::mpsc::channel::<Vec<u8>>(1);
529 peer1.on_data(Box::new(move |data| {
530 let pinned_message_sender = message_sender.clone();
531 Box::pin(async move {
532 pinned_message_sender
533 .send(data)
534 .await
535 .expect("failed to send connect");
536 })
537 }));
538
539 peer1.init().await?;
540
541 let _ = connect_receiver.recv().await;
542 if let Some(data_channel) = peer2.get_data_channel() {
543 data_channel.send_text("Hello, world!").await?;
544 }
545
546 let data = message_receiver
547 .recv()
548 .await
549 .expect("failed to receive message from peer2");
550
551 assert_eq!(String::from_utf8_lossy(data.as_ref()), "Hello, world!");
552
553 peer1.close().await?;
554 peer2.close().await?;
555
556 Ok(())
557 }
558}