use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::net::SocketAddr;
use std::str::FromStr;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum SignalingError {
#[error("Invalid SDP: {0}")]
InvalidSdp(String),
#[error("Session not found: {0}")]
SessionNotFound(String),
#[error("Transport error: {0}")]
TransportError(String),
}
#[async_trait]
pub trait SignalingTransport: Send + Sync {
type PeerId: Clone + Send + Sync + fmt::Debug + fmt::Display + FromStr;
type Error: std::error::Error + Send + Sync + 'static;
async fn send_message(
&self,
peer: &Self::PeerId,
message: SignalingMessage,
) -> Result<(), Self::Error>;
async fn receive_message(&self) -> Result<(Self::PeerId, SignalingMessage), Self::Error>;
async fn discover_peer_endpoint(
&self,
peer: &Self::PeerId,
) -> Result<Option<SocketAddr>, Self::Error>;
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum SignalingMessage {
Offer {
session_id: String,
sdp: String,
quic_endpoint: Option<SocketAddr>,
},
Answer {
session_id: String,
sdp: String,
quic_endpoint: Option<SocketAddr>,
},
IceCandidate {
session_id: String,
candidate: String,
sdp_mid: Option<String>,
sdp_mline_index: Option<u16>,
},
IceComplete {
session_id: String,
},
Bye {
session_id: String,
reason: Option<String>,
},
}
impl SignalingMessage {
#[must_use]
pub fn session_id(&self) -> &str {
match self {
Self::Offer { session_id, .. }
| Self::Answer { session_id, .. }
| Self::IceCandidate { session_id, .. }
| Self::IceComplete { session_id }
| Self::Bye { session_id, .. } => session_id,
}
}
}
pub struct SignalingHandler<T: SignalingTransport> {
transport: std::sync::Arc<T>,
}
impl<T: SignalingTransport> SignalingHandler<T> {
#[must_use]
pub fn new(transport: std::sync::Arc<T>) -> Self {
Self { transport }
}
pub async fn send_message(
&self,
peer: &T::PeerId,
message: SignalingMessage,
) -> Result<(), T::Error> {
self.transport.send_message(peer, message).await
}
pub async fn receive_message(&self) -> Result<(T::PeerId, SignalingMessage), T::Error> {
self.transport.receive_message().await
}
pub async fn discover_peer_endpoint(
&self,
peer: &T::PeerId,
) -> Result<Option<std::net::SocketAddr>, T::Error> {
self.transport.discover_peer_endpoint(peer).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
struct MockTransport {
messages: Mutex<VecDeque<(String, SignalingMessage)>>,
}
#[derive(Debug)]
struct MockError;
impl std::fmt::Display for MockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Mock error")
}
}
impl std::error::Error for MockError {}
impl MockTransport {
fn new() -> Self {
Self {
messages: Mutex::new(VecDeque::new()),
}
}
fn add_message(&self, peer: String, message: SignalingMessage) {
self.messages.lock().unwrap().push_back((peer, message));
}
}
#[async_trait]
impl SignalingTransport for MockTransport {
type PeerId = String;
type Error = MockError;
async fn send_message(
&self,
peer: &String,
message: SignalingMessage,
) -> Result<(), MockError> {
self.messages.lock().unwrap().push_back((peer.clone(), message));
Ok(())
}
async fn receive_message(&self) -> Result<(String, SignalingMessage), MockError> {
if let Some((peer, message)) = self.messages.lock().unwrap().pop_front() {
Ok((peer, message))
} else {
Err(MockError)
}
}
async fn discover_peer_endpoint(
&self,
_peer: &String,
) -> Result<Option<std::net::SocketAddr>, MockError> {
Ok(Some("127.0.0.1:8080".parse().unwrap()))
}
}
#[tokio::test]
async fn test_signaling_handler_send_message() {
let transport = Arc::new(MockTransport::new());
let handler = SignalingHandler::new(transport.clone());
let message = SignalingMessage::Offer {
session_id: "test-session".to_string(),
sdp: "test-sdp".to_string(),
quic_endpoint: None,
};
let result = handler.send_message(&"peer1".to_string(), message.clone()).await;
assert!(result.is_ok());
let received = transport.messages.lock().unwrap().pop_front();
assert_eq!(received, Some(("peer1".to_string(), message)));
}
#[tokio::test]
async fn test_signaling_handler_receive_message() {
let transport = Arc::new(MockTransport::new());
let handler = SignalingHandler::new(transport.clone());
let message = SignalingMessage::Answer {
session_id: "test-session".to_string(),
sdp: "test-sdp".to_string(),
quic_endpoint: None,
};
transport.add_message("peer1".to_string(), message.clone());
let result = handler.receive_message().await;
assert!(result.is_ok());
let (peer, received_message) = result.unwrap();
assert_eq!(peer, "peer1");
assert_eq!(received_message, message);
}
#[tokio::test]
async fn test_signaling_handler_discover_endpoint() {
let transport = Arc::new(MockTransport::new());
let handler = SignalingHandler::new(transport);
let result = handler.discover_peer_endpoint(&"peer1".to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some("127.0.0.1:8080".parse().unwrap()));
}
}