#![allow(dead_code)]
use super::packet::ControlType;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SrtEncryption {
#[default]
None,
Aes128,
Aes256,
}
impl SrtEncryption {
#[must_use]
pub const fn key_len_bytes(self) -> usize {
match self {
Self::None => 0,
Self::Aes128 => 16,
Self::Aes256 => 32,
}
}
#[must_use]
pub const fn is_encrypted(self) -> bool {
!matches!(self, Self::None)
}
#[must_use]
pub const fn handshake_field(self) -> u16 {
match self {
Self::None => 0,
Self::Aes128 => 2,
Self::Aes256 => 3,
}
}
#[must_use]
pub const fn from_handshake_field(v: u16) -> Self {
match v {
2 => Self::Aes128,
3 => Self::Aes256,
_ => Self::None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SrtPacketType {
Data,
Control,
Keepalive,
Ack,
Nak,
Shutdown,
Handshake,
AckAck,
DropReq,
}
impl SrtPacketType {
#[must_use]
pub fn from_control_type(ct: &ControlType) -> Self {
match ct {
ControlType::Keepalive => Self::Keepalive,
ControlType::Ack => Self::Ack,
ControlType::Nak => Self::Nak,
ControlType::Shutdown => Self::Shutdown,
ControlType::Handshake => Self::Handshake,
ControlType::AckAck => Self::AckAck,
ControlType::DropReq => Self::DropReq,
_ => Self::Control,
}
}
#[must_use]
pub const fn is_data(self) -> bool {
matches!(self, Self::Data)
}
#[must_use]
pub const fn is_control(self) -> bool {
!matches!(self, Self::Data)
}
#[must_use]
pub const fn is_handshake(self) -> bool {
matches!(self, Self::Handshake)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SrtHandshake {
pub magic: u32,
pub version: u32,
pub extension_field: u16,
pub initial_seq: u32,
pub mtu: u32,
pub flow_window: u32,
pub handshake_type: i32,
pub socket_id: u32,
pub syn_cookie: u32,
pub peer_ip_v4: u32,
}
impl SrtHandshake {
pub const SRT_MAGIC: u32 = 0x4A17;
pub const VERSION_1_3_0: u32 = 0x0001_0300;
pub const TYPE_WAVEAHAND: i32 = 0;
pub const TYPE_INDUCTION: i32 = 1;
pub const TYPE_CONCLUSION: i32 = -1;
pub const TYPE_AGREEMENT: i32 = -2;
pub const WIRE_SIZE: usize = 48;
#[must_use]
pub fn new() -> Self {
Self {
magic: Self::SRT_MAGIC,
version: Self::VERSION_1_3_0,
extension_field: 0,
initial_seq: 0,
mtu: 1500,
flow_window: 8192,
handshake_type: Self::TYPE_WAVEAHAND,
socket_id: 0,
syn_cookie: 0,
peer_ip_v4: 0,
}
}
#[must_use]
pub const fn with_socket_id(mut self, id: u32) -> Self {
self.socket_id = id;
self
}
#[must_use]
pub const fn with_syn_cookie(mut self, cookie: u32) -> Self {
self.syn_cookie = cookie;
self
}
#[must_use]
pub const fn with_handshake_type(mut self, ht: i32) -> Self {
self.handshake_type = ht;
self
}
#[must_use]
pub fn with_latency_extension(mut self, latency_ms: u32) -> Self {
let clamped = latency_ms.min(u32::from(u16::MAX)) as u16;
self.extension_field = (self.extension_field & 0x00FF) | (clamped << 8);
self
}
#[must_use]
pub const fn is_valid(&self) -> bool {
self.magic == Self::SRT_MAGIC
}
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(Self::WIRE_SIZE);
buf.extend_from_slice(&self.magic.to_be_bytes());
let version_lo = (self.version & 0xFFFF) as u16;
buf.extend_from_slice(&version_lo.to_be_bytes());
buf.extend_from_slice(&self.extension_field.to_be_bytes());
buf.extend_from_slice(&self.initial_seq.to_be_bytes());
buf.extend_from_slice(&self.mtu.to_be_bytes());
buf.extend_from_slice(&self.flow_window.to_be_bytes());
buf.extend_from_slice(&self.handshake_type.to_be_bytes());
buf.extend_from_slice(&self.socket_id.to_be_bytes());
buf.extend_from_slice(&self.syn_cookie.to_be_bytes());
buf.extend_from_slice(&self.peer_ip_v4.to_be_bytes());
while buf.len() < Self::WIRE_SIZE {
buf.push(0u8);
}
buf
}
pub fn decode(data: &[u8]) -> Result<Self, String> {
if data.len() < Self::WIRE_SIZE {
return Err(format!(
"SrtHandshake::decode: buffer too short ({} < {})",
data.len(),
Self::WIRE_SIZE
));
}
let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
if magic != Self::SRT_MAGIC {
return Err(format!(
"SrtHandshake::decode: invalid magic 0x{magic:04X} (expected 0x{:04X})",
Self::SRT_MAGIC
));
}
let version_lo = u16::from_be_bytes([data[4], data[5]]);
let extension_field = u16::from_be_bytes([data[6], data[7]]);
let initial_seq = u32::from_be_bytes([data[8], data[9], data[10], data[11]]);
let mtu = u32::from_be_bytes([data[12], data[13], data[14], data[15]]);
let flow_window = u32::from_be_bytes([data[16], data[17], data[18], data[19]]);
let handshake_type = i32::from_be_bytes([data[20], data[21], data[22], data[23]]);
let socket_id = u32::from_be_bytes([data[24], data[25], data[26], data[27]]);
let syn_cookie = u32::from_be_bytes([data[28], data[29], data[30], data[31]]);
let peer_ip_v4 = u32::from_be_bytes([data[32], data[33], data[34], data[35]]);
Ok(Self {
magic,
version: u32::from(version_lo),
extension_field,
initial_seq,
mtu,
flow_window,
handshake_type,
socket_id,
syn_cookie,
peer_ip_v4,
})
}
}
impl Default for SrtHandshake {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SrtStreamConfig {
pub latency_ms: u32,
pub max_bandwidth: u64,
pub encryption: SrtEncryption,
pub passphrase: Option<String>,
pub stream_id: Option<String>,
pub max_payload_size: u16,
}
impl Default for SrtStreamConfig {
fn default() -> Self {
Self {
latency_ms: 120,
max_bandwidth: 0,
encryption: SrtEncryption::None,
passphrase: None,
stream_id: None,
max_payload_size: 1316,
}
}
}
impl SrtStreamConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_latency(mut self, ms: u32) -> Self {
self.latency_ms = ms;
self
}
#[must_use]
pub const fn with_max_bandwidth(mut self, bps: u64) -> Self {
self.max_bandwidth = bps;
self
}
#[must_use]
pub const fn with_encryption(mut self, enc: SrtEncryption) -> Self {
self.encryption = enc;
self
}
#[must_use]
pub fn with_passphrase(mut self, passphrase: &str) -> Self {
self.passphrase = Some(passphrase.to_owned());
self
}
#[must_use]
pub fn with_stream_id(mut self, id: &str) -> Self {
self.stream_id = Some(id.to_owned());
self
}
#[must_use]
pub const fn with_max_payload_size(mut self, size: u16) -> Self {
self.max_payload_size = size;
self
}
#[must_use]
pub const fn is_encrypted(&self) -> bool {
self.encryption.is_encrypted()
}
#[must_use]
pub fn has_stream_id(&self) -> bool {
self.stream_id.is_some()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SrtExtensionBlock {
pub ext_type: u16,
pub ext_len: u16,
pub data: Vec<u8>,
}
impl SrtExtensionBlock {
pub const EXT_HSREQ: u16 = 1;
pub const EXT_HSRSP: u16 = 2;
pub const EXT_KMREQ: u16 = 3;
pub const EXT_KMRSP: u16 = 4;
pub const EXT_SID: u16 = 5;
#[must_use]
pub fn new(ext_type: u16, data: Vec<u8>) -> Self {
let word_count = (data.len() + 3) / 4;
let ext_len = word_count as u16;
Self {
ext_type,
ext_len,
data,
}
}
#[must_use]
pub fn hsreq(srt_version: u32, srt_flags: u32, latency_ms: u16) -> Self {
let mut data = Vec::with_capacity(12);
data.extend_from_slice(&srt_version.to_be_bytes());
data.extend_from_slice(&srt_flags.to_be_bytes());
let latency_field = (u32::from(latency_ms) << 16) | u32::from(latency_ms);
data.extend_from_slice(&latency_field.to_be_bytes());
Self::new(Self::EXT_HSREQ, data)
}
#[must_use]
pub fn stream_id(sid: &str) -> Self {
Self::new(Self::EXT_SID, sid.as_bytes().to_vec())
}
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let padded_len = usize::from(self.ext_len) * 4;
let mut buf = Vec::with_capacity(4 + padded_len);
buf.extend_from_slice(&self.ext_type.to_be_bytes());
buf.extend_from_slice(&self.ext_len.to_be_bytes());
buf.extend_from_slice(&self.data);
while buf.len() < 4 + padded_len {
buf.push(0u8);
}
buf
}
pub fn decode(data: &[u8]) -> Result<Self, String> {
if data.len() < 4 {
return Err(format!(
"SrtExtensionBlock::decode: buffer too short ({} < 4)",
data.len()
));
}
let ext_type = u16::from_be_bytes([data[0], data[1]]);
let ext_len = u16::from_be_bytes([data[2], data[3]]);
let payload_len = usize::from(ext_len) * 4;
if data.len() < 4 + payload_len {
return Err(format!(
"SrtExtensionBlock::decode: payload truncated (need {} got {})",
4 + payload_len,
data.len()
));
}
let payload = data[4..4 + payload_len].to_vec();
Ok(Self {
ext_type,
ext_len,
data: payload,
})
}
#[must_use]
pub fn wire_size(&self) -> usize {
4 + usize::from(self.ext_len) * 4
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encryption_key_len() {
assert_eq!(SrtEncryption::None.key_len_bytes(), 0);
assert_eq!(SrtEncryption::Aes128.key_len_bytes(), 16);
assert_eq!(SrtEncryption::Aes256.key_len_bytes(), 32);
}
#[test]
fn test_encryption_is_encrypted() {
assert!(!SrtEncryption::None.is_encrypted());
assert!(SrtEncryption::Aes128.is_encrypted());
assert!(SrtEncryption::Aes256.is_encrypted());
}
#[test]
fn test_encryption_handshake_field_roundtrip() {
for enc in [
SrtEncryption::None,
SrtEncryption::Aes128,
SrtEncryption::Aes256,
] {
let field = enc.handshake_field();
let decoded = SrtEncryption::from_handshake_field(field);
assert_eq!(decoded, enc);
}
}
#[test]
fn test_encryption_default_is_none() {
assert_eq!(SrtEncryption::default(), SrtEncryption::None);
}
#[test]
fn test_packet_type_data_flags() {
assert!(SrtPacketType::Data.is_data());
assert!(!SrtPacketType::Data.is_control());
assert!(!SrtPacketType::Data.is_handshake());
}
#[test]
fn test_packet_type_control_flags() {
for pt in [
SrtPacketType::Control,
SrtPacketType::Keepalive,
SrtPacketType::Ack,
SrtPacketType::Nak,
SrtPacketType::Shutdown,
SrtPacketType::Handshake,
SrtPacketType::AckAck,
SrtPacketType::DropReq,
] {
assert!(!pt.is_data(), "{pt:?} should not be data");
assert!(pt.is_control(), "{pt:?} should be control");
}
}
#[test]
fn test_packet_type_handshake_flag() {
assert!(SrtPacketType::Handshake.is_handshake());
assert!(!SrtPacketType::Keepalive.is_handshake());
}
#[test]
fn test_packet_type_from_control_type() {
assert_eq!(
SrtPacketType::from_control_type(&ControlType::Handshake),
SrtPacketType::Handshake
);
assert_eq!(
SrtPacketType::from_control_type(&ControlType::Keepalive),
SrtPacketType::Keepalive
);
assert_eq!(
SrtPacketType::from_control_type(&ControlType::Ack),
SrtPacketType::Ack
);
assert_eq!(
SrtPacketType::from_control_type(&ControlType::Nak),
SrtPacketType::Nak
);
assert_eq!(
SrtPacketType::from_control_type(&ControlType::Shutdown),
SrtPacketType::Shutdown
);
assert_eq!(
SrtPacketType::from_control_type(&ControlType::AckAck),
SrtPacketType::AckAck
);
assert_eq!(
SrtPacketType::from_control_type(&ControlType::DropReq),
SrtPacketType::DropReq
);
assert_eq!(
SrtPacketType::from_control_type(&ControlType::UserDefined),
SrtPacketType::Control
);
}
#[test]
fn test_handshake_new_defaults() {
let hs = SrtHandshake::new();
assert_eq!(hs.magic, SrtHandshake::SRT_MAGIC);
assert_eq!(hs.version, SrtHandshake::VERSION_1_3_0);
assert_eq!(hs.mtu, 1500);
assert_eq!(hs.flow_window, 8192);
assert!(hs.is_valid());
}
#[test]
fn test_handshake_encode_decode_roundtrip() {
let hs = SrtHandshake::new()
.with_socket_id(42)
.with_syn_cookie(0xDEAD_BEEF)
.with_handshake_type(SrtHandshake::TYPE_INDUCTION);
let encoded = hs.encode();
assert_eq!(encoded.len(), SrtHandshake::WIRE_SIZE);
let decoded = SrtHandshake::decode(&encoded).expect("decode should succeed");
assert_eq!(decoded.magic, SrtHandshake::SRT_MAGIC);
assert_eq!(decoded.socket_id, 42);
assert_eq!(decoded.syn_cookie, 0xDEAD_BEEF);
assert_eq!(decoded.handshake_type, SrtHandshake::TYPE_INDUCTION);
assert!(decoded.is_valid());
}
#[test]
fn test_handshake_decode_too_short() {
let err = SrtHandshake::decode(&[0u8; 10]);
assert!(err.is_err());
}
#[test]
fn test_handshake_decode_wrong_magic() {
let mut buf = SrtHandshake::new().encode();
buf[0] = 0xFF;
buf[1] = 0xFF;
buf[2] = 0xFF;
buf[3] = 0xFF;
let err = SrtHandshake::decode(&buf);
assert!(err.is_err());
}
#[test]
fn test_handshake_with_latency_extension() {
let hs = SrtHandshake::new().with_latency_extension(200);
let latency = (hs.extension_field >> 8) as u32;
assert_eq!(latency, 200);
}
#[test]
fn test_handshake_is_valid_wrong_magic() {
let hs = SrtHandshake {
magic: 0x0000,
..SrtHandshake::new()
};
assert!(!hs.is_valid());
}
#[test]
fn test_stream_config_defaults() {
let cfg = SrtStreamConfig::new();
assert_eq!(cfg.latency_ms, 120);
assert_eq!(cfg.max_bandwidth, 0);
assert_eq!(cfg.encryption, SrtEncryption::None);
assert!(!cfg.is_encrypted());
assert!(!cfg.has_stream_id());
assert_eq!(cfg.max_payload_size, 1316);
}
#[test]
fn test_stream_config_builder_chain() {
let cfg = SrtStreamConfig::new()
.with_latency(500)
.with_max_bandwidth(10_000_000)
.with_encryption(SrtEncryption::Aes256)
.with_passphrase("s3cr3t")
.with_stream_id("live/stream1")
.with_max_payload_size(188);
assert_eq!(cfg.latency_ms, 500);
assert_eq!(cfg.max_bandwidth, 10_000_000);
assert_eq!(cfg.encryption, SrtEncryption::Aes256);
assert!(cfg.is_encrypted());
assert_eq!(cfg.passphrase.as_deref(), Some("s3cr3t"));
assert_eq!(cfg.stream_id.as_deref(), Some("live/stream1"));
assert!(cfg.has_stream_id());
assert_eq!(cfg.max_payload_size, 188);
}
#[test]
fn test_extension_block_new_computes_ext_len() {
let block = SrtExtensionBlock::new(SrtExtensionBlock::EXT_HSREQ, vec![0u8; 12]);
assert_eq!(block.ext_len, 3); assert_eq!(block.wire_size(), 16); }
#[test]
fn test_extension_block_encode_decode_roundtrip() {
let original = SrtExtensionBlock::new(SrtExtensionBlock::EXT_SID, b"live/test".to_vec());
let encoded = original.encode();
let decoded = SrtExtensionBlock::decode(&encoded).expect("decode should succeed");
assert_eq!(decoded.ext_type, SrtExtensionBlock::EXT_SID);
assert_eq!(decoded.ext_len, original.ext_len);
}
#[test]
fn test_extension_block_hsreq_factory() {
let block = SrtExtensionBlock::hsreq(SrtHandshake::VERSION_1_3_0, 0x00BF, 120);
assert_eq!(block.ext_type, SrtExtensionBlock::EXT_HSREQ);
assert_eq!(block.ext_len, 3); }
#[test]
fn test_extension_block_stream_id_factory() {
let block = SrtExtensionBlock::stream_id("media/channel1");
assert_eq!(block.ext_type, SrtExtensionBlock::EXT_SID);
assert!(!block.data.is_empty());
}
#[test]
fn test_extension_block_decode_too_short() {
let err = SrtExtensionBlock::decode(&[0u8; 2]);
assert!(err.is_err());
}
#[test]
fn test_extension_block_decode_truncated_payload() {
let buf = [
0x00, 0x05, 0x00, 0x04, 0x61, 0x62, 0x63, 0x64, ];
let err = SrtExtensionBlock::decode(&buf);
assert!(err.is_err());
}
#[test]
fn test_extension_type_constants() {
assert_eq!(SrtExtensionBlock::EXT_HSREQ, 1);
assert_eq!(SrtExtensionBlock::EXT_HSRSP, 2);
assert_eq!(SrtExtensionBlock::EXT_KMREQ, 3);
assert_eq!(SrtExtensionBlock::EXT_KMRSP, 4);
assert_eq!(SrtExtensionBlock::EXT_SID, 5);
}
}