use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::broadcast;
use super::registry::MountPointRegistry;
use super::state::{RtspServerConfig, RtspSession, RtspSessionState};
use crate::rtsp::message::{try_parse_request, Headers, RequestParseStatus, Response};
use crate::rtsp::transport::encode_interleaved;
use crate::rtsp::Method;
const READ_BUF_CAPACITY: usize = 65536;
const CHUNK_SIZE: usize = 4096;
const SERVER_NAME: &str = concat!("oximedia-net-rtsp/", env!("CARGO_PKG_VERSION"));
pub struct ServerConnection {
stream: TcpStream,
config: Arc<RtspServerConfig>,
registry: MountPointRegistry,
session: Option<RtspSession>,
read_buf: Vec<u8>,
rtp_stop: Option<Arc<AtomicBool>>,
}
impl ServerConnection {
#[must_use]
pub fn new(
stream: TcpStream,
config: Arc<RtspServerConfig>,
registry: MountPointRegistry,
) -> Self {
Self {
stream,
config,
registry,
session: None,
read_buf: Vec::with_capacity(READ_BUF_CAPACITY),
rtp_stop: None,
}
}
pub async fn run(mut self) {
let _ = self.run_inner().await;
if let Some(flag) = self.rtp_stop.take() {
flag.store(true, Ordering::Relaxed);
}
}
async fn run_inner(&mut self) -> Result<(), ()> {
loop {
match try_parse_request(&self.read_buf).map_err(|_| ())? {
RequestParseStatus::Complete { consumed, request } => {
self.read_buf.drain(..consumed);
let cseq = request
.headers
.get("CSeq")
.and_then(|v| v.trim().parse::<u32>().ok())
.unwrap_or(0);
let response = self.dispatch(&request, cseq).await;
let wire = response.encode();
self.stream.write_all(&wire).await.map_err(|_| ())?;
self.stream.flush().await.map_err(|_| ())?;
}
RequestParseStatus::NeedMore => {
let mut chunk = [0u8; CHUNK_SIZE];
let n = self.stream.read(&mut chunk).await.map_err(|_| ())?;
if n == 0 {
return Err(()); }
self.read_buf.extend_from_slice(&chunk[..n]);
}
}
}
}
async fn dispatch(&mut self, req: &crate::rtsp::message::Request, cseq: u32) -> Response {
match req.method {
Method::Options => self.handle_options(cseq),
Method::Describe => self.handle_describe(&req.uri, cseq),
Method::Setup => self.handle_setup(&req.uri, req.headers.get("Transport"), cseq),
Method::Play => self.handle_play(cseq).await,
Method::Pause => self.handle_pause(cseq),
Method::Teardown => self.handle_teardown(cseq),
Method::GetParameter => self.handle_get_parameter(cseq),
_ => {
let mut r = Response::build(501, cseq);
add_server_header(&mut r.headers);
r
}
}
}
fn handle_options(&self, cseq: u32) -> Response {
let mut r = Response::build(200, cseq);
r.headers.insert(
"Public",
"OPTIONS, DESCRIBE, SETUP, PLAY, PAUSE, TEARDOWN, GET_PARAMETER",
);
add_server_header(&mut r.headers);
r
}
fn handle_describe(&self, uri: &str, cseq: u32) -> Response {
let path = extract_path(uri);
match self.registry.lookup(&path) {
Some(mp) => {
let sdp_bytes = mp.sdp.as_bytes().to_vec();
let sdp_len = sdp_bytes.len();
let mut r = Response::build(200, cseq);
r.headers.insert("Content-Type", "application/sdp");
r.headers.insert("Content-Base", uri);
r.headers.insert("Content-Length", sdp_len.to_string());
add_server_header(&mut r.headers);
r.body = sdp_bytes;
r
}
None => {
let mut r = Response::build(404, cseq);
add_server_header(&mut r.headers);
r
}
}
}
fn handle_setup(&mut self, uri: &str, transport_header: Option<&str>, cseq: u32) -> Response {
let (rtp_ch, rtcp_ch) = transport_header
.and_then(parse_interleaved_channels)
.unwrap_or((0, 1));
let path = extract_path(uri);
if self.registry.lookup(&path).is_none() {
let parent = parent_path(&path);
if self.registry.lookup(&parent).is_none() {
let mut r = Response::build(404, cseq);
add_server_header(&mut r.headers);
return r;
}
self.create_session(parent, rtp_ch);
} else {
self.create_session(path, rtp_ch);
}
let session_id = self
.session
.as_ref()
.map(|s| s.id.clone())
.unwrap_or_default();
let timeout_secs = self.config.session_timeout.as_secs();
let mut r = Response::build(200, cseq);
r.headers.insert(
"Transport",
format!("RTP/AVP/TCP;unicast;interleaved={rtp_ch}-{rtcp_ch}"),
);
r.headers
.insert("Session", format!("{session_id};timeout={timeout_secs}"));
add_server_header(&mut r.headers);
r
}
async fn handle_play(&mut self, cseq: u32) -> Response {
let (session_id, mount_path, channel_id) = match self.session.as_mut() {
Some(s)
if s.state == RtspSessionState::Ready || s.state == RtspSessionState::Paused =>
{
let id = s.id.clone();
let path = s.mount_path.clone();
let ch = s.channel_id;
s.state = RtspSessionState::Playing;
s.refresh();
(id, path, ch)
}
Some(s) if s.state == RtspSessionState::Playing => {
let id = s.id.clone();
let timeout_secs = self.config.session_timeout.as_secs();
s.refresh();
let mut r = Response::build(200, cseq);
r.headers
.insert("Session", format!("{id};timeout={timeout_secs}"));
add_server_header(&mut r.headers);
return r;
}
_ => {
let mut r = Response::build(455, cseq); add_server_header(&mut r.headers);
return r;
}
};
if let Some(flag) = self.rtp_stop.take() {
flag.store(true, Ordering::Relaxed);
}
if let Some(mp) = self.registry.lookup(&mount_path) {
let stop = Arc::new(AtomicBool::new(false));
self.rtp_stop = Some(Arc::clone(&stop));
let rx = mp.subscribe();
let (reader_stream, mut writer) = self.stream.split();
let _ = reader_stream;
let (tx, mut rtp_rx) = tokio::sync::mpsc::channel::<Vec<u8>>(64);
let stop_clone = Arc::clone(&stop);
let mut broadcast_rx = rx;
tokio::spawn(async move {
loop {
if stop_clone.load(Ordering::Relaxed) {
break;
}
match broadcast_rx.recv().await {
Ok(rtp_bytes) => {
let framed = encode_interleaved(channel_id, &rtp_bytes);
if tx.send(framed).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Closed) => break,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
}
}
});
while let Ok(frame) = rtp_rx.try_recv() {
if writer.write_all(&frame).await.is_err() {
break;
}
}
let _ = rtp_rx; }
let timeout_secs = self.config.session_timeout.as_secs();
let mut r = Response::build(200, cseq);
r.headers
.insert("Session", format!("{session_id};timeout={timeout_secs}"));
r.headers.insert("Range", "npt=0.000-");
add_server_header(&mut r.headers);
r
}
fn handle_pause(&mut self, cseq: u32) -> Response {
if let Some(flag) = self.rtp_stop.take() {
flag.store(true, Ordering::Relaxed);
}
match self.session.as_mut() {
Some(s) if s.state == RtspSessionState::Playing => {
let id = s.id.clone();
let timeout_secs = self.config.session_timeout.as_secs();
s.state = RtspSessionState::Paused;
s.refresh();
let mut r = Response::build(200, cseq);
r.headers
.insert("Session", format!("{id};timeout={timeout_secs}"));
add_server_header(&mut r.headers);
r
}
_ => {
let mut r = Response::build(455, cseq);
add_server_header(&mut r.headers);
r
}
}
}
fn handle_teardown(&mut self, cseq: u32) -> Response {
if let Some(flag) = self.rtp_stop.take() {
flag.store(true, Ordering::Relaxed);
}
self.session = None;
let mut r = Response::build(200, cseq);
add_server_header(&mut r.headers);
r
}
fn handle_get_parameter(&mut self, cseq: u32) -> Response {
let session_line = match self.session.as_mut() {
Some(s) => {
s.refresh();
let timeout_secs = self.config.session_timeout.as_secs();
format!("{};timeout={timeout_secs}", s.id)
}
None => String::new(),
};
let mut r = Response::build(200, cseq);
if !session_line.is_empty() {
r.headers.insert("Session", session_line);
}
add_server_header(&mut r.headers);
r
}
fn create_session(&mut self, mount_path: String, channel_id: u8) {
let id = generate_session_id();
let timeout = self.config.session_timeout;
let mut session = RtspSession::new(id, mount_path, channel_id, timeout);
session.state = RtspSessionState::Ready;
self.session = Some(session);
}
}
fn add_server_header(headers: &mut Headers) {
headers.insert("Server", SERVER_NAME);
}
fn extract_path(uri: &str) -> String {
if let Some(rest) = uri
.strip_prefix("rtsp://")
.or_else(|| uri.strip_prefix("rtsps://"))
{
if let Some(idx) = rest.find('/') {
return rest[idx..].to_string();
}
return "/".to_string();
}
if uri.starts_with('/') {
uri.to_string()
} else {
format!("/{uri}")
}
}
fn parent_path(path: &str) -> String {
match path.rfind('/') {
Some(0) | None => "/".to_string(),
Some(idx) => path[..idx].to_string(),
}
}
fn parse_interleaved_channels(transport: &str) -> Option<(u8, u8)> {
for part in transport.split(';') {
let part = part.trim();
if let Some(rest) = part.strip_prefix("interleaved=") {
let mut nums = rest.splitn(2, '-');
let a = nums.next()?.trim().parse::<u8>().ok()?;
let b = nums.next()?.trim().parse::<u8>().ok()?;
return Some((a, b));
}
}
None
}
fn generate_session_id() -> String {
static SEQ: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_micros() as u64)
.unwrap_or(0);
let seq = SEQ.fetch_add(1, Ordering::Relaxed);
format!("{:012x}{:04x}", ts & 0xFFFF_FFFF_FFFF, seq & 0xFFFF)
}
async fn write_response(stream: &mut TcpStream, resp: &Response) -> Result<(), ()> {
let wire = resp.encode();
stream.write_all(&wire).await.map_err(|_| ())?;
stream.flush().await.map_err(|_| ())?;
Ok(())
}
pub async fn forward_rtp_frames(
stream: &mut TcpStream,
receiver: &mut broadcast::Receiver<Arc<Vec<u8>>>,
channel_id: u8,
count: usize,
) -> Result<usize, ()> {
let mut forwarded = 0;
for _ in 0..count {
match receiver.recv().await {
Ok(rtp_bytes) => {
let framed = encode_interleaved(channel_id, &rtp_bytes);
stream.write_all(&framed).await.map_err(|_| ())?;
stream.flush().await.map_err(|_| ())?;
forwarded += 1;
}
Err(broadcast::error::RecvError::Closed) => break,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
}
}
Ok(forwarded)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_path_from_rtsp_uri() {
assert_eq!(
extract_path("rtsp://cam.local:554/live/stream"),
"/live/stream"
);
assert_eq!(extract_path("rtsp://cam/"), "/");
assert_eq!(extract_path("rtsp://cam"), "/");
assert_eq!(extract_path("/stream"), "/stream");
}
#[test]
fn parent_path_strips_last_segment() {
assert_eq!(parent_path("/stream/trackID=1"), "/stream");
assert_eq!(parent_path("/stream"), "/");
assert_eq!(parent_path("/"), "/");
}
#[test]
fn parse_interleaved_channels_from_transport() {
assert_eq!(
parse_interleaved_channels("RTP/AVP/TCP;unicast;interleaved=0-1"),
Some((0, 1))
);
assert_eq!(
parse_interleaved_channels("RTP/AVP/TCP;interleaved=2-3;unicast"),
Some((2, 3))
);
assert_eq!(parse_interleaved_channels("RTP/AVP/UDP;unicast"), None);
}
#[test]
fn generate_session_id_is_unique() {
let id1 = generate_session_id();
let id2 = generate_session_id();
assert_ne!(id1, id2);
assert!(!id1.is_empty());
}
#[test]
fn options_response_contains_public_header() {
let _registry = MountPointRegistry::new();
let _config = Arc::new(RtspServerConfig::default());
let mut headers = Headers::new();
add_server_header(&mut headers);
assert_eq!(headers.get("server"), Some(SERVER_NAME));
}
}