use crate::{Result, StreamError};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const RTMP_VERSION: u8 = 0x03;
const HANDSHAKE_SIZE: usize = 1536;
pub async fn accept<S>(stream: &mut S) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut c0 = [0u8; 1];
stream.read_exact(&mut c0).await?;
if c0[0] != RTMP_VERSION {
return Err(StreamError::Handshake(format!(
"unsupported RTMP version {}",
c0[0]
)));
}
let mut c1 = [0u8; HANDSHAKE_SIZE];
stream.read_exact(&mut c1).await?;
let mut out = Vec::with_capacity(1 + HANDSHAKE_SIZE * 2);
out.push(RTMP_VERSION); out.extend_from_slice(&server_s1()); out.extend_from_slice(&c1); stream.write_all(&out).await?;
stream.flush().await?;
let mut c2 = [0u8; HANDSHAKE_SIZE];
stream.read_exact(&mut c2).await?;
Ok(())
}
pub async fn initiate<S>(stream: &mut S) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut out = Vec::with_capacity(1 + HANDSHAKE_SIZE);
out.push(RTMP_VERSION); out.extend_from_slice(&server_s1()); stream.write_all(&out).await?;
stream.flush().await?;
let mut s0 = [0u8; 1];
stream.read_exact(&mut s0).await?;
if s0[0] != RTMP_VERSION {
return Err(StreamError::Handshake(format!(
"unsupported RTMP version {}",
s0[0]
)));
}
let mut s1 = [0u8; HANDSHAKE_SIZE];
stream.read_exact(&mut s1).await?;
let mut s2 = [0u8; HANDSHAKE_SIZE];
stream.read_exact(&mut s2).await?;
stream.write_all(&s1).await?;
stream.flush().await?;
Ok(())
}
fn server_s1() -> [u8; HANDSHAKE_SIZE] {
let mut s1 = [0u8; HANDSHAKE_SIZE];
let mut state: u32 = 0x9E37_79B9;
for byte in s1.iter_mut().skip(8) {
state ^= state << 13;
state ^= state >> 17;
state ^= state << 5;
*byte = (state & 0xFF) as u8;
}
s1
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn completes_simple_handshake_and_echoes_c1() {
let (mut client, mut server) = duplex(8192);
let server_task = tokio::spawn(async move {
accept(&mut server).await.unwrap();
});
let mut c1 = [0u8; HANDSHAKE_SIZE];
for (i, b) in c1.iter_mut().enumerate() {
*b = (i % 251) as u8;
}
client.write_u8(RTMP_VERSION).await.unwrap();
client.write_all(&c1).await.unwrap();
let mut s0 = [0u8; 1];
client.read_exact(&mut s0).await.unwrap();
assert_eq!(s0[0], RTMP_VERSION);
let mut s1 = [0u8; HANDSHAKE_SIZE];
client.read_exact(&mut s1).await.unwrap();
let mut s2 = [0u8; HANDSHAKE_SIZE];
client.read_exact(&mut s2).await.unwrap();
assert_eq!(&s2[..], &c1[..], "S2 must echo C1 in the simple handshake");
client.write_all(&s1).await.unwrap(); server_task.await.unwrap();
}
#[tokio::test]
async fn client_initiate_interops_with_server_accept() {
let (mut client, mut server) = duplex(8192);
let server_task = tokio::spawn(async move { accept(&mut server).await });
initiate(&mut client).await.unwrap();
server_task.await.unwrap().unwrap();
}
#[tokio::test]
async fn rejects_wrong_version() {
let (mut a, mut b) = duplex(4096);
a.write_u8(0xFF).await.unwrap(); a.write_all(&[0u8; HANDSHAKE_SIZE]).await.unwrap();
let err = accept(&mut b).await.unwrap_err();
assert!(matches!(err, StreamError::Handshake(_)));
}
}