use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, broadcast};
use super::super::identity::ConversationId;
use super::session::WebRtcSessionId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SignalingMessage {
Offer {
session_id: WebRtcSessionId,
sdp: String,
},
Answer {
session_id: WebRtcSessionId,
sdp: String,
},
IceCandidate {
session_id: WebRtcSessionId,
candidate: String,
sdp_mid: Option<String>,
sdp_mline_index: Option<u16>,
},
IceGatheringComplete { session_id: WebRtcSessionId },
}
impl SignalingMessage {
pub fn session_id(&self) -> &WebRtcSessionId {
match self {
SignalingMessage::Offer { session_id, .. } => session_id,
SignalingMessage::Answer { session_id, .. } => session_id,
SignalingMessage::IceCandidate { session_id, .. } => session_id,
SignalingMessage::IceGatheringComplete { session_id } => session_id,
}
}
}
#[async_trait]
pub trait WebRtcSignaling: Send + Sync {
async fn send_signaling(
&self,
target: &ConversationId,
message: SignalingMessage,
) -> Result<()>;
async fn receive_signaling(&self, target: &ConversationId) -> Result<Option<SignalingMessage>>;
}
#[derive(Clone)]
pub struct BroadcastSignaling {
tx: broadcast::Sender<(ConversationId, SignalingMessage)>,
}
impl BroadcastSignaling {
pub fn new(capacity: usize) -> Self {
let (tx, _) = broadcast::channel(capacity);
Self { tx }
}
pub fn subscribe(&self) -> broadcast::Receiver<(ConversationId, SignalingMessage)> {
self.tx.subscribe()
}
}
#[async_trait]
impl WebRtcSignaling for BroadcastSignaling {
async fn send_signaling(
&self,
target: &ConversationId,
message: SignalingMessage,
) -> Result<()> {
let _ = self.tx.send((target.clone(), message));
Ok(())
}
async fn receive_signaling(&self, target: &ConversationId) -> Result<Option<SignalingMessage>> {
let mut rx = self.tx.subscribe();
loop {
match rx.recv().await {
Ok((conv, msg)) if &conv == target => return Ok(Some(msg)),
Ok(_) => continue, Err(broadcast::error::RecvError::Closed) => return Ok(None),
Err(broadcast::error::RecvError::Lagged(_)) => continue,
}
}
}
}
pub struct ChannelMessageSignaling {
queues: Arc<RwLock<HashMap<String, broadcast::Sender<SignalingMessage>>>>,
}
pub const SIGNALING_METADATA_KEY: &str = "_bw_webrtc_signaling";
impl ChannelMessageSignaling {
pub fn new() -> Self {
Self {
queues: Arc::new(RwLock::new(HashMap::new())),
}
}
fn conv_key(conv: &ConversationId) -> String {
format!("{}::{}", conv.platform, conv.channel_id)
}
pub async fn inject(&self, conv: &ConversationId, json_payload: &str) -> Result<()> {
let msg: SignalingMessage = serde_json::from_str(json_payload)?;
let key = Self::conv_key(conv);
let queues: tokio::sync::RwLockReadGuard<
'_,
HashMap<String, broadcast::Sender<SignalingMessage>>,
> = self.queues.read().await;
if let Some(tx) = queues.get(&key) {
let _ = tx.send(msg);
}
Ok(())
}
}
impl Default for ChannelMessageSignaling {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl WebRtcSignaling for ChannelMessageSignaling {
async fn send_signaling(
&self,
_target: &ConversationId,
_message: SignalingMessage,
) -> Result<()> {
Err(anyhow::anyhow!(
"ChannelMessageSignaling outbound sending must be handled by the channel adapter; \
serialize the SignalingMessage as JSON and send it via Channel::send_message"
))
}
async fn receive_signaling(&self, target: &ConversationId) -> Result<Option<SignalingMessage>> {
let key = Self::conv_key(target);
let tx = {
let mut queues = self.queues.write().await;
queues
.entry(key)
.or_insert_with(|| broadcast::channel(64).0)
.clone()
};
let mut rx = tx.subscribe();
match rx.recv().await {
Ok(msg) => Ok(Some(msg)),
Err(broadcast::error::RecvError::Closed) => Ok(None),
Err(broadcast::error::RecvError::Lagged(_)) => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
fn test_conv() -> ConversationId {
ConversationId {
platform: "test".to_string(),
channel_id: "chan-1".to_string(),
server_id: None,
}
}
fn test_session_id() -> WebRtcSessionId {
WebRtcSessionId(Uuid::new_v4())
}
#[test]
fn signaling_message_serde_roundtrip() {
let msg = SignalingMessage::IceCandidate {
session_id: test_session_id(),
candidate: "candidate:1 1 UDP 2130706431 192.168.1.1 5000 typ host".to_string(),
sdp_mid: Some("audio".to_string()),
sdp_mline_index: Some(0),
};
let json = serde_json::to_string(&msg).unwrap();
let rt: SignalingMessage = serde_json::from_str(&json).unwrap();
match rt {
SignalingMessage::IceCandidate { sdp_mid, .. } => {
assert_eq!(sdp_mid, Some("audio".to_string()));
}
_ => panic!("wrong variant"),
}
}
#[tokio::test]
async fn broadcast_signaling_loopback() {
let sig = BroadcastSignaling::new(16);
let conv = test_conv();
let sid = test_session_id();
let sig2 = sig.clone();
let conv2 = conv.clone();
let sid2 = sid.clone();
let handle =
tokio::spawn(async move { sig2.receive_signaling(&conv2).await.unwrap().unwrap() });
tokio::task::yield_now().await;
sig.send_signaling(
&conv,
SignalingMessage::Offer {
session_id: sid,
sdp: "v=0\r\n...".to_string(),
},
)
.await
.unwrap();
let received = handle.await.unwrap();
assert_eq!(received.session_id(), &sid2);
}
}