use crate::{Result, StreamError};
use bytes::Bytes;
use std::collections::HashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub const DEFAULT_CHUNK_SIZE: usize = 128;
const EXTENDED_TIMESTAMP: u32 = 0x00FF_FFFF;
const MAX_MESSAGE_LEN: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct RtmpMessage {
pub type_id: u8,
pub timestamp: u32,
pub msg_stream_id: u32,
pub payload: Bytes,
}
#[derive(Default, Clone)]
struct ChunkStream {
timestamp: u32,
delta: u32,
length: usize,
type_id: u8,
msg_stream_id: u32,
extended: bool,
payload: Vec<u8>,
}
pub struct ChunkReader<R> {
inner: R,
chunk_size: usize,
streams: HashMap<u32, ChunkStream>,
last_fmt: u8,
}
impl<R: AsyncRead + Unpin> ChunkReader<R> {
pub fn new(inner: R) -> Self {
Self {
inner,
chunk_size: DEFAULT_CHUNK_SIZE,
streams: HashMap::new(),
last_fmt: 0,
}
}
pub fn set_chunk_size(&mut self, size: usize) {
self.chunk_size = size.clamp(1, 0x7FFF_FFFF);
}
pub async fn read_message(&mut self) -> Result<RtmpMessage> {
loop {
let csid = self.read_basic_header().await?;
let fmt = self.last_fmt;
let mut st = self.streams.remove(&csid).unwrap_or_default();
let new_message = st.payload.is_empty();
if new_message {
self.read_message_header(fmt, &mut st).await?;
} else if fmt != 3 {
return Err(StreamError::protocol("interleaved RTMP message header"));
} else if st.extended {
let _ = self.read_u32().await?;
}
if st.length > MAX_MESSAGE_LEN {
return Err(StreamError::protocol("RTMP message_length too large"));
}
let want = (st.length - st.payload.len()).min(self.chunk_size);
let mut buf = vec![0u8; want];
self.inner.read_exact(&mut buf).await?;
st.payload.extend_from_slice(&buf);
if st.payload.len() >= st.length {
let payload = Bytes::from(std::mem::take(&mut st.payload));
let msg = RtmpMessage {
type_id: st.type_id,
timestamp: st.timestamp,
msg_stream_id: st.msg_stream_id,
payload,
};
self.streams.insert(csid, st); return Ok(msg);
}
self.streams.insert(csid, st);
}
}
async fn read_basic_header(&mut self) -> Result<u32> {
let b0 = self.read_u8().await?;
self.last_fmt = b0 >> 6;
let csid = match b0 & 0x3F {
0 => 64 + self.read_u8().await? as u32,
1 => {
let lo = self.read_u8().await? as u32;
let hi = self.read_u8().await? as u32;
64 + lo + hi * 256
}
n => n as u32,
};
Ok(csid)
}
async fn read_message_header(&mut self, fmt: u8, st: &mut ChunkStream) -> Result<()> {
match fmt {
0 => {
let ts = self.read_u24().await?;
st.length = self.read_u24().await? as usize;
st.type_id = self.read_u8().await?;
st.msg_stream_id = self.read_u32_le().await?;
st.timestamp = self.resolve_timestamp(ts, st).await?;
st.delta = 0;
}
1 => {
let delta = self.read_u24().await?;
st.length = self.read_u24().await? as usize;
st.type_id = self.read_u8().await?;
let d = self.resolve_timestamp(delta, st).await?;
st.delta = d;
st.timestamp = st.timestamp.wrapping_add(d);
}
2 => {
let delta = self.read_u24().await?;
let d = self.resolve_timestamp(delta, st).await?;
st.delta = d;
st.timestamp = st.timestamp.wrapping_add(d);
}
3 => {
if st.extended {
let _ = self.read_u32().await?;
}
st.timestamp = st.timestamp.wrapping_add(st.delta);
}
_ => unreachable!("fmt is 2 bits"),
}
Ok(())
}
async fn resolve_timestamp(&mut self, field: u32, st: &mut ChunkStream) -> Result<u32> {
if field == EXTENDED_TIMESTAMP {
st.extended = true;
self.read_u32().await
} else {
st.extended = false;
Ok(field)
}
}
async fn read_u8(&mut self) -> Result<u8> {
Ok(self.inner.read_u8().await?)
}
async fn read_u24(&mut self) -> Result<u32> {
let mut b = [0u8; 3];
self.inner.read_exact(&mut b).await?;
Ok((b[0] as u32) << 16 | (b[1] as u32) << 8 | b[2] as u32)
}
async fn read_u32(&mut self) -> Result<u32> {
Ok(self.inner.read_u32().await?)
}
async fn read_u32_le(&mut self) -> Result<u32> {
Ok(self.inner.read_u32_le().await?)
}
}
pub struct ChunkWriter<W> {
inner: W,
chunk_size: usize,
}
impl<W: AsyncWrite + Unpin> ChunkWriter<W> {
pub fn new(inner: W) -> Self {
Self {
inner,
chunk_size: DEFAULT_CHUNK_SIZE,
}
}
pub fn set_chunk_size(&mut self, size: usize) {
self.chunk_size = size.clamp(1, 0x7FFF_FFFF);
}
pub async fn write_message(
&mut self,
csid: u8,
type_id: u8,
timestamp: u32,
msg_stream_id: u32,
payload: &[u8],
) -> Result<()> {
let bytes = self.encode_message(csid, type_id, timestamp, msg_stream_id, payload);
self.inner.write_all(&bytes).await?;
self.inner.flush().await?;
Ok(())
}
fn encode_message(
&self,
csid: u8,
type_id: u8,
timestamp: u32,
msg_stream_id: u32,
payload: &[u8],
) -> Vec<u8> {
let mut out = Vec::with_capacity(payload.len() + 16);
let extended = timestamp >= EXTENDED_TIMESTAMP;
out.push(csid & 0x3F);
let ts_field = if extended {
EXTENDED_TIMESTAMP
} else {
timestamp
};
out.extend_from_slice(&ts_field.to_be_bytes()[1..]); out.extend_from_slice(&(payload.len() as u32).to_be_bytes()[1..]); out.push(type_id);
out.extend_from_slice(&msg_stream_id.to_le_bytes()); if extended {
out.extend_from_slice(×tamp.to_be_bytes());
}
let mut pos = 0;
let mut first = true;
while pos < payload.len() || (first && payload.is_empty()) {
if !first {
out.push(0xC0 | (csid & 0x3F)); if extended {
out.extend_from_slice(×tamp.to_be_bytes());
}
}
let take = (payload.len() - pos).min(self.chunk_size);
out.extend_from_slice(&payload[pos..pos + take]);
pos += take;
first = false;
if payload.is_empty() {
break;
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BufMut;
async fn roundtrip(payload: &[u8], chunk_size: usize) -> RtmpMessage {
let (mut a, b) = tokio::io::duplex(1 << 20);
let mut writer = ChunkWriter::new(&mut a);
writer.set_chunk_size(chunk_size);
writer.write_message(4, 9, 1234, 1, payload).await.unwrap();
drop(a);
let mut reader = ChunkReader::new(b);
reader.set_chunk_size(chunk_size);
reader.read_message().await.unwrap()
}
#[tokio::test]
async fn single_chunk_message_roundtrips() {
let msg = roundtrip(b"hello rtmp", 128).await;
assert_eq!(msg.type_id, 9);
assert_eq!(msg.timestamp, 1234);
assert_eq!(msg.msg_stream_id, 1);
assert_eq!(&msg.payload[..], b"hello rtmp");
}
#[tokio::test]
async fn multi_chunk_message_reassembles() {
let mut payload = Vec::new();
for i in 0..1000u32 {
payload.put_u32(i);
}
let msg = roundtrip(&payload, 128).await;
assert_eq!(msg.payload.len(), payload.len());
assert_eq!(&msg.payload[..], &payload[..]);
}
#[tokio::test]
async fn extended_timestamp_roundtrips() {
let (mut a, b) = tokio::io::duplex(4096);
let mut writer = ChunkWriter::new(&mut a);
writer
.write_message(4, 8, 0x0100_0000, 1, b"x")
.await
.unwrap();
drop(a);
let mut reader = ChunkReader::new(b);
let msg = reader.read_message().await.unwrap();
assert_eq!(msg.timestamp, 0x0100_0000);
assert_eq!(&msg.payload[..], b"x");
}
}