use crate::webrtc::server::WebRTCSFU;
use crate::webrtc::signaling::{
IceCandidateParams, JoinRoomParams, PublishParams, SignalingMessage, SignalingMethod,
SignalingParams, SubscribeParams,
};
use crate::HttpRequest;
use crate::Websocket;
use crate::WsFrame;
use anyhow::Result;
use serde_json;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
pub struct WebRtcSignalingHandler {
sfu: Arc<WebRTCSFU>,
peer_senders: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
}
impl WebRtcSignalingHandler {
pub fn new(sfu: Arc<WebRTCSFU>) -> Self {
let peer_senders = Arc::new(Mutex::new(HashMap::new()));
Self { sfu, peer_senders }
}
pub async fn handle_websocket(&self, req: &mut HttpRequest) -> Result<()> {
let mut ws = req.upgrade_websocket().await?;
{
let peer_senders = self.peer_senders.clone();
let sfu_peer_senders = self.sfu.get_peer_senders();
let peer_senders_guard = sfu_peer_senders.read().await;
if peer_senders_guard.is_none() {
drop(peer_senders_guard);
self.sfu.set_peer_senders(peer_senders);
println!("Peer senders已注册到SFU");
} else {
drop(peer_senders_guard);
println!("Peer senders已经注册,跳过");
}
}
self.handle_ws_messages(&mut ws).await
}
async fn handle_ws_messages(&self, ws: &mut Websocket) -> Result<()> {
println!("WebRTC WebSocket连接建立");
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
loop {
tokio::select! {
result = ws.recv() => {
match result? {
WsFrame::Text(text) => {
if let Err(e) = self.handle_signaling_message(ws, &text, &tx).await {
eprintln!("处理信令消息失败: {e}");
}
}
WsFrame::Binary(_) => {
continue;
}
}
}
Some(message) = rx.recv() => {
if let Err(e) = ws.send_text(&message).await {
eprintln!("发送WebSocket消息失败: {e}");
return Ok(());
}
}
}
}
}
async fn handle_signaling_message(
&self,
ws: &mut Websocket,
text: &str,
ws_tx: &mpsc::UnboundedSender<String>,
) -> Result<()> {
let msg: SignalingMessage = match serde_json::from_str(text) {
Ok(m) => m,
Err(e) => {
eprintln!("解析信令消息失败: {e}");
return Err(anyhow::anyhow!("Invalid signaling message"));
}
};
let request_id = msg.id;
match msg.method {
SignalingMethod::JoinRoom => {
if let SignalingParams::JoinRoom(params) = msg.params {
self.handle_join_room(ws, params, request_id, ws_tx).await?;
}
}
SignalingMethod::Publish => {
if let SignalingParams::Publish(params) = msg.params {
self.handle_publish(ws, params, request_id).await?;
}
}
SignalingMethod::Subscribe => {
if let SignalingParams::Subscribe(params) = msg.params {
self.handle_subscribe(ws, params, request_id, ws_tx).await?;
}
}
SignalingMethod::IceCandidate => {
if let SignalingParams::IceCandidate(params) = msg.params {
self.handle_ice_candidate(¶ms).await?;
}
}
SignalingMethod::LeaveRoom => {
if let SignalingParams::LeaveRoom(params) = msg.params {
println!("Peer {} 离开房间 {}", params.peer_id, params.room_id);
if let Some(room) = self.sfu.get_room(¶ms.room_id).await {
if let Err(e) = room.broadcast_peer_left(¶ms.peer_id).await {
eprintln!("广播peer离开事件失败: {e}");
}
if let Err(e) = room.remove_peer(¶ms.peer_id).await {
eprintln!("从房间移除peer失败: {e}");
}
let mut forwarders = room.rtp_forwarders.lock().await;
forwarders.remove(&format!("{}_rtp", params.peer_id));
drop(forwarders);
let mut tracks = room.tracks.lock().await;
tracks.remove(¶ms.peer_id);
drop(tracks);
}
let mut senders = self.peer_senders.lock().await;
senders.remove(¶ms.peer_id);
}
}
_ => {
println!("未处理的信令方法: {:?}", msg.method);
}
}
Ok(())
}
async fn handle_join_room(
&self,
ws: &mut Websocket,
params: JoinRoomParams,
request_id: Option<u64>,
ws_tx: &mpsc::UnboundedSender<String>,
) -> Result<()> {
println!("Peer {} 加入房间 {}", params.peer_id, params.room_id);
{
let mut senders = self.peer_senders.lock().await;
senders.insert(params.peer_id.clone(), ws_tx.clone());
println!("Peer {} 的消息通道已注册", params.peer_id);
}
if self.sfu.get_room(¶ms.room_id).await.is_none() {
self.sfu.create_room(¶ms.room_id, None).await?;
}
if let Some(room) = self.sfu.get_room(¶ms.room_id).await {
let peers = room.get_peers_info().await;
let response = serde_json::json!({
"jsonrpc": "2.0",
"method": "join_room",
"params": {
"room_id": params.room_id,
"peer_id": params.peer_id,
"peers": peers,
},
"id": request_id,
});
ws.send_text(&response.to_string()).await?;
room.broadcast_peer_joined(¶ms.peer_id).await?;
}
Ok(())
}
async fn handle_publish(
&self,
ws: &mut Websocket,
params: PublishParams,
request_id: Option<u64>,
) -> Result<()> {
println!("Peer {} 开始推流到房间 {}", params.peer_id, params.room_id);
let answer_sdp = self
.sfu
.handle_offer(¶ms.peer_id, ¶ms.room_id, ¶ms.sdp)
.await?;
let response = serde_json::json!({
"jsonrpc": "2.0",
"method": "answer",
"params": {
"room_id": params.room_id,
"peer_id": params.peer_id,
"sdp": answer_sdp,
},
"id": request_id,
});
ws.send_text(&response.to_string()).await?;
Ok(())
}
async fn handle_subscribe(
&self,
ws: &mut Websocket,
params: SubscribeParams,
request_id: Option<u64>,
ws_tx: &mpsc::UnboundedSender<String>,
) -> Result<()> {
println!(
"Peer {} 订阅 {} 的流 (房间: {})",
params.subscriber_id, params.publisher_id, params.room_id
);
{
let mut senders = self.peer_senders.lock().await;
senders.insert(params.subscriber_id.clone(), ws_tx.clone());
println!("订阅者 {} 的消息通道已注册", params.subscriber_id);
}
if let Some(offer_sdp) = ¶ms.sdp {
let answer_sdp = self
.sfu
.handle_subscribe(
¶ms.subscriber_id,
¶ms.room_id,
¶ms.publisher_id,
offer_sdp,
)
.await?;
let response = serde_json::json!({
"jsonrpc": "2.0",
"method": "answer",
"params": {
"room_id": params.room_id,
"peer_id": params.subscriber_id,
"sdp": answer_sdp,
"type": "answer",
},
"id": request_id,
});
ws.send_text(&response.to_string()).await?;
} else {
let response = serde_json::json!({
"jsonrpc": "2.0",
"error": {
"code": -32602,
"message": "Subscribe requires SDP offer in params.sdp"
},
"id": request_id,
});
ws.send_text(&response.to_string()).await?;
}
Ok(())
}
async fn handle_ice_candidate(&self, params: &IceCandidateParams) -> Result<()> {
println!(
"收到ICE候选: Peer {} 在房间 {}",
params.peer_id, params.room_id
);
if let Some(room) = self.sfu.get_room(¶ms.room_id).await {
let peers = room.peers.lock().await;
if let Some(peer) = peers.get(¶ms.peer_id) {
let ice_candidate = webrtc::ice_transport::ice_candidate::RTCIceCandidateInit {
candidate: params.candidate.clone(),
sdp_mid: if params.sdp_mid.is_empty() {
None
} else {
Some(params.sdp_mid.clone())
},
sdp_mline_index: if params.sdp_mline_index == 0 {
None
} else {
Some(params.sdp_mline_index)
},
username_fragment: None,
};
if let Err(e) = peer.peer_connection.add_ice_candidate(ice_candidate).await {
eprintln!("添加ICE候选失败: {e}");
}
}
drop(peers);
}
Ok(())
}
pub fn get_peer_senders(&self) -> Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>> {
self.peer_senders.clone()
}
}
#[allow(dead_code)]
pub async fn create_webrtc_websocket_handler(
req: &mut HttpRequest,
sfu: Arc<WebRTCSFU>,
) -> Result<()> {
let handler = WebRtcSignalingHandler::new(sfu);
handler.handle_websocket(req).await
}