use std::fmt;
pub const SRT_MAGIC: u32 = 0x0000_0004;
pub const SRT_VERSION_1_4: u32 = 0x0001_0400;
pub const SRT_VERSION_1_5: u32 = 0x0001_0500;
pub const MAX_STREAM_ID_LEN: usize = 512;
pub const MIN_LATENCY_MS: u32 = 20;
pub const MAX_LATENCY_MS: u32 = 30_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeType {
Induction,
WaveaHand,
Conclusion,
Agreement,
Rejection(RejectionReason),
}
impl HandshakeType {
#[must_use]
pub fn to_wire(self) -> u32 {
match self {
Self::Induction => 1,
Self::WaveaHand => 0,
Self::Conclusion => 0xFFFF_FFFF,
Self::Agreement => 0xFFFF_FFFE,
Self::Rejection(r) => 0x1000_0000 | (r as u32),
}
}
#[must_use]
pub fn from_wire(v: u32) -> Option<Self> {
match v {
1 => Some(Self::Induction),
0 => Some(Self::WaveaHand),
0xFFFF_FFFF => Some(Self::Conclusion),
0xFFFF_FFFE => Some(Self::Agreement),
v if v & 0x1000_0000 != 0 => {
let code = v & !0x1000_0000;
Some(Self::Rejection(RejectionReason::from_code(code)))
}
_ => None,
}
}
}
impl fmt::Display for HandshakeType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Induction => write!(f, "INDUCTION"),
Self::WaveaHand => write!(f, "WAVEAHAND"),
Self::Conclusion => write!(f, "CONCLUSION"),
Self::Agreement => write!(f, "AGREEMENT"),
Self::Rejection(r) => write!(f, "REJECTION({r})"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum RejectionReason {
Unknown = 0,
System = 1,
Peer = 2,
Resource = 3,
Forbidden = 4,
Version = 5,
Passphrase = 6,
MediaType = 7,
BadRequest = 8,
Unauthorized = 9,
Overloaded = 10,
Conflict = 11,
GeoBlocked = 12,
ClosedSession = 13,
Timeout = 14,
ApplicationDefined = 1000,
}
impl RejectionReason {
#[must_use]
pub fn from_code(code: u32) -> Self {
match code {
0 => Self::Unknown,
1 => Self::System,
2 => Self::Peer,
3 => Self::Resource,
4 => Self::Forbidden,
5 => Self::Version,
6 => Self::Passphrase,
7 => Self::MediaType,
8 => Self::BadRequest,
9 => Self::Unauthorized,
10 => Self::Overloaded,
11 => Self::Conflict,
12 => Self::GeoBlocked,
13 => Self::ClosedSession,
14 => Self::Timeout,
v if v >= 1000 => Self::ApplicationDefined,
_ => Self::Unknown,
}
}
}
impl fmt::Display for RejectionReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Unknown => "UNKNOWN",
Self::System => "SYSTEM",
Self::Peer => "PEER",
Self::Resource => "RESOURCE",
Self::Forbidden => "FORBIDDEN",
Self::Version => "VERSION",
Self::Passphrase => "PASSPHRASE",
Self::MediaType => "MEDIA_TYPE",
Self::BadRequest => "BAD_REQUEST",
Self::Unauthorized => "UNAUTHORIZED",
Self::Overloaded => "OVERLOADED",
Self::Conflict => "CONFLICT",
Self::GeoBlocked => "GEO_BLOCKED",
Self::ClosedSession => "CLOSED_SESSION",
Self::Timeout => "TIMEOUT",
Self::ApplicationDefined => "APP_DEFINED",
};
write!(f, "{s}")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct SrtFlags(pub u32);
impl SrtFlags {
pub const TSBPD_SND: u32 = 1 << 0;
pub const TSBPD_RCV: u32 = 1 << 1;
pub const HAICRYPT_OFF: u32 = 1 << 2;
pub const TLPKT_DROP: u32 = 1 << 3;
pub const NAK_REPORT: u32 = 1 << 4;
pub const REXMIT_FLGS: u32 = 1 << 5;
pub const STREAM_ID: u32 = 1 << 6;
#[must_use]
pub const fn new(bits: u32) -> Self {
Self(bits)
}
#[must_use]
pub fn has(self, flag: u32) -> bool {
self.0 & flag != 0
}
pub fn set(&mut self, flag: u32) {
self.0 |= flag;
}
pub fn clear(&mut self, flag: u32) {
self.0 &= !flag;
}
#[must_use]
pub const fn bits(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HsreqBlock {
pub srt_version: u32,
pub srt_flags: SrtFlags,
pub recv_tsbpd_delay_ms: u16,
pub snd_tsbpd_delay_ms: u16,
}
impl HsreqBlock {
#[must_use]
pub fn new(recv_latency_ms: u16, snd_latency_ms: u16) -> Self {
Self {
srt_version: SRT_VERSION_1_4,
srt_flags: SrtFlags::new(
SrtFlags::TSBPD_SND | SrtFlags::TSBPD_RCV | SrtFlags::NAK_REPORT,
),
recv_tsbpd_delay_ms: recv_latency_ms,
snd_tsbpd_delay_ms: snd_latency_ms,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.srt_version < SRT_VERSION_1_4 {
return Err(format!(
"unsupported SRT version 0x{:08X} (minimum 0x{:08X})",
self.srt_version, SRT_VERSION_1_4
));
}
for (name, val) in [
("recv_tsbpd_delay_ms", u32::from(self.recv_tsbpd_delay_ms)),
("snd_tsbpd_delay_ms", u32::from(self.snd_tsbpd_delay_ms)),
] {
if val < MIN_LATENCY_MS || val > MAX_LATENCY_MS {
return Err(format!(
"{name} = {val} ms is out of [{MIN_LATENCY_MS}, {MAX_LATENCY_MS}] range"
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HsrspBlock {
pub srt_version: u32,
pub srt_flags: SrtFlags,
pub recv_tsbpd_delay_ms: u16,
pub snd_tsbpd_delay_ms: u16,
}
impl HsrspBlock {
#[must_use]
pub fn negotiate(
caller_req: &HsreqBlock,
listener_recv_latency_ms: u16,
listener_snd_latency_ms: u16,
listener_version: u32,
) -> Self {
let recv = caller_req.recv_tsbpd_delay_ms.max(listener_recv_latency_ms);
let snd = caller_req.snd_tsbpd_delay_ms.max(listener_snd_latency_ms);
let version = listener_version.min(caller_req.srt_version);
let flags = SrtFlags::new(
caller_req.srt_flags.bits() & {
SrtFlags::TSBPD_SND | SrtFlags::TSBPD_RCV | SrtFlags::NAK_REPORT
},
);
Self {
srt_version: version,
srt_flags: flags,
recv_tsbpd_delay_ms: recv,
snd_tsbpd_delay_ms: snd,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StreamIdBlock {
pub stream_id: String,
}
impl StreamIdBlock {
pub fn new(stream_id: impl Into<String>) -> Result<Self, String> {
let s = stream_id.into();
if s.len() > MAX_STREAM_ID_LEN {
return Err(format!(
"stream ID length {} exceeds maximum {MAX_STREAM_ID_LEN}",
s.len()
));
}
Ok(Self { stream_id: s })
}
}
#[derive(Debug, Clone)]
pub struct HandshakePacket {
pub udt_version: u32,
pub encryption_field: u16,
pub extension_field: u16,
pub initial_packet_seq_no: u32,
pub mss: u32,
pub max_flow_window_size: u32,
pub handshake_type: HandshakeType,
pub srt_socket_id: u32,
pub syn_cookie: u32,
pub peer_addr: [u8; 16],
}
impl HandshakePacket {
#[must_use]
pub fn induction_request(socket_id: u32, initial_seq: u32, mss: u32) -> Self {
Self {
udt_version: SRT_MAGIC,
encryption_field: 0,
extension_field: 0x4A17, initial_packet_seq_no: initial_seq,
mss,
max_flow_window_size: 8192,
handshake_type: HandshakeType::Induction,
srt_socket_id: socket_id,
syn_cookie: 0,
peer_addr: [0u8; 16],
}
}
#[must_use]
pub fn induction_response(
socket_id: u32,
initial_seq: u32,
mss: u32,
syn_cookie: u32,
peer_addr_v4: [u8; 4],
) -> Self {
let mut peer_addr = [0u8; 16];
peer_addr[..4].copy_from_slice(&peer_addr_v4);
Self {
udt_version: SRT_MAGIC,
encryption_field: 0,
extension_field: 0x4A17,
initial_packet_seq_no: initial_seq,
mss,
max_flow_window_size: 8192,
handshake_type: HandshakeType::WaveaHand,
srt_socket_id: socket_id,
syn_cookie,
peer_addr,
}
}
#[must_use]
pub fn conclusion_request(
socket_id: u32,
initial_seq: u32,
mss: u32,
syn_cookie: u32,
extension_field: u16,
) -> Self {
Self {
udt_version: SRT_MAGIC,
encryption_field: 0,
extension_field,
initial_packet_seq_no: initial_seq,
mss,
max_flow_window_size: 8192,
handshake_type: HandshakeType::Conclusion,
srt_socket_id: socket_id,
syn_cookie,
peer_addr: [0u8; 16],
}
}
#[must_use]
pub fn rejection(socket_id: u32, reason: RejectionReason) -> Self {
Self {
udt_version: SRT_MAGIC,
encryption_field: 0,
extension_field: 0,
initial_packet_seq_no: 0,
mss: 0,
max_flow_window_size: 0,
handshake_type: HandshakeType::Rejection(reason),
srt_socket_id: socket_id,
syn_cookie: 0,
peer_addr: [0u8; 16],
}
}
pub fn validate(&self) -> Result<(), String> {
if self.udt_version != SRT_MAGIC {
return Err(format!(
"invalid UDT version 0x{:08X} (expected 0x{:08X})",
self.udt_version, SRT_MAGIC
));
}
let ignore_mss = matches!(
self.handshake_type,
HandshakeType::Rejection(_) | HandshakeType::WaveaHand
);
if !ignore_mss && self.mss < 76 {
return Err(format!("MSS {} is below minimum 76", self.mss));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CallerPhase {
Idle,
AwaitingInduction,
AwaitingConclusion,
Connected,
Failed(RejectionReason),
}
#[derive(Debug)]
pub struct CallerHandshake {
pub socket_id: u32,
pub initial_seq: u32,
pub mss: u32,
pub hsreq: HsreqBlock,
pub stream_id: Option<StreamIdBlock>,
pub phase: CallerPhase,
pub syn_cookie: u32,
}
impl CallerHandshake {
#[must_use]
pub fn new(
socket_id: u32,
initial_seq: u32,
mss: u32,
recv_latency_ms: u16,
snd_latency_ms: u16,
) -> Self {
Self {
socket_id,
initial_seq,
mss,
hsreq: HsreqBlock::new(recv_latency_ms, snd_latency_ms),
stream_id: None,
phase: CallerPhase::Idle,
syn_cookie: 0,
}
}
pub fn with_stream_id(mut self, sid: StreamIdBlock) -> Self {
self.stream_id = Some(sid);
self
}
pub fn start(&mut self) -> Result<HandshakePacket, String> {
if self.phase != CallerPhase::Idle {
return Err(format!("cannot start in phase {:?}", self.phase));
}
self.phase = CallerPhase::AwaitingInduction;
Ok(HandshakePacket::induction_request(
self.socket_id,
self.initial_seq,
self.mss,
))
}
pub fn on_induction_response(
&mut self,
response: &HandshakePacket,
) -> Result<HandshakePacket, String> {
if self.phase != CallerPhase::AwaitingInduction {
return Err(format!(
"unexpected induction response in phase {:?}",
self.phase
));
}
if response.handshake_type != HandshakeType::WaveaHand {
return Err(format!(
"expected WAVEAHAND, got {}",
response.handshake_type
));
}
response.validate()?;
self.syn_cookie = response.syn_cookie;
self.phase = CallerPhase::AwaitingConclusion;
let mut ext_field: u16 = 1; if self.stream_id.is_some() {
ext_field |= 1 << 2; }
Ok(HandshakePacket::conclusion_request(
self.socket_id,
self.initial_seq,
self.mss,
self.syn_cookie,
ext_field,
))
}
pub fn on_conclusion_response(
&mut self,
response: &HandshakePacket,
) -> Result<Option<HsrspBlock>, String> {
if self.phase != CallerPhase::AwaitingConclusion {
return Err(format!(
"unexpected conclusion response in phase {:?}",
self.phase
));
}
match response.handshake_type {
HandshakeType::Rejection(r) => {
self.phase = CallerPhase::Failed(r);
Err(format!("connection rejected: {r}"))
}
HandshakeType::Agreement => {
self.phase = CallerPhase::Connected;
Ok(None)
}
HandshakeType::Conclusion => {
self.phase = CallerPhase::Connected;
Ok(None)
}
other => Err(format!(
"unexpected handshake type in conclusion response: {other}"
)),
}
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.phase == CallerPhase::Connected
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ListenerPhase {
AwaitingInduction,
AwaitingConclusion,
Connected,
Rejected(RejectionReason),
}
#[derive(Debug)]
pub struct ListenerHandshake {
pub socket_id: u32,
pub mss: u32,
pub recv_latency_ms: u16,
pub snd_latency_ms: u16,
pub srt_version: u32,
pub syn_cookie: u32,
pub phase: ListenerPhase,
pub negotiated: Option<HsrspBlock>,
}
impl ListenerHandshake {
#[must_use]
pub fn new(socket_id: u32, mss: u32, recv_latency_ms: u16, snd_latency_ms: u16) -> Self {
Self {
socket_id,
mss,
recv_latency_ms,
snd_latency_ms,
srt_version: SRT_VERSION_1_4,
syn_cookie: 0,
phase: ListenerPhase::AwaitingInduction,
negotiated: None,
}
}
pub fn on_induction_request(
&mut self,
request: &HandshakePacket,
cookie: u32,
caller_ipv4: [u8; 4],
) -> Result<HandshakePacket, String> {
if self.phase != ListenerPhase::AwaitingInduction {
return Err(format!("unexpected induction in phase {:?}", self.phase));
}
if request.handshake_type != HandshakeType::Induction {
return Err(format!(
"expected INDUCTION, got {}",
request.handshake_type
));
}
request.validate()?;
self.syn_cookie = cookie;
self.phase = ListenerPhase::AwaitingConclusion;
Ok(HandshakePacket::induction_response(
self.socket_id,
request.initial_packet_seq_no,
self.mss.min(request.mss),
cookie,
caller_ipv4,
))
}
pub fn on_conclusion_request(
&mut self,
request: &HandshakePacket,
caller_hsreq: &HsreqBlock,
) -> Result<HandshakePacket, String> {
if self.phase != ListenerPhase::AwaitingConclusion {
return Err(format!("unexpected conclusion in phase {:?}", self.phase));
}
if request.handshake_type != HandshakeType::Conclusion {
return Err(format!(
"expected CONCLUSION, got {}",
request.handshake_type
));
}
if request.syn_cookie != self.syn_cookie {
self.phase = ListenerPhase::Rejected(RejectionReason::BadRequest);
return Ok(HandshakePacket::rejection(
self.socket_id,
RejectionReason::BadRequest,
));
}
if caller_hsreq.srt_version < SRT_VERSION_1_4 {
self.phase = ListenerPhase::Rejected(RejectionReason::Version);
return Ok(HandshakePacket::rejection(
self.socket_id,
RejectionReason::Version,
));
}
if let Err(e) = caller_hsreq.validate() {
self.phase = ListenerPhase::Rejected(RejectionReason::BadRequest);
return Err(format!("caller HSREQ invalid: {e}"));
}
let hsrsp = HsrspBlock::negotiate(
caller_hsreq,
self.recv_latency_ms,
self.snd_latency_ms,
self.srt_version,
);
self.negotiated = Some(hsrsp);
self.phase = ListenerPhase::Connected;
let mut agreement = HandshakePacket::conclusion_request(
self.socket_id,
request.initial_packet_seq_no,
self.mss.min(request.mss),
self.syn_cookie,
0,
);
agreement.handshake_type = HandshakeType::Agreement;
Ok(agreement)
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.phase == ListenerPhase::Connected
}
pub fn reject(&mut self, reason: RejectionReason) -> HandshakePacket {
self.phase = ListenerPhase::Rejected(reason);
HandshakePacket::rejection(self.socket_id, reason)
}
}
#[must_use]
pub fn compute_syn_cookie(caller_ipv4: [u8; 4], epoch: u32) -> u32 {
const FNV_OFFSET: u32 = 2_166_136_261;
const FNV_PRIME: u32 = 16_777_619;
let mut hash = FNV_OFFSET;
for byte in caller_ipv4.iter().chain(epoch.to_le_bytes().iter()) {
hash ^= u32::from(*byte);
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_type_wire_roundtrip() {
let types = [
HandshakeType::Induction,
HandshakeType::WaveaHand,
HandshakeType::Conclusion,
HandshakeType::Agreement,
HandshakeType::Rejection(RejectionReason::Passphrase),
];
for t in types {
let wire = t.to_wire();
let decoded = HandshakeType::from_wire(wire).expect("should decode");
assert_eq!(t, decoded, "roundtrip failed for {t}");
}
}
#[test]
fn test_rejection_type_carries_reason() {
let t = HandshakeType::Rejection(RejectionReason::Forbidden);
let wire = t.to_wire();
let decoded = HandshakeType::from_wire(wire).expect("should decode rejection");
assert_eq!(
decoded,
HandshakeType::Rejection(RejectionReason::Forbidden)
);
}
#[test]
fn test_unknown_wire_value_returns_none() {
assert!(HandshakeType::from_wire(2).is_none());
}
#[test]
fn test_rejection_reason_from_code() {
assert_eq!(RejectionReason::from_code(6), RejectionReason::Passphrase);
assert_eq!(RejectionReason::from_code(9), RejectionReason::Unauthorized);
assert_eq!(
RejectionReason::from_code(1000),
RejectionReason::ApplicationDefined
);
assert_eq!(
RejectionReason::from_code(9999),
RejectionReason::ApplicationDefined
);
}
#[test]
fn test_rejection_reason_display() {
assert_eq!(RejectionReason::Passphrase.to_string(), "PASSPHRASE");
assert_eq!(RejectionReason::Version.to_string(), "VERSION");
}
#[test]
fn test_srt_flags_set_and_has() {
let mut flags = SrtFlags::new(0);
assert!(!flags.has(SrtFlags::TSBPD_SND));
flags.set(SrtFlags::TSBPD_SND);
assert!(flags.has(SrtFlags::TSBPD_SND));
flags.clear(SrtFlags::TSBPD_SND);
assert!(!flags.has(SrtFlags::TSBPD_SND));
}
#[test]
fn test_srt_flags_multiple_bits() {
let flags = SrtFlags::new(SrtFlags::TSBPD_SND | SrtFlags::NAK_REPORT);
assert!(flags.has(SrtFlags::TSBPD_SND));
assert!(flags.has(SrtFlags::NAK_REPORT));
assert!(!flags.has(SrtFlags::HAICRYPT_OFF));
}
#[test]
fn test_hsreq_validate_ok() {
let hsreq = HsreqBlock::new(120, 120);
assert!(hsreq.validate().is_ok());
}
#[test]
fn test_hsreq_validate_latency_too_low() {
let mut hsreq = HsreqBlock::new(10, 120); hsreq.recv_tsbpd_delay_ms = 10;
assert!(hsreq.validate().is_err());
}
#[test]
fn test_hsreq_validate_old_version() {
let mut hsreq = HsreqBlock::new(120, 120);
hsreq.srt_version = 0x0001_0300; assert!(hsreq.validate().is_err());
}
#[test]
fn test_hsrsp_negotiation_takes_max_latency() {
let caller_req = HsreqBlock::new(200, 100);
let hsrsp = HsrspBlock::negotiate(&caller_req, 300, 150, SRT_VERSION_1_4);
assert_eq!(hsrsp.recv_tsbpd_delay_ms, 300);
assert_eq!(hsrsp.snd_tsbpd_delay_ms, 150);
}
#[test]
fn test_hsrsp_negotiation_uses_min_version() {
let caller_req = HsreqBlock {
srt_version: SRT_VERSION_1_5,
..HsreqBlock::new(120, 120)
};
let hsrsp = HsrspBlock::negotiate(&caller_req, 120, 120, SRT_VERSION_1_4);
assert_eq!(hsrsp.srt_version, SRT_VERSION_1_4);
}
#[test]
fn test_stream_id_valid() {
let block = StreamIdBlock::new("my-camera-feed").expect("should construct");
assert_eq!(block.stream_id, "my-camera-feed");
}
#[test]
fn test_stream_id_too_long() {
let too_long = "x".repeat(MAX_STREAM_ID_LEN + 1);
assert!(StreamIdBlock::new(too_long).is_err());
}
#[test]
fn test_induction_request_validates() {
let pkt = HandshakePacket::induction_request(42, 1000, 1500);
assert!(pkt.validate().is_ok());
assert_eq!(pkt.handshake_type, HandshakeType::Induction);
assert_eq!(pkt.udt_version, SRT_MAGIC);
}
#[test]
fn test_induction_response_carries_cookie() {
let pkt =
HandshakePacket::induction_response(99, 1000, 1500, 0xDEAD_BEEF, [192, 168, 1, 1]);
assert_eq!(pkt.syn_cookie, 0xDEAD_BEEF);
assert_eq!(pkt.peer_addr[..4], [192, 168, 1, 1]);
assert_eq!(pkt.handshake_type, HandshakeType::WaveaHand);
}
#[test]
fn test_rejection_packet_valid_type() {
let pkt = HandshakePacket::rejection(7, RejectionReason::Overloaded);
assert_eq!(
pkt.handshake_type,
HandshakeType::Rejection(RejectionReason::Overloaded)
);
}
#[test]
fn test_full_handshake_roundtrip() {
let mut caller = CallerHandshake::new(
0xAABB_CCDD, 12345, 1500, 120, 120, );
let mut listener = ListenerHandshake::new(
0x1122_3344, 1500, 200, 150, );
let induction_req = caller.start().expect("start should succeed");
assert_eq!(caller.phase, CallerPhase::AwaitingInduction);
let cookie = compute_syn_cookie([10, 0, 0, 1], 42);
let induction_resp = listener
.on_induction_request(&induction_req, cookie, [10, 0, 0, 1])
.expect("listener should accept induction");
assert_eq!(listener.phase, ListenerPhase::AwaitingConclusion);
assert_eq!(induction_resp.syn_cookie, cookie);
let conclusion_req = caller
.on_induction_response(&induction_resp)
.expect("caller should accept induction response");
assert_eq!(caller.phase, CallerPhase::AwaitingConclusion);
assert_eq!(conclusion_req.syn_cookie, cookie);
let hsreq = HsreqBlock::new(120, 120);
let agreement = listener
.on_conclusion_request(&conclusion_req, &hsreq)
.expect("listener should send agreement");
assert_eq!(agreement.handshake_type, HandshakeType::Agreement);
assert!(listener.is_connected());
caller
.on_conclusion_response(&agreement)
.expect("caller should accept agreement");
assert!(caller.is_connected());
}
#[test]
fn test_handshake_rejection_on_bad_cookie() {
let mut listener = ListenerHandshake::new(1, 1500, 120, 120);
let induction_req = HandshakePacket::induction_request(2, 1000, 1500);
let real_cookie = 0xCAFE_BABE;
listener
.on_induction_request(&induction_req, real_cookie, [127, 0, 0, 1])
.expect("induction ok");
let mut bad_conclusion = HandshakePacket::conclusion_request(2, 1000, 1500, 0xDEAD_BEEF, 1);
bad_conclusion.handshake_type = HandshakeType::Conclusion;
let hsreq = HsreqBlock::new(120, 120);
let rejection = listener
.on_conclusion_request(&bad_conclusion, &hsreq)
.expect("should return rejection packet");
assert_eq!(
rejection.handshake_type,
HandshakeType::Rejection(RejectionReason::BadRequest)
);
assert_eq!(
listener.phase,
ListenerPhase::Rejected(RejectionReason::BadRequest)
);
}
#[test]
fn test_syn_cookie_deterministic() {
let ip = [192, 168, 1, 100];
let c1 = compute_syn_cookie(ip, 100);
let c2 = compute_syn_cookie(ip, 100);
assert_eq!(c1, c2);
}
#[test]
fn test_syn_cookie_different_ip_different_result() {
let c1 = compute_syn_cookie([10, 0, 0, 1], 50);
let c2 = compute_syn_cookie([10, 0, 0, 2], 50);
assert_ne!(c1, c2);
}
#[test]
fn test_syn_cookie_different_epoch_different_result() {
let ip = [172, 16, 0, 1];
let c1 = compute_syn_cookie(ip, 0);
let c2 = compute_syn_cookie(ip, 1);
assert_ne!(c1, c2);
}
#[test]
fn test_handshake_type_display() {
assert_eq!(HandshakeType::Induction.to_string(), "INDUCTION");
assert_eq!(HandshakeType::WaveaHand.to_string(), "WAVEAHAND");
assert_eq!(HandshakeType::Conclusion.to_string(), "CONCLUSION");
assert_eq!(HandshakeType::Agreement.to_string(), "AGREEMENT");
let rej = HandshakeType::Rejection(RejectionReason::Forbidden);
assert_eq!(rej.to_string(), "REJECTION(FORBIDDEN)");
}
}