#![allow(dead_code)]
#![allow(clippy::too_many_arguments)]
use super::datachannel::{DataChannel, DataChannelConfig, DataChannelManager};
use super::dtls::{DtlsConfig, DtlsConnection, DtlsEndpoint, DtlsRole};
use super::ice::{IceCandidate, IceServer};
use super::ice_agent::{IceAgent, IceAgentConfig, IceConnectionState};
use super::rtcp::Packet as RtcpPacket;
use super::rtp::{Packet as RtpPacket, Session as RtpSession};
use super::sctp::Association;
use super::sdp::{Attribute, MediaDescription, MediaType, SessionDescription};
use crate::error::{NetError, NetResult};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PeerConnectionState {
New,
Connecting,
Connected,
Disconnected,
Failed,
Closed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SignalingState {
Stable,
HaveLocalOffer,
HaveRemoteOffer,
HaveLocalAnswer,
HaveRemoteAnswer,
Closed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SdpType {
Offer,
Answer,
Pranswer,
Rollback,
}
impl SdpType {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Offer => "offer",
Self::Answer => "answer",
Self::Pranswer => "pranswer",
Self::Rollback => "rollback",
}
}
}
#[derive(Debug, Clone)]
pub struct SessionDescriptionInit {
pub sdp_type: SdpType,
pub sdp: String,
}
impl SessionDescriptionInit {
#[must_use]
pub fn new(sdp_type: SdpType, sdp: impl Into<String>) -> Self {
Self {
sdp_type,
sdp: sdp.into(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PeerConnectionConfig {
pub ice_servers: Vec<IceServer>,
pub bundle_policy: BundlePolicy,
pub rtcp_mux_policy: RtcpMuxPolicy,
}
impl PeerConnectionConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_ice_server(mut self, server: IceServer) -> Self {
self.ice_servers.push(server);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BundlePolicy {
Balanced,
#[default]
MaxBundle,
MaxCompat,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RtcpMuxPolicy {
#[default]
Require,
}
pub struct MediaTrack {
id: String,
kind: MediaType,
rtp_session: Arc<Mutex<RtpSession>>,
rtcp_tx: mpsc::UnboundedSender<RtcpPacket>,
}
impl MediaTrack {
#[must_use]
pub fn new(id: impl Into<String>, kind: MediaType, ssrc: u32) -> Self {
let (rtcp_tx, _rtcp_rx) = mpsc::unbounded_channel();
Self {
id: id.into(),
kind,
rtp_session: Arc::new(Mutex::new(RtpSession::new(ssrc))),
rtcp_tx,
}
}
#[must_use]
pub fn id(&self) -> &str {
&self.id
}
#[must_use]
pub const fn kind(&self) -> MediaType {
self.kind
}
pub async fn send_rtp(
&self,
payload_type: u8,
timestamp: u32,
payload: impl Into<bytes::Bytes>,
) -> NetResult<RtpPacket> {
let packet = self
.rtp_session
.lock()
.unwrap_or_else(|e| e.into_inner())
.create_packet(payload_type, timestamp, payload);
Ok(packet)
}
pub fn send_rtcp(&self, packet: RtcpPacket) -> NetResult<()> {
self.rtcp_tx
.send(packet)
.map_err(|_| NetError::connection("RTCP channel closed"))?;
Ok(())
}
#[must_use]
pub fn stats(&self) -> super::rtp::Statistics {
self.rtp_session
.lock()
.unwrap_or_else(|e| e.into_inner())
.stats()
.clone()
}
}
pub struct PeerConnection {
config: PeerConnectionConfig,
state: Arc<Mutex<PeerConnectionState>>,
signaling_state: Arc<Mutex<SignalingState>>,
ice_agent: Arc<Mutex<Option<IceAgent>>>,
dtls_endpoint: Arc<Mutex<Option<DtlsEndpoint>>>,
dtls_connection: Arc<Mutex<Option<Arc<DtlsConnection>>>>,
sctp_association: Arc<Mutex<Option<Arc<Association>>>>,
dc_manager: Arc<Mutex<Option<DataChannelManager>>>,
tracks: Arc<Mutex<Vec<Arc<MediaTrack>>>>,
local_description: Arc<Mutex<Option<SessionDescription>>>,
remote_description: Arc<Mutex<Option<SessionDescription>>>,
pending_local_candidates: Arc<Mutex<Vec<IceCandidate>>>,
}
impl PeerConnection {
pub fn new(config: PeerConnectionConfig) -> NetResult<Self> {
Ok(Self {
config,
state: Arc::new(Mutex::new(PeerConnectionState::New)),
signaling_state: Arc::new(Mutex::new(SignalingState::Stable)),
ice_agent: Arc::new(Mutex::new(None)),
dtls_endpoint: Arc::new(Mutex::new(None)),
dtls_connection: Arc::new(Mutex::new(None)),
sctp_association: Arc::new(Mutex::new(None)),
dc_manager: Arc::new(Mutex::new(None)),
tracks: Arc::new(Mutex::new(Vec::new())),
local_description: Arc::new(Mutex::new(None)),
remote_description: Arc::new(Mutex::new(None)),
pending_local_candidates: Arc::new(Mutex::new(Vec::new())),
})
}
pub async fn create_offer(&self) -> NetResult<SessionDescriptionInit> {
*self
.signaling_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = SignalingState::HaveLocalOffer;
let ice_config = IceAgentConfig {
ice_servers: self.config.ice_servers.clone(),
controlling: true,
..Default::default()
};
let ice_agent = IceAgent::new(ice_config.clone());
let local_candidates = ice_agent.gather_candidates().await?;
*self.ice_agent.lock().unwrap_or_else(|e| e.into_inner()) = Some(ice_agent);
*self
.pending_local_candidates
.lock()
.unwrap_or_else(|e| e.into_inner()) = local_candidates.clone();
let dtls_config = DtlsConfig::new_self_signed(DtlsRole::Server)?;
let fingerprint = dtls_config.fingerprint();
let mut sdp = SessionDescription::new()
.with_origin(format!("- {} 0 IN IP4 0.0.0.0", get_timestamp()))
.with_session_name("WebRTC Session")
.with_attribute(Attribute::new("group", "BUNDLE 0"));
let mut media = MediaDescription::data_channel(9)
.with_format("webrtc-datachannel")
.with_mid("0")
.with_ice(ice_config.local_ufrag.clone(), ice_config.local_pwd.clone())
.with_fingerprint(super::sdp::Fingerprint::new(
fingerprint.algorithm,
fingerprint.value,
))
.with_rtcp_mux();
media.setup = Some("actpass".to_string());
for candidate in &local_candidates {
media
.attributes
.push(Attribute::new("candidate", candidate.to_sdp()));
}
sdp = sdp.with_media(media);
let sdp_string = sdp.to_sdp();
Ok(SessionDescriptionInit::new(SdpType::Offer, sdp_string))
}
pub async fn create_answer(&self) -> NetResult<SessionDescriptionInit> {
*self
.signaling_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = SignalingState::HaveLocalAnswer;
let remote_desc = self
.remote_description
.lock()
.unwrap_or_else(|e| e.into_inner());
let remote_desc = remote_desc
.as_ref()
.ok_or_else(|| NetError::invalid_state("No remote description"))?;
let remote_media = remote_desc
.media
.first()
.ok_or_else(|| NetError::protocol("No media in remote description"))?;
let remote_ufrag = remote_media
.ice_ufrag
.clone()
.ok_or_else(|| NetError::protocol("No ICE ufrag"))?;
let remote_pwd = remote_media
.ice_pwd
.clone()
.ok_or_else(|| NetError::protocol("No ICE pwd"))?;
let ice_config = IceAgentConfig {
ice_servers: self.config.ice_servers.clone(),
controlling: false,
remote_ufrag: Some(remote_ufrag),
remote_pwd: Some(remote_pwd),
..Default::default()
};
let ice_agent = IceAgent::new(ice_config.clone());
let local_candidates = ice_agent.gather_candidates().await?;
*self.ice_agent.lock().unwrap_or_else(|e| e.into_inner()) = Some(ice_agent);
*self
.pending_local_candidates
.lock()
.unwrap_or_else(|e| e.into_inner()) = local_candidates.clone();
let dtls_config = DtlsConfig::new_self_signed(DtlsRole::Client)?;
let fingerprint = dtls_config.fingerprint();
let mut sdp = SessionDescription::new()
.with_origin(format!("- {} 0 IN IP4 0.0.0.0", get_timestamp()))
.with_session_name("WebRTC Session")
.with_attribute(Attribute::new("group", "BUNDLE 0"));
let mut media = MediaDescription::data_channel(9)
.with_format("webrtc-datachannel")
.with_mid("0")
.with_ice(ice_config.local_ufrag.clone(), ice_config.local_pwd.clone())
.with_fingerprint(super::sdp::Fingerprint::new(
fingerprint.algorithm,
fingerprint.value,
))
.with_rtcp_mux();
media.setup = Some("active".to_string());
for candidate in &local_candidates {
media
.attributes
.push(Attribute::new("candidate", candidate.to_sdp()));
}
sdp = sdp.with_media(media);
let sdp_string = sdp.to_sdp();
Ok(SessionDescriptionInit::new(SdpType::Answer, sdp_string))
}
pub async fn set_local_description(&self, desc: SessionDescriptionInit) -> NetResult<()> {
let sdp = SessionDescription::parse(&desc.sdp)?;
*self
.local_description
.lock()
.unwrap_or_else(|e| e.into_inner()) = Some(sdp);
match desc.sdp_type {
SdpType::Offer => {
*self
.signaling_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = SignalingState::HaveLocalOffer;
}
SdpType::Answer => {
*self
.signaling_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = SignalingState::Stable;
}
_ => {}
}
Ok(())
}
pub async fn set_remote_description(&self, desc: SessionDescriptionInit) -> NetResult<()> {
let sdp = SessionDescription::parse(&desc.sdp)?;
for media in &sdp.media {
for attr in &media.attributes {
if attr.name == "candidate" {
if let Some(ref value) = attr.value {
if let Ok(candidate) = IceCandidate::parse(value) {
if let Some(ref ice_agent) =
*self.ice_agent.lock().unwrap_or_else(|e| e.into_inner())
{
ice_agent.add_remote_candidate(candidate);
}
}
}
}
}
}
*self
.remote_description
.lock()
.unwrap_or_else(|e| e.into_inner()) = Some(sdp);
match desc.sdp_type {
SdpType::Offer => {
*self
.signaling_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = SignalingState::HaveRemoteOffer;
}
SdpType::Answer => {
*self
.signaling_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = SignalingState::Stable;
self.start_connection().await?;
}
_ => {}
}
Ok(())
}
async fn start_connection(&self) -> NetResult<()> {
*self.state.lock().unwrap_or_else(|e| e.into_inner()) = PeerConnectionState::Connecting;
if let Some(ref ice_agent) = *self.ice_agent.lock().unwrap_or_else(|e| e.into_inner()) {
ice_agent.check_connectivity().await?;
if ice_agent.state() == IceConnectionState::Connected {
if let Some(socket) = ice_agent.socket() {
let dtls_config = DtlsConfig::new_self_signed(DtlsRole::Client)?;
let dtls_endpoint = DtlsEndpoint::new(dtls_config, socket);
let dtls_conn = dtls_endpoint.handshake().await?;
let dtls_conn = Arc::new(dtls_conn);
*self
.dtls_connection
.lock()
.unwrap_or_else(|e| e.into_inner()) = Some(dtls_conn.clone());
let sctp_assoc = Arc::new(Association::new(5000, 5000));
*self
.sctp_association
.lock()
.unwrap_or_else(|e| e.into_inner()) = Some(sctp_assoc.clone());
let dc_manager = DataChannelManager::new(sctp_assoc, dtls_conn);
*self.dc_manager.lock().unwrap_or_else(|e| e.into_inner()) = Some(dc_manager);
*self.state.lock().unwrap_or_else(|e| e.into_inner()) =
PeerConnectionState::Connected;
}
}
}
Ok(())
}
pub async fn add_ice_candidate(&self, candidate: IceCandidate) -> NetResult<()> {
if let Some(ref ice_agent) = *self.ice_agent.lock().unwrap_or_else(|e| e.into_inner()) {
ice_agent.add_remote_candidate(candidate);
}
Ok(())
}
pub async fn create_data_channel(
&self,
label: impl Into<String>,
) -> NetResult<Arc<DataChannel>> {
let config = DataChannelConfig::new(label);
let dc_manager = self.dc_manager.lock().unwrap_or_else(|e| e.into_inner());
let dc_manager = dc_manager
.as_ref()
.ok_or_else(|| NetError::invalid_state("Connection not established"))?;
dc_manager.create_channel(config).await
}
pub fn add_track(&self, track: Arc<MediaTrack>) {
self.tracks
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(track);
}
#[must_use]
pub fn tracks(&self) -> Vec<Arc<MediaTrack>> {
self.tracks
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
#[must_use]
pub fn state(&self) -> PeerConnectionState {
*self.state.lock().unwrap_or_else(|e| e.into_inner())
}
#[must_use]
pub fn signaling_state(&self) -> SignalingState {
*self
.signaling_state
.lock()
.unwrap_or_else(|e| e.into_inner())
}
#[must_use]
pub fn local_description(&self) -> Option<SessionDescription> {
self.local_description
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
#[must_use]
pub fn remote_description(&self) -> Option<SessionDescription> {
self.remote_description
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
pub async fn close(&self) -> NetResult<()> {
*self.state.lock().unwrap_or_else(|e| e.into_inner()) = PeerConnectionState::Closed;
*self
.signaling_state
.lock()
.unwrap_or_else(|e| e.into_inner()) = SignalingState::Closed;
Ok(())
}
}
fn get_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sdp_type() {
assert_eq!(SdpType::Offer.as_str(), "offer");
assert_eq!(SdpType::Answer.as_str(), "answer");
}
#[test]
fn test_peer_connection_new() {
let config = PeerConnectionConfig::new();
let pc = PeerConnection::new(config).expect("should succeed in test");
assert_eq!(pc.state(), PeerConnectionState::New);
assert_eq!(pc.signaling_state(), SignalingState::Stable);
}
#[test]
fn test_peer_connection_config() {
let config = PeerConnectionConfig::new()
.with_ice_server(IceServer::stun("stun:stun.example.com:3478"));
assert_eq!(config.ice_servers.len(), 1);
}
#[test]
fn test_media_track() {
let track = MediaTrack::new("track1", MediaType::Audio, 12345);
assert_eq!(track.id(), "track1");
assert_eq!(track.kind(), MediaType::Audio);
}
#[test]
fn test_session_description_init() {
let desc = SessionDescriptionInit::new(SdpType::Offer, "v=0\r\n");
assert_eq!(desc.sdp_type, SdpType::Offer);
assert_eq!(desc.sdp, "v=0\r\n");
}
}