use crate::core::GunCore;
use crate::dam::Mesh;
use crate::error::{GunError, GunResult};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::ice_transport::ice_credential_type::RTCIceCredentialType;
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::peer_connection::RTCPeerConnection;
#[derive(Clone, Debug)]
pub struct WebRTCOptions {
pub ice_servers: Vec<RTCIceServer>,
pub data_channel: RTCDataChannelInit,
pub max_connections: usize,
pub room: Option<String>,
pub enabled: bool,
}
impl Default for WebRTCOptions {
fn default() -> Self {
let ice_servers = vec![
RTCIceServer {
urls: vec!["stun:stun.l.google.com:19302".to_string()],
username: String::new(),
credential: String::new(),
credential_type: RTCIceCredentialType::Password,
},
RTCIceServer {
urls: vec!["stun:stun.cloudflare.com:3478".to_string()],
username: String::new(),
credential: String::new(),
credential_type: RTCIceCredentialType::Password,
},
];
let data_channel = RTCDataChannelInit {
ordered: Some(false),
max_retransmits: Some(2u16),
..Default::default()
};
Self {
ice_servers,
data_channel,
max_connections: 55, room: None,
enabled: true,
}
}
}
pub struct WebRTCPeer {
pub peer_id: String,
pc: Arc<RTCPeerConnection>,
data_channel: Arc<webrtc::data_channel::RTCDataChannel>,
#[allow(dead_code)] message_sender: tokio::sync::mpsc::UnboundedSender<String>,
}
impl std::fmt::Debug for WebRTCPeer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebRTCPeer")
.field("peer_id", &self.peer_id)
.finish_non_exhaustive()
}
}
impl WebRTCPeer {
pub async fn new(
peer_id: String,
config: &WebRTCOptions,
) -> GunResult<(Self, tokio::sync::mpsc::UnboundedReceiver<String>)> {
let mut m = MediaEngine::default();
m.register_default_codecs()
.map_err(|e| GunError::WebRTC(format!("Failed to register codecs: {}", e)))?;
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut m)
.map_err(|e| GunError::WebRTC(format!("Failed to register interceptors: {}", e)))?;
let api = APIBuilder::new()
.with_media_engine(m)
.with_interceptor_registry(registry)
.build();
let rtc_config = RTCConfiguration {
ice_servers: config.ice_servers.clone(),
..Default::default()
};
let pc = Arc::new(api.new_peer_connection(rtc_config).await.map_err(|e| {
GunError::Network(format!("Failed to create RTCPeerConnection: {}", e))
})?);
let data_channel = pc
.create_data_channel("dc", Some(config.data_channel.clone()))
.await
.map_err(|e| GunError::Network(format!("Failed to create data channel: {}", e)))?;
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let tx_clone = tx.clone();
data_channel.on_message(Box::new(move |msg: DataChannelMessage| {
if msg.is_string {
if let Ok(text) = String::from_utf8(msg.data.to_vec()) {
let _ = tx_clone.send(text);
}
} else {
if let Ok(text) = String::from_utf8(msg.data.to_vec()) {
let _ = tx_clone.send(text);
}
}
Box::pin(async {})
}));
let peer_id_clone = peer_id.clone();
pc.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| {
tracing::info!("WebRTC peer {} connection state: {:?}", peer_id_clone, s);
Box::pin(async {})
}));
let peer_id_for_candidates = peer_id.clone();
pc.on_ice_candidate(Box::new(
move |candidate: Option<webrtc::ice_transport::ice_candidate::RTCIceCandidate>| {
let peer_id_clone = peer_id_for_candidates.clone();
Box::pin(async move {
if let Some(candidate) = candidate {
tracing::debug!(
"ICE candidate for peer {}: {:?}",
peer_id_clone,
candidate
);
}
})
},
));
Ok((
Self {
peer_id,
pc,
data_channel,
message_sender: tx,
},
rx,
))
}
pub async fn send(&self, message: &str) -> GunResult<()> {
let data: bytes::Bytes = message.as_bytes().to_vec().into();
self.data_channel
.send(&data)
.await
.map_err(|e| GunError::Network(format!("Failed to send WebRTC message: {}", e)))?;
Ok(())
}
pub async fn create_offer(&self) -> GunResult<RTCSessionDescription> {
let offer = self
.pc
.create_offer(None)
.await
.map_err(|e| GunError::Network(format!("Failed to create offer: {}", e)))?;
self.pc
.set_local_description(offer.clone())
.await
.map_err(|e| GunError::Network(format!("Failed to set local description: {}", e)))?;
Ok(offer)
}
pub async fn create_answer(&self) -> GunResult<RTCSessionDescription> {
let answer = self
.pc
.create_answer(None)
.await
.map_err(|e| GunError::Network(format!("Failed to create answer: {}", e)))?;
self.pc
.set_local_description(answer.clone())
.await
.map_err(|e| GunError::Network(format!("Failed to set local description: {}", e)))?;
Ok(answer)
}
pub async fn set_remote_description(&self, desc: RTCSessionDescription) -> GunResult<()> {
self.pc
.set_remote_description(desc)
.await
.map_err(|e| GunError::Network(format!("Failed to set remote description: {}", e)))?;
Ok(())
}
pub async fn add_ice_candidate(&self, candidate: RTCIceCandidateInit) -> GunResult<()> {
self.pc
.add_ice_candidate(candidate)
.await
.map_err(|e| GunError::Network(format!("Failed to add ICE candidate: {}", e)))?;
Ok(())
}
pub async fn close(&self) -> GunResult<()> {
self.data_channel
.close()
.await
.map_err(|e| GunError::Network(format!("Failed to close data channel: {}", e)))?;
self.pc
.close()
.await
.map_err(|e| GunError::Network(format!("Failed to close peer connection: {}", e)))?;
Ok(())
}
pub async fn connection_state(&self) -> RTCPeerConnectionState {
self.pc.connection_state()
}
}
pub struct WebRTCManager {
#[allow(dead_code)] core: Arc<GunCore>,
mesh: Arc<Mesh>,
options: WebRTCOptions,
peers: Arc<RwLock<HashMap<String, WebRTCPeer>>>,
pub(crate) pid: String, }
impl WebRTCManager {
pub fn new(core: Arc<GunCore>, mesh: Arc<Mesh>, options: WebRTCOptions) -> Self {
let pid = core.random_id(9);
Self {
core,
mesh,
options,
peers: Arc::new(RwLock::new(HashMap::new())),
pid,
}
}
pub fn pid(&self) -> &str {
&self.pid
}
pub async fn handle_rtc_message(&self, msg: &Value) -> GunResult<()> {
let ok = msg.get("ok").and_then(|v| v.get("rtc"));
if ok.is_none() {
return Ok(());
}
let rtc = ok.unwrap();
let peer_id = rtc
.get("id")
.and_then(|v| v.as_str())
.ok_or_else(|| GunError::InvalidData("Missing RTC peer ID".to_string()))?;
if peer_id == self.pid {
return Ok(());
}
if rtc.get("candidate").is_some() {
self.handle_ice_candidate(peer_id, rtc).await?;
} else if rtc.get("answer").is_some() {
self.handle_answer(peer_id, rtc).await?;
} else if rtc.get("offer").is_some() {
self.handle_offer(peer_id, rtc).await?;
} else if rtc.get("id").is_some() {
self.initiate_connection(peer_id).await?;
}
Ok(())
}
async fn handle_ice_candidate(&self, peer_id: &str, rtc: &Value) -> GunResult<()> {
let peers = self.peers.read().await;
if peers.get(peer_id).is_some() {
let _candidate_json = rtc.get("candidate").unwrap();
tracing::debug!("Received ICE candidate for peer {}", peer_id);
}
Ok(())
}
async fn handle_answer(&self, peer_id: &str, rtc: &Value) -> GunResult<()> {
let peers = self.peers.read().await;
if let Some(peer) = peers.get(peer_id) {
let answer_json = rtc.get("answer").unwrap();
let sdp_str = answer_json
.get("sdp")
.and_then(|v| v.as_str())
.unwrap_or("");
let _sdp_type = answer_json
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("answer");
let desc = RTCSessionDescription::answer(sdp_str.to_string())
.map_err(|e| GunError::WebRTC(format!("Failed to parse answer SDP: {}", e)))?;
peer.set_remote_description(desc).await?;
}
Ok(())
}
async fn handle_offer(&self, peer_id: &str, rtc: &Value) -> GunResult<()> {
let should_create = {
let peers = self.peers.read().await;
!peers.contains_key(peer_id)
};
if should_create {
let peer_id_for_task = peer_id.to_string();
let options_clone = self.options.clone();
let (peer, mut rx) = WebRTCPeer::new(peer_id_for_task.clone(), &options_clone).await?;
let mesh_clone = self.mesh.clone();
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if let Err(e) = mesh_clone.hear(&msg, None).await {
tracing::error!("Error forwarding WebRTC message to mesh: {}", e);
}
}
});
let mut peers = self.peers.write().await;
peers.insert(peer_id_for_task, peer);
}
let peer_exists = {
let peers = self.peers.read().await;
peers.get(peer_id).is_some()
};
if peer_exists {
let offer_json = rtc.get("offer").unwrap();
let sdp_str = offer_json
.get("sdp")
.and_then(|v| v.as_str())
.map(|s| s.replace("\\r\\n", "\r\n"))
.unwrap_or_default();
let desc = RTCSessionDescription::offer(sdp_str)
.map_err(|e| GunError::WebRTC(format!("Failed to parse offer SDP: {}", e)))?;
let peer_id_clone = peer_id.to_string();
let peers = self.peers.read().await;
if let Some(peer) = peers.get(peer_id) {
peer.set_remote_description(desc).await?;
let answer = peer.create_answer().await?;
drop(peers); self.send_rtc_message(&peer_id_clone, "answer", &answer)
.await?;
}
}
Ok(())
}
async fn initiate_connection(&self, peer_id: &str) -> GunResult<()> {
let should_create = {
let peers = self.peers.read().await;
!peers.contains_key(peer_id) && peers.len() < self.options.max_connections
};
if !should_create {
let peers = self.peers.read().await;
if peers.contains_key(peer_id) {
return Ok(());
}
if peers.len() >= self.options.max_connections {
tracing::warn!("WebRTC connection limit reached, skipping peer {}", peer_id);
return Ok(());
}
}
let (peer, mut rx) = WebRTCPeer::new(peer_id.to_string(), &self.options).await?;
let mesh_clone = self.mesh.clone();
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if let Err(e) = mesh_clone.hear(&msg, None).await {
tracing::error!("Error forwarding WebRTC message to mesh: {}", e);
}
}
});
let offer = peer.create_offer().await?;
{
let mut peers = self.peers.write().await;
peers.insert(peer_id.to_string(), peer);
}
self.send_rtc_message(peer_id, "offer", &offer).await?;
Ok(())
}
async fn send_rtc_message(
&self,
peer_id: &str,
msg_type: &str,
sdp: &RTCSessionDescription,
) -> GunResult<()> {
let mut rtc_msg = serde_json::json!({
"ok": {
"rtc": {
"id": self.pid,
}
}
});
match msg_type {
"offer" => {
rtc_msg["ok"]["rtc"]["offer"] = serde_json::json!({
"type": "offer",
"sdp": sdp.sdp
});
}
"answer" => {
rtc_msg["ok"]["rtc"]["answer"] = serde_json::json!({
"type": "answer",
"sdp": sdp.sdp
});
}
_ => {
return Err(GunError::InvalidData(format!(
"Unknown RTC message type: {}",
msg_type
)))
}
}
let msg_str = serde_json::to_string(&rtc_msg).map_err(GunError::Serialization)?;
if self.mesh.get_peer(peer_id).await.is_some() {
self.mesh.send_to_peer_by_id(&msg_str, peer_id).await?;
}
Ok(())
}
pub async fn send_message(&self, peer_id: &str, message: &str) -> GunResult<()> {
let peers = self.peers.read().await;
if let Some(peer) = peers.get(peer_id) {
if matches!(
peer.connection_state().await,
RTCPeerConnectionState::Connected
) {
return peer.send(message).await;
}
}
self.mesh.send_to_peer_by_id(message, peer_id).await
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct RTCMessage {
pub ok: RTCMessageOk,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct RTCMessageOk {
pub rtc: RTCMessageRTC,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct RTCMessageRTC {
pub id: String,
pub offer: Option<RTCMessageSDP>,
pub answer: Option<RTCMessageSDP>,
pub candidate: Option<RTCMessageCandidate>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct RTCMessageSDP {
#[serde(rename = "type")]
pub sdp_type: String,
pub sdp: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct RTCMessageCandidate {
pub candidate: String,
pub sdp_mid: Option<String>,
pub sdp_m_line_index: Option<u16>,
}