use anyhow::Result;
use bytes::Bytes;
use rustrtc::media::track::MediaStreamTrack;
use rustrtc::{MediaKind, PeerConnection, RtcConfiguration, TransceiverDirection};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
use webrtc::api::APIBuilder;
use webrtc::api::media_engine::MediaEngine;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration as WebrtcConfiguration;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::rtp::header::{Extension, Header};
use webrtc::rtp::packet::Packet;
use webrtc::rtp_transceiver::RTCRtpEncodingParameters;
use webrtc::rtp_transceiver::RTCRtpTransceiverInit;
use webrtc::rtp_transceiver::rtp_codec::RTCRtpHeaderExtensionCapability;
use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType};
use webrtc::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection;
use webrtc::track::track_local::TrackLocalWriter;
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
#[tokio::test]
async fn test_simulcast_ingest_and_switch() -> Result<()> {
rustls::crypto::CryptoProvider::install_default(rustls::crypto::ring::default_provider()).ok();
let _ = env_logger::builder().is_test(true).try_init();
let rust_config = RtcConfiguration::default();
let rust_pc = PeerConnection::new(rust_config);
let transceiver = rust_pc.add_transceiver(MediaKind::Video, TransceiverDirection::RecvOnly);
let mut m = MediaEngine::default();
m.register_default_codecs()?;
m.register_header_extension(
RTCRtpHeaderExtensionCapability {
uri: "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id".to_owned(),
},
RTPCodecType::Video,
Some(RTCRtpTransceiverDirection::Sendrecv),
)?;
m.register_header_extension(
RTCRtpHeaderExtensionCapability {
uri: "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id".to_owned(),
},
RTPCodecType::Video,
Some(RTCRtpTransceiverDirection::Sendrecv),
)?;
m.register_header_extension(
RTCRtpHeaderExtensionCapability {
uri: "urn:ietf:params:rtp-hdrext:sdes:mid".to_owned(),
},
RTPCodecType::Video,
Some(RTCRtpTransceiverDirection::Sendrecv),
)?;
let registry = Registry::new();
let api = APIBuilder::new()
.with_media_engine(m)
.with_interceptor_registry(registry)
.build();
let webrtc_config = WebrtcConfiguration::default();
let webrtc_pc = api.new_peer_connection(webrtc_config).await?;
let codec = RTCRtpCodecCapability {
mime_type: webrtc::api::media_engine::MIME_TYPE_VP8.to_owned(),
clock_rate: 90000,
channels: 0,
..Default::default()
};
let video_track = Arc::new(TrackLocalStaticRTP::new(
codec,
"video".to_string(),
"webrtc_stream".to_string(),
));
let _webrtc_transceiver = webrtc_pc
.add_transceiver_from_track(
Arc::clone(&video_track)
as Arc<dyn webrtc::track::track_local::TrackLocal + Send + Sync>,
Some(RTCRtpTransceiverInit {
direction: RTCRtpTransceiverDirection::Sendonly,
send_encodings: vec![
RTCRtpEncodingParameters {
rid: "hi".into(),
..Default::default()
},
RTCRtpEncodingParameters {
rid: "mid".into(),
..Default::default()
},
RTCRtpEncodingParameters {
rid: "lo".into(),
..Default::default()
},
],
}),
)
.await?;
let offer = webrtc_pc.create_offer(None).await?;
let mut gather_complete = webrtc_pc.gathering_complete_promise().await;
webrtc_pc.set_local_description(offer.clone()).await?;
let _ = gather_complete.recv().await;
let offer = webrtc_pc.local_description().await.unwrap();
println!("WebRTC Offer SDP:\n{}", offer.sdp);
let mut offer_sdp = offer.sdp;
if offer_sdp.contains("m=video") {
offer_sdp = offer_sdp.replace("a=sendonly", "a=sendonly\r\na=rid:hi send\r\na=rid:mid send\r\na=rid:lo send\r\na=simulcast:send hi;mid;lo");
}
let rust_offer = rustrtc::SessionDescription::parse(rustrtc::SdpType::Offer, &offer_sdp)?;
rust_pc.set_remote_description(rust_offer).await?;
let answer = rust_pc.create_answer().await?;
rust_pc.set_local_description(answer.clone())?;
let answer_sdp = answer.to_sdp_string();
println!("RustRTC Answer SDP:\n{}", answer_sdp);
let webrtc_answer = RTCSessionDescription::answer(answer_sdp)?;
webrtc_pc.set_remote_description(webrtc_answer).await?;
rust_pc.wait_for_connected().await?;
let video_track_clone = video_track.clone();
tokio::spawn(async move {
let mut sequence_number = 0u16;
loop {
tokio::time::sleep(Duration::from_millis(33)).await;
for rid in &["hi", "mid", "lo"] {
let packet = Packet {
header: Header {
version: 2,
payload_type: 96,
sequence_number,
timestamp: sequence_number as u32 * 3000,
ssrc: 0,
extensions: vec![Extension {
id: 1, payload: Bytes::from(rid.as_bytes().to_vec()),
}],
..Default::default()
},
payload: vec![0u8; 100].into(),
};
if let Err(e) = video_track_clone.write_rtp(&packet).await {
println!("Failed to write RTP: {}", e);
return;
}
}
sequence_number = sequence_number.wrapping_add(1);
}
});
let receiver = transceiver.receiver().unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
let simulcast_rids = receiver.get_simulcast_rids();
println!("Received Simulcast RIDs: {:?}", simulcast_rids);
assert_eq!(simulcast_rids.len(), 3);
assert!(simulcast_rids.contains(&"hi".to_string()));
assert!(simulcast_rids.contains(&"mid".to_string()));
assert!(simulcast_rids.contains(&"lo".to_string()));
let hi_track = receiver.simulcast_track("hi").expect("hi track not found");
let mid_track = receiver
.simulcast_track("mid")
.expect("mid track not found");
let lo_track = receiver.simulcast_track("lo").expect("lo track not found");
println!(
"Track IDs: hi={}, mid={}, lo={}",
hi_track.id(),
mid_track.id(),
lo_track.id()
);
let selector = rustrtc::media::track::SelectorTrack::new(lo_track.clone());
println!("Switching to mid");
selector.switch_to(mid_track).await.unwrap();
println!("Waiting for packet on mid...");
match timeout(Duration::from_secs(2), selector.recv()).await {
Ok(Ok(_)) => println!("Received packet on mid"),
Ok(Err(e)) => println!("Error receiving on mid: {}", e),
Err(_) => println!("Timeout receiving on mid"),
}
println!("Switching to hi");
selector.switch_to(hi_track).await.unwrap();
println!("Waiting for packet on hi...");
match timeout(Duration::from_secs(2), selector.recv()).await {
Ok(Ok(_)) => println!("Received packet on hi"),
Ok(Err(e)) => println!("Error receiving on hi: {}", e),
Err(_) => println!("Timeout receiving on hi"),
}
rust_pc.close();
webrtc_pc.close().await?;
Ok(())
}