datachannel/
peerconnection.rs

1use std::ffi::{c_void, CStr, CString};
2use std::os::raw::c_char;
3use std::ptr;
4
5use datachannel_sys as sys;
6use derive_more::Debug;
7use parking_lot::ReentrantMutex;
8use serde::{Deserialize, Serialize};
9use webrtc_sdp::{media_type::SdpMedia, parse_sdp, SdpSession};
10
11use crate::config::RtcConfig;
12use crate::datachannel::{DataChannelHandler, DataChannelInit, RtcDataChannel};
13use crate::error::{check, Error, Result};
14use crate::track::{RtcTrack, TrackHandler, TrackInit};
15use crate::{logger, DataChannelId, DataChannelInfo};
16
17#[derive(Debug, PartialEq, Eq)]
18pub enum ConnectionState {
19    New,
20    Connecting,
21    Connected,
22    Disconnected,
23    Failed,
24    Closed,
25}
26
27impl ConnectionState {
28    fn from_raw(state: sys::rtcState) -> Self {
29        match state {
30            sys::rtcState_RTC_NEW => Self::New,
31            sys::rtcState_RTC_CONNECTING => Self::Connecting,
32            sys::rtcState_RTC_CONNECTED => Self::Connected,
33            sys::rtcState_RTC_DISCONNECTED => Self::Disconnected,
34            sys::rtcState_RTC_FAILED => Self::Failed,
35            sys::rtcState_RTC_CLOSED => Self::Closed,
36            _ => panic!("Unknown rtcState: {}", state),
37        }
38    }
39}
40
41#[derive(Debug, PartialEq, Eq)]
42pub enum GatheringState {
43    New,
44    InProgress,
45    Complete,
46}
47
48impl GatheringState {
49    fn from_raw(state: sys::rtcGatheringState) -> Self {
50        match state {
51            sys::rtcGatheringState_RTC_GATHERING_NEW => Self::New,
52            sys::rtcGatheringState_RTC_GATHERING_INPROGRESS => Self::InProgress,
53            sys::rtcGatheringState_RTC_GATHERING_COMPLETE => Self::Complete,
54            _ => panic!("Unknown rtcGatheringState: {}", state),
55        }
56    }
57}
58
59#[derive(Debug, PartialEq, Eq)]
60pub enum SignalingState {
61    Stable,
62    HaveLocalOffer,
63    HaveRemoteOffer,
64    HaveLocalPranswer,
65    HaveRemotePranswer,
66}
67
68impl SignalingState {
69    fn from_raw(state: sys::rtcSignalingState) -> Self {
70        match state {
71            sys::rtcSignalingState_RTC_SIGNALING_STABLE => Self::Stable,
72            sys::rtcSignalingState_RTC_SIGNALING_HAVE_LOCAL_OFFER => Self::HaveLocalOffer,
73            sys::rtcSignalingState_RTC_SIGNALING_HAVE_REMOTE_OFFER => Self::HaveRemoteOffer,
74            sys::rtcSignalingState_RTC_SIGNALING_HAVE_LOCAL_PRANSWER => Self::HaveLocalPranswer,
75            sys::rtcSignalingState_RTC_SIGNALING_HAVE_REMOTE_PRANSWER => Self::HaveRemotePranswer,
76            _ => panic!("Unknown rtcSignalingState: {}", state),
77        }
78    }
79}
80
81#[derive(Debug, PartialEq, Eq)]
82pub enum IceState {
83    New,
84    Checking,
85    Connected,
86    Completed,
87    Failed,
88    Disconnected,
89    Closed,
90}
91
92impl IceState {
93    fn from_raw(state: sys::rtcIceState) -> Self {
94        match state {
95            sys::rtcIceState_RTC_ICE_NEW => Self::New,
96            sys::rtcIceState_RTC_ICE_CHECKING => Self::Checking,
97            sys::rtcIceState_RTC_ICE_CONNECTED => Self::Connected,
98            sys::rtcIceState_RTC_ICE_COMPLETED => Self::Completed,
99            sys::rtcIceState_RTC_ICE_FAILED => Self::Failed,
100            sys::rtcIceState_RTC_ICE_DISCONNECTED => Self::Disconnected,
101            sys::rtcIceState_RTC_ICE_CLOSED => Self::Closed,
102            _ => panic!("Unknown rtcIceState: {}", state),
103        }
104    }
105}
106
107#[derive(Debug, PartialEq, Eq, PartialOrd, Hash)]
108pub struct CandidatePair {
109    pub local: String,
110    pub remote: String,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114pub struct SessionDescription {
115    #[debug("{}", fmt_sdp(sdp))]
116    #[serde(with = "serde_sdp")]
117    pub sdp: SdpSession,
118    #[serde(rename = "type")]
119    pub sdp_type: SdpType,
120}
121
122pub fn fmt_sdp(sdp: &SdpSession) -> String {
123    sdp.to_string()
124        .trim_end()
125        .split("\r\n")
126        .collect::<Vec<_>>()
127        .join("; ")
128}
129
130pub mod serde_sdp {
131    use super::SdpSession;
132    use serde::{de, Deserialize, Deserializer, Serializer};
133
134    pub fn serialize<S>(sdp: &SdpSession, serializer: S) -> Result<S::Ok, S::Error>
135    where
136        S: Serializer,
137    {
138        serializer.serialize_str(&sdp.to_string())
139    }
140
141    pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<SdpSession, D::Error>
142    where
143        D: Deserializer<'de>,
144    {
145        let sdp = String::deserialize(deserializer)?;
146        webrtc_sdp::parse_sdp(&sdp, false).map_err(de::Error::custom)
147    }
148}
149
150#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
151#[serde(rename_all = "lowercase")]
152pub enum SdpType {
153    Answer,
154    Offer,
155    Pranswer,
156    Rollback,
157}
158
159impl SdpType {
160    fn from(val: &str) -> Result<Self> {
161        match val {
162            "answer" => Ok(Self::Answer),
163            "offer" => Ok(Self::Offer),
164            "pranswer" => Ok(Self::Pranswer),
165            "rollback" => Ok(Self::Rollback),
166            _ => Err(Error::InvalidArg),
167        }
168    }
169
170    fn val(&self) -> &'static str {
171        match self {
172            Self::Answer => "answer",
173            Self::Offer => "offer",
174            Self::Pranswer => "pranswer",
175            Self::Rollback => "rollback",
176        }
177    }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
181pub struct IceCandidate {
182    pub candidate: String,
183    #[serde(rename = "sdpMid")]
184    pub mid: String,
185}
186
187#[allow(unused_variables)]
188#[allow(clippy::boxed_local)]
189pub trait PeerConnectionHandler {
190    type DCH;
191
192    fn data_channel_handler(&mut self, info: DataChannelInfo) -> Self::DCH;
193
194    fn on_description(&mut self, sess_desc: SessionDescription) {}
195    fn on_candidate(&mut self, cand: IceCandidate) {}
196    fn on_connection_state_change(&mut self, state: ConnectionState) {}
197    fn on_gathering_state_change(&mut self, state: GatheringState) {}
198    fn on_signaling_state_change(&mut self, state: SignalingState) {}
199    fn on_ice_state_change(&mut self, state: IceState) {}
200    fn on_data_channel(&mut self, data_channel: Box<RtcDataChannel<Self::DCH>>) {}
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
204pub struct PeerConnectionId(i32);
205
206pub struct RtcPeerConnection<P> {
207    lock: ReentrantMutex<()>,
208    id: PeerConnectionId,
209    pc_handler: P,
210}
211
212impl<P> RtcPeerConnection<P>
213where
214    P: PeerConnectionHandler + Send,
215    P::DCH: DataChannelHandler + Send,
216{
217    pub fn new(config: &RtcConfig, pc_handler: P) -> Result<Box<Self>> {
218        #[cfg(feature = "log")]
219        crate::ensure_logging();
220
221        unsafe {
222            let id = check(sys::rtcCreatePeerConnection(&config.as_raw()))?;
223            let mut rtc_pc = Box::new(RtcPeerConnection {
224                lock: ReentrantMutex::new(()),
225                id: PeerConnectionId(id),
226                pc_handler,
227            });
228            let ptr = &mut *rtc_pc;
229
230            sys::rtcSetUserPointer(id, ptr as *mut _ as *mut c_void);
231
232            check(sys::rtcSetLocalDescriptionCallback(
233                id,
234                Some(RtcPeerConnection::<P>::local_description_cb),
235            ))?;
236
237            check(sys::rtcSetLocalCandidateCallback(
238                id,
239                Some(RtcPeerConnection::<P>::local_candidate_cb),
240            ))?;
241
242            check(sys::rtcSetStateChangeCallback(
243                id,
244                Some(RtcPeerConnection::<P>::state_change_cb),
245            ))?;
246
247            check(sys::rtcSetGatheringStateChangeCallback(
248                id,
249                Some(RtcPeerConnection::<P>::gathering_state_cb),
250            ))?;
251
252            check(sys::rtcSetSignalingStateChangeCallback(
253                id,
254                Some(RtcPeerConnection::<P>::signaling_state_cb),
255            ))?;
256
257            check(sys::rtcSetIceStateChangeCallback(
258                id,
259                Some(RtcPeerConnection::<P>::ice_state_cb),
260            ))?;
261
262            check(sys::rtcSetDataChannelCallback(
263                id,
264                Some(RtcPeerConnection::<P>::data_channel_cb),
265            ))?;
266
267            Ok(rtc_pc)
268        }
269    }
270
271    unsafe extern "C" fn local_description_cb(
272        _: i32,
273        sdp: *const c_char,
274        sdp_type: *const c_char,
275        ptr: *mut c_void,
276    ) {
277        let rtc_pc = &mut *(ptr as *mut RtcPeerConnection<P>);
278
279        let sdp = CStr::from_ptr(sdp).to_string_lossy();
280        let sdp = match parse_sdp(&sdp, false) {
281            Ok(sdp) => sdp,
282            Err(err) => {
283                logger::warn!("Ignoring invalid SDP: {}", err);
284                logger::debug!("{}", sdp);
285                return;
286            }
287        };
288
289        let sdp_type = CStr::from_ptr(sdp_type).to_string_lossy();
290        let sdp_type = match SdpType::from(&sdp_type) {
291            Ok(sdp_type) => sdp_type,
292            Err(_) => {
293                logger::warn!("Ignoring session with invalid SdpType: {}", sdp_type);
294                logger::debug!("{}", sdp);
295                return;
296            }
297        };
298
299        let sess_desc = SessionDescription { sdp, sdp_type };
300
301        let _guard = rtc_pc.lock.lock();
302        rtc_pc.pc_handler.on_description(sess_desc);
303    }
304
305    unsafe extern "C" fn local_candidate_cb(
306        _: i32,
307        cand: *const c_char,
308        mid: *const c_char,
309        ptr: *mut c_void,
310    ) {
311        let rtc_pc = &mut *(ptr as *mut RtcPeerConnection<P>);
312
313        let candidate = CStr::from_ptr(cand).to_string_lossy().to_string();
314        let mid = CStr::from_ptr(mid).to_string_lossy().to_string();
315        let cand = IceCandidate { candidate, mid };
316
317        let _guard = rtc_pc.lock.lock();
318        rtc_pc.pc_handler.on_candidate(cand);
319    }
320
321    unsafe extern "C" fn state_change_cb(_: i32, state: sys::rtcState, ptr: *mut c_void) {
322        let rtc_pc = &mut *(ptr as *mut RtcPeerConnection<P>);
323
324        let state = ConnectionState::from_raw(state);
325
326        let _guard = rtc_pc.lock.lock();
327        rtc_pc.pc_handler.on_connection_state_change(state);
328    }
329
330    unsafe extern "C" fn gathering_state_cb(_: i32, state: sys::rtcState, ptr: *mut c_void) {
331        let rtc_pc = &mut *(ptr as *mut RtcPeerConnection<P>);
332
333        let state = GatheringState::from_raw(state);
334
335        let _guard = rtc_pc.lock.lock();
336        rtc_pc.pc_handler.on_gathering_state_change(state);
337    }
338
339    unsafe extern "C" fn signaling_state_cb(_: i32, state: sys::rtcState, ptr: *mut c_void) {
340        let rtc_pc = &mut *(ptr as *mut RtcPeerConnection<P>);
341
342        let state = SignalingState::from_raw(state);
343
344        let _guard = rtc_pc.lock.lock();
345        rtc_pc.pc_handler.on_signaling_state_change(state);
346    }
347
348    unsafe extern "C" fn ice_state_cb(_: i32, state: sys::rtcIceState, ptr: *mut c_void) {
349        let rtc_pc = &mut *(ptr as *mut RtcPeerConnection<P>);
350
351        let state = IceState::from_raw(state);
352
353        let _guard = rtc_pc.lock.lock();
354        rtc_pc.pc_handler.on_ice_state_change(state);
355    }
356
357    unsafe extern "C" fn data_channel_cb(_: i32, id: i32, ptr: *mut c_void) {
358        let rtc_pc = &mut *(ptr as *mut RtcPeerConnection<P>);
359
360        let id = DataChannelId(id);
361        let info = DataChannelInfo {
362            id,
363            label: DataChannelInfo::label(id),
364            protocol: DataChannelInfo::protocol(id),
365            reliability: DataChannelInfo::reliability(id),
366            stream: DataChannelInfo::stream(id),
367        };
368
369        let guard = rtc_pc.lock.lock();
370        let dc = rtc_pc.pc_handler.data_channel_handler(info);
371        drop(guard);
372
373        match RtcDataChannel::new(id, dc) {
374            Ok(dc) => {
375                let _guard = rtc_pc.lock.lock();
376                rtc_pc.pc_handler.on_data_channel(dc);
377            }
378            Err(err) => logger::error!(
379                "Couldn't create RtcDataChannel with id={:?} from RtcPeerConnection {:p}: {}",
380                id,
381                ptr,
382                err
383            ),
384        }
385    }
386
387    pub fn id(&self) -> PeerConnectionId {
388        self.id
389    }
390
391    /// Creates a boxed [`RtcDataChannel`].
392    pub fn create_data_channel<C>(
393        &mut self,
394        label: &str,
395        dc_handler: C,
396    ) -> Result<Box<RtcDataChannel<C>>>
397    where
398        C: DataChannelHandler + Send,
399    {
400        let label = CString::new(label)?;
401        let id = DataChannelId(check(unsafe {
402            sys::rtcCreateDataChannel(self.id.0, label.as_ptr())
403        })?);
404        RtcDataChannel::new(id, dc_handler)
405    }
406
407    pub fn create_data_channel_ex<C>(
408        &mut self,
409        label: &str,
410        dc_handler: C,
411        dc_init: &DataChannelInit,
412    ) -> Result<Box<RtcDataChannel<C>>>
413    where
414        C: DataChannelHandler + Send,
415    {
416        let label = CString::new(label)?;
417        let id = DataChannelId(check(unsafe {
418            sys::rtcCreateDataChannelEx(self.id.0, label.as_ptr(), &dc_init.as_raw()?)
419        })?);
420        RtcDataChannel::new(id, dc_handler)
421    }
422
423    /// Creates a boxed [`RtcTrack`].
424    pub fn add_track<C>(&mut self, sdp_media: &SdpMedia, t_handler: C) -> Result<Box<RtcTrack<C>>>
425    where
426        C: TrackHandler + Send,
427    {
428        let desc = sdp_media.to_string();
429        let desc = CString::new(desc.strip_prefix("m=").unwrap_or(&desc))?;
430        let id = check(unsafe { sys::rtcAddTrack(self.id.0, desc.as_ptr()) })?;
431        RtcTrack::new(id, t_handler)
432    }
433
434    pub fn add_track_ex<C>(&mut self, t_init: &TrackInit, t_handler: C) -> Result<Box<RtcTrack<C>>>
435    where
436        C: TrackHandler + Send,
437    {
438        let id = check(unsafe { sys::rtcAddTrackEx(self.id.0, &t_init.as_raw()) })?;
439        RtcTrack::new(id, t_handler)
440    }
441
442    pub fn set_local_description(&mut self, sdp_type: SdpType) -> Result<()> {
443        let sdp_type = CString::new(sdp_type.val())?;
444        check(unsafe { sys::rtcSetLocalDescription(self.id.0, sdp_type.as_ptr()) })?;
445        Ok(())
446    }
447
448    pub fn set_remote_description(&mut self, sess_desc: &SessionDescription) -> Result<()> {
449        let sdp = CString::new(sess_desc.sdp.to_string())?;
450        let sdp_type = CString::new(sess_desc.sdp_type.val())?;
451        check(unsafe { sys::rtcSetRemoteDescription(self.id.0, sdp.as_ptr(), sdp_type.as_ptr()) })?;
452        Ok(())
453    }
454
455    pub fn add_remote_candidate(&mut self, cand: &IceCandidate) -> Result<()> {
456        let mid = CString::new(cand.mid.clone())?;
457        let cand = CString::new(cand.candidate.clone())?;
458        unsafe { sys::rtcAddRemoteCandidate(self.id.0, cand.as_ptr(), mid.as_ptr()) };
459        Ok(())
460    }
461
462    pub fn local_description(&self) -> Option<SessionDescription> {
463        let sdp = self
464            .read_string_ffi(sys::rtcGetLocalDescription, "local_description")
465            .map(|sdp| webrtc_sdp::parse_sdp(&sdp, false).map_err(|e| e.to_string()));
466
467        let sdp_type = self
468            .read_string_ffi(sys::rtcGetLocalDescriptionType, "local_description_type")
469            .map(|sdp_type| SdpType::from(&sdp_type).map_err(|e| e.to_string()));
470
471        match (sdp, sdp_type) {
472            (Some(Ok(sdp)), Some(Ok(sdp_type))) => Some(SessionDescription { sdp, sdp_type }),
473            (Some(Err(e)), _) | (None, Some(Err(e))) => {
474                logger::error!("Got an invalid Sessiondescription: {}", e);
475                None
476            }
477            _ => None,
478        }
479    }
480
481    pub fn remote_description(&self) -> Option<SessionDescription> {
482        let sdp = self
483            .read_string_ffi(sys::rtcGetRemoteDescription, "remote_description")
484            .map(|sdp| webrtc_sdp::parse_sdp(&sdp, false).map_err(|e| e.to_string()));
485
486        let sdp_type = self
487            .read_string_ffi(sys::rtcGetRemoteDescriptionType, "remote_description_type")
488            .map(|sdp_type| SdpType::from(&sdp_type).map_err(|e| e.to_string()));
489
490        match (sdp, sdp_type) {
491            (Some(Ok(sdp)), Some(Ok(sdp_type))) => Some(SessionDescription { sdp, sdp_type }),
492            (Some(Err(e)), _) | (None, Some(Err(e))) => {
493                logger::error!("Got an invalid Sessiondescription: {}", e);
494                None
495            }
496            _ => None,
497        }
498    }
499
500    pub fn local_address(&self) -> Option<String> {
501        self.read_string_ffi(sys::rtcGetLocalAddress, "local_address")
502    }
503
504    pub fn remote_address(&self) -> Option<String> {
505        self.read_string_ffi(sys::rtcGetRemoteAddress, "remote_address")
506    }
507
508    pub fn selected_candidate_pair(&self) -> Option<CandidatePair> {
509        let buf_size = check(unsafe {
510            sys::rtcGetSelectedCandidatePair(
511                self.id.0,
512                ptr::null_mut() as *mut c_char,
513                0,
514                ptr::null_mut() as *mut c_char,
515                0,
516            )
517        });
518
519        let buf_size = match buf_size {
520            Ok(buf_size) => buf_size as usize,
521            Err(err) => {
522                logger::error!("Couldn't get buffer size: {}", err);
523                return None;
524            }
525        };
526
527        let mut local_buf = vec![0; buf_size];
528        let mut remote_buf = vec![0; buf_size];
529        match check(unsafe {
530            sys::rtcGetSelectedCandidatePair(
531                self.id.0,
532                local_buf.as_mut_ptr() as *mut c_char,
533                buf_size as i32,
534                remote_buf.as_mut_ptr() as *mut c_char,
535                buf_size as i32,
536            )
537        }) {
538            Ok(_) => {
539                let local = crate::ffi_string(&local_buf);
540                let remote = crate::ffi_string(&remote_buf);
541                match (local, remote) {
542                    (Ok(local), Ok(remote)) => Some(CandidatePair { local, remote }),
543                    (Ok(_), Err(err)) | (Err(err), Ok(_)) | (Err(err), Err(_)) => {
544                        logger::error!(
545                            "Couldn't get RtcPeerConnection {:p} candidate_pair: {}",
546                            self,
547                            err
548                        );
549                        None
550                    }
551                }
552            }
553            Err(Error::NotAvailable) => None,
554            Err(err) => {
555                logger::warn!(
556                    "Couldn't get RtcPeerConnection {:p} candidate_pair: {}",
557                    self,
558                    err
559                );
560                None
561            }
562        }
563    }
564
565    fn read_string_ffi(
566        &self,
567        str_fn: unsafe extern "C" fn(i32, *mut c_char, i32) -> i32,
568        prop: &str,
569    ) -> Option<String> {
570        let buf_size = match check(unsafe { str_fn(self.id.0, ptr::null_mut() as *mut c_char, 0) })
571        {
572            Ok(buf_size) => buf_size as usize,
573            Err(err) => {
574                logger::error!("Couldn't get buffer size: {}", err);
575                return None;
576            }
577        };
578
579        let mut buf = vec![0; buf_size];
580        match check(unsafe { str_fn(self.id.0, buf.as_mut_ptr() as *mut c_char, buf_size as i32) })
581        {
582            Ok(_) => match String::from_utf8(buf) {
583                Ok(local) => Some(local.trim_matches(char::from(0)).to_string()),
584                Err(err) => {
585                    logger::error!(
586                        "Couldn't get RtcPeerConnection {:p} {}: {}",
587                        self,
588                        prop,
589                        err
590                    );
591                    None
592                }
593            },
594            Err(Error::NotAvailable) => None,
595            Err(err) => {
596                logger::warn!(
597                    "Couldn't get RtcPeerConnection {:p} {}: {}",
598                    self,
599                    prop,
600                    err
601                );
602                None
603            }
604        }
605    }
606}
607
608impl<P> Drop for RtcPeerConnection<P> {
609    fn drop(&mut self) {
610        if let Err(err) = check(unsafe { sys::rtcDeletePeerConnection(self.id.0) }) {
611            logger::error!(
612                "Error while dropping RtcPeerConnection id={:?} {:p}: {}",
613                self.id,
614                self,
615                err
616            )
617        }
618    }
619}