use anyhow::Result;
use bytes::BytesMut;
use sansio::Protocol;
use shared::{TaggedBytesMut, TransportContext, TransportProtocol};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use tokio::time::timeout;
use rtc::peer_connection::RTCPeerConnectionBuilder;
use rtc::peer_connection::configuration::RTCConfigurationBuilder;
use rtc::peer_connection::configuration::setting_engine::SettingEngine;
use rtc::peer_connection::event::RTCDataChannelEvent;
use rtc::peer_connection::event::RTCPeerConnectionEvent;
use rtc::peer_connection::state::RTCIceConnectionState;
use rtc::peer_connection::state::RTCPeerConnectionState;
use rtc::peer_connection::transport::RTCDtlsRole;
use rtc::peer_connection::transport::RTCIceServer;
use rtc::peer_connection::transport::{CandidateConfig, CandidateHostConfig};
use webrtc::api::APIBuilder;
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MediaEngine;
use webrtc::ice_transport::ice_server::RTCIceServer as WebrtcIceServer;
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::RTCPeerConnection as WebrtcPeerConnection;
use webrtc::peer_connection::configuration::RTCConfiguration as WebrtcRTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState as WebrtcRTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription as WebrtcRTCSessionDescription;
const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(30);
#[tokio::test]
async fn test_data_channel_close_interop() -> Result<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.is_test(true)
.try_init()
.ok();
log::info!("Starting data channel close interop test: RTC sends and closes");
let webrtc_received_messages = Arc::new(Mutex::new(Vec::<String>::new()));
let webrtc_channel_closed = Arc::new(Mutex::new(false));
let mut messages_to_send = 3;
let socket = UdpSocket::bind("127.0.0.1:0").await?;
let local_addr = socket.local_addr()?;
log::info!("RTC peer bound to {}", local_addr);
let mut setting_engine = SettingEngine::default();
setting_engine.set_answering_dtls_role(RTCDtlsRole::Server)?;
let config = RTCConfigurationBuilder::new()
.with_ice_servers(vec![RTCIceServer {
urls: vec!["stun:stun.l.google.com:19302".to_owned()],
..Default::default()
}])
.build();
let mut rtc_pc = RTCPeerConnectionBuilder::new()
.with_configuration(config)
.with_setting_engine(setting_engine)
.build()?;
log::info!("Created RTC peer connection");
let dc_label = "test-channel";
let _rtc_dc = rtc_pc.create_data_channel(dc_label, None)?;
log::info!("RTC created data channel: {}", dc_label);
let candidate = CandidateHostConfig {
base_config: CandidateConfig {
network: "udp".to_owned(),
address: local_addr.ip().to_string(),
port: local_addr.port(),
component: 1,
..Default::default()
},
..Default::default()
}
.new_candidate_host()?;
let local_candidate_init =
rtc::peer_connection::transport::RTCIceCandidate::from(&candidate).to_json()?;
rtc_pc.add_local_candidate(local_candidate_init)?;
let offer = rtc_pc.create_offer(None)?;
log::info!("RTC created offer");
rtc_pc.set_local_description(offer.clone())?;
log::info!("RTC set local description");
let webrtc_offer = WebrtcRTCSessionDescription::offer(offer.sdp.clone())?;
let webrtc_pc = create_webrtc_peer().await?;
log::info!("Created webrtc peer connection");
let webrtc_received_messages_clone = Arc::clone(&webrtc_received_messages);
let webrtc_channel_closed_clone = Arc::clone(&webrtc_channel_closed);
webrtc_pc.on_data_channel(Box::new(move |dc| {
let messages = Arc::clone(&webrtc_received_messages_clone);
let closed = Arc::clone(&webrtc_channel_closed_clone);
Box::pin(async move {
let label = dc.label();
log::info!("WebRTC received data channel: {}", label);
dc.on_open(Box::new(|| {
log::info!("WebRTC data channel opened");
Box::pin(async {})
}));
dc.on_message(Box::new(move |msg| {
let messages = Arc::clone(&messages);
Box::pin(async move {
let data = String::from_utf8(msg.data.to_vec()).unwrap_or_default();
log::info!("WebRTC received message: '{}'", data);
let mut msgs = messages.lock().await;
msgs.push(data);
})
}));
dc.on_close(Box::new(move || {
let closed = Arc::clone(&closed);
Box::pin(async move {
log::info!("WebRTC data channel closed");
let mut is_closed = closed.lock().await;
*is_closed = true;
})
}));
})
}));
webrtc_pc.set_remote_description(webrtc_offer).await?;
log::info!("WebRTC set remote description");
let answer = webrtc_pc.create_answer(None).await?;
log::info!("WebRTC created answer");
webrtc_pc.set_local_description(answer.clone()).await?;
log::info!("WebRTC set local description");
let mut gathering_done = webrtc_pc.gathering_complete_promise().await;
let _ = timeout(Duration::from_secs(5), gathering_done.recv()).await;
let answer_with_candidates = webrtc_pc
.local_description()
.await
.expect("local description should be set");
log::info!("WebRTC answer with candidates ready");
let rtc_answer = rtc::peer_connection::sdp::RTCSessionDescription::answer(
answer_with_candidates.sdp.clone(),
)?;
rtc_pc.set_remote_description(rtc_answer)?;
log::info!("RTC set remote description");
let mut buf = vec![0u8; 2000];
let mut rtc_connected = false;
let mut webrtc_connected = false;
let mut rtc_data_channel_opened = false;
let mut rtc_dc_id: Option<u16> = None;
let mut last_message_time = Instant::now();
let message_interval = Duration::from_millis(500);
let start_time = Instant::now();
let test_timeout = Duration::from_secs(30);
while start_time.elapsed() < test_timeout {
while let Some(msg) = rtc_pc.poll_write() {
match socket.send_to(&msg.message, msg.transport.peer_addr).await {
Ok(n) => {
log::trace!("RTC sent {} bytes to {}", n, msg.transport.peer_addr);
}
Err(err) => {
log::error!("RTC socket write error: {}", err);
}
}
}
while let Some(event) = rtc_pc.poll_event() {
match event {
RTCPeerConnectionEvent::OnIceConnectionStateChangeEvent(state) => {
log::info!("RTC ICE connection state: {}", state);
if state == RTCIceConnectionState::Failed {
return Err(anyhow::anyhow!("RTC ICE connection failed"));
}
if state == RTCIceConnectionState::Connected {
log::info!("RTC ICE connected!");
}
}
RTCPeerConnectionEvent::OnConnectionStateChangeEvent(state) => {
log::info!("RTC peer connection state: {}", state);
if state == RTCPeerConnectionState::Failed {
return Err(anyhow::anyhow!("RTC peer connection failed"));
}
if state == RTCPeerConnectionState::Connected {
log::info!("RTC peer connection connected!");
rtc_connected = true;
}
}
RTCPeerConnectionEvent::OnDataChannel(dc_event) => {
log::info!("RTC data channel event: {:?}", dc_event);
match dc_event {
RTCDataChannelEvent::OnOpen(channel_id) => {
let dc = rtc_pc
.data_channel(channel_id)
.expect("data channel should exist");
log::info!(
"RTC data channel opened: {} (id: {})",
dc.label(),
channel_id
);
rtc_data_channel_opened = true;
rtc_dc_id = Some(channel_id);
last_message_time = Instant::now();
}
RTCDataChannelEvent::OnClose(channel_id) => {
log::info!("RTC data channel {} closed", channel_id);
rtc_data_channel_opened = false;
}
_ => {}
}
}
_ => {}
}
}
if !webrtc_connected
&& webrtc_pc.connection_state() == WebrtcRTCPeerConnectionState::Connected
{
log::info!("WebRTC peer connection connected!");
webrtc_connected = true;
}
if rtc_connected && webrtc_connected && rtc_data_channel_opened {
let elapsed = Instant::now().duration_since(last_message_time);
if elapsed >= message_interval {
if let Some(dc_id) = &rtc_dc_id {
let mut dc = rtc_pc
.data_channel(*dc_id)
.expect("data channel should exist");
if messages_to_send > 0 {
let message = format!("Message #{}", 4 - messages_to_send);
log::info!("RTC sending: '{}'", message);
dc.send_text(message)?;
last_message_time = Instant::now();
messages_to_send -= 1;
} else {
log::info!("RTC finished sending messages, exiting to close connection");
dc.close()?;
}
}
}
}
let webrtc_closed = webrtc_channel_closed.lock().await;
if *webrtc_closed {
log::info!("✅ Test completed successfully!");
log::info!(" Data channel closed (detected by WebRTC)");
let webrtc_msgs = webrtc_received_messages.lock().await;
log::info!(
" WebRTC received {} messages: {:?}",
webrtc_msgs.len(),
webrtc_msgs
);
assert!(
webrtc_msgs.len() >= 3,
"WebRTC should have received at least 3 messages before close"
);
webrtc_pc.close().await?;
rtc_pc.close()?;
return Ok(());
}
let eto = rtc_pc
.poll_timeout()
.unwrap_or(Instant::now() + DEFAULT_TIMEOUT_DURATION);
let delay_from_now = eto
.checked_duration_since(Instant::now())
.unwrap_or(Duration::from_secs(0));
if delay_from_now.is_zero() {
rtc_pc.handle_timeout(Instant::now())?;
continue;
}
let timer = tokio::time::sleep(delay_from_now);
tokio::pin!(timer);
tokio::select! {
_ = timer.as_mut() => {
rtc_pc.handle_timeout(Instant::now())?;
}
res = socket.recv_from(&mut buf) => {
match res {
Ok((n, peer_addr)) => {
log::trace!("RTC received {} bytes from {}", n, peer_addr);
rtc_pc.handle_read(TaggedBytesMut {
now: Instant::now(),
transport: TransportContext {
local_addr,
peer_addr,
ecn: None,
transport_protocol: TransportProtocol::UDP,
},
message: BytesMut::from(&buf[..n]),
})?;
}
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
}
Err(err) => {
log::error!("RTC socket read error: {}", err);
return Err(err.into());
}
}
}
}
}
webrtc_pc.close().await?;
rtc_pc.close()?;
Err(anyhow::anyhow!(
"Test timeout - data channel close not detected by WebRTC in time"
))
}
async fn create_webrtc_peer() -> Result<Arc<WebrtcPeerConnection>> {
let mut media_engine = MediaEngine::default();
media_engine.register_default_codecs()?;
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media_engine)?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
let config = WebrtcRTCConfiguration {
ice_servers: vec![WebrtcIceServer {
urls: vec!["stun:stun.l.google.com:19302".to_owned()],
..Default::default()
}],
..Default::default()
};
let peer_connection = Arc::new(api.new_peer_connection(config).await?);
Ok(peer_connection)
}