use crate::network::http::{
server::HFactory,
session::{HAsyncService, Session},
ws::OpCode,
};
use bytes::{Bytes, BytesMut};
use glib::{prelude::ObjectExt, types::StaticType};
use gst::prelude::*;
use gstreamer as gst;
use gstreamer_app as gst_app;
use gstreamer_video as gst_video;
use serde::{Deserialize, Serialize};
use std::sync::{
Arc, OnceLock,
atomic::{AtomicU64, Ordering},
};
use tokio::sync::{RwLock, broadcast, mpsc};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use webrtc::{
api::{APIBuilder, media_engine::MediaEngine},
data_channel::data_channel_message::DataChannelMessage,
ice_transport::ice_server::RTCIceServer,
interceptor::registry::Registry,
peer_connection::{RTCPeerConnection, configuration::RTCConfiguration},
rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType},
track::track_local::track_local_static_sample::TrackLocalStaticSample,
};
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub udp_min: u16,
pub udp_max: u16,
pub stun_urls: Vec<String>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
udp_min: 50000,
udp_max: 50100,
stun_urls: vec!["stun:stun.l.google.com:19302".into()],
}
}
}
#[derive(Debug, Clone)]
pub struct StreamCtrl {
pub width: i32,
pub height: i32,
pub fps: i32,
pub bitrate_kbps: i32,
}
impl Default for StreamCtrl {
fn default() -> Self {
Self {
width: 1280,
height: 720,
fps: 60,
bitrate_kbps: 6000,
}
}
}
fn ctrl_needs_restart(prev: &StreamCtrl, next: &StreamCtrl) -> bool {
prev.width != next.width || prev.height != next.height || prev.fps != next.fps
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct IceCandidateWire {
candidate: String,
#[serde(rename = "sdpMid")]
sdp_mid: Option<String>,
#[serde(rename = "sdpMLineIndex")]
sdp_mline_index: Option<u16>,
#[serde(rename = "usernameFragment")]
username_fragment: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
enum WsMsg {
Offer(String),
Answer(String),
Ice(IceCandidateWire),
ClientStats(ClientStats),
ServerStats(ServerStats),
Ctrl {
width: i32,
height: i32,
fps: i32,
bitrate_kbps: i32,
},
Info(String),
Error(String),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ClientStats {
pub rtt_ms: Option<f64>,
pub jitter_ms: Option<f64>,
pub loss: Option<f64>,
pub fps: Option<f64>,
pub available_in_bps: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ServerStats {
pub fps: f64,
pub dropped_samples: u64,
}
#[derive(Clone)]
struct CapturedFrame {
data: Bytes, dur: std::time::Duration, }
struct CaptureHub {
pipeline: gst::Pipeline,
stop: CancellationToken,
bus_handle: tokio::task::JoinHandle<()>,
tx: broadcast::Sender<CapturedFrame>,
}
#[derive(Debug, Clone)]
pub struct RtmpBroadcaster {
pub ingest_url: String,
pub stream_key: String,
pub bitrate_kbps: Option<u32>,
pub gop_seconds: Option<u32>,
}
static CAPTURE_HUB: OnceLock<Arc<RwLock<Option<Arc<CaptureHub>>>>> = OnceLock::new();
static ACTIVE_WS: OnceLock<AtomicU64> = OnceLock::new();
fn capture_hub_slot() -> &'static Arc<RwLock<Option<Arc<CaptureHub>>>> {
CAPTURE_HUB.get_or_init(|| Arc::new(RwLock::new(None)))
}
fn active_ws() -> &'static AtomicU64 {
ACTIVE_WS.get_or_init(|| AtomicU64::new(0))
}
#[inline]
fn utc_ms_now() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
struct GstStream {
pipeline: gst::Pipeline,
capsfilter: Option<gst::Element>,
encoder: Option<gst::Element>,
frame_counter: Arc<AtomicU64>,
dropped_counter: Arc<AtomicU64>,
}
struct StreamRuntime {
stream: GstStream,
push_stop: CancellationToken,
push_handle: tokio::task::JoinHandle<()>,
pump_stop: CancellationToken,
pump_handle: tokio::task::JoinHandle<()>,
audio_stream: Option<GstStream>,
audio_pump_stop: Option<CancellationToken>,
audio_pump_handle: Option<tokio::task::JoinHandle<()>>,
fps_stop: CancellationToken,
fps_handle: tokio::task::JoinHandle<()>,
bus_stop: CancellationToken,
bus_handle: tokio::task::JoinHandle<()>,
}
impl Drop for StreamRuntime {
fn drop(&mut self) {
self.push_stop.cancel();
self.pump_stop.cancel();
if let Some(s) = self.audio_pump_stop.as_ref() {
s.cancel();
}
self.fps_stop.cancel();
self.bus_stop.cancel();
self.push_handle.abort();
self.pump_handle.abort();
if let Some(h) = self.audio_pump_handle.as_ref() {
h.abort();
}
self.fps_handle.abort();
self.bus_handle.abort();
gst_stop_pipeline_graceful(&self.stream.pipeline, 500);
if let Some(a) = self.audio_stream.as_ref() {
gst_stop_pipeline_graceful(&a.pipeline, 500);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Codec {
H264,
Opus,
}
impl Codec {
fn mime(self) -> &'static str {
match self {
Codec::H264 => "video/H264",
Codec::Opus => "audio/opus",
}
}
fn offer_rtpmap_token(self) -> &'static str {
match self {
Codec::H264 => "H264/90000",
Codec::Opus => "opus/48000/2",
}
}
fn default_pt(self) -> u8 {
match self {
Codec::H264 => 96,
Codec::Opus => 111,
}
}
}
fn choose_video_codec_from_offer(_offer_sdp: &str) -> Codec {
Codec::H264
}
fn choose_audio_codec_from_offer(_offer_sdp: &str) -> Codec {
Codec::Opus
}
fn find_pt_in_offer(offer_sdp: &str, rtpmap_token: &str) -> Option<u8> {
for line in offer_sdp.lines() {
if let Some(rest) = line.strip_prefix("a=rtpmap:")
&& let Some((pt_str, codec_part)) = rest.split_once(' ')
&& codec_part.trim() == rtpmap_token
&& let Ok(pt) = pt_str.trim().parse::<u16>()
&& pt <= 255
{
return Some(pt as u8);
}
}
None
}
fn find_fmtp_in_offer(offer_sdp: &str, pt: u8) -> Option<String> {
let prefix = format!("a=fmtp:{pt} ");
for line in offer_sdp.lines() {
if let Some(rest) = line.strip_prefix(&prefix) {
return Some(rest.trim().to_string());
}
}
None
}
fn codec_cap(codec: Codec, fmtp_from_offer: Option<&str>) -> RTCRtpCodecCapability {
match codec {
Codec::H264 => RTCRtpCodecCapability {
mime_type: codec.mime().to_string(),
clock_rate: 90000,
channels: 0,
sdp_fmtp_line: fmtp_from_offer
.unwrap_or("level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f")
.to_string(),
rtcp_feedback: vec![],
},
Codec::Opus => RTCRtpCodecCapability {
mime_type: codec.mime().to_string(),
clock_rate: 48000,
channels: 2,
sdp_fmtp_line: fmtp_from_offer.unwrap_or("").to_string(),
rtcp_feedback: vec![],
},
}
}
fn prop_type(elem: &gst::Element, prop: &str) -> Option<glib::Type> {
elem.find_property(prop).map(|ps| ps.value_type())
}
fn set_prop_int(elem: &gst::Element, prop: &str, value: u64) {
let Some(ps) = elem.find_property(prop) else {
return;
};
let t = ps.value_type();
if t == u64::static_type() {
elem.set_property(prop, value);
} else if t == i64::static_type() {
elem.set_property(prop, value as i64);
} else if t == u32::static_type() {
elem.set_property(prop, value as u32);
} else if t == i32::static_type() {
elem.set_property(prop, value as i32);
}
}
fn set_prop_from_str(elem: &gst::Element, prop: &str, value: &str) {
if elem.find_property(prop).is_some() {
elem.set_property_from_str(prop, value);
}
}
fn gst_stop_pipeline_graceful(p: &gst::Pipeline, timeout_ms: u64) {
use gst::MessageView;
let _ = p.send_event(gst::event::Eos::new());
if let Some(bus) = p.bus() {
let deadline = std::time::Instant::now() + std::time::Duration::from_millis(timeout_ms);
while std::time::Instant::now() < deadline {
match bus.timed_pop(gst::ClockTime::from_mseconds(50)) {
None => {}
Some(msg) => match msg.view() {
MessageView::Eos(..) => break,
MessageView::Error(e) => {
warn!(
"gst error while stopping from {:?}: {} (debug={:?})",
e.src().map(|s| s.path_string()),
e.error(),
e.debug()
);
break;
}
_ => {}
},
}
}
}
if let Err(e) = p.set_state(gst::State::Null) {
warn!("gst set_state(NULL) failed: {e:?}");
}
let _ = p.state(gst::ClockTime::from_mseconds(timeout_ms));
}
fn gst_has_element(name: &str) -> bool {
gst::ElementFactory::find(name).is_some()
}
fn spawn_gst_bus_logger(
pipeline: &gst::Pipeline,
stop: CancellationToken,
) -> tokio::task::JoinHandle<()> {
let bus = pipeline.bus().expect("pipeline has no bus");
tokio::spawn(async move {
use gst::MessageView;
loop {
tokio::select! {
_ = stop.cancelled() => break,
_ = tokio::time::sleep(std::time::Duration::from_millis(50)) => {
while let Some(msg) = bus.pop() {
match msg.view() {
MessageView::Error(e) => {
error!(
"gst error from {:?}: {} (debug: {:?})",
e.src().map(|s| s.path_string()),
e.error(),
e.debug()
);
}
MessageView::Warning(w) => {
warn!(
"gst warning from {:?}: {} (debug: {:?})",
w.src().map(|s| s.path_string()),
w.error(),
w.debug()
);
}
MessageView::StateChanged(s) => {
if let Some(src) = msg.src()
&& src.type_().name() == "GstPipeline" {
info!("gst state changed: {:?} -> {:?}", s.old(), s.current());
}
}
MessageView::Eos(..) => warn!("gst EOS"),
_ => {}
}
}
}
}
}
info!("gst bus logger stopped");
})
}
fn build_capture_pipeline(
fps: i32,
rtmp: Option<Arc<RtmpBroadcaster>>,
) -> std::io::Result<(gst::Pipeline, gst_app::AppSink)> {
let fps = fps.max(1);
let w = 1280;
let h = 720;
let (src, src_factory) = if cfg!(target_os = "macos") {
("avfvideosrc capture-screen=true", "avfvideosrc")
} else if cfg!(target_os = "windows") {
(
"d3d11screencapturesrc show-cursor=true ! d3d11convert ! d3d11download",
"d3d11screencapturesrc",
)
} else {
return Err(std::io::Error::other("Unsupported platform"));
};
if !gst_has_element(src_factory) {
return Err(std::io::Error::other(format!(
"Missing {src_factory}. Install GStreamer."
)));
}
let (enc_name, enc_is_nv, enc_is_amf, enc_is_x264) = if gst_has_element("nvh264enc") {
("nvh264enc", true, false, false)
} else if gst_has_element("amfh264enc") {
("amfh264enc", false, true, false)
} else if gst_has_element("x264enc") {
("x264enc", false, false, true)
} else {
("", false, false, false)
};
let rtmp_branch = if let Some(r) = rtmp {
if enc_name.is_empty()
|| !gst_has_element("h264parse")
|| !gst_has_element("flvmux")
|| !gst_has_element("rtmpsink")
{
warn!("RTMP requested but missing encoder/parse/mux/sink; RTMP disabled.");
"".to_string()
} else {
let kbps = r.bitrate_kbps.unwrap_or(4500).max(300);
let gop_s = r.gop_seconds.unwrap_or(2).max(1);
let gop = (fps as u32).saturating_mul(gop_s).max(1); let location = format!(
"{}/{} live=1",
r.ingest_url.trim_end_matches('/'),
r.stream_key
);
let enc_props = if enc_is_nv {
format!(
"{enc} name=rtmpenc rc-mode=cbr bitrate={kbps} bframes=0 gop-size={gop} preset=low-latency-hq tune=ultra-low-latency",
enc = enc_name
)
} else if enc_is_amf {
format!(
"{enc} name=rtmpenc usage=ultralowlatency rate-control=cbr bitrate={kbps} b-frames=0 gop-size={gop}",
enc = enc_name
)
} else if enc_is_x264 {
format!(
"{enc} name=rtmpenc bitrate={kbps} speed-preset=veryfast tune=zerolatency key-int-max={gop} bframes=0",
enc = enc_name
)
} else {
format!("{enc} name=rtmpenc", enc = enc_name)
};
let aac_enc = if gst_has_element("fdkaacenc") {
Some("fdkaacenc bitrate=128000")
} else if gst_has_element("faac") {
Some("faac bitrate=128000")
} else if gst_has_element("voaacenc") {
Some("voaacenc bitrate=128000")
} else if gst_has_element("avenc_aac") {
Some("avenc_aac bitrate=128000")
} else {
None
};
let has_aacparse = gst_has_element("aacparse");
let (audio_src, has_audio_src) = if cfg!(target_os = "windows") {
("wasapi2src loopback=true", gst_has_element("wasapi2src"))
} else {
("autoaudiosrc", gst_has_element("autoaudiosrc"))
};
let audio_branch = if let (Some(aac_enc), true) = (aac_enc, has_aacparse) {
if has_audio_src {
format!(
r#"
{audio_src} !
queue leaky=downstream max-size-buffers=8 max-size-bytes=0 max-size-time=0 !
audioconvert !
audioresample !
audio/x-raw,rate=48000,channels=2 !
{aac_enc} !
aacparse !
mux.
"#,
audio_src = audio_src,
aac_enc = aac_enc
)
} else {
format!(
r#"
audiotestsrc wave=silence is-live=true !
queue leaky=downstream max-size-buffers=8 max-size-bytes=0 max-size-time=0 !
audioconvert !
audioresample !
audio/x-raw,rate=48000,channels=2 !
{aac_enc} !
aacparse !
mux.
"#,
aac_enc = aac_enc
)
}
} else {
warn!(
"RTMP: No AAC encoder/aacparse available (fdkaacenc/faac/voaacenc/avenc_aac + aacparse). Telegram may not show the stream."
);
"".to_string()
};
format!(
r#"
t. ! queue leaky=downstream max-size-buffers=2 max-size-bytes=0 max-size-time=0 !
videoconvert !
{enc_props} !
h264parse config-interval=1 !
video/x-h264,stream-format=avc,alignment=au !
mux.
{audio_branch}
flvmux name=mux streamable=true !
rtmpsink location="{location}" sync=false async=false
"#
)
}
} else {
"".to_string()
};
let desc = format!(
r#"{src} !
videoconvert !
videoscale !
videorate drop-only=true !
video/x-raw,format=NV12,width={w},height={h},framerate={fps}/1 !
tee name=t
t. ! queue leaky=downstream max-size-buffers=1 max-size-bytes=0 max-size-time=0 !
appsink name=rawsink emit-signals=true sync=false max-buffers=1 drop=true
{rtmp_branch}
"#,
src = src,
w = w,
h = h,
fps = fps,
rtmp_branch = rtmp_branch,
);
let pipeline = gst::parse::launch(&desc)
.map_err(|e| std::io::Error::other(format!("parse_launch failed: {e:?}\nDESC:\n{desc}")))?
.downcast::<gst::Pipeline>()
.map_err(|e| std::io::Error::other(format!("not a pipeline: {e:?}")))?;
let appsink = pipeline
.by_name("rawsink")
.ok_or_else(|| std::io::Error::other("appsink rawsink not found"))?
.downcast::<gst_app::AppSink>()
.map_err(|e| std::io::Error::other(format!("rawsink not AppSink: {e:?}")))?;
Ok((pipeline, appsink))
}
async fn ensure_capture_hub(
initial_fps: i32,
rtmp: Option<Arc<RtmpBroadcaster>>,
) -> std::io::Result<Arc<CaptureHub>> {
let slot = capture_hub_slot();
{
if let Some(h) = slot.read().await.as_ref() {
return Ok(h.clone());
}
}
let (pipeline, appsink) = build_capture_pipeline(initial_fps, rtmp)?;
let (tx, _rx) = broadcast::channel::<CapturedFrame>(16);
{
let tx = tx.clone();
appsink.set_callbacks(
gst_app::AppSinkCallbacks::builder()
.new_sample(move |sink| {
let sample = sink.pull_sample().map_err(|_| gst::FlowError::Eos)?;
let buffer = sample.buffer().ok_or(gst::FlowError::Error)?;
let map = buffer.map_readable().map_err(|_| gst::FlowError::Error)?;
let data = Bytes::copy_from_slice(map.as_slice());
let dur = buffer
.duration()
.map(|d| std::time::Duration::from_nanos(d.nseconds()))
.filter(|d| d.as_nanos() > 0)
.unwrap_or_else(|| std::time::Duration::from_millis(16));
let _ = tx.send(CapturedFrame { data, dur });
Ok(gst::FlowSuccess::Ok)
})
.build(),
);
}
pipeline
.set_state(gst::State::Playing)
.map_err(|e| std::io::Error::other(format!("capture set_state(Playing) failed: {e:?}")))?;
let stop = CancellationToken::new();
let bus_handle = spawn_gst_bus_logger(&pipeline, stop.child_token());
let hub = Arc::new(CaptureHub {
pipeline,
stop,
bus_handle,
tx,
});
*slot.write().await = Some(hub.clone());
info!("capture hub started");
Ok(hub)
}
async fn maybe_stop_capture_hub() {
if active_ws().load(Ordering::Relaxed) != 0 {
return;
}
let slot = capture_hub_slot();
let hub = slot.write().await.take();
if let Some(hub) = hub {
info!("stopping capture hub (no active WS)");
hub.stop.cancel();
gst_stop_pipeline_graceful(&hub.pipeline, 1500);
hub.bus_handle.abort();
}
}
fn build_encoder_pipeline_h264_from_appsrc(
ctrl: &StreamCtrl,
) -> std::io::Result<(GstStream, gst_app::AppSrc, mpsc::Receiver<gst::Sample>)> {
let fps = ctrl.fps.max(1);
let w = ctrl.width.max(1);
let h = ctrl.height.max(1);
let (enc_name, enc_is_vt, enc_is_nv, enc_is_amf, enc_is_x264) =
if cfg!(target_os = "macos") && gst_has_element("vtenc_h264") {
("vtenc_h264", true, false, false, false)
} else if gst_has_element("nvh264enc") {
("nvh264enc", false, true, false, false)
} else if gst_has_element("amfh264enc") {
("amfh264enc", false, false, true, false)
} else if gst_has_element("x264enc") {
("x264enc", false, false, false, true)
} else {
return Err(std::io::Error::other(
"No H264 encoder found (vtenc_h264/nvh264enc/amfh264enc/x264enc).",
));
};
let desc = format!(
"appsrc name=rawsrc is-live=true format=time do-timestamp=true block=false max-bytes=0 !
videoconvert !
videoscale !
videorate drop-only=true !
capsfilter name=vcaps caps=video/x-raw,format=NV12,width={w},height={h},framerate={fps}/1 !
identity name=ftap signal-handoffs=true silent=true !
{enc} name=venc !
h264parse config-interval=1 !
capsfilter caps=video/x-h264,stream-format=byte-stream,alignment=au !
identity name=keyreq silent=true !
appsink name=hsink emit-signals=true sync=false max-buffers=1 drop=true",
w = w,
h = h,
fps = fps,
enc = enc_name,
);
let pipeline = gst::parse::launch(&desc)
.map_err(|e| std::io::Error::other(format!("parse_launch failed: {e:?}\nDESC:\n{desc}")))?
.downcast::<gst::Pipeline>()
.map_err(|e| std::io::Error::other(format!("not a pipeline: {e:?}")))?;
let appsrc = pipeline
.by_name("rawsrc")
.ok_or_else(|| std::io::Error::other("appsrc rawsrc not found"))?
.downcast::<gst_app::AppSrc>()
.map_err(|e| std::io::Error::other(format!("rawsrc not AppSrc: {e:?}")))?;
let caps_in = gst::Caps::builder("video/x-raw")
.field("format", "NV12")
.field("width", w)
.field("height", h)
.field("framerate", gst::Fraction::new(fps, 1))
.build();
appsrc.set_caps(Some(&caps_in));
if appsrc.find_property("block").is_some() {
appsrc.set_property("block", false);
}
set_prop_int(appsrc.upcast_ref::<gst::Element>(), "max-bytes", 0u64);
set_prop_int(appsrc.upcast_ref::<gst::Element>(), "min-latency", 0u64);
set_prop_int(appsrc.upcast_ref::<gst::Element>(), "max-latency", 0u64);
let capsfilter = pipeline
.by_name("vcaps")
.ok_or_else(|| std::io::Error::other("capsfilter vcaps not found"))?;
let encoder = pipeline
.by_name("venc")
.ok_or_else(|| std::io::Error::other("encoder venc not found"))?;
let set_str = |el: &gst::Element, k: &str, v: &str| {
if el.find_property(k).is_some() {
let _ = std::panic::catch_unwind(|| set_prop_from_str(el, k, v));
}
};
let set_u32 = |el: &gst::Element, k: &str, v: u32| {
if el.find_property(k).is_some() {
let _ = std::panic::catch_unwind(|| set_prop_int(el, k, v as u64));
}
};
let set_bool = |el: &gst::Element, k: &str, v: bool| {
if el.find_property(k).is_some() {
let _ = std::panic::catch_unwind(|| {
set_prop_from_str(el, k, if v { "true" } else { "false" })
});
}
};
let kbps = ctrl.bitrate_kbps.max(300) as u32;
let gop_frames = (fps as u32).max(1);
if enc_is_vt {
let bps = kbps.saturating_mul(1000);
set_u32(&encoder, "bitrate", bps);
set_bool(&encoder, "allow-frame-reordering", false);
set_bool(&encoder, "realtime", true);
set_u32(&encoder, "max-keyframe-interval", gop_frames);
set_u32(&encoder, "max-keyframe-interval-duration", 2);
set_u32(&encoder, "keyframe-interval", 2);
} else if enc_is_nv {
set_str(&encoder, "preset", "low-latency-hq");
set_str(&encoder, "tune", "ultra-low-latency");
set_str(&encoder, "rc-mode", "cbr");
set_u32(&encoder, "bitrate", kbps);
set_u32(&encoder, "max-bitrate", kbps);
set_u32(&encoder, "vbv-buffer-size", kbps);
set_u32(&encoder, "gop-size", gop_frames);
set_u32(&encoder, "bframes", 0);
set_bool(&encoder, "repeat-sequence-header", true);
set_bool(&encoder, "zerolatency", true);
set_u32(&encoder, "iframeinterval", gop_frames);
set_u32(&encoder, "idrinterval", gop_frames);
} else if enc_is_amf {
set_str(&encoder, "usage", "ultralowlatency");
set_str(&encoder, "rate-control", "cbr");
set_u32(&encoder, "bitrate", kbps);
set_u32(&encoder, "b-frames", 0);
set_u32(&encoder, "gop-size", gop_frames);
} else if enc_is_x264 {
set_u32(&encoder, "bitrate", kbps);
set_str(&encoder, "speed-preset", "veryfast");
set_str(&encoder, "tune", "zerolatency");
set_u32(&encoder, "key-int-max", gop_frames);
set_u32(&encoder, "bframes", 0);
set_bool(&encoder, "byte-stream", true);
set_u32(&encoder, "rc-lookahead", 0);
set_bool(&encoder, "sync-lookahead", false);
}
let (sample_tx, sample_rx) = mpsc::channel::<gst::Sample>(32);
let dropped_counter = Arc::new(AtomicU64::new(0));
let dropped_counter_cb = dropped_counter.clone();
let appsink = pipeline
.by_name("hsink")
.ok_or_else(|| std::io::Error::other("appsink hsink not found"))?
.downcast::<gst_app::AppSink>()
.map_err(|e| std::io::Error::other(format!("hsink not AppSink: {e:?}")))?;
appsink.set_callbacks(
gst_app::AppSinkCallbacks::builder()
.new_sample(move |sink| {
let sample = sink.pull_sample().map_err(|_| gst::FlowError::Eos)?;
if sample_tx.try_send(sample).is_err() {
dropped_counter_cb.fetch_add(1, Ordering::Relaxed);
}
Ok(gst::FlowSuccess::Ok)
})
.build(),
);
let ftap = pipeline
.by_name("ftap")
.ok_or_else(|| std::io::Error::other("identity ftap not found"))?;
let frame_counter = Arc::new(AtomicU64::new(0));
{
let fc = frame_counter.clone();
let _ = ftap.connect("handoff", false, move |_values| {
fc.fetch_add(1, Ordering::Relaxed);
None
});
}
Ok((
GstStream {
pipeline,
capsfilter: Some(capsfilter),
encoder: Some(encoder),
frame_counter,
dropped_counter,
},
appsrc,
sample_rx,
))
}
fn request_keyframe(pipeline: &gst::Pipeline) {
let Some(keyreq) = pipeline.by_name("keyreq") else {
return;
};
let Some(srcpad) = keyreq.static_pad("src") else {
return;
};
let ev = gst_video::UpstreamForceKeyUnitEvent::builder()
.all_headers(true)
.build();
if !srcpad.send_event(ev) {
warn!("request_keyframe: send_event returned false");
}
}
fn apply_ctrl(stream: &GstStream, ctrl: &StreamCtrl) -> std::io::Result<()> {
let Some(capsfilter) = stream.capsfilter.as_ref() else {
return Ok(());
};
let Some(encoder) = stream.encoder.as_ref() else {
return Ok(());
};
let caps = gst::Caps::builder("video/x-raw")
.field("format", "NV12")
.field("width", ctrl.width.max(1))
.field("height", ctrl.height.max(1))
.field("framerate", gst::Fraction::new(ctrl.fps.max(1), 1))
.build();
capsfilter.set_property("caps", &caps);
let kbps = ctrl.bitrate_kbps.max(300) as u64;
set_prop_int(encoder, "bitrate", kbps);
set_prop_int(encoder, "max-bitrate", kbps);
set_prop_int(encoder, "target-bitrate", kbps);
Ok(())
}
fn build_pipeline_opus() -> std::io::Result<(GstStream, mpsc::Receiver<gst::Sample>)> {
if cfg!(target_os = "windows") && !gst_has_element("wasapi2src") {
return Err(std::io::Error::other(
"Missing wasapi2src. Install GStreamer WASAPI plugins.",
));
}
if !cfg!(target_os = "windows") && !gst_has_element("autoaudiosrc") {
return Err(std::io::Error::other(
"Missing autoaudiosrc. Install GStreamer audio plugins.",
));
}
if !gst_has_element("opusenc") {
return Err(std::io::Error::other(
"Missing opusenc. Install gst-plugins-bad (or your Opus plugin set).",
));
}
if !gst_has_element("opusparse") {
return Err(std::io::Error::other(
"Missing opusparse. Install GStreamer plugins.",
));
}
let audio_src = if cfg!(target_os = "windows") {
"wasapi2src loopback=true"
} else {
"autoaudiosrc"
};
let pipeline_desc = format!(
"{audio_src} !
queue leaky=downstream max-size-buffers=8 max-size-bytes=0 max-size-time=0 !
audioconvert !
audioresample !
audio/x-raw,rate=48000,channels=2 !
opusenc bitrate=64000 frame-size=20 !
opusparse !
appsink name=asink emit-signals=true sync=false max-buffers=2 drop=true"
);
let pipeline = gst::parse::launch(&pipeline_desc)
.map_err(|e| std::io::Error::other(format!("parse_launch failed: {e:?}")))?
.downcast::<gst::Pipeline>()
.map_err(|e| std::io::Error::other(format!("not a pipeline: {e:?}")))?;
let (sample_tx, sample_rx) = mpsc::channel::<gst::Sample>(32);
let dropped_counter = Arc::new(AtomicU64::new(0));
let dropped_counter_cb = dropped_counter.clone();
let appsink = pipeline
.by_name("asink")
.ok_or_else(|| std::io::Error::other("appsink asink not found"))?
.downcast::<gst_app::AppSink>()
.map_err(|e| std::io::Error::other(format!("asink not AppSink: {e:?}")))?;
appsink.set_callbacks(
gst_app::AppSinkCallbacks::builder()
.new_sample(move |sink| {
let sample = sink.pull_sample().map_err(|_| gst::FlowError::Eos)?;
if sample_tx.try_send(sample).is_err() {
dropped_counter_cb.fetch_add(1, Ordering::Relaxed);
}
Ok(gst::FlowSuccess::Ok)
})
.build(),
);
Ok((
GstStream {
pipeline,
capsfilter: None,
encoder: None,
frame_counter: Arc::new(AtomicU64::new(0)),
dropped_counter,
},
sample_rx,
))
}
async fn pump_h264_samples(
mut sample_rx: mpsc::Receiver<gst::Sample>,
track: Arc<TrackLocalStaticSample>,
ctrl_state: Arc<RwLock<StreamCtrl>>,
stop: CancellationToken,
) -> std::io::Result<()> {
loop {
tokio::select! {
_ = stop.cancelled() => {
info!("pump_h264_samples cancelled");
break;
}
opt = sample_rx.recv() => {
let Some(sample) = opt else {
info!("sample_rx closed");
break;
};
let buffer = sample.buffer().ok_or_else(|| std::io::Error::other("no buffer"))?;
let map = buffer.map_readable()
.map_err(|e| std::io::Error::other(format!("map buffer: {e}")))?;
let data = map.as_slice();
let fps = ctrl_state.read().await.fps.max(1) as u64;
let dur = buffer.duration()
.map(|d| std::time::Duration::from_nanos(d.nseconds()))
.filter(|d| d.as_nanos() > 0)
.unwrap_or_else(|| std::time::Duration::from_nanos(1_000_000_000u64 / fps));
let s = webrtc::media::Sample {
data: Bytes::copy_from_slice(data),
duration: dur,
..Default::default()
};
if let Err(e) = track.write_sample(&s).await {
warn!("track.write_sample failed: {e}");
break;
}
}
}
}
Ok(())
}
async fn pump_opus_samples(
mut sample_rx: mpsc::Receiver<gst::Sample>,
track: Arc<TrackLocalStaticSample>,
stop: CancellationToken,
) -> std::io::Result<()> {
let dur = std::time::Duration::from_millis(20);
loop {
tokio::select! {
_ = stop.cancelled() => {
info!("pump_opus_samples cancelled");
break;
}
opt = sample_rx.recv() => {
let Some(sample) = opt else {
info!("audio sample_rx closed");
break;
};
let buffer = sample.buffer().ok_or_else(|| std::io::Error::other("audio: no buffer"))?;
let map = buffer.map_readable()
.map_err(|e| std::io::Error::other(format!("audio map buffer: {e}")))?;
let data = map.as_slice();
let s = webrtc::media::Sample {
data: Bytes::copy_from_slice(data),
duration: dur,
..Default::default()
};
if let Err(e) = track.write_sample(&s).await {
warn!("audio track.write_sample failed: {e}");
break;
}
}
}
}
Ok(())
}
async fn start_stream_runtime(
ctrl: StreamCtrl,
ctrl_state: Arc<RwLock<StreamCtrl>>,
video_track: Arc<TrackLocalStaticSample>,
audio_track: Arc<TrackLocalStaticSample>,
rtmp: Option<Arc<RtmpBroadcaster>>,
out_tx: mpsc::Sender<WsMsg>,
) -> std::io::Result<StreamRuntime> {
let hub = ensure_capture_hub(ctrl.fps, rtmp).await?;
let mut cap_rx = hub.tx.subscribe();
let (stream, appsrc, sample_rx) = build_encoder_pipeline_h264_from_appsrc(&ctrl)?;
stream
.pipeline
.set_state(gst::State::Playing)
.map_err(|e| std::io::Error::other(format!("encoder set_state(Playing) failed: {e:?}")))?;
let (audio_stream, audio_rx) = build_pipeline_opus()?;
audio_stream
.pipeline
.set_state(gst::State::Playing)
.map_err(|e| {
std::io::Error::other(format!("audio gst set_state(Playing) failed: {e:?}"))
})?;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
request_keyframe(&stream.pipeline);
let bus_stop = CancellationToken::new();
let bus_handle = spawn_gst_bus_logger(&stream.pipeline, bus_stop.child_token());
let fps_stop = CancellationToken::new();
let fps_stop_child = fps_stop.child_token();
let fc = stream.frame_counter.clone();
let dropped = stream.dropped_counter.clone();
let out_tx_fps = out_tx.clone();
let fps_handle = tokio::spawn(async move {
let mut tick = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
tokio::select! {
_ = fps_stop_child.cancelled() => break,
_ = tick.tick() => {
let frames = fc.swap(0, Ordering::Relaxed);
let dropped_samples = dropped.swap(0, Ordering::Relaxed);
let _ = out_tx_fps.send(WsMsg::ServerStats(ServerStats {
fps: frames as f64,
dropped_samples,
})).await;
}
}
}
info!("fps reporter stopped");
});
let push_stop = CancellationToken::new();
let push_stop_child = push_stop.child_token();
let push_handle = tokio::spawn(async move {
let mut pts_ns: u64 = 0;
let result: std::io::Result<()> = async {
loop {
tokio::select! {
_ = push_stop_child.cancelled() => break,
msg = cap_rx.recv() => {
let Ok(frame) = msg else { continue };
let dur_ns = (frame.dur.as_nanos() as u64).max(1);
let mut buf = gst::Buffer::with_size(frame.data.len())
.map_err(|_| std::io::Error::other("gst::Buffer::with_size failed"))?;
{
let bm = buf.get_mut().ok_or_else(|| std::io::Error::other("buffer not writable"))?;
{
let mut map = bm.map_writable().map_err(|_| std::io::Error::other("map_writable failed"))?;
map.as_mut_slice().copy_from_slice(&frame.data);
}
bm.set_duration(gst::ClockTime::from_nseconds(dur_ns));
bm.set_pts(gst::ClockTime::from_nseconds(pts_ns));
bm.set_dts(gst::ClockTime::from_nseconds(pts_ns));
}
pts_ns = pts_ns.saturating_add(dur_ns);
if let Err(e) = appsrc.push_buffer(buf) {
warn!("appsrc.push_buffer failed: {e:?}");
}
}
}
}
Ok(())
}.await;
if let Err(e) = result {
warn!("raw->appsrc push task error: {e}");
}
info!("raw->appsrc push task ended");
});
let pump_stop = CancellationToken::new();
let pump_stop_child = pump_stop.child_token();
let pump_handle = tokio::spawn(async move {
if let Err(e) = pump_h264_samples(sample_rx, video_track, ctrl_state, pump_stop_child).await
{
warn!("pump_h264_samples error: {e}");
}
info!("video pump task ended");
});
let audio_pump_stop = CancellationToken::new();
let audio_pump_child = audio_pump_stop.child_token();
let audio_pump_handle = tokio::spawn(async move {
if let Err(e) = pump_opus_samples(audio_rx, audio_track, audio_pump_child).await {
warn!("pump_opus_samples error: {e}");
}
info!("audio pump task ended");
});
Ok(StreamRuntime {
stream,
push_stop,
push_handle,
pump_stop,
pump_handle,
audio_stream: Some(audio_stream),
audio_pump_stop: Some(audio_pump_stop),
audio_pump_handle: Some(audio_pump_handle),
fps_stop,
fps_handle,
bus_stop,
bus_handle,
})
}
async fn stop_stream_runtime(mut rt: StreamRuntime) {
rt.push_stop.cancel();
rt.pump_stop.cancel();
if let Some(s) = rt.audio_pump_stop.as_ref() {
s.cancel();
}
rt.fps_stop.cancel();
rt.bus_stop.cancel();
gst_stop_pipeline_graceful(&rt.stream.pipeline, 1500);
if let Some(a) = rt.audio_stream.as_ref() {
gst_stop_pipeline_graceful(&a.pipeline, 1500);
}
let push_handle = std::mem::replace(&mut rt.push_handle, tokio::spawn(async {}));
let pump_handle = std::mem::replace(&mut rt.pump_handle, tokio::spawn(async {}));
let audio_pump_handle = rt.audio_pump_handle.take();
let fps_handle = std::mem::replace(&mut rt.fps_handle, tokio::spawn(async {}));
let bus_handle = std::mem::replace(&mut rt.bus_handle, tokio::spawn(async {}));
let _ = push_handle.await;
let _ = pump_handle.await;
if let Some(h) = audio_pump_handle {
let _ = h.await;
}
let _ = fps_handle.await;
let _ = bus_handle.await;
info!("stream runtime fully stopped");
}
#[derive(Debug, Clone)]
pub enum SessionEvent {
WsAccepted {
ws_id: u64,
},
WsClosed {
ws_id: u64,
reason: Option<String>,
},
WsProtocolError {
ws_id: u64,
message: String,
},
PcCreated {
ws_id: u64,
pc_id: u64,
},
PcClosed {
ws_id: u64,
pc_id: u64,
reason: Option<String>,
},
PcState {
ws_id: u64,
pc_id: u64,
state: webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState,
},
IceState {
ws_id: u64,
pc_id: u64,
state: webrtc::ice_transport::ice_connection_state::RTCIceConnectionState,
},
IceGatheringState {
ws_id: u64,
pc_id: u64,
state: webrtc::ice_transport::ice_gatherer_state::RTCIceGathererState,
},
DataChannelOpen {
ws_id: u64,
pc_id: u64,
dc_id: u64,
label: String,
},
StreamCtrlApplied {
ws_id: u64,
pc_id: Option<u64>,
prev: StreamCtrl,
next: StreamCtrl,
restarted: bool,
},
}
pub type SessionEventCallback = Arc<dyn Fn(SessionEvent) + Send + Sync + 'static>;
#[derive(Debug, Clone)]
pub enum DataChannelPayload {
Text(String),
Binary(Bytes),
}
pub type DataChannelMessageCallback =
Arc<dyn Fn(u64 /*id*/, DataChannelPayload /*payload*/) + Send + Sync + 'static>;
#[derive(Clone)]
struct WsCtx {
ws_id: u64,
cfg: Arc<ServerConfig>,
out_tx: mpsc::Sender<WsMsg>,
ctrl_state: Arc<RwLock<StreamCtrl>>,
pc: Arc<RwLock<Option<Arc<RTCPeerConnection>>>>,
runtime: Arc<RwLock<Option<StreamRuntime>>>,
track_slot: Arc<RwLock<Option<Arc<TrackLocalStaticSample>>>>,
audio_track_slot: Arc<RwLock<Option<Arc<TrackLocalStaticSample>>>>,
pc_id_slot: Arc<RwLock<Option<u64>>>,
pc_next_id: Arc<AtomicU64>,
on_event: Option<SessionEventCallback>,
dc_next_id: Arc<AtomicU64>,
on_dc_message: Option<DataChannelMessageCallback>,
last_keyreq_ms: Arc<AtomicU64>,
last_bitrate_change_ms: Arc<AtomicU64>,
rtmp: Option<Arc<RtmpBroadcaster>>,
}
async fn handle_ws_json(ctx: WsCtx, text: &str) -> std::io::Result<()> {
let m: WsMsg =
serde_json::from_str(text).map_err(|e| std::io::Error::other(format!("Bad JSON: {e}")))?;
match m {
WsMsg::Offer(offer_sdp) => {
if let Some(old) = ctx.runtime.write().await.take() {
stop_stream_runtime(old).await;
}
if let Some(old_peer) = ctx.pc.write().await.take() {
if let Some(old_pc_id) = *ctx.pc_id_slot.read().await {
if let Some(cb) = ctx.on_event.as_ref() {
cb(SessionEvent::PcClosed {
ws_id: ctx.ws_id,
pc_id: old_pc_id,
reason: Some("replaced by new Offer".into()),
});
}
}
let _ = old_peer.close().await;
}
*ctx.pc_id_slot.write().await = None;
*ctx.track_slot.write().await = None;
*ctx.audio_track_slot.write().await = None;
let video_codec = choose_video_codec_from_offer(&offer_sdp);
let audio_codec = choose_audio_codec_from_offer(&offer_sdp);
let video_pt = find_pt_in_offer(&offer_sdp, video_codec.offer_rtpmap_token())
.unwrap_or(video_codec.default_pt());
let audio_pt = find_pt_in_offer(&offer_sdp, audio_codec.offer_rtpmap_token())
.unwrap_or(audio_codec.default_pt());
let video_fmtp = find_fmtp_in_offer(&offer_sdp, video_pt);
let audio_fmtp = find_fmtp_in_offer(&offer_sdp, audio_pt);
let _ = ctx
.out_tx
.send(WsMsg::Info(format!(
"Selected video codec: {video_codec:?}"
)))
.await;
let _ = ctx
.out_tx
.send(WsMsg::Info(format!(
"Selected audio codec: {audio_codec:?}"
)))
.await;
let _ = ctx
.out_tx
.send(WsMsg::Info(format!("Negotiated video PT: {video_pt}")))
.await;
let _ = ctx
.out_tx
.send(WsMsg::Info(format!("Negotiated audio PT: {audio_pt}")))
.await;
let mut me = MediaEngine::default();
me.register_codec(
RTCRtpCodecParameters {
capability: codec_cap(video_codec, video_fmtp.as_deref()),
payload_type: video_pt,
..Default::default()
},
RTPCodecType::Video,
)
.map_err(|e| std::io::Error::other(format!("register video codec: {e}")))?;
me.register_codec(
RTCRtpCodecParameters {
capability: codec_cap(audio_codec, audio_fmtp.as_deref()),
payload_type: audio_pt,
..Default::default()
},
RTPCodecType::Audio,
)
.map_err(|e| std::io::Error::other(format!("register audio codec: {e}")))?;
let mut registry = Registry::new();
registry =
webrtc::api::interceptor_registry::register_default_interceptors(registry, &mut me)
.map_err(|e| std::io::Error::other(format!("register interceptors: {e}")))?;
use webrtc::api::setting_engine::SettingEngine;
let udp_network = webrtc::ice::udp_network::UDPNetwork::Ephemeral(
webrtc::ice::udp_network::EphemeralUDP::new(ctx.cfg.udp_min, ctx.cfg.udp_max)
.map_err(|e| std::io::Error::other(format!("EphemeralUDP: {e}")))?,
);
let mut se = SettingEngine::default();
se.set_udp_network(udp_network);
let api = APIBuilder::new()
.with_setting_engine(se)
.with_media_engine(me)
.with_interceptor_registry(registry)
.build();
let config = RTCConfiguration {
ice_servers: vec![RTCIceServer {
urls: ctx.cfg.stun_urls.clone(),
..Default::default()
}],
..Default::default()
};
let peer = Arc::new(
api.new_peer_connection(config)
.await
.map_err(|e| std::io::Error::other(format!("new_peer_connection: {e}")))?,
);
let pc_id = ctx.pc_next_id.fetch_add(1, Ordering::Relaxed);
*ctx.pc_id_slot.write().await = Some(pc_id);
*ctx.pc.write().await = Some(peer.clone());
if let Some(cb) = ctx.on_event.as_ref() {
cb(SessionEvent::PcCreated {
ws_id: ctx.ws_id,
pc_id,
});
}
{
let out_tx = ctx.out_tx.clone();
let on_event = ctx.on_event.clone();
let ws_id = ctx.ws_id;
peer.on_peer_connection_state_change(Box::new(move |s| {
let out_tx = out_tx.clone();
let on_event = on_event.clone();
Box::pin(async move {
info!("pc state: {s:?}");
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::PcState {
ws_id,
pc_id,
state: s,
});
}
let _ = out_tx.send(WsMsg::Info(format!("PC state: {s:?}"))).await;
})
}));
}
{
let out_tx = ctx.out_tx.clone();
let on_event = ctx.on_event.clone();
let ws_id = ctx.ws_id;
peer.on_ice_connection_state_change(Box::new(move |s| {
let out_tx = out_tx.clone();
let on_event = on_event.clone();
Box::pin(async move {
info!("ice conn state: {s:?}");
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::IceState {
ws_id,
pc_id,
state: s,
});
}
let _ = out_tx
.send(WsMsg::Info(format!("ICE conn state: {s:?}")))
.await;
})
}));
}
{
let out_tx = ctx.out_tx.clone();
let on_event = ctx.on_event.clone();
let ws_id = ctx.ws_id;
peer.on_ice_gathering_state_change(Box::new(move |s| {
let out_tx = out_tx.clone();
let on_event = on_event.clone();
Box::pin(async move {
info!("ice gathering state: {s:?}");
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::IceGatheringState {
ws_id,
pc_id,
state: s,
});
}
let _ = out_tx
.send(WsMsg::Info(format!("ICE gathering: {s:?}")))
.await;
})
}));
}
{
let out_tx = ctx.out_tx.clone();
peer.on_ice_candidate(Box::new(move |c| {
let out_tx = out_tx.clone();
Box::pin(async move {
let Some(c) = c else { return };
match c.to_json() {
Ok(ice_init) => {
let wire = IceCandidateWire {
candidate: ice_init.candidate,
sdp_mid: ice_init.sdp_mid,
sdp_mline_index: ice_init.sdp_mline_index,
username_fragment: ice_init.username_fragment,
};
let _ = out_tx.send(WsMsg::Ice(wire)).await;
}
Err(e) => {
let _ = out_tx
.send(WsMsg::Error(format!("ICE to_json failed: {e}")))
.await;
}
}
})
}));
}
{
let dc_next_id = ctx.dc_next_id.clone();
let on_dc_message = ctx.on_dc_message.clone();
let on_event_outer = ctx.on_event.clone();
let ws_id = ctx.ws_id;
peer.on_data_channel(Box::new(move |dc| {
let dc_next_id = dc_next_id.clone();
let on_dc_message = on_dc_message.clone();
let on_event = on_event_outer.clone();
Box::pin(async move {
let dc_id = dc_next_id.fetch_add(1, Ordering::Relaxed);
let label = dc.label().to_string();
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::DataChannelOpen {
ws_id,
pc_id,
dc_id,
label: label.clone(),
});
}
info!("[dc#{dc_id}] opened label={label}");
let cb = on_dc_message.clone();
dc.on_message(Box::new(move |msg: DataChannelMessage| {
let cb = cb.clone();
Box::pin(async move {
let Some(cb) = cb.as_ref() else { return };
if msg.is_string {
match String::from_utf8(msg.data.to_vec()) {
Ok(s) => cb(dc_id, DataChannelPayload::Text(s)),
Err(_) => cb(
dc_id,
DataChannelPayload::Binary(Bytes::copy_from_slice(
&msg.data,
)),
),
}
} else {
cb(
dc_id,
DataChannelPayload::Binary(Bytes::copy_from_slice(
&msg.data,
)),
);
}
})
}));
})
}));
}
let video_track = Arc::new(TrackLocalStaticSample::new(
codec_cap(video_codec, video_fmtp.as_deref()),
"video".to_string(),
"desktop".to_string(),
));
peer.add_track(video_track.clone())
.await
.map_err(|e| std::io::Error::other(format!("add video track: {e}")))?;
let audio_track = Arc::new(TrackLocalStaticSample::new(
codec_cap(audio_codec, audio_fmtp.as_deref()),
"audio".to_string(),
"default".to_string(),
));
peer.add_track(audio_track.clone())
.await
.map_err(|e| std::io::Error::other(format!("add audio track: {e}")))?;
peer.set_remote_description(
webrtc::peer_connection::sdp::session_description::RTCSessionDescription::offer(
offer_sdp,
)
.map_err(|e| std::io::Error::other(format!("offer parse: {e}")))?,
)
.await
.map_err(|e| std::io::Error::other(format!("set_remote_description: {e}")))?;
let answer = peer
.create_answer(None)
.await
.map_err(|e| std::io::Error::other(format!("create_answer: {e}")))?;
peer.set_local_description(answer)
.await
.map_err(|e| std::io::Error::other(format!("set_local_description: {e}")))?;
if let Some(local) = peer.local_description().await {
let _ = ctx.out_tx.send(WsMsg::Answer(local.sdp)).await;
}
let ctrl_now = ctx.ctrl_state.read().await.clone();
let rt = start_stream_runtime(
ctrl_now,
ctx.ctrl_state.clone(),
video_track.clone(),
audio_track.clone(),
ctx.rtmp.clone(),
ctx.out_tx.clone(),
)
.await?;
*ctx.track_slot.write().await = Some(video_track);
*ctx.audio_track_slot.write().await = Some(audio_track);
*ctx.runtime.write().await = Some(rt);
let _ = ctx
.out_tx
.send(WsMsg::Info("Streaming started".into()))
.await;
}
WsMsg::Ice(cand) => {
if let Some(peer) = ctx.pc.read().await.as_ref() {
let c = webrtc::ice_transport::ice_candidate::RTCIceCandidateInit {
candidate: cand.candidate,
sdp_mid: cand.sdp_mid,
sdp_mline_index: cand.sdp_mline_index,
username_fragment: cand.username_fragment,
};
if let Err(e) = peer.add_ice_candidate(c).await {
let _ = ctx
.out_tx
.send(WsMsg::Error(format!("add_ice_candidate failed: {e}")))
.await;
}
} else {
let _ = ctx
.out_tx
.send(WsMsg::Error("ICE received but peer is not ready".into()))
.await;
}
}
WsMsg::ClientStats(st) => {
let loss = st.loss.unwrap_or(0.0);
let avail = st.available_in_bps.unwrap_or(f64::INFINITY);
if loss > 0.02 {
let now_ms = utc_ms_now();
let last = ctx.last_keyreq_ms.load(Ordering::Relaxed);
if now_ms.saturating_sub(last) >= 800 {
ctx.last_keyreq_ms.store(now_ms, Ordering::Relaxed);
if let Some(rt) = ctx.runtime.read().await.as_ref() {
request_keyframe(&rt.stream.pipeline);
}
}
}
let now_ms = utc_ms_now();
let last = ctx.last_bitrate_change_ms.load(Ordering::Relaxed);
if now_ms.saturating_sub(last) >= 1200 {
let target_kbps = ctx.ctrl_state.read().await.bitrate_kbps.max(300);
let target_bps = (target_kbps as f64) * 1000.0;
let low = avail < 1.15 * target_bps;
let critical = avail < 0.90 * target_bps;
if critical || low {
let mut stc = ctx.ctrl_state.write().await;
let old = stc.bitrate_kbps;
let factor = if critical { 0.75 } else { 0.85 };
stc.bitrate_kbps = ((stc.bitrate_kbps as f64) * factor).round() as i32;
stc.bitrate_kbps = stc.bitrate_kbps.clamp(600, 12_000);
let new = stc.bitrate_kbps;
drop(stc);
if new != old {
ctx.last_bitrate_change_ms.store(now_ms, Ordering::Relaxed);
if let Some(rt) = ctx.runtime.read().await.as_ref() {
let _ = apply_ctrl(&rt.stream, &*ctx.ctrl_state.read().await);
}
}
} else if avail > 1.6 * target_bps {
let mut stc = ctx.ctrl_state.write().await;
let old = stc.bitrate_kbps;
stc.bitrate_kbps = ((stc.bitrate_kbps as f64) * 1.10).round() as i32;
stc.bitrate_kbps = stc.bitrate_kbps.clamp(600, 12_000);
let new = stc.bitrate_kbps;
drop(stc);
if new != old {
ctx.last_bitrate_change_ms.store(now_ms, Ordering::Relaxed);
if let Some(rt) = ctx.runtime.read().await.as_ref() {
let _ = apply_ctrl(&rt.stream, &*ctx.ctrl_state.read().await);
}
}
}
}
}
WsMsg::Ctrl {
width,
height,
fps,
bitrate_kbps,
} => {
let prev = ctx.ctrl_state.read().await.clone();
{
let mut st = ctx.ctrl_state.write().await;
st.width = width;
st.height = height;
st.fps = fps;
st.bitrate_kbps = bitrate_kbps;
}
let next = ctx.ctrl_state.read().await.clone();
let need_restart = ctrl_needs_restart(&prev, &next);
if let Some(cb) = ctx.on_event.as_ref() {
cb(SessionEvent::StreamCtrlApplied {
ws_id: ctx.ws_id,
pc_id: *ctx.pc_id_slot.read().await,
prev: prev.clone(),
next: next.clone(),
restarted: need_restart,
});
}
if need_restart {
let _ = ctx
.out_tx
.send(WsMsg::Info(
"CTRL requires restart (size/fps changed)".into(),
))
.await;
if let Some(old) = ctx.runtime.write().await.take() {
stop_stream_runtime(old).await;
}
let Some(video_track) = ctx.track_slot.read().await.clone() else {
let _ = ctx
.out_tx
.send(WsMsg::Error(
"Cannot restart: track not initialized. Send Offer first.".into(),
))
.await;
return Ok(());
};
let Some(audio_track) = ctx.audio_track_slot.read().await.clone() else {
let _ = ctx
.out_tx
.send(WsMsg::Error(
"Cannot restart: audio track not initialized. Send Offer first.".into(),
))
.await;
return Ok(());
};
let rt = start_stream_runtime(
next.clone(),
ctx.ctrl_state.clone(),
video_track,
audio_track,
ctx.rtmp.clone(),
ctx.out_tx.clone(),
)
.await?;
*ctx.runtime.write().await = Some(rt);
let _ = ctx
.out_tx
.send(WsMsg::Info(format!(
"CTRL applied with restart: {}x{}@{} bitrate={}kbps",
next.width, next.height, next.fps, next.bitrate_kbps
)))
.await;
} else if let Some(rt) = ctx.runtime.read().await.as_ref() {
if let Err(e) = apply_ctrl(&rt.stream, &next) {
let _ = ctx
.out_tx
.send(WsMsg::Error(format!("apply_ctrl failed: {e}")))
.await;
} else {
request_keyframe(&rt.stream.pipeline);
let _ = ctx
.out_tx
.send(WsMsg::Info(format!(
"CTRL applied: {}x{}@{} bitrate={}kbps",
next.width, next.height, next.fps, next.bitrate_kbps
)))
.await;
}
} else {
let _ = ctx
.out_tx
.send(WsMsg::Error(
"CTRL received but stream is not running".into(),
))
.await;
}
}
_ => {}
}
Ok(())
}
pub struct Server {
cfg: ServerConfig,
initial_ctrl: StreamCtrl,
index: Option<Bytes>,
dc_next_id: Arc<AtomicU64>,
on_dc_message: Option<DataChannelMessageCallback>,
ws_next_id: Arc<AtomicU64>,
pc_next_id: Arc<AtomicU64>,
on_event: Option<SessionEventCallback>,
rtmp: Option<Arc<RtmpBroadcaster>>,
}
impl Default for Server {
fn default() -> Self {
Server::new(ServerConfig::default(), StreamCtrl::default(), None, None)
}
}
impl Server {
pub fn new(
cfg: ServerConfig,
initial_ctrl: StreamCtrl,
index: Option<Bytes>,
rtmp: Option<RtmpBroadcaster>,
) -> Self {
Self {
cfg,
initial_ctrl,
index,
dc_next_id: Arc::new(AtomicU64::new(1)),
on_dc_message: None,
ws_next_id: Arc::new(AtomicU64::new(1)),
pc_next_id: Arc::new(AtomicU64::new(1)),
on_event: None,
rtmp: rtmp.map(Arc::new),
}
}
pub fn set_on_dc_message(&mut self, cb: DataChannelMessageCallback) {
self.on_dc_message = Some(cb);
}
pub fn set_on_event(&mut self, cb: SessionEventCallback) {
self.on_event = Some(cb);
}
}
#[async_trait::async_trait(?Send)]
impl HAsyncService for Server {
async fn call<S: Session>(&mut self, session: &mut S) -> std::io::Result<()> {
if !session.is_ws() {
if let Some(index) = self.index.clone() {
return session
.status_code(http::StatusCode::OK)
.header(
http::header::CONTENT_LENGTH,
http::HeaderValue::from_str(&index.len().to_string()).map_err(|e| {
std::io::Error::other(format!("Invalid header value: {e}"))
})?,
)?
.body(index)
.eom();
}
session
.status_code(http::StatusCode::NOT_FOUND)
.body(bytes::Bytes::from_static(b"WebRTC index page not found"))
.eom()?;
return Ok(());
}
if let Err(e) = session.ws_accept_async().await {
session
.status_code(http::StatusCode::BAD_REQUEST)
.header_str("Connection", "close")?
.eom()?;
return Err(e);
}
let ws_id = self.ws_next_id.fetch_add(1, Ordering::Relaxed);
active_ws().fetch_add(1, Ordering::Relaxed);
let on_event = self.on_event.clone();
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::WsAccepted { ws_id });
}
let (out_tx, mut out_rx) = mpsc::channel::<WsMsg>(64);
let pc: Arc<RwLock<Option<Arc<RTCPeerConnection>>>> = Arc::new(RwLock::new(None));
let runtime: Arc<RwLock<Option<StreamRuntime>>> = Arc::new(RwLock::new(None));
let track_slot: Arc<RwLock<Option<Arc<TrackLocalStaticSample>>>> =
Arc::new(RwLock::new(None));
let audio_track_slot: Arc<RwLock<Option<Arc<TrackLocalStaticSample>>>> =
Arc::new(RwLock::new(None));
let ctrl_state: Arc<RwLock<StreamCtrl>> = Arc::new(RwLock::new(self.initial_ctrl.clone()));
let pc_id_slot: Arc<RwLock<Option<u64>>> = Arc::new(RwLock::new(None));
let cfg = Arc::new(self.cfg.clone());
let ctx = WsCtx {
ws_id,
cfg,
out_tx: out_tx.clone(),
ctrl_state: ctrl_state.clone(),
pc: pc.clone(),
runtime: runtime.clone(),
track_slot: track_slot.clone(),
audio_track_slot: audio_track_slot.clone(),
pc_id_slot: pc_id_slot.clone(),
pc_next_id: self.pc_next_id.clone(),
on_event: on_event.clone(),
dc_next_id: self.dc_next_id.clone(),
on_dc_message: self.on_dc_message.clone(),
last_keyreq_ms: Arc::new(AtomicU64::new(0)),
last_bitrate_change_ms: Arc::new(AtomicU64::new(0)),
rtmp: self.rtmp.clone(),
};
let _ = out_tx.send(WsMsg::Info("WS connected".into())).await;
let mut frag_buf = BytesMut::new();
let mut expecting_continuation = false;
let mut initial_is_text = false;
let err_protocol = Bytes::from_static(b"protocol error");
let err_unexpected = Bytes::from_static(b"unexpected continue");
let err_utf8 = Bytes::from_static(b"invalid utf8");
loop {
tokio::select! {
incoming = session.ws_read_async() => {
let (code, payload, fin) = incoming?;
match code {
OpCode::Ping => session.ws_write_async(OpCode::Pong, payload, true).await?,
OpCode::Pong => {}
OpCode::Close => {
session.ws_write_async(OpCode::Close, payload, true).await?;
break;
}
OpCode::Text | OpCode::Binary => {
if expecting_continuation {
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::WsProtocolError { ws_id, message: "expected continuation frame".into() });
}
session.ws_close_async(Some(err_protocol)).await?;
break;
}
if !fin {
frag_buf.clear();
frag_buf.extend_from_slice(payload.as_ref());
expecting_continuation = true;
initial_is_text = matches!(code, OpCode::Text);
continue;
}
if matches!(code, OpCode::Binary) {
continue;
}
let text = match std::str::from_utf8(payload.as_ref()) {
Ok(s) => s,
Err(_) => {
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::WsProtocolError { ws_id, message: "invalid utf8".into() });
}
session.ws_close_async(Some(err_utf8)).await?;
break;
}
};
if let Err(e) = handle_ws_json(ctx.clone(), text).await {
let _ = out_tx.send(WsMsg::Error(format!("{e}"))).await;
}
}
OpCode::Continue => {
if !expecting_continuation {
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::WsProtocolError { ws_id, message: "unexpected continuation".into() });
}
session.ws_close_async(Some(err_unexpected)).await?;
break;
}
frag_buf.extend_from_slice(payload.as_ref());
if fin {
let whole = frag_buf.as_ref();
if initial_is_text {
let text = match std::str::from_utf8(whole) {
Ok(s) => s,
Err(_) => {
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::WsProtocolError { ws_id, message: "invalid utf8".into() });
}
session.ws_close_async(Some(err_utf8)).await?;
break;
}
};
if let Err(e) = handle_ws_json(ctx.clone(), text).await {
let _ = out_tx.send(WsMsg::Error(format!("{e}"))).await;
}
}
frag_buf.clear();
expecting_continuation = false;
initial_is_text = false;
}
}
}
}
opt = out_rx.recv() => {
let Some(m) = opt else { break };
let bytes = match serde_json::to_vec(&m) {
Ok(v) => v,
Err(e) => {
format!(r#"{{"type":"Error","data":"json encode failed: {e}"}}"#).into_bytes()
}
};
session.ws_write_async(OpCode::Text, Bytes::from(bytes), true).await?;
}
}
}
if let Some(cb) = on_event.as_ref() {
cb(SessionEvent::WsClosed {
ws_id,
reason: Some("ws loop ended".into()),
});
}
if let Some(rt) = runtime.write().await.take() {
stop_stream_runtime(rt).await;
}
if let Some(peer) = pc.write().await.take() {
if let Some(cb) = on_event.as_ref() {
if let Some(pc_id) = *pc_id_slot.read().await {
cb(SessionEvent::PcClosed {
ws_id,
pc_id,
reason: Some("call() cleanup".into()),
});
}
}
if let Err(e) = peer.close().await {
warn!("peer.close failed: {e}");
}
}
if active_ws().fetch_sub(1, Ordering::Relaxed) == 1 {
maybe_stop_capture_hub().await;
}
Ok(())
}
}
impl HFactory for Server {
#[cfg(any(feature = "net-h2-server", feature = "net-h3-server"))]
type HAsyncService = Self;
#[cfg(any(feature = "net-h2-server", feature = "net-h3-server"))]
fn async_service(&self, _id: usize) -> Self::HAsyncService {
Server {
cfg: self.cfg.clone(),
initial_ctrl: self.initial_ctrl.clone(),
index: self.index.clone(),
dc_next_id: self.dc_next_id.clone(),
on_dc_message: self.on_dc_message.clone(),
ws_next_id: self.ws_next_id.clone(),
pc_next_id: self.pc_next_id.clone(),
on_event: self.on_event.clone(),
rtmp: self.rtmp.clone(),
}
}
}
#[cfg(test)]
pub mod tests {
use crate::network::http::server::{H2Config, HFactory};
use crate::stream::webrtc::{DataChannelPayload, Server};
use bytes::Bytes;
use tracing::info;
#[test]
fn test_webrtc() {
let cancel_token = tokio_util::sync::CancellationToken::new();
let mtls = crate::MtlsIdentity::generate(&[], &[], false);
let html_file = std::path::Path::new(file!())
.parent()
.unwrap()
.join("webrtc.html");
crate::stream::init().expect("webRTC init failed");
const ADDRESS_PORT: &str = "127.0.0.1:8080";
let mut webrtc_server = Server::new(
Default::default(),
Default::default(),
std::fs::read(html_file).ok().map(Bytes::from),
None, );
webrtc_server.set_on_dc_message(std::sync::Arc::new(|dc_id, payload| match payload {
DataChannelPayload::Text(s) => info!("[dc#{dc_id}] TEXT: {}", s),
DataChannelPayload::Binary(b) => info!("[dc#{dc_id}] BIN: {} bytes", b.len()),
}));
webrtc_server.set_on_event(std::sync::Arc::new(|ev| {
info!("[event] {:?}", ev);
}));
webrtc_server
.start_h2_tls(
ADDRESS_PORT,
(
Some(mtls.ca_cert_pem.as_bytes()),
mtls.server_cert_pem.as_bytes(),
mtls.server_key_pem.as_bytes(),
),
H2Config::default(),
cancel_token,
)
.expect("start_webrtc_server failed");
}
}