use std::sync::Arc;
use std::time::Duration;
use axum::extract::ws::{Message, WebSocket};
use futures::stream::{SplitSink, StreamExt};
use futures::SinkExt;
use tokio::sync::{mpsc, Mutex};
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::{MediaEngine, MIME_TYPE_OPUS};
use webrtc::api::setting_engine::SettingEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::data_channel_message::DataChannelMessage;
use webrtc::data_channel::RTCDataChannel;
use webrtc::ice::udp_network::UDPNetwork;
use webrtc::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit};
use webrtc::ice_transport::ice_candidate_type::RTCIceCandidateType;
use webrtc::ice_transport::ice_credential_type::RTCIceCredentialType;
use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::interceptor::registry::Registry;
use webrtc::media::Sample;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::peer_connection::RTCPeerConnection;
use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType};
use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample;
use webrtc::track::track_local::TrackLocal;
use webrtc::track::track_remote::TrackRemote;
use crate::frames::{AudioRawData, Frame, FrameDirection, FrameProcessor};
use crate::transport::base::BaseTransport;
use crate::transport::incoming::dispatch_text_message;
use crate::transport::output::OutputMessage;
use super::codec::{OpusInbound, OpusOutbound};
use super::params::VaniWebRTCParams;
use super::signaling::{munge_answer_sdp, SignalMsg};
const AUDIO_OUT_CHANNEL_CAP: usize = 150;
const OPUS_SAMPLE_DURATION: Duration = Duration::from_millis(20);
type SharedWsTx = Arc<Mutex<SplitSink<WebSocket, Message>>>;
pub struct VaniWebRTCTransport {
base: Arc<BaseTransport>,
audio_out_rx: std::sync::Mutex<Option<mpsc::Receiver<OutputMessage>>>,
params: VaniWebRTCParams,
}
impl VaniWebRTCTransport {
pub fn new(name: &str, params: VaniWebRTCParams) -> Self {
let base = Arc::new(BaseTransport::new(name, params.transport.clone()));
let (audio_out_tx, audio_out_rx) = mpsc::channel::<OutputMessage>(AUDIO_OUT_CHANNEL_CAP);
base.set_audio_out_tx(audio_out_tx);
Self {
base,
audio_out_rx: std::sync::Mutex::new(Some(audio_out_rx)),
params,
}
}
pub fn input(&self) -> FrameProcessor {
self.base.input()
}
pub fn output(&self) -> FrameProcessor {
self.base.output()
}
async fn build_peer_connection(&self) -> webrtc::error::Result<Arc<RTCPeerConnection>> {
let mut media = MediaEngine::default();
media.register_default_codecs()?;
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media)?;
let mut setting = SettingEngine::default();
if !self.params.nat_1to1_ips.is_empty() {
setting.set_nat_1to1_ips(
self.params.nat_1to1_ips.clone(),
RTCIceCandidateType::Host,
);
}
if let Some(mux) = &self.params.udp_mux {
setting.set_udp_network(UDPNetwork::Muxed(mux.clone()));
}
let api = APIBuilder::new()
.with_media_engine(media)
.with_interceptor_registry(registry)
.with_setting_engine(setting)
.build();
let mut ice_servers: Vec<RTCIceServer> = self
.params
.ice_servers
.iter()
.map(|url| RTCIceServer {
urls: vec![url.clone()],
..Default::default()
})
.collect();
ice_servers.extend(self.params.turn_servers.iter().map(|t| RTCIceServer {
urls: t.urls.clone(),
username: t.username.clone(),
credential: t.credential.clone(),
credential_type: RTCIceCredentialType::Password,
}));
let config = RTCConfiguration {
ice_servers,
..Default::default()
};
Ok(Arc::new(api.new_peer_connection(config).await?))
}
pub async fn run(
&self,
socket: WebSocket,
push_tx: mpsc::Sender<(Frame, FrameDirection)>,
) {
let mut audio_out_rx = self
.audio_out_rx
.lock()
.unwrap()
.take()
.expect("run called more than once on the same VaniWebRTCTransport");
let pc = match self.build_peer_connection().await {
Ok(pc) => pc,
Err(e) => {
log::error!("vaniwebrtc: failed to build peer connection: {}", e);
return;
}
};
let (ws_tx, mut ws_rx) = socket.split();
let ws_tx: SharedWsTx = Arc::new(Mutex::new(ws_tx));
let local_track = Arc::new(TrackLocalStaticSample::new(
RTCRtpCodecCapability {
mime_type: MIME_TYPE_OPUS.to_owned(),
..Default::default()
},
"audio".to_owned(),
"rustvani".to_owned(),
));
match pc
.add_track(Arc::clone(&local_track) as Arc<dyn TrackLocal + Send + Sync>)
.await
{
Ok(rtp_sender) => {
tokio::spawn(async move {
let mut rtcp_buf = vec![0u8; 1500];
while rtp_sender.read(&mut rtcp_buf).await.is_ok() {}
});
}
Err(e) => log::error!("vaniwebrtc: add_track failed: {}", e),
}
{
let base = self.base.clone();
let out_rate = self.params.transport.audio_in_sample_rate.unwrap_or(16_000);
let denoiser_factory = self.params.denoiser_factory.clone();
pc.on_track(Box::new(move |track: Arc<TrackRemote>, _recv, _trans| {
let base = base.clone();
let denoiser_factory = denoiser_factory.clone();
Box::pin(async move {
if track.kind() != RTPCodecType::Audio {
return;
}
let denoiser = denoiser_factory.as_ref().map(|f| f());
let mut inbound = OpusInbound::new(out_rate, denoiser);
tokio::spawn(async move {
loop {
match track.read_rtp().await {
Ok((packet, _)) => {
let pcm = inbound.push_rtp(&packet.payload);
if !pcm.is_empty() {
let data = AudioRawData::new(pcm, inbound.out_rate(), 1);
base.push_audio_frame(data).await;
}
}
Err(e) => {
log::debug!("vaniwebrtc: inbound track ended: {}", e);
break;
}
}
}
});
})
}));
}
let dc_slot: Arc<Mutex<Option<Arc<RTCDataChannel>>>> = Arc::new(Mutex::new(None));
{
let dc_slot = dc_slot.clone();
let push_tx = push_tx.clone();
pc.on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
let dc_slot = dc_slot.clone();
let push_tx = push_tx.clone();
Box::pin(async move {
*dc_slot.lock().await = Some(dc.clone());
let push_tx = push_tx.clone();
dc.on_message(Box::new(move |msg: DataChannelMessage| {
let push_tx = push_tx.clone();
Box::pin(async move {
if msg.is_string {
if let Ok(text) = String::from_utf8(msg.data.to_vec()) {
dispatch_text_message(&text, &push_tx).await;
}
}
})
}));
})
}));
}
{
let ws_tx = ws_tx.clone();
pc.on_ice_candidate(Box::new(move |c: Option<RTCIceCandidate>| {
let ws_tx = ws_tx.clone();
Box::pin(async move {
if let Some(c) = c {
if let Ok(init) = c.to_json() {
let msg = SignalMsg::Ice {
candidate: init.candidate,
sdp_mid: init.sdp_mid,
sdp_mline_index: init.sdp_mline_index,
};
send_signal(&ws_tx, msg).await;
}
}
})
}));
}
let mut outbound = OpusOutbound::new();
let audio_out_rate = self.params.transport.audio_out_sample_rate.unwrap_or(16_000);
loop {
tokio::select! {
maybe_msg = ws_rx.next() => {
match maybe_msg {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<SignalMsg>(&text) {
Ok(SignalMsg::Offer { sdp }) => {
if let Err(e) = self
.handle_offer(&pc, &ws_tx, sdp)
.await
{
log::warn!("vaniwebrtc: offer handling failed: {}", e);
}
}
Ok(SignalMsg::Ice { candidate, sdp_mid, sdp_mline_index }) => {
let init = RTCIceCandidateInit {
candidate,
sdp_mid,
sdp_mline_index,
username_fragment: None,
};
if let Err(e) = pc.add_ice_candidate(init).await {
log::warn!("vaniwebrtc: add_ice_candidate failed: {}", e);
}
}
Ok(SignalMsg::Bye) => break,
Ok(SignalMsg::Answer { .. }) => {} Err(e) => log::warn!("vaniwebrtc: bad signaling message: {}", e),
}
}
Some(Ok(Message::Close(_))) | None => {
log::debug!("vaniwebrtc: signaling socket closed");
break;
}
Some(Ok(_)) => {} Some(Err(e)) => {
log::warn!("vaniwebrtc: signaling error: {}", e);
break;
}
}
}
output_msg = audio_out_rx.recv() => {
match output_msg {
Some(OutputMessage::Audio(pcm)) => {
for packet in outbound.push_pcm(&pcm, audio_out_rate) {
let sample = Sample {
data: bytes::Bytes::from(packet),
duration: OPUS_SAMPLE_DURATION,
..Default::default()
};
if local_track.write_sample(&sample).await.is_err() {
log::warn!("vaniwebrtc: write_sample failed");
}
}
}
Some(OutputMessage::Text(json)) => {
if let Some(dc) = dc_slot.lock().await.clone() {
let _ = dc.send_text(json).await;
}
}
Some(OutputMessage::Interruption) => {
while let Ok(queued) = audio_out_rx.try_recv() {
match queued {
OutputMessage::Interruption => break,
OutputMessage::Audio(_) | OutputMessage::Text(_) => {}
}
}
outbound.reset();
if let Some(dc) = dc_slot.lock().await.clone() {
let _ = dc.send_text(r#"{"type":"interruption"}"#).await;
}
log::debug!("vaniwebrtc: sent interruption to client");
}
None => break, }
}
}
}
let _ = pc.close().await;
let _ = push_tx
.send((Frame::end(), FrameDirection::Downstream))
.await;
}
async fn handle_offer(
&self,
pc: &Arc<RTCPeerConnection>,
ws_tx: &SharedWsTx,
sdp: String,
) -> webrtc::error::Result<()> {
pc.set_remote_description(RTCSessionDescription::offer(sdp)?).await?;
let answer = pc.create_answer(None).await?;
pc.set_local_description(answer.clone()).await?;
let munged = munge_answer_sdp(&answer.sdp, &self.params);
send_signal(ws_tx, SignalMsg::Answer { sdp: munged }).await;
Ok(())
}
}
async fn send_signal(ws_tx: &SharedWsTx, msg: SignalMsg) {
if let Ok(json) = serde_json::to_string(&msg) {
let mut guard = ws_tx.lock().await;
let _ = guard.send(Message::Text(json)).await;
}
}