#![allow(dead_code)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::similar_names)]
#![allow(clippy::unreadable_literal)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::match_same_arms)]
#![allow(clippy::many_single_char_names)]
#![allow(clippy::unnecessary_wraps)]
#![allow(clippy::range_plus_one)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::manual_div_ceil)]
#![allow(clippy::comparison_chain)]
#![allow(clippy::unused_self)]
#![allow(clippy::trivially_copy_pass_by_ref)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::struct_excessive_bools)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::redundant_closure_for_method_calls)]
#![allow(clippy::must_use_candidate)]
#![allow(clippy::should_implement_trait)]
#![allow(clippy::items_after_statements)]
#![allow(clippy::if_not_else)]
#![allow(clippy::format_push_string)]
#![allow(clippy::single_match_else)]
#![allow(clippy::redundant_slicing)]
#![allow(clippy::uninlined_format_args)]
#![allow(clippy::map_unwrap_or)]
#![allow(clippy::derivable_impls)]
#![allow(clippy::assigning_clones)]
#![allow(clippy::if_same_then_else)]
#![allow(clippy::format_collect)]
#![allow(clippy::useless_conversion)]
#![allow(clippy::unused_async)]
#![allow(clippy::identity_op)]
use crate::error::{NetError, NetResult};
use bytes::{Buf, BufMut, Bytes, BytesMut};
pub const RTMP_VERSION: u8 = 3;
pub const HANDSHAKE_SIZE: usize = 1536;
pub const C0_SIZE: usize = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeState {
Uninitialized,
VersionSent,
AckSent,
Done,
}
impl HandshakeState {
#[must_use]
pub const fn is_done(&self) -> bool {
matches!(self, Self::Done)
}
}
#[derive(Debug)]
pub struct Handshake {
state: HandshakeState,
client_timestamp: u32,
server_timestamp: u32,
client_random: [u8; HANDSHAKE_SIZE - 8],
server_random: [u8; HANDSHAKE_SIZE - 8],
epoch: u32,
}
impl Handshake {
#[must_use]
pub fn new() -> Self {
Self {
state: HandshakeState::Uninitialized,
client_timestamp: 0,
server_timestamp: 0,
client_random: [0u8; HANDSHAKE_SIZE - 8],
server_random: [0u8; HANDSHAKE_SIZE - 8],
epoch: 0,
}
}
#[must_use]
pub const fn state(&self) -> HandshakeState {
self.state
}
#[must_use]
pub const fn is_done(&self) -> bool {
self.state.is_done()
}
pub fn set_epoch(&mut self, epoch: u32) {
self.epoch = epoch;
}
#[must_use]
pub fn generate_c0c1(&mut self) -> Bytes {
let mut buf = BytesMut::with_capacity(C0_SIZE + HANDSHAKE_SIZE);
buf.put_u8(RTMP_VERSION);
self.client_timestamp = self.epoch;
buf.put_u32(self.client_timestamp);
buf.put_u32(0);
fill_random_buffer(&mut self.client_random, self.epoch);
buf.put_slice(&self.client_random);
self.state = HandshakeState::VersionSent;
buf.freeze()
}
pub fn parse_s0s1(&mut self, data: &[u8]) -> NetResult<()> {
if data.len() < C0_SIZE + HANDSHAKE_SIZE {
return Err(NetError::handshake(format!(
"S0+S1 too short: {} bytes",
data.len()
)));
}
let mut buf = &data[..];
let version = buf.get_u8();
if version != RTMP_VERSION {
return Err(NetError::handshake(format!(
"Unsupported RTMP version: {version}"
)));
}
self.server_timestamp = buf.get_u32();
let _ = buf.get_u32();
let random_len = HANDSHAKE_SIZE - 8;
if buf.len() >= random_len {
self.server_random.copy_from_slice(&buf[..random_len]);
}
Ok(())
}
#[must_use]
pub fn generate_c2(&mut self) -> Bytes {
let mut buf = BytesMut::with_capacity(HANDSHAKE_SIZE);
buf.put_u32(self.server_timestamp);
buf.put_u32(self.epoch);
buf.put_slice(&self.server_random);
self.state = HandshakeState::AckSent;
buf.freeze()
}
pub fn parse_s2(&mut self, data: &[u8]) -> NetResult<()> {
if data.len() < HANDSHAKE_SIZE {
return Err(NetError::handshake(format!(
"S2 too short: {} bytes",
data.len()
)));
}
let mut buf = &data[..];
let timestamp = buf.get_u32();
if timestamp != self.client_timestamp {
}
let _ = buf.get_u32();
let random_len = self.client_random.len();
if buf.len() >= random_len && buf[..random_len] != self.client_random {
}
self.state = HandshakeState::Done;
Ok(())
}
pub fn parse_c0c1(&mut self, data: &[u8]) -> NetResult<()> {
if data.len() < C0_SIZE + HANDSHAKE_SIZE {
return Err(NetError::handshake(format!(
"C0+C1 too short: {} bytes",
data.len()
)));
}
let mut buf = &data[..];
let version = buf.get_u8();
if version != RTMP_VERSION {
return Err(NetError::handshake(format!(
"Unsupported RTMP version: {version}"
)));
}
self.client_timestamp = buf.get_u32();
let _ = buf.get_u32();
let random_len = HANDSHAKE_SIZE - 8;
if buf.len() >= random_len {
self.client_random.copy_from_slice(&buf[..random_len]);
}
Ok(())
}
#[must_use]
pub fn generate_s0s1s2(&mut self) -> Bytes {
let mut buf = BytesMut::with_capacity(C0_SIZE + HANDSHAKE_SIZE * 2);
buf.put_u8(RTMP_VERSION);
self.server_timestamp = self.epoch;
buf.put_u32(self.server_timestamp);
buf.put_u32(0);
fill_random_buffer(&mut self.server_random, self.epoch);
buf.put_slice(&self.server_random);
buf.put_u32(self.client_timestamp);
buf.put_u32(self.epoch);
buf.put_slice(&self.client_random);
buf.freeze()
}
pub fn parse_c2(&mut self, data: &[u8]) -> NetResult<()> {
if data.len() < HANDSHAKE_SIZE {
return Err(NetError::handshake(format!(
"C2 too short: {} bytes",
data.len()
)));
}
let mut buf = &data[..];
let timestamp = buf.get_u32();
if timestamp != self.server_timestamp {
}
let _ = buf.get_u32();
let random_len = self.server_random.len();
if buf.len() >= random_len && buf[..random_len] != self.server_random {
}
self.state = HandshakeState::Done;
Ok(())
}
#[must_use]
pub fn create_random_data(seed: u32) -> [u8; HANDSHAKE_SIZE - 8] {
let mut data = [0u8; HANDSHAKE_SIZE - 8];
let mut s = seed as u64;
for byte in &mut data {
s = s.wrapping_mul(1103515245).wrapping_add(12345);
*byte = (s >> 16) as u8;
}
data
}
}
impl Default for Handshake {
fn default() -> Self {
Self::new()
}
}
fn fill_random_buffer(buf: &mut [u8], seed: u32) {
let mut s = u64::from(seed);
for byte in buf.iter_mut() {
s = s.wrapping_mul(1103515245).wrapping_add(12345);
*byte = (s >> 16) as u8;
}
}
#[must_use]
#[allow(dead_code)]
pub fn validate_digest(data: &[u8], key: &[u8]) -> bool {
let _ = (data, key);
true
}
#[must_use]
#[allow(dead_code)]
pub fn compute_digest_offset(data: &[u8], scheme: u8) -> usize {
match scheme {
0 => {
let offset = data.get(8..12).map(|b| {
(u32::from(b[0]) + u32::from(b[1]) + u32::from(b[2]) + u32::from(b[3])) % 728 + 12
});
offset.unwrap_or(12) as usize
}
1 => {
let offset = data.get(764..768).map(|b| {
(u32::from(b[0]) + u32::from(b[1]) + u32::from(b[2]) + u32::from(b[3])) % 728 + 776
});
offset.unwrap_or(776) as usize
}
_ => 12,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_state() {
let state = HandshakeState::Uninitialized;
assert!(!state.is_done());
let state = HandshakeState::Done;
assert!(state.is_done());
}
#[test]
fn test_handshake_new() {
let hs = Handshake::new();
assert_eq!(hs.state(), HandshakeState::Uninitialized);
assert!(!hs.is_done());
}
#[test]
fn test_generate_c0c1() {
let mut hs = Handshake::new();
hs.set_epoch(1000);
let data = hs.generate_c0c1();
assert_eq!(data.len(), C0_SIZE + HANDSHAKE_SIZE);
assert_eq!(data[0], RTMP_VERSION);
assert_eq!(hs.state(), HandshakeState::VersionSent);
}
#[test]
fn test_client_server_handshake() {
let mut client = Handshake::new();
let mut server = Handshake::new();
client.set_epoch(1000);
server.set_epoch(2000);
let c0c1 = client.generate_c0c1();
assert_eq!(client.state(), HandshakeState::VersionSent);
server.parse_c0c1(&c0c1).expect("should succeed in test");
let s0s1s2 = server.generate_s0s1s2();
client.parse_s0s1(&s0s1s2).expect("should succeed in test");
let c2 = client.generate_c2();
assert_eq!(client.state(), HandshakeState::AckSent);
client
.parse_s2(&s0s1s2[C0_SIZE + HANDSHAKE_SIZE..])
.expect("should succeed in test");
assert_eq!(client.state(), HandshakeState::Done);
server.parse_c2(&c2).expect("should succeed in test");
assert_eq!(server.state(), HandshakeState::Done);
}
#[test]
fn test_invalid_version() {
let mut hs = Handshake::new();
let mut data = vec![0u8; C0_SIZE + HANDSHAKE_SIZE];
data[0] = 4;
let result = hs.parse_c0c1(&data);
assert!(result.is_err());
}
#[test]
fn test_short_packet() {
let mut hs = Handshake::new();
let short_data = vec![0u8; 100];
let result = hs.parse_c0c1(&short_data);
assert!(result.is_err());
}
#[test]
fn test_create_random_data() {
let data1 = Handshake::create_random_data(42);
let data2 = Handshake::create_random_data(42);
assert_eq!(data1, data2);
let data3 = Handshake::create_random_data(43);
assert_ne!(data1, data3);
}
}