use std::net::SocketAddr;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use crate::inbound::{InboundProtocol, IngestContext, PublishSession};
use crate::protocol::tsdemux::{TsDemuxer, TsTrackKind};
use crate::{CodecId, MediaFrame, Result};
pub struct UdpTsHandler {
bind: SocketAddr,
key: crate::StreamKey,
recv_buf: usize,
}
impl UdpTsHandler {
pub fn new(bind: SocketAddr, key: crate::StreamKey) -> Self {
Self {
bind,
key,
recv_buf: 2048,
}
}
pub fn recv_buffer(mut self, bytes: usize) -> Self {
self.recv_buf = bytes.max(188);
self
}
}
#[async_trait]
impl InboundProtocol for UdpTsHandler {
fn name(&self) -> &'static str {
"udp"
}
async fn serve(&self, ctx: IngestContext, shutdown: CancellationToken) -> Result<()> {
use tokio::net::UdpSocket;
let socket = UdpSocket::bind(self.bind).await?;
info!(bind = %self.bind, "udp ts listener bound");
let mut buf = vec![0u8; self.recv_buf];
let mut demux = TsDemuxer::new();
let mut session: Option<PublishSession> = None;
loop {
let n = tokio::select! {
_ = shutdown.cancelled() => break,
r = socket.recv_from(&mut buf) => match r {
Ok((n, _from)) => n,
Err(e) => {
warn!(error = %e, "udp recv failed");
continue;
}
}
};
for au in demux.push(&buf[..n]) {
if au.codec == CodecId::Unknown {
continue;
}
if session.is_none() {
session = Some(ctx.open_publish(self.key.clone()).await?);
}
let sess = session.as_ref().unwrap();
let pts = au.pts_ms;
let mut frame = match au.kind {
TsTrackKind::Video => {
MediaFrame::new_video(pts, pts, au.data, au.codec, au.keyframe)
}
TsTrackKind::Audio => MediaFrame::new_audio(pts, au.data, au.codec),
};
if au.is_config {
frame.flags |= crate::FrameFlags::CONFIG;
}
let _ = sess.publish_frame(frame)?;
}
}
if let Some(sess) = session {
sess.finish().await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Engine;
use std::sync::Arc;
#[tokio::test]
async fn binds_and_shuts_down_cleanly() {
let engine: Arc<Engine> = Engine::builder()
.application(crate::AppSpec::new("live"))
.build();
let ctx = IngestContext::new(engine);
let handler = UdpTsHandler::new(
"127.0.0.1:0".parse().unwrap(),
crate::StreamKey::new("live", "feed"),
);
let shutdown = CancellationToken::new();
let token = shutdown.clone();
let task = tokio::spawn(async move { handler.serve(ctx, token).await });
tokio::task::yield_now().await;
shutdown.cancel();
let res = tokio::time::timeout(std::time::Duration::from_secs(5), task)
.await
.expect("serve returned after cancel")
.expect("task joined");
assert!(res.is_ok(), "clean shutdown");
}
}