use anyhow::{Context, Result};
use axum::{
extract::{
ws::{Message, WebSocket},
State, WebSocketUpgrade,
},
response::Response,
routing::get,
Router,
};
use futures_util::{SinkExt, StreamExt};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tower_http::cors::CorsLayer;
use tracing::{debug, error, info, warn};
use crate::webrtc::{SignalingMessage, WebRTCConnection, WebRTCServer};
const MAX_CONNECTIONS: usize = 1;
struct ConnectionHandle {
shutdown_tx: mpsc::Sender<()>,
}
#[derive(Clone)]
pub struct SignalingServerState {
pub webrtc_server: Arc<WebRTCServer>,
pub active_connections: Arc<AtomicUsize>,
current_connection: Arc<Mutex<Option<ConnectionHandle>>>,
}
pub async fn start_signaling_server(port: u16, webrtc_server: Arc<WebRTCServer>) -> Result<()> {
info!(
"🔧 Starting WebRTC signaling server on port {} (single-connection mode)",
port
);
let state = SignalingServerState {
webrtc_server,
active_connections: Arc::new(AtomicUsize::new(0)),
current_connection: Arc::new(Mutex::new(None)),
};
let app = Router::new()
.route("/webrtc", get(websocket_handler))
.layer(CorsLayer::permissive())
.with_state(state);
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.context("Failed to bind signaling server")?;
info!(
"✅ WebRTC signaling server listening on {} (max connections: {})",
addr, MAX_CONNECTIONS
);
axum::serve(listener, app).await.context("Signaling server error")?;
Ok(())
}
async fn websocket_handler(ws: WebSocketUpgrade, State(state): State<SignalingServerState>) -> Response {
let current_connections = state.active_connections.load(Ordering::Relaxed);
info!(
"🔗 New WebRTC signaling connection (current active: {}, max: {})",
current_connections, MAX_CONNECTIONS
);
ws.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: SignalingServerState) {
{
let mut current = state.current_connection.lock().await;
if let Some(old_handle) = current.take() {
warn!("⚠️ Disconnecting previous WebRTC connection (new connection arrived)");
let _ = old_handle.shutdown_tx.try_send(());
}
}
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
let count = state.active_connections.fetch_add(1, Ordering::Relaxed) + 1;
info!("✅ WebRTC signaling client connected (active: {})", count);
let (mut sender, mut receiver) = socket.split();
if let Err(e) = sender
.send(Message::Text(
r#"{"type":"welcome","message":"WebRTC signaling server ready"}"#.to_string(),
))
.await
{
error!("Failed to send welcome message: {}", e);
return;
}
info!("🔗 Creating new WebRTC connection for client");
let connection = match state.webrtc_server.create_connection().await {
Ok(conn) => Arc::new(conn),
Err(e) => {
error!("❌ Failed to create WebRTC connection: {}", e);
let error_msg = format!(r#"{{"type":"error","message":"Failed to create connection: {}"}}"#, e);
let _ = sender.send(Message::Text(error_msg)).await;
return;
}
};
info!("⏳ Waiting for 'ready' signal from client before sending offer...");
{
let mut current = state.current_connection.lock().await;
*current = Some(ConnectionHandle {
shutdown_tx: shutdown_tx.clone(),
});
}
let conn_for_streaming = Arc::clone(&connection);
tokio::spawn(async move {
if let Err(e) = conn_for_streaming.run_streaming_loop().await {
error!("❌ Streaming loop error: {}", e);
}
});
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!("🛑 Received shutdown signal, closing connection");
break;
}
msg = receiver.next() => {
let Some(msg) = msg else {
break;
};
match handle_signaling_message(msg, &connection, &mut sender).await {
Ok(true) => continue,
Ok(false) => break,
Err(e) => {
error!("Message handling error: {}", e);
break;
}
}
}
}
}
{
let mut current = state.current_connection.lock().await;
*current = None;
}
info!("🧹 Cleaning up WebRTC connection");
if let Err(e) = connection.peer_connection().close().await {
error!("❌ Error closing peer connection: {}", e);
}
let count = state.active_connections.fetch_sub(1, Ordering::Relaxed) - 1;
info!("👋 WebRTC signaling connection closed (active: {})", count);
}
async fn handle_signaling_message(
msg: Result<Message, axum::Error>,
connection: &Arc<WebRTCConnection>,
sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
) -> Result<bool> {
match msg {
Ok(Message::Text(text)) => {
debug!("📨 Received signaling message: {}", text);
match serde_json::from_str::<SignalingMessage>(&text) {
Ok(SignalingMessage::Ready) => {
info!("📨 Received 'ready' signal from browser - sending offer");
match connection.create_offer().await {
Ok(offer_sdp) => {
let offer_msg = SignalingMessage::Offer { sdp: offer_sdp };
match serde_json::to_string(&offer_msg) {
Ok(json) => {
if let Err(e) = sender.send(Message::Text(json)).await {
error!("❌ Failed to send SDP offer: {}", e);
} else {
info!("✅ Sent SDP offer to browser");
}
}
Err(e) => {
error!("Failed to serialize offer: {}", e);
}
}
}
Err(e) => {
error!("❌ Failed to create SDP offer: {}", e);
let error_msg = format!(r#"{{"type":"error","message":"Failed to create offer: {}"}}"#, e);
let _ = sender.send(Message::Text(error_msg)).await;
}
}
}
Ok(SignalingMessage::Answer { sdp }) => {
info!("📨 Received SDP answer from browser");
match connection.handle_answer(sdp).await {
Ok(()) => {
info!("✅ WebRTC connection established");
}
Err(e) => {
error!("❌ Failed to handle SDP answer: {}", e);
let error_msg =
format!(r#"{{"type":"error","message":"Failed to process answer: {}"}}"#, e);
let _ = sender.send(Message::Text(error_msg)).await;
}
}
}
Ok(SignalingMessage::Offer { .. }) => {
warn!("⚠️ Received offer from browser but only server-initiated flow is supported");
let error_msg = r#"{"type":"error","message":"Client-initiated SDP offers not supported. Use server-initiated flow."}"#;
let _ = sender.send(Message::Text(error_msg.to_string())).await;
}
Ok(SignalingMessage::IceCandidate {
candidate,
sdp_mid,
sdp_mline_index,
}) => {
debug!(
"🧊 Received ICE candidate: {} ({}:{})",
candidate, sdp_mid, sdp_mline_index
);
if let Err(e) = connection.add_ice_candidate(candidate, sdp_mid, sdp_mline_index).await {
error!("❌ Failed to add ICE candidate: {}", e);
} else {
debug!("✅ ICE candidate added successfully");
}
}
Err(e) => {
warn!("⚠️ Invalid signaling message: {}", e);
let error_msg = format!(r#"{{"type":"error","message":"Invalid message format: {}"}}"#, e);
let _ = sender.send(Message::Text(error_msg)).await;
}
}
Ok(true) }
Ok(Message::Close(_)) => {
info!("👋 WebRTC signaling client disconnected");
Ok(false) }
Ok(Message::Ping(data)) => {
if let Err(e) = sender.send(Message::Pong(data)).await {
error!("Failed to send pong: {}", e);
return Ok(false); }
Ok(true) }
Ok(_) => {
debug!("Received non-text WebSocket message");
Ok(true) }
Err(e) => {
error!("WebSocket error: {}", e);
Ok(false) }
}
}