use crate::webrtc::peer::Peer;
use anyhow::{anyhow, Result};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use webrtc::track::track_local::TrackLocal;
use webrtc::track::track_local::TrackLocalWriter;
pub struct Room {
pub room_id: String,
pub max_peers: u32,
pub peers: Arc<Mutex<HashMap<String, Arc<Peer>>>>,
pub tracks: Arc<Mutex<HashMap<String, Vec<Arc<dyn TrackLocal + Send + Sync>>>>>,
pub rtp_forwarders: Arc<
Mutex<HashMap<String, Arc<tokio::sync::broadcast::Sender<webrtc::rtp::packet::Packet>>>>,
>,
}
impl Room {
pub fn new(room_id: &str, max_peers: u32) -> Self {
Self {
room_id: room_id.to_string(),
max_peers,
peers: Arc::new(Mutex::new(HashMap::new())),
tracks: Arc::new(Mutex::new(HashMap::new())),
rtp_forwarders: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn create_rtp_forwarder(&self, peer_id: &str) -> Result<()> {
let forwarder_key = format!("{peer_id}_rtp");
let mut forwarders = self.rtp_forwarders.lock().await;
if forwarders.contains_key(&forwarder_key) {
println!("RTP转发器已存在: {peer_id}");
return Ok(());
}
let (tx, _) = tokio::sync::broadcast::channel::<webrtc::rtp::packet::Packet>(1000);
forwarders.insert(forwarder_key, Arc::new(tx));
println!("为发布者 {peer_id} 创建RTP转发器");
Ok(())
}
pub async fn add_peer(&self, peer_id: &str, peer: Arc<Peer>) -> Result<()> {
let mut peers = self.peers.lock().await;
if peers.len() >= self.max_peers as usize {
return Err(anyhow!("Room is full"));
}
if peers.contains_key(peer_id) {
return Err(anyhow!("Peer already exists"));
}
peers.insert(peer_id.to_string(), peer);
Ok(())
}
pub async fn remove_peer(&self, peer_id: &str) -> Result<()> {
let mut peers = self.peers.lock().await;
if let Some(peer) = peers.remove(peer_id) {
if let Err(e) = peer.peer_connection.close().await {
eprintln!("关闭Peer {peer_id} 的PeerConnection失败: {e}");
}
} else {
return Err(anyhow!("Peer not found"));
}
let mut tracks = self.tracks.lock().await;
tracks.remove(peer_id);
let mut forwarders = self.rtp_forwarders.lock().await;
forwarders.remove(&format!("{peer_id}_rtp"));
Ok(())
}
pub async fn forward_track(&self, from_peer: &str, to_peer: &str) -> Result<()> {
let (to_peer_pc, broadcast_sender) = {
let peers = self.peers.lock().await;
let _from_peer_obj = peers
.get(from_peer)
.ok_or_else(|| anyhow::anyhow!("Publisher peer not found: {from_peer}"))?;
let to_peer_obj = peers
.get(to_peer)
.ok_or_else(|| anyhow::anyhow!("Subscriber peer not found: {to_peer}"))?;
println!("准备转发 {from_peer} 的轨道给 {to_peer}");
let forwarders = self.rtp_forwarders.lock().await;
let forwarder_key = format!("{from_peer}_rtp");
if !forwarders.contains_key(&forwarder_key) {
drop(forwarders);
drop(peers);
return Err(anyhow::anyhow!(
"RTP forwarder not found for peer: {from_peer}"
));
}
let broadcast_sender = forwarders
.get(&forwarder_key)
.ok_or_else(|| anyhow::anyhow!("RTP forwarder not found for peer: {}", from_peer))?
.clone();
let to_peer_pc = to_peer_obj.peer_connection.clone();
(to_peer_pc, broadcast_sender)
};
let codec = webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability {
mime_type: "video/VP8".to_string(),
..Default::default()
};
let track_local =
webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP::new(
codec,
format!("{from_peer}_video_track"), format!("{from_peer}_video"), );
let track_local_arc = Arc::new(track_local);
let _rtp_sender = to_peer_pc.add_track(track_local_arc.clone()).await?;
println!("已为订阅者 {to_peer} 添加TrackLocal");
let track_local_for_task = track_local_arc.clone();
let subscriber_id = to_peer.to_string();
let from_peer_id = from_peer.to_string();
tokio::spawn(async move {
let mut receiver = broadcast_sender.subscribe();
println!("启动RTP转发任务: {from_peer_id} -> {subscriber_id}");
loop {
match receiver.recv().await {
Ok(rtp_packet) => {
if let Err(e) = track_local_for_task.write(&rtp_packet.payload).await {
eprintln!("写入RTP包失败: {e}");
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
eprintln!("RTP转发滞后,丢失了 {n} 个包");
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
println!("RTP转发器关闭: {from_peer_id} -> {subscriber_id}");
break;
}
}
}
});
Ok(())
}
pub async fn add_tracks(
&self,
peer_id: &str,
tracks: Vec<Arc<dyn TrackLocal + Send + Sync>>,
) -> Result<()> {
let mut track_map = self.tracks.lock().await;
println!("Peer {} 添加了 {} 个轨道", peer_id, tracks.len());
track_map.insert(peer_id.to_string(), tracks);
Ok(())
}
pub async fn get_peers_info(&self) -> Vec<PeerInfo> {
let peers = self.peers.lock().await;
peers
.iter()
.map(|(id, peer)| PeerInfo {
id: id.clone(),
is_publishing: peer.publishing,
})
.collect()
}
pub async fn broadcast_peer_joined(&self, new_peer_id: &str) -> Result<()> {
let peers = self.peers.lock().await;
for (peer_id, peer) in peers.iter() {
if peer_id != new_peer_id {
peer.notify_peer_joined(new_peer_id).await?;
}
}
Ok(())
}
pub async fn broadcast_peer_left(&self, left_peer_id: &str) -> Result<()> {
let peers = self.peers.lock().await;
for (peer_id, peer) in peers.iter() {
if peer_id != left_peer_id {
peer.notify_peer_left(left_peer_id).await?;
}
}
Ok(())
}
pub async fn get_peer(&self, peer_id: &str) -> Option<Arc<Peer>> {
let peers = self.peers.lock().await;
peers.get(peer_id).cloned()
}
pub async fn is_empty(&self) -> bool {
let peers = self.peers.lock().await;
peers.is_empty()
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct PeerInfo {
pub id: String,
pub is_publishing: bool,
}