use super::packet::SrtPacket;
use super::socket::{ConnectionState, SrtConfig, SrtSocket};
use crate::error::{NetError, NetResult};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant, SystemTime};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionMode {
Caller,
Listener,
Rendezvous,
}
impl ConnectionMode {
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
Self::Caller => "caller",
Self::Listener => "listener",
Self::Rendezvous => "rendezvous",
}
}
}
impl std::fmt::Display for ConnectionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.name())
}
}
#[derive(Debug)]
pub struct CallerState {
socket: SrtSocket,
mode: ConnectionMode,
peer_addr: SocketAddr,
retry_count: u32,
max_retries: u32,
last_send: Option<Instant>,
retry_interval: Duration,
started_at: Instant,
}
impl CallerState {
#[must_use]
pub fn new(config: SrtConfig, peer_addr: SocketAddr) -> Self {
Self {
socket: SrtSocket::new(config),
mode: ConnectionMode::Caller,
peer_addr,
retry_count: 0,
max_retries: 10,
last_send: None,
retry_interval: Duration::from_millis(250),
started_at: Instant::now(),
}
}
pub fn set_max_retries(&mut self, max: u32) {
self.max_retries = max;
}
#[must_use]
pub fn generate_initial_handshake(&mut self) -> SrtPacket {
self.last_send = Some(Instant::now());
self.socket.generate_caller_handshake()
}
pub fn process_response(&mut self, packet: SrtPacket) -> NetResult<Vec<SrtPacket>> {
self.socket.process_packet(packet)
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.socket.is_connected()
}
#[must_use]
pub const fn mode(&self) -> ConnectionMode {
self.mode
}
#[must_use]
pub const fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
#[must_use]
pub fn needs_retry(&self) -> bool {
if self.socket.is_connected() {
return false;
}
match self.last_send {
Some(t) => t.elapsed() > self.retry_interval && self.retry_count < self.max_retries,
None => true,
}
}
pub fn retry_handshake(&mut self) -> Option<SrtPacket> {
if self.retry_count >= self.max_retries {
return None;
}
self.retry_count += 1;
self.last_send = Some(Instant::now());
Some(self.socket.generate_caller_handshake())
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
#[must_use]
pub fn socket(&self) -> &SrtSocket {
&self.socket
}
pub fn socket_mut(&mut self) -> &mut SrtSocket {
&mut self.socket
}
}
#[derive(Debug)]
pub struct PendingConnection {
pub addr: SocketAddr,
pub socket: SrtSocket,
pub stage: u8,
pub created_at: Instant,
pub syn_cookie: u32,
}
#[derive(Debug)]
pub struct ListenerState {
config: SrtConfig,
mode: ConnectionMode,
bind_addr: SocketAddr,
pending: HashMap<SocketAddr, PendingConnection>,
established: Vec<SocketAddr>,
max_pending: usize,
pending_timeout: Duration,
total_accepted: u64,
}
impl ListenerState {
#[must_use]
pub fn new(config: SrtConfig, bind_addr: SocketAddr) -> Self {
Self {
config,
mode: ConnectionMode::Listener,
bind_addr,
pending: HashMap::new(),
established: Vec::new(),
max_pending: 128,
pending_timeout: Duration::from_secs(5),
total_accepted: 0,
}
}
pub fn set_max_pending(&mut self, max: usize) {
self.max_pending = max;
}
pub fn process_incoming(
&mut self,
from: SocketAddr,
packet: SrtPacket,
) -> NetResult<Vec<SrtPacket>> {
if let Some(pending) = self.pending.get_mut(&from) {
let responses = pending.socket.process_packet(packet)?;
if pending.socket.is_connected() {
pending.stage = 2;
self.established.push(from);
self.total_accepted += 1;
}
Ok(responses)
} else {
if self.pending.len() >= self.max_pending {
return Err(NetError::connection("Max pending connections reached"));
}
let mut socket = SrtSocket::new(self.config.clone());
let responses = socket.process_packet(packet)?;
let syn_cookie = generate_listener_cookie(&from);
let stage = if socket.is_connected() { 2 } else { 1 };
let conn = PendingConnection {
addr: from,
socket,
stage,
created_at: Instant::now(),
syn_cookie,
};
if stage == 2 {
self.established.push(from);
self.total_accepted += 1;
}
self.pending.insert(from, conn);
Ok(responses)
}
}
pub fn cleanup_pending(&mut self) {
let timeout = self.pending_timeout;
self.pending
.retain(|_, conn| conn.stage >= 2 || conn.created_at.elapsed() < timeout);
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.pending.iter().filter(|(_, c)| c.stage < 2).count()
}
#[must_use]
pub fn established_count(&self) -> usize {
self.established.len()
}
#[must_use]
pub fn total_accepted(&self) -> u64 {
self.total_accepted
}
#[must_use]
pub const fn bind_addr(&self) -> SocketAddr {
self.bind_addr
}
#[must_use]
pub const fn mode(&self) -> ConnectionMode {
self.mode
}
#[must_use]
pub fn get_pending(&self, addr: &SocketAddr) -> Option<&PendingConnection> {
self.pending.get(addr)
}
#[must_use]
pub fn established_addrs(&self) -> &[SocketAddr] {
&self.established
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RendezvousPhase {
Waving,
Attention,
Fine,
Connected,
Failed,
}
impl RendezvousPhase {
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
Self::Waving => "waving",
Self::Attention => "attention",
Self::Fine => "fine",
Self::Connected => "connected",
Self::Failed => "failed",
}
}
}
#[derive(Debug)]
pub struct RendezvousState {
socket: SrtSocket,
mode: ConnectionMode,
peer_addr: SocketAddr,
phase: RendezvousPhase,
wave_count: u32,
max_wave_count: u32,
wave_interval: Duration,
last_wave: Option<Instant>,
started_at: Instant,
peer_socket_id: Option<u32>,
}
impl RendezvousState {
#[must_use]
pub fn new(config: SrtConfig, peer_addr: SocketAddr) -> Self {
Self {
socket: SrtSocket::new(config),
mode: ConnectionMode::Rendezvous,
peer_addr,
phase: RendezvousPhase::Waving,
wave_count: 0,
max_wave_count: 25,
wave_interval: Duration::from_millis(250),
last_wave: None,
started_at: Instant::now(),
peer_socket_id: None,
}
}
#[must_use]
pub fn generate_wave(&mut self) -> SrtPacket {
self.wave_count += 1;
self.last_wave = Some(Instant::now());
self.socket.generate_caller_handshake()
}
pub fn process_packet(&mut self, packet: SrtPacket) -> NetResult<Vec<SrtPacket>> {
let responses = self.socket.process_packet(packet)?;
if self.socket.is_connected() {
self.phase = RendezvousPhase::Connected;
} else if self.socket.state() == ConnectionState::Handshaking {
match self.phase {
RendezvousPhase::Waving => {
self.phase = RendezvousPhase::Attention;
}
RendezvousPhase::Attention => {
self.phase = RendezvousPhase::Fine;
}
_ => {}
}
}
Ok(responses)
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.phase == RendezvousPhase::Connected
}
#[must_use]
pub const fn phase(&self) -> RendezvousPhase {
self.phase
}
#[must_use]
pub const fn mode(&self) -> ConnectionMode {
self.mode
}
#[must_use]
pub const fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
#[must_use]
pub fn needs_wave(&self) -> bool {
if self.is_connected() || self.wave_count >= self.max_wave_count {
return false;
}
match self.last_wave {
Some(t) => t.elapsed() > self.wave_interval,
None => true,
}
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
#[must_use]
pub fn is_timed_out(&self) -> bool {
self.wave_count >= self.max_wave_count && !self.is_connected()
}
#[must_use]
pub fn socket(&self) -> &SrtSocket {
&self.socket
}
pub fn socket_mut(&mut self) -> &mut SrtSocket {
&mut self.socket
}
}
fn generate_listener_cookie(addr: &SocketAddr) -> u32 {
let seed = match addr {
SocketAddr::V4(a) => {
let ip_bytes = a.ip().octets();
let port = a.port() as u32;
u32::from_be_bytes(ip_bytes) ^ (port << 16) ^ port
}
SocketAddr::V6(a) => {
let ip_bytes = a.ip().octets();
let port = a.port() as u32;
let h = u32::from_be_bytes([ip_bytes[0], ip_bytes[1], ip_bytes[2], ip_bytes[3]]);
h ^ (port << 16) ^ port
}
};
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs() as u32)
.unwrap_or(0);
seed ^ now ^ 0xBEEF_CAFE
}
#[cfg(test)]
mod tests {
use super::*;
fn test_addr() -> SocketAddr {
"127.0.0.1:9000".parse().expect("valid addr")
}
fn test_addr2() -> SocketAddr {
"127.0.0.1:9001".parse().expect("valid addr")
}
#[test]
fn test_connection_mode_display() {
assert_eq!(ConnectionMode::Caller.name(), "caller");
assert_eq!(ConnectionMode::Listener.name(), "listener");
assert_eq!(ConnectionMode::Rendezvous.name(), "rendezvous");
assert_eq!(format!("{}", ConnectionMode::Caller), "caller");
}
#[test]
fn test_caller_state_new() {
let state = CallerState::new(SrtConfig::default(), test_addr());
assert_eq!(state.mode(), ConnectionMode::Caller);
assert_eq!(state.peer_addr(), test_addr());
assert!(!state.is_connected());
}
#[test]
fn test_caller_initial_handshake() {
let mut state = CallerState::new(SrtConfig::default(), test_addr());
let pkt = state.generate_initial_handshake();
assert!(pkt.is_control());
}
#[test]
fn test_caller_needs_retry() {
let state = CallerState::new(SrtConfig::default(), test_addr());
assert!(state.needs_retry());
}
#[test]
fn test_caller_retry_limit() {
let mut state = CallerState::new(SrtConfig::default(), test_addr());
state.set_max_retries(2);
state.retry_handshake();
state.retry_handshake();
assert!(state.retry_handshake().is_none());
}
#[test]
fn test_caller_elapsed() {
let state = CallerState::new(SrtConfig::default(), test_addr());
let _ = state.elapsed();
assert!(state.elapsed().as_nanos() < u128::MAX);
}
#[test]
fn test_listener_state_new() {
let state = ListenerState::new(SrtConfig::default(), test_addr());
assert_eq!(state.mode(), ConnectionMode::Listener);
assert_eq!(state.bind_addr(), test_addr());
assert_eq!(state.pending_count(), 0);
assert_eq!(state.established_count(), 0);
}
#[test]
fn test_listener_process_incoming() {
let mut listener = ListenerState::new(SrtConfig::default(), test_addr());
let mut caller_socket = SrtSocket::new(SrtConfig::default());
let handshake = caller_socket.generate_caller_handshake();
let responses = listener
.process_incoming(test_addr2(), handshake)
.expect("should process");
assert!(listener.get_pending(&test_addr2()).is_some());
assert!(!responses.is_empty());
}
#[test]
fn test_listener_max_pending() {
let mut listener = ListenerState::new(SrtConfig::default(), test_addr());
listener.set_max_pending(1);
let mut s1 = SrtSocket::new(SrtConfig::default());
let h1 = s1.generate_caller_handshake();
listener
.process_incoming(test_addr2(), h1)
.expect("should work");
let mut s2 = SrtSocket::new(SrtConfig::default());
let h2 = s2.generate_caller_handshake();
let addr3: SocketAddr = "127.0.0.1:9002".parse().expect("valid");
let result = listener.process_incoming(addr3, h2);
assert!(result.is_err());
}
#[test]
fn test_listener_cleanup() {
let mut listener = ListenerState::new(SrtConfig::default(), test_addr());
listener.cleanup_pending();
assert_eq!(listener.pending_count(), 0);
}
#[test]
fn test_listener_total_accepted() {
let listener = ListenerState::new(SrtConfig::default(), test_addr());
assert_eq!(listener.total_accepted(), 0);
}
#[test]
fn test_rendezvous_phase_names() {
assert_eq!(RendezvousPhase::Waving.name(), "waving");
assert_eq!(RendezvousPhase::Attention.name(), "attention");
assert_eq!(RendezvousPhase::Fine.name(), "fine");
assert_eq!(RendezvousPhase::Connected.name(), "connected");
assert_eq!(RendezvousPhase::Failed.name(), "failed");
}
#[test]
fn test_rendezvous_state_new() {
let state = RendezvousState::new(SrtConfig::default(), test_addr());
assert_eq!(state.mode(), ConnectionMode::Rendezvous);
assert_eq!(state.phase(), RendezvousPhase::Waving);
assert!(!state.is_connected());
}
#[test]
fn test_rendezvous_wave() {
let mut state = RendezvousState::new(SrtConfig::default(), test_addr());
let pkt = state.generate_wave();
assert!(pkt.is_control());
}
#[test]
fn test_rendezvous_needs_wave() {
let state = RendezvousState::new(SrtConfig::default(), test_addr());
assert!(state.needs_wave());
}
#[test]
fn test_rendezvous_timeout() {
let mut state = RendezvousState::new(SrtConfig::default(), test_addr());
state.max_wave_count = 3;
for _ in 0..3 {
let _ = state.generate_wave();
}
assert!(state.is_timed_out());
}
#[test]
fn test_rendezvous_elapsed() {
let state = RendezvousState::new(SrtConfig::default(), test_addr());
let _ = state.elapsed();
assert!(state.elapsed().as_nanos() < u128::MAX);
}
#[test]
fn test_listener_cookie() {
let addr = test_addr();
let c1 = generate_listener_cookie(&addr);
let c2 = generate_listener_cookie(&addr);
assert_eq!(c1, c2);
}
#[test]
fn test_listener_cookie_different_addrs() {
let c1 = generate_listener_cookie(&test_addr());
let c2 = generate_listener_cookie(&test_addr2());
assert_ne!(c1, 0);
assert_ne!(c2, 0);
}
#[test]
fn test_ipv6_cookie() {
let addr: SocketAddr = "[::1]:9000".parse().expect("valid");
let cookie = generate_listener_cookie(&addr);
assert_ne!(cookie, 0);
}
#[test]
fn test_caller_socket_access() {
let state = CallerState::new(SrtConfig::default(), test_addr());
assert_eq!(state.socket().state(), ConnectionState::Initial);
}
#[test]
fn test_rendezvous_socket_access() {
let state = RendezvousState::new(SrtConfig::default(), test_addr());
let rtt = state.socket().rtt();
assert!(rtt > 0); }
#[test]
fn test_listener_established_addrs() {
let listener = ListenerState::new(SrtConfig::default(), test_addr());
assert!(listener.established_addrs().is_empty());
}
#[test]
fn test_rendezvous_cross_connection() {
let addr_a = test_addr();
let addr_b = test_addr2();
let mut side_a = RendezvousState::new(SrtConfig::default(), addr_b);
let mut side_b = RendezvousState::new(SrtConfig::default(), addr_a);
let wave_a = side_a.generate_wave();
let wave_b = side_b.generate_wave();
let resp_a = side_a.process_packet(wave_b);
let resp_b = side_b.process_packet(wave_a);
assert!(resp_a.is_ok());
assert!(resp_b.is_ok());
if let Ok(responses) = resp_a {
for r in responses {
let _ = side_b.process_packet(r);
}
}
if let Ok(responses) = resp_b {
for r in responses {
let _ = side_a.process_packet(r);
}
}
let progressed =
side_a.phase() != RendezvousPhase::Waving || side_b.phase() != RendezvousPhase::Waving;
assert!(progressed);
}
}