use std::collections::VecDeque;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use super::auth::{Challenge, Credentials};
use super::message::{try_parse_response, Method, ParseStatus, Request, Response};
use super::sdp::SessionDescription;
use super::transport::{next_frame, FrameStatus, InterleavedPacket};
use super::url::RtspUrl;
use crate::error::NetError;
pub const DEFAULT_RTSP_PORT: u16 = 554;
pub const USER_AGENT: &str = concat!("oximedia-net/", env!("CARGO_PKG_VERSION"));
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub io_timeout: Duration,
pub user_agent: String,
pub credentials: Option<Credentials>,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
io_timeout: Duration::from_secs(15),
user_agent: USER_AGENT.to_string(),
credentials: None,
}
}
}
#[derive(Debug, Clone)]
pub struct SetupTransport {
pub interleaved_rtp: u8,
pub interleaved_rtcp: u8,
}
impl SetupTransport {
#[must_use]
pub fn tcp_interleaved(rtp_channel: u8) -> Self {
Self {
interleaved_rtp: rtp_channel,
interleaved_rtcp: rtp_channel.wrapping_add(1),
}
}
fn header_value(&self) -> String {
format!(
"RTP/AVP/TCP;unicast;interleaved={}-{}",
self.interleaved_rtp, self.interleaved_rtcp
)
}
}
#[derive(Debug, Clone)]
pub struct SetupResponse {
pub session: String,
pub timeout: u64,
pub transport: String,
}
#[derive(Debug, Clone)]
pub enum ServerEvent {
Packet(InterleavedPacket),
Message(Response),
}
pub struct RtspClient {
stream: TcpStream,
url: RtspUrl,
cfg: ClientConfig,
cseq: u32,
session: Option<String>,
session_timeout: u64,
challenge: Option<Challenge>,
nc: u32,
rx_buf: Vec<u8>,
pending_events: VecDeque<ServerEvent>,
}
impl RtspClient {
pub async fn connect(url: &str) -> Result<Self, NetError> {
Self::connect_with(url, ClientConfig::default()).await
}
pub async fn connect_with(url: &str, cfg: ClientConfig) -> Result<Self, NetError> {
let parsed = RtspUrl::parse(url)?;
let stream = tokio::time::timeout(cfg.io_timeout, TcpStream::connect(parsed.authority()))
.await
.map_err(|_| NetError::Timeout(format!("connect to {}", parsed.authority())))?
.map_err(NetError::Io)?;
Ok(Self {
stream,
url: parsed,
cfg,
cseq: 0,
session: None,
session_timeout: 60,
challenge: None,
nc: 0,
rx_buf: Vec::with_capacity(8192),
pending_events: VecDeque::new(),
})
}
#[must_use]
pub fn url(&self) -> &RtspUrl {
&self.url
}
#[must_use]
pub fn session(&self) -> Option<&str> {
self.session.as_deref()
}
#[must_use]
pub fn session_timeout(&self) -> u64 {
self.session_timeout
}
pub async fn options(&mut self) -> Result<Vec<Method>, NetError> {
let resp = self
.request(Method::Options, &self.url.request_uri(), None, &[])
.await?;
if !resp.is_success() {
return Err(resp.into_http_error());
}
let methods = resp
.headers
.get("Public")
.map(|s| {
s.split(',')
.filter_map(|m| Method::parse(m.trim()).ok())
.collect()
})
.unwrap_or_default();
Ok(methods)
}
pub async fn describe(&mut self) -> Result<SessionDescription, NetError> {
let resp = self
.request(
Method::Describe,
&self.url.request_uri(),
None,
&[("Accept", "application/sdp")],
)
.await?;
if !resp.is_success() {
return Err(resp.into_http_error());
}
let text = std::str::from_utf8(&resp.body)
.map_err(|e| NetError::Protocol(format!("non-UTF-8 SDP body: {e}")))?;
SessionDescription::parse(text)
}
pub async fn setup(
&mut self,
control_url: &str,
transport: &SetupTransport,
) -> Result<SetupResponse, NetError> {
let target = self.url.resolve_control(control_url);
let resp = self
.request(
Method::Setup,
&target,
None,
&[("Transport", transport.header_value().as_str())],
)
.await?;
if !resp.is_success() {
return Err(resp.into_http_error());
}
let session_header = resp
.headers
.get("Session")
.ok_or_else(|| NetError::Protocol("SETUP response missing Session header".into()))?
.to_string();
let (session_id, timeout) = parse_session_header(&session_header);
self.session = Some(session_id.clone());
self.session_timeout = timeout;
Ok(SetupResponse {
session: session_id,
timeout,
transport: resp.headers.get("Transport").unwrap_or("").to_string(),
})
}
pub async fn play(&mut self) -> Result<(), NetError> {
let target = self.url.request_uri();
let resp = self.request(Method::Play, &target, None, &[]).await?;
if !resp.is_success() {
return Err(resp.into_http_error());
}
Ok(())
}
pub async fn pause(&mut self) -> Result<(), NetError> {
let target = self.url.request_uri();
let resp = self.request(Method::Pause, &target, None, &[]).await?;
if !resp.is_success() {
return Err(resp.into_http_error());
}
Ok(())
}
pub async fn keepalive(&mut self) -> Result<(), NetError> {
let target = self.url.request_uri();
let resp = self
.request(Method::GetParameter, &target, None, &[])
.await?;
if resp.is_success() {
return Ok(());
}
if resp.status == 405 {
let _ = self.options().await?;
return Ok(());
}
Err(resp.into_http_error())
}
pub async fn teardown(&mut self) -> Result<(), NetError> {
if self.session.is_none() {
return Ok(());
}
let target = self.url.request_uri();
let resp = self.request(Method::Teardown, &target, None, &[]).await?;
self.session = None;
if !resp.is_success() {
return Err(resp.into_http_error());
}
Ok(())
}
pub async fn next_event(&mut self) -> Result<ServerEvent, NetError> {
if let Some(ev) = self.pending_events.pop_front() {
return Ok(ev);
}
loop {
if let Some(ev) = self.try_drain_frame()? {
return Ok(ev);
}
self.read_more().await?;
}
}
async fn request(
&mut self,
method: Method,
target: &str,
body: Option<Vec<u8>>,
extra_headers: &[(&str, &str)],
) -> Result<Response, NetError> {
let resp = self
.send_once(method, target, body.clone(), extra_headers)
.await?;
if resp.is_unauthorized() && self.try_pick_up_challenge(&resp).is_some() {
return self.send_once(method, target, body, extra_headers).await;
}
Ok(resp)
}
async fn send_once(
&mut self,
method: Method,
target: &str,
body: Option<Vec<u8>>,
extra_headers: &[(&str, &str)],
) -> Result<Response, NetError> {
self.cseq += 1;
let mut req = Request::new(method, target, self.cseq)
.with_header("User-Agent", self.cfg.user_agent.clone());
for (k, v) in extra_headers {
req = req.with_header(k, *v);
}
if let Some(body) = body {
req = req.with_body(body);
}
if let Some(auth) = self.build_auth_header(method, target) {
req.headers.insert("Authorization", auth);
}
if let Some(s) = &self.session {
if req.headers.get("Session").is_none() {
req.headers.insert("Session", s.clone());
}
}
let wire = req.encode();
self.write_all(&wire).await?;
loop {
if let Some(ev) = self.try_drain_frame()? {
match ev {
ServerEvent::Packet(_) => self.pending_events.push_back(ev),
ServerEvent::Message(resp) => return Ok(resp),
}
continue;
}
self.read_more().await?;
}
}
fn try_pick_up_challenge(&mut self, resp: &Response) -> Option<()> {
let challenge_header = resp.headers.get("WWW-Authenticate")?;
let challenge = Challenge::parse(challenge_header).ok()?;
self.nc = 0;
self.challenge = Some(challenge);
Some(())
}
fn build_auth_header(&mut self, method: Method, target: &str) -> Option<String> {
let challenge = self.challenge.as_ref()?;
let creds = self.cfg.credentials.clone().or_else(|| {
self.url.userinfo.clone().map(|(u, p)| Credentials {
username: u,
password: p,
})
})?;
self.nc += 1;
let cnonce = generate_cnonce(self.cseq);
Some(challenge.build_authorization(&creds, method.as_str(), target, self.nc, &cnonce))
}
async fn read_more(&mut self) -> Result<(), NetError> {
let mut chunk = [0u8; 4096];
let n = tokio::time::timeout(self.cfg.io_timeout, self.stream.read(&mut chunk))
.await
.map_err(|_| NetError::Timeout("RTSP read".into()))?
.map_err(NetError::Io)?;
if n == 0 {
return Err(NetError::Connection("server closed connection".into()));
}
self.rx_buf.extend_from_slice(&chunk[..n]);
Ok(())
}
async fn write_all(&mut self, data: &[u8]) -> Result<(), NetError> {
tokio::time::timeout(self.cfg.io_timeout, self.stream.write_all(data))
.await
.map_err(|_| NetError::Timeout("RTSP write".into()))?
.map_err(NetError::Io)?;
self.stream.flush().await.map_err(NetError::Io)?;
Ok(())
}
fn try_drain_frame(&mut self) -> Result<Option<ServerEvent>, NetError> {
match next_frame(&self.rx_buf) {
FrameStatus::NeedMore => Ok(None),
FrameStatus::Interleaved { consumed, packet } => {
self.rx_buf.drain(..consumed);
Ok(Some(ServerEvent::Packet(packet)))
}
FrameStatus::RtspMessage => match try_parse_response(&self.rx_buf)? {
ParseStatus::NeedMore => Ok(None),
ParseStatus::Parsed { consumed, response } => {
self.rx_buf.drain(..consumed);
Ok(Some(ServerEvent::Message(response)))
}
},
}
}
}
fn parse_session_header(value: &str) -> (String, u64) {
let mut parts = value.split(';');
let id = parts.next().unwrap_or("").trim().to_string();
let mut timeout = 60u64;
for part in parts {
let part = part.trim();
if let Some(rest) = part.strip_prefix("timeout=") {
if let Ok(v) = rest.parse::<u64>() {
timeout = v;
}
}
}
(id, timeout)
}
fn generate_cnonce(seed: u32) -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
format!("{:08x}{:016x}", seed, ts)
}
#[cfg(test)]
mod tests {
use super::super::transport::encode_interleaved;
use super::*;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
async fn spawn_fake_server(script: Vec<Vec<u8>>) -> (u16, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let handle = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 2048];
for chunk in script {
let _ = stream.read(&mut buf).await;
stream.write_all(&chunk).await.unwrap();
}
let _ = stream.shutdown().await;
});
(port, handle)
}
#[test]
fn session_header_parses_timeout() {
assert_eq!(
parse_session_header("12345678;timeout=30"),
("12345678".to_string(), 30)
);
assert_eq!(parse_session_header(" abcd "), ("abcd".to_string(), 60));
}
#[test]
fn setup_transport_header() {
let t = SetupTransport::tcp_interleaved(0);
assert_eq!(t.header_value(), "RTP/AVP/TCP;unicast;interleaved=0-1");
}
#[tokio::test]
async fn options_round_trip_against_fake_server() {
let response = b"RTSP/1.0 200 OK\r\nCSeq: 1\r\nPublic: OPTIONS, DESCRIBE, SETUP, PLAY, TEARDOWN\r\n\r\n".to_vec();
let (port, _handle) = spawn_fake_server(vec![response]).await;
let url = format!("rtsp://127.0.0.1:{port}/test");
let mut c = RtspClient::connect(&url).await.unwrap();
let methods = c.options().await.unwrap();
assert!(methods.contains(&Method::Describe));
assert!(methods.contains(&Method::Setup));
assert!(methods.contains(&Method::Play));
}
#[tokio::test]
async fn describe_parses_returned_sdp() {
let body =
"v=0\r\nm=video 0 RTP/AVP 96\r\na=rtpmap:96 H264/90000\r\na=control:trackID=1\r\n";
let response = format!(
"RTSP/1.0 200 OK\r\nCSeq: 1\r\nContent-Type: application/sdp\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
let (port, _h) = spawn_fake_server(vec![response.into_bytes()]).await;
let url = format!("rtsp://127.0.0.1:{port}/test");
let mut c = RtspClient::connect(&url).await.unwrap();
let sdp = c.describe().await.unwrap();
let v = sdp.video().unwrap();
assert_eq!(v.primary_rtpmap().unwrap().encoding, "H264");
assert_eq!(v.control.as_deref(), Some("trackID=1"));
}
#[tokio::test]
async fn setup_stores_session_and_timeout() {
let response =
b"RTSP/1.0 200 OK\r\nCSeq: 1\r\nSession: ABCD1234;timeout=30\r\nTransport: RTP/AVP/TCP;unicast;interleaved=0-1\r\n\r\n";
let (port, _h) = spawn_fake_server(vec![response.to_vec()]).await;
let url = format!("rtsp://127.0.0.1:{port}/test");
let mut c = RtspClient::connect(&url).await.unwrap();
let r = c
.setup("trackID=1", &SetupTransport::tcp_interleaved(0))
.await
.unwrap();
assert_eq!(r.session, "ABCD1234");
assert_eq!(r.timeout, 30);
assert_eq!(c.session(), Some("ABCD1234"));
assert_eq!(c.session_timeout(), 30);
}
#[tokio::test]
async fn interleaved_packet_delivered_via_next_event() {
let payload = b"FAKE-RTP-PAYLOAD";
let mut script = Vec::new();
script.extend_from_slice(b"RTSP/1.0 200 OK\r\nCSeq: 1\r\nSession: S\r\n\r\n");
script.push(b'$');
script.push(0);
script.extend_from_slice(&(payload.len() as u16).to_be_bytes());
script.extend_from_slice(payload);
let (port, _h) = spawn_fake_server(vec![script]).await;
let url = format!("rtsp://127.0.0.1:{port}/test");
let mut c = RtspClient::connect(&url).await.unwrap();
let _ = c
.setup("trackID=1", &SetupTransport::tcp_interleaved(0))
.await
.unwrap();
let ev = c.next_event().await.unwrap();
match ev {
ServerEvent::Packet(p) => {
assert_eq!(p.channel, 0);
assert_eq!(p.data, payload);
}
_ => panic!("expected packet"),
}
}
#[test]
fn encode_interleaved_helper_visible() {
let _ = encode_interleaved(0, b"x");
}
}