use super::{
connection_mode::{CallerState, ConnectionMode, ListenerState, RendezvousState},
packet::SrtPacket,
socket::{ConnectionState, SrtConfig},
};
use crate::error::{NetError, NetResult};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ConnectionParams {
pub mode: ConnectionMode,
pub local_addr: SocketAddr,
pub peer_addr: Option<SocketAddr>,
pub config: SrtConfig,
pub connect_timeout: Duration,
pub max_retries: u32,
pub retry_interval: Duration,
}
impl ConnectionParams {
#[must_use]
pub fn caller(local_addr: SocketAddr, peer_addr: SocketAddr, config: SrtConfig) -> Self {
Self {
mode: ConnectionMode::Caller,
local_addr,
peer_addr: Some(peer_addr),
config,
connect_timeout: Duration::from_secs(5),
max_retries: 20,
retry_interval: Duration::from_millis(250),
}
}
#[must_use]
pub fn listener(local_addr: SocketAddr, config: SrtConfig) -> Self {
Self {
mode: ConnectionMode::Listener,
local_addr,
peer_addr: None,
config,
connect_timeout: Duration::from_secs(30),
max_retries: 128,
retry_interval: Duration::from_millis(500),
}
}
#[must_use]
pub fn rendezvous(local_addr: SocketAddr, peer_addr: SocketAddr, config: SrtConfig) -> Self {
Self {
mode: ConnectionMode::Rendezvous,
local_addr,
peer_addr: Some(peer_addr),
config,
connect_timeout: Duration::from_secs(10),
max_retries: 40,
retry_interval: Duration::from_millis(250),
}
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
#[must_use]
pub const fn with_max_retries(mut self, max: u32) -> Self {
self.max_retries = max;
self
}
}
#[derive(Debug, Clone)]
pub enum HandshakeStep {
InProgress { packets_to_send: Vec<SrtPacket> },
Connected,
Failed(String),
}
pub struct SrtConnectionFactory {
params: ConnectionParams,
caller: Option<CallerState>,
listener: Option<ListenerState>,
rendezvous: Option<RendezvousState>,
started_at: Instant,
connected: bool,
}
impl SrtConnectionFactory {
#[must_use]
pub fn new(params: ConnectionParams) -> Self {
let mut factory = Self {
params: params.clone(),
caller: None,
listener: None,
rendezvous: None,
started_at: Instant::now(),
connected: false,
};
match params.mode {
ConnectionMode::Caller => {
let peer = params
.peer_addr
.unwrap_or_else(|| "0.0.0.0:0".parse().expect("zero addr"));
let mut caller = CallerState::new(params.config.clone(), peer);
caller.set_max_retries(params.max_retries);
factory.caller = Some(caller);
}
ConnectionMode::Listener => {
let mut listener = ListenerState::new(params.config.clone(), params.local_addr);
listener.set_max_pending(params.max_retries as usize);
factory.listener = Some(listener);
}
ConnectionMode::Rendezvous => {
let peer = params
.peer_addr
.unwrap_or_else(|| "0.0.0.0:0".parse().expect("zero addr"));
factory.rendezvous = Some(RendezvousState::new(params.config.clone(), peer));
}
}
factory
}
#[must_use]
pub fn mode(&self) -> ConnectionMode {
self.params.mode
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.connected
}
#[must_use]
pub fn is_timed_out(&self) -> bool {
self.started_at.elapsed() > self.params.connect_timeout
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
pub fn start_handshake(&mut self) -> Vec<SrtPacket> {
match self.params.mode {
ConnectionMode::Caller => {
if let Some(ref mut caller) = self.caller {
vec![caller.generate_initial_handshake()]
} else {
Vec::new()
}
}
ConnectionMode::Rendezvous => {
if let Some(ref mut rdv) = self.rendezvous {
vec![rdv.generate_wave()]
} else {
Vec::new()
}
}
ConnectionMode::Listener => Vec::new(),
}
}
pub fn process_packet(&mut self, from: SocketAddr, packet: SrtPacket) -> HandshakeStep {
if self.connected {
return HandshakeStep::Connected;
}
match self.params.mode {
ConnectionMode::Caller => self.process_caller_packet(packet),
ConnectionMode::Listener => self.process_listener_packet(from, packet),
ConnectionMode::Rendezvous => self.process_rendezvous_packet(packet),
}
}
pub fn tick(&mut self) -> Vec<SrtPacket> {
if self.connected || self.is_timed_out() {
return Vec::new();
}
match self.params.mode {
ConnectionMode::Caller => {
if let Some(ref mut caller) = self.caller {
if caller.needs_retry() {
if let Some(pkt) = caller.retry_handshake() {
return vec![pkt];
}
}
}
Vec::new()
}
ConnectionMode::Rendezvous => {
if let Some(ref mut rdv) = self.rendezvous {
if rdv.needs_wave() {
return vec![rdv.generate_wave()];
}
}
Vec::new()
}
ConnectionMode::Listener => Vec::new(),
}
}
pub fn cleanup_pending(&mut self) {
if let Some(ref mut listener) = self.listener {
listener.cleanup_pending();
}
}
#[must_use]
pub fn established_count(&self) -> usize {
self.listener
.as_ref()
.map(|l| l.established_count())
.unwrap_or(0)
}
#[must_use]
pub fn total_accepted(&self) -> u64 {
self.listener
.as_ref()
.map(|l| l.total_accepted())
.unwrap_or(0)
}
fn process_caller_packet(&mut self, packet: SrtPacket) -> HandshakeStep {
if let Some(ref mut caller) = self.caller {
match caller.process_response(packet) {
Ok(responses) => {
if caller.is_connected() {
self.connected = true;
HandshakeStep::Connected
} else {
HandshakeStep::InProgress {
packets_to_send: responses,
}
}
}
Err(e) => HandshakeStep::Failed(e.to_string()),
}
} else {
HandshakeStep::Failed("Caller state not initialised".to_owned())
}
}
fn process_listener_packet(&mut self, from: SocketAddr, packet: SrtPacket) -> HandshakeStep {
if let Some(ref mut listener) = self.listener {
match listener.process_incoming(from, packet) {
Ok(responses) => {
if listener.established_count() > 0 {
self.connected = true;
HandshakeStep::Connected
} else {
HandshakeStep::InProgress {
packets_to_send: responses,
}
}
}
Err(e) => HandshakeStep::Failed(e.to_string()),
}
} else {
HandshakeStep::Failed("Listener state not initialised".to_owned())
}
}
fn process_rendezvous_packet(&mut self, packet: SrtPacket) -> HandshakeStep {
if let Some(ref mut rdv) = self.rendezvous {
match rdv.process_packet(packet) {
Ok(responses) => {
if rdv.is_connected() {
self.connected = true;
HandshakeStep::Connected
} else {
HandshakeStep::InProgress {
packets_to_send: responses,
}
}
}
Err(e) => HandshakeStep::Failed(e.to_string()),
}
} else {
HandshakeStep::Failed("Rendezvous state not initialised".to_owned())
}
}
}
impl std::fmt::Debug for SrtConnectionFactory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SrtConnectionFactory")
.field("mode", &self.params.mode)
.field("connected", &self.connected)
.field("elapsed_ms", &self.elapsed().as_millis())
.finish()
}
}
pub fn simulate_handshake(
side_a: &mut SrtConnectionFactory,
side_b: &mut SrtConnectionFactory,
max_rounds: u32,
) -> NetResult<()> {
let a_addr = side_a.params.local_addr;
let b_addr = side_b.params.local_addr;
let mut a_to_b = side_a.start_handshake();
let mut b_to_a = side_b.start_handshake();
for _ in 0..max_rounds {
if side_a.is_connected() && side_b.is_connected() {
return Ok(());
}
for pkt in a_to_b.drain(..) {
match side_b.process_packet(a_addr, pkt) {
HandshakeStep::Connected => {
}
HandshakeStep::InProgress { packets_to_send } => {
b_to_a.extend(packets_to_send);
}
HandshakeStep::Failed(msg) => {
return Err(NetError::handshake(format!("Side B failed: {msg}")));
}
}
}
for pkt in b_to_a.drain(..) {
match side_a.process_packet(b_addr, pkt) {
HandshakeStep::Connected => {
}
HandshakeStep::InProgress { packets_to_send } => {
a_to_b.extend(packets_to_send);
}
HandshakeStep::Failed(msg) => {
return Err(NetError::handshake(format!("Side A failed: {msg}")));
}
}
}
a_to_b.extend(side_a.tick());
b_to_a.extend(side_b.tick());
}
if side_a.is_connected() && side_b.is_connected() {
Ok(())
} else {
Err(NetError::timeout(format!(
"Handshake did not complete in {max_rounds} rounds (mode={:?})",
side_a.mode()
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn addr(port: u16) -> SocketAddr {
format!("127.0.0.1:{port}").parse().expect("valid addr")
}
fn default_config() -> SrtConfig {
SrtConfig::default()
}
#[test]
fn test_caller_params() {
let p = ConnectionParams::caller(addr(9000), addr(9001), default_config());
assert_eq!(p.mode, ConnectionMode::Caller);
assert!(p.peer_addr.is_some());
}
#[test]
fn test_listener_params() {
let p = ConnectionParams::listener(addr(9000), default_config());
assert_eq!(p.mode, ConnectionMode::Listener);
assert!(p.peer_addr.is_none());
}
#[test]
fn test_rendezvous_params() {
let p = ConnectionParams::rendezvous(addr(9000), addr(9001), default_config());
assert_eq!(p.mode, ConnectionMode::Rendezvous);
assert_eq!(p.peer_addr, Some(addr(9001)));
}
#[test]
fn test_params_with_timeout() {
let p = ConnectionParams::caller(addr(9000), addr(9001), default_config())
.with_timeout(Duration::from_secs(10));
assert_eq!(p.connect_timeout, Duration::from_secs(10));
}
#[test]
fn test_params_with_max_retries() {
let p =
ConnectionParams::caller(addr(9000), addr(9001), default_config()).with_max_retries(5);
assert_eq!(p.max_retries, 5);
}
#[test]
fn test_factory_caller_new() {
let p = ConnectionParams::caller(addr(9000), addr(9001), default_config());
let factory = SrtConnectionFactory::new(p);
assert_eq!(factory.mode(), ConnectionMode::Caller);
assert!(!factory.is_connected());
}
#[test]
fn test_factory_listener_new() {
let p = ConnectionParams::listener(addr(9000), default_config());
let factory = SrtConnectionFactory::new(p);
assert_eq!(factory.mode(), ConnectionMode::Listener);
assert_eq!(factory.established_count(), 0);
}
#[test]
fn test_factory_rendezvous_new() {
let p = ConnectionParams::rendezvous(addr(9000), addr(9001), default_config());
let factory = SrtConnectionFactory::new(p);
assert_eq!(factory.mode(), ConnectionMode::Rendezvous);
}
#[test]
fn test_caller_start_handshake() {
let p = ConnectionParams::caller(addr(9000), addr(9001), default_config());
let mut factory = SrtConnectionFactory::new(p);
let pkts = factory.start_handshake();
assert_eq!(pkts.len(), 1);
assert!(pkts[0].is_control());
}
#[test]
fn test_listener_start_handshake_empty() {
let p = ConnectionParams::listener(addr(9000), default_config());
let mut factory = SrtConnectionFactory::new(p);
let pkts = factory.start_handshake();
assert!(pkts.is_empty());
}
#[test]
fn test_rendezvous_start_handshake() {
let p = ConnectionParams::rendezvous(addr(9000), addr(9001), default_config());
let mut factory = SrtConnectionFactory::new(p);
let pkts = factory.start_handshake();
assert_eq!(pkts.len(), 1);
assert!(pkts[0].is_control());
}
#[test]
fn test_caller_tick_retries() {
let p =
ConnectionParams::caller(addr(9000), addr(9001), default_config()).with_max_retries(3);
let mut factory = SrtConnectionFactory::new(p);
let _initial = factory.start_handshake();
let _ = factory.tick();
}
#[test]
fn test_factory_timeout() {
let p = ConnectionParams::caller(addr(9000), addr(9001), default_config())
.with_timeout(Duration::from_nanos(1));
let factory = SrtConnectionFactory::new(p);
std::thread::sleep(Duration::from_micros(10));
assert!(factory.is_timed_out());
}
#[test]
fn test_factory_elapsed() {
let p = ConnectionParams::caller(addr(9000), addr(9001), default_config());
let factory = SrtConnectionFactory::new(p);
assert!(factory.elapsed().as_nanos() > 0);
}
#[test]
fn test_listener_process_incoming() {
let lp = ConnectionParams::listener(addr(9000), default_config());
let mut listener_factory = SrtConnectionFactory::new(lp);
let cp = ConnectionParams::caller(addr(9001), addr(9000), default_config());
let mut caller_factory = SrtConnectionFactory::new(cp);
let pkts = caller_factory.start_handshake();
assert!(!pkts.is_empty());
let step =
listener_factory.process_packet(addr(9001), pkts.into_iter().next().expect("pkt"));
match step {
HandshakeStep::Connected | HandshakeStep::InProgress { .. } => {}
HandshakeStep::Failed(msg) => panic!("Should not fail: {msg}"),
}
}
#[test]
fn test_simulate_caller_listener() {
let lp = ConnectionParams::listener(addr(9010), default_config());
let cp = ConnectionParams::caller(addr(9011), addr(9010), default_config());
let mut listener_factory = SrtConnectionFactory::new(lp);
let mut caller_factory = SrtConnectionFactory::new(cp);
let _ = simulate_handshake(&mut caller_factory, &mut listener_factory, 10);
}
#[test]
fn test_simulate_rendezvous() {
let p_a = ConnectionParams::rendezvous(addr(9020), addr(9021), default_config());
let p_b = ConnectionParams::rendezvous(addr(9021), addr(9020), default_config());
let mut factory_a = SrtConnectionFactory::new(p_a);
let mut factory_b = SrtConnectionFactory::new(p_b);
let _ = simulate_handshake(&mut factory_a, &mut factory_b, 10);
}
#[test]
fn test_total_accepted_zero() {
let p = ConnectionParams::listener(addr(9000), default_config());
let factory = SrtConnectionFactory::new(p);
assert_eq!(factory.total_accepted(), 0);
}
#[test]
fn test_cleanup_pending_no_op() {
let p = ConnectionParams::listener(addr(9000), default_config());
let mut factory = SrtConnectionFactory::new(p);
factory.cleanup_pending(); }
#[test]
fn test_factory_debug() {
let p = ConnectionParams::caller(addr(9000), addr(9001), default_config());
let factory = SrtConnectionFactory::new(p);
let debug = format!("{factory:?}");
assert!(debug.contains("Caller") || debug.contains("caller"));
}
#[test]
fn test_handshake_step_in_progress() {
let step = HandshakeStep::InProgress {
packets_to_send: Vec::new(),
};
assert!(!matches!(step, HandshakeStep::Connected));
}
#[test]
fn test_handshake_step_failed() {
let step = HandshakeStep::Failed("test error".to_owned());
if let HandshakeStep::Failed(msg) = step {
assert_eq!(msg, "test error");
} else {
panic!("Expected Failed variant");
}
}
}