use std::{
fmt::{self, Display, Formatter},
{collections::BTreeMap, convert::TryFrom, time::Duration},
};
use bitflags::bitflags;
use bytes::{Buf, BufMut};
use log::warn;
use crate::{options::SrtVersion, packet::PacketParseError};
#[derive(Clone, Eq, PartialEq)]
pub enum SrtControlPacket {
Reject,
HandshakeRequest(SrtHandshake),
HandshakeResponse(SrtHandshake),
KeyRefreshRequest(KeyingMaterialMessage),
KeyRefreshResponse(KeyingMaterialMessage),
StreamId(String),
Congestion(String),
Filter(FilterSpec),
Group {
ty: GroupType,
flags: GroupFlags,
weight: u16,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FilterSpec(pub BTreeMap<String, String>);
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum GroupType {
Undefined,
Broadcast,
MainBackup,
Balancing,
Multicast,
Unrecognized(u8),
}
bitflags! {
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub struct GroupFlags: u8 {
const MSG_SYNC = 1 << 6;
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct KeyingMaterialMessage {
pub pt: PacketType, pub key_flags: KeyFlags,
pub keki: u32,
pub cipher: CipherType,
pub auth: Auth,
pub salt: Vec<u8>,
pub wrapped_keys: Vec<u8>,
}
impl From<GroupType> for u8 {
fn from(from: GroupType) -> u8 {
match from {
GroupType::Undefined => 0,
GroupType::Broadcast => 1,
GroupType::MainBackup => 2,
GroupType::Balancing => 3,
GroupType::Multicast => 4,
GroupType::Unrecognized(u) => u,
}
}
}
impl From<u8> for GroupType {
fn from(from: u8) -> GroupType {
match from {
0 => GroupType::Undefined,
1 => GroupType::Broadcast,
2 => GroupType::MainBackup,
3 => GroupType::Balancing,
4 => GroupType::Multicast,
u => GroupType::Unrecognized(u),
}
}
}
impl fmt::Debug for KeyingMaterialMessage {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("KeyingMaterialMessage")
.field("pt", &self.pt)
.field("key_flags", &self.key_flags)
.field("keki", &self.keki)
.field("cipher", &self.cipher)
.field("auth", &self.auth)
.finish()
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Auth {
None = 0,
}
impl TryFrom<u8> for Auth {
type Error = PacketParseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Auth::None),
e => Err(PacketParseError::BadAuth(e)),
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum StreamEncapsulation {
Udp = 1,
Srt = 2,
}
impl TryFrom<u8> for StreamEncapsulation {
type Error = PacketParseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
1 => StreamEncapsulation::Udp,
2 => StreamEncapsulation::Srt,
e => return Err(PacketParseError::BadStreamEncapsulation(e)),
})
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum PacketType {
MediaStream = 1, KeyingMaterial = 2, }
bitflags! {
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub struct KeyFlags : u8 {
const EVEN = 0b01;
const ODD = 0b10;
}
}
impl TryFrom<u8> for PacketType {
type Error = PacketParseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1 => Ok(PacketType::MediaStream),
2 => Ok(PacketType::KeyingMaterial),
err => Err(PacketParseError::BadKeyPacketType(err)),
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum CipherType {
None = 0,
Ecb = 1,
Ctr = 2,
Cbc = 3,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct SrtHandshake {
pub version: SrtVersion,
pub flags: SrtShakeFlags,
pub send_latency: Duration,
pub recv_latency: Duration,
}
bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct SrtShakeFlags: u32 {
const TSBPDSND = 0x1;
const TSBPDRCV = 0x2;
const HAICRYPT = 0x4;
const TLPKTDROP = 0x8;
const NAKREPORT = 0x10;
const REXMITFLG = 0x20;
const STREAM = 0x40;
const PACKET_FILTER = 0x80;
const SUPPORTED = Self::TSBPDSND.bits() | Self::TSBPDRCV.bits() | Self::HAICRYPT.bits() | Self::REXMITFLG.bits();
}
}
fn le_bytes_to_string(le_bytes: &mut impl Buf) -> Result<String, PacketParseError> {
if le_bytes.remaining() % 4 != 0 {
return Err(PacketParseError::NotEnoughData);
}
let mut str_bytes = Vec::with_capacity(le_bytes.remaining());
while le_bytes.remaining() > 4 {
str_bytes.extend(le_bytes.get_u32_le().to_be_bytes());
}
match le_bytes.get_u32_le().to_be_bytes() {
[a, 0, 0, 0] => str_bytes.push(a),
[a, b, 0, 0] => str_bytes.extend([a, b]),
[a, b, c, 0] => str_bytes.extend([a, b, c]),
[a, b, c, d] => str_bytes.extend([a, b, c, d]),
}
String::from_utf8(str_bytes).map_err(|e| PacketParseError::StreamTypeNotUtf8(e.utf8_error()))
}
fn string_to_le_bytes(str: &str, into: &mut impl BufMut) {
let mut chunks = str.as_bytes().chunks_exact(4);
while let Some(&[a, b, c, d]) = chunks.next() {
into.put(&[d, c, b, a][..]);
}
match *chunks.remainder() {
[a, b, c] => into.put(&[0, c, b, a][..]),
[a, b] => into.put(&[0, 0, b, a][..]),
[a] => into.put(&[0, 0, 0, a][..]),
[] => {} _ => unreachable!(),
}
}
impl Display for FilterSpec {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
for (i, (k, v)) in self.0.iter().enumerate() {
write!(f, "{k}:{v}")?;
if i != self.0.len() - 1 {
write!(f, ",")?;
}
}
Ok(())
}
}
impl SrtControlPacket {
pub fn parse<T: Buf>(
packet_type: u16,
buf: &mut T,
) -> Result<SrtControlPacket, PacketParseError> {
use self::SrtControlPacket::*;
match packet_type {
0 => Ok(Reject),
1 => Ok(HandshakeRequest(SrtHandshake::parse(buf)?)),
2 => Ok(HandshakeResponse(SrtHandshake::parse(buf)?)),
3 => Ok(KeyRefreshRequest(KeyingMaterialMessage::parse(buf)?)),
4 => Ok(KeyRefreshResponse(KeyingMaterialMessage::parse(buf)?)),
5 => {
le_bytes_to_string(buf).map(StreamId)
}
6 => le_bytes_to_string(buf).map(Congestion),
7 => {
let filter_str = le_bytes_to_string(buf)?;
Ok(Filter(FilterSpec(
filter_str
.split(',')
.map(|kv| {
let mut colon_split_iter = kv.split(':');
let k = colon_split_iter
.next()
.ok_or_else(|| PacketParseError::BadFilter(filter_str.clone()))?;
let v = colon_split_iter
.next()
.ok_or_else(|| PacketParseError::BadFilter(filter_str.clone()))?;
if colon_split_iter.next().is_some() {
return Err(PacketParseError::BadFilter(filter_str.clone()));
}
Ok((k.to_string(), v.to_string()))
})
.collect::<Result<_, _>>()?,
)))
}
8 => {
let ty = buf.get_u8().into();
let flags = GroupFlags::from_bits_truncate(buf.get_u8());
let weight = buf.get_u16_le();
Ok(Group { ty, flags, weight })
}
_ => Err(PacketParseError::UnsupportedSrtExtensionType(packet_type)),
}
}
pub fn type_id(&self) -> u16 {
use self::SrtControlPacket::*;
match self {
Reject => 0,
HandshakeRequest(_) => 1,
HandshakeResponse(_) => 2,
KeyRefreshRequest(_) => 3,
KeyRefreshResponse(_) => 4,
StreamId(_) => 5,
Congestion(_) => 6,
Filter(_) => 7,
Group { .. } => 8,
}
}
pub fn serialize<T: BufMut>(&self, into: &mut T) {
use self::SrtControlPacket::*;
match self {
HandshakeRequest(s) | HandshakeResponse(s) => {
s.serialize(into);
}
KeyRefreshRequest(k) | KeyRefreshResponse(k) => {
k.serialize(into);
}
Filter(filter) => {
string_to_le_bytes(&format!("{filter}"), into);
}
Group { ty, flags, weight } => {
into.put_u8((*ty).into());
into.put_u8(flags.bits());
into.put_u16_le(*weight);
}
Reject => {}
StreamId(str) | Congestion(str) => {
string_to_le_bytes(str, into);
}
}
}
pub fn size_words(&self) -> u16 {
use self::SrtControlPacket::*;
match self {
HandshakeRequest(_) | HandshakeResponse(_) => 3,
KeyRefreshRequest(ref k) | KeyRefreshResponse(ref k) => {
4 + k.salt.len() as u16 / 4 + k.wrapped_keys.len() as u16 / 4
}
Congestion(str) | StreamId(str) => ((str.len() + 3) / 4) as u16, Group { .. } => 1,
Filter(filter) => ((format!("{filter}").len() + 3) / 4) as u16, _ => unimplemented!("{:?}", self),
}
}
}
impl SrtHandshake {
pub fn parse<T: Buf>(buf: &mut T) -> Result<SrtHandshake, PacketParseError> {
if buf.remaining() < 12 {
return Err(PacketParseError::NotEnoughData);
}
let version = SrtVersion::parse(buf.get_u32());
let shake_flags = buf.get_u32();
let flags = match SrtShakeFlags::from_bits(shake_flags) {
Some(i) => i,
None => {
warn!("Unrecognized SRT flags: 0b{:b}", shake_flags);
SrtShakeFlags::from_bits_truncate(shake_flags)
}
};
let peer_latency = buf.get_u16();
let latency = buf.get_u16();
Ok(SrtHandshake {
version,
flags,
send_latency: Duration::from_millis(u64::from(peer_latency)),
recv_latency: Duration::from_millis(u64::from(latency)),
})
}
pub fn serialize<T: BufMut>(&self, into: &mut T) {
into.put_u32(self.version.to_u32());
into.put_u32(self.flags.bits());
into.put_u16(self.send_latency.as_millis() as u16);
into.put_u16(self.recv_latency.as_millis() as u16); }
}
impl KeyingMaterialMessage {
const SIGN: u16 =
((b'H' - b'@') as u16) << 10 | ((b'A' - b'@') as u16) << 5 | (b'I' - b'@') as u16;
pub fn parse(buf: &mut impl Buf) -> Result<KeyingMaterialMessage, PacketParseError> {
if buf.remaining() < 4 * 4 {
return Err(PacketParseError::NotEnoughData);
}
let vers_pt = buf.get_u8();
if (vers_pt & 0b1000_0000) != 0 {
return Err(PacketParseError::BadSrtExtensionMessage);
}
let version = vers_pt >> 4;
if version != 1 {
return Err(PacketParseError::BadSrtExtensionMessage);
}
let pt = PacketType::try_from(vers_pt & 0b0000_1111)?;
let sign = buf.get_u16();
if sign != Self::SIGN {
return Err(PacketParseError::BadKeySign(sign));
}
let key_flags = KeyFlags::from_bits_truncate(buf.get_u8() & 0b0000_0011);
let keki = buf.get_u32();
let cipher = CipherType::try_from(buf.get_u8())?;
let auth = Auth::try_from(buf.get_u8())?;
let se = StreamEncapsulation::try_from(buf.get_u8())?;
if se != StreamEncapsulation::Srt {
return Err(PacketParseError::StreamEncapsulationNotSrt);
}
let _resv1 = buf.get_u8();
let _resv2 = buf.get_u16();
let salt_len = usize::from(buf.get_u8()) * 4;
let key_len = usize::from(buf.get_u8()) * 4;
match key_len {
16 | 24 | 32 => {}
e => return Err(PacketParseError::BadCryptoLength(e as u32)),
}
if buf.remaining() < salt_len + key_len * (key_flags.bits().count_ones() as usize) + 8 {
return Err(PacketParseError::NotEnoughData);
}
let mut salt = vec![];
for _ in 0..salt_len / 4 {
salt.extend_from_slice(&buf.get_u32().to_be_bytes()[..]);
}
let mut wrapped_keys = vec![];
for _ in 0..(key_len * key_flags.bits().count_ones() as usize + 8) / 4 {
wrapped_keys.extend_from_slice(&buf.get_u32().to_be_bytes()[..]);
}
Ok(KeyingMaterialMessage {
pt,
key_flags,
keki,
cipher,
auth,
salt,
wrapped_keys,
})
}
fn serialize<T: BufMut>(&self, into: &mut T) {
into.put_u8(1 << 4 | self.pt as u8);
into.put_u16(Self::SIGN);
into.put_u8(self.key_flags.bits());
into.put_u32(self.keki);
into.put_u8(self.cipher as u8);
into.put_u8(self.auth as u8);
into.put_u8(StreamEncapsulation::Srt as u8);
into.put_u8(0);
into.put_u16(0); into.put_u8((self.salt.len() / 4) as u8);
let key_len = (self.wrapped_keys.len() - 8) / self.key_flags.bits().count_ones() as usize;
into.put_u8((key_len / 4) as u8);
into.put(&self.salt[..]);
for num in self.wrapped_keys[..].chunks(4) {
into.put_u32(u32::from_be_bytes([num[0], num[1], num[2], num[3]]));
}
}
}
impl fmt::Debug for SrtControlPacket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SrtControlPacket::Reject => write!(f, "reject"),
SrtControlPacket::HandshakeRequest(req) => write!(f, "hsreq={req:?}"),
SrtControlPacket::HandshakeResponse(resp) => write!(f, "hsresp={resp:?}"),
SrtControlPacket::KeyRefreshRequest(req) => write!(f, "kmreq={req:?}"),
SrtControlPacket::KeyRefreshResponse(resp) => write!(f, "kmresp={resp:?}"),
SrtControlPacket::StreamId(sid) => write!(f, "streamid={sid}"),
SrtControlPacket::Congestion(ctype) => write!(f, "congestion={ctype}"),
SrtControlPacket::Filter(filter) => write!(f, "filter={filter:?}"),
SrtControlPacket::Group { ty, flags, weight } => {
write!(f, "group=({ty:?}, {flags:?}, {weight:?})")
}
}
}
}
impl TryFrom<u8> for CipherType {
type Error = PacketParseError;
fn try_from(from: u8) -> Result<CipherType, PacketParseError> {
match from {
0 => Ok(CipherType::None),
1 => Ok(CipherType::Ecb),
2 => Ok(CipherType::Ctr),
3 => Ok(CipherType::Cbc),
e => Err(PacketParseError::BadCipherKind(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::{KeyingMaterialMessage, SrtControlPacket, SrtHandshake, SrtShakeFlags};
use crate::{options::*, packet::*};
use std::{io::Cursor, time::Duration};
#[test]
fn deser_ser_shake() {
let handshake = Packet::Control(ControlPacket {
timestamp: TimeStamp::from_micros(123_141),
dest_sockid: SocketId(123),
control_type: ControlTypes::Srt(SrtControlPacket::HandshakeRequest(SrtHandshake {
version: SrtVersion::CURRENT,
flags: SrtShakeFlags::empty(),
send_latency: Duration::from_millis(4000),
recv_latency: Duration::from_millis(3000),
})),
});
let mut buf = Vec::new();
handshake.serialize(&mut buf);
let deserialized = Packet::parse(&mut Cursor::new(buf), false).unwrap();
assert_eq!(handshake, deserialized);
}
#[test]
fn ser_deser_sid() {
let sid = Packet::Control(ControlPacket {
timestamp: TimeStamp::from_micros(123),
dest_sockid: SocketId(1234),
control_type: ControlTypes::Srt(SrtControlPacket::StreamId("Hellohelloheloo".into())),
});
let mut buf = Vec::new();
sid.serialize(&mut buf);
let deser = Packet::parse(&mut Cursor::new(buf), false).unwrap();
assert_eq!(sid, deser);
}
#[test]
fn srt_key_message_debug() {
let salt = b"\x00\x00\x00\x00\x00\x00\x00\x00\x85\x2c\x3c\xcd\x02\x65\x1a\x22";
let wrapped = b"U\x06\xe9\xfd\xdfd\xf1'nr\xf4\xe9f\x81#(\xb7\xb5D\x19{\x9b\xcdx";
let km = KeyingMaterialMessage {
pt: PacketType::KeyingMaterial,
key_flags: KeyFlags::EVEN,
keki: 0,
cipher: CipherType::Ctr,
auth: Auth::None,
salt: salt[..].into(),
wrapped_keys: wrapped[..].into(),
};
assert_eq!(format!("{km:?}"), "KeyingMaterialMessage { pt: KeyingMaterial, key_flags: KeyFlags(EVEN), keki: 0, cipher: Ctr, auth: None }")
}
}