use super::packet::{ControlPacket, ControlType, DataPacket, HandshakeInfo, SrtPacket};
use crate::error::{NetError, NetResult};
use bytes::Bytes;
use std::collections::VecDeque;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Initial,
Handshaking,
Connected,
Closing,
Closed,
Broken,
}
impl ConnectionState {
#[must_use]
pub const fn is_connected(&self) -> bool {
matches!(self, Self::Connected)
}
#[must_use]
pub const fn is_finished(&self) -> bool {
matches!(self, Self::Closed | Self::Broken)
}
}
#[derive(Debug, Clone)]
pub struct SrtConfig {
pub mtu: u32,
pub flow_window: u32,
pub latency_ms: u32,
pub peer_latency_ms: u32,
pub too_late_drop: bool,
pub connect_timeout: Duration,
pub peer_idle_timeout: Duration,
pub max_bandwidth: u64,
pub key_size: u8,
pub stream_id: Option<String>,
pub passphrase: Option<String>,
}
impl Default for SrtConfig {
fn default() -> Self {
Self {
mtu: 1500,
flow_window: 8192,
latency_ms: 120,
peer_latency_ms: 0,
too_late_drop: true,
connect_timeout: Duration::from_secs(3),
peer_idle_timeout: Duration::from_secs(5),
max_bandwidth: 0,
key_size: 0,
stream_id: None,
passphrase: None,
}
}
}
impl SrtConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_latency(mut self, latency_ms: u32) -> Self {
self.latency_ms = latency_ms;
self
}
#[must_use]
pub const fn with_mtu(mut self, mtu: u32) -> Self {
self.mtu = mtu;
self
}
#[must_use]
pub fn with_stream_id(mut self, stream_id: impl Into<String>) -> Self {
self.stream_id = Some(stream_id.into());
self
}
#[must_use]
pub fn with_passphrase(mut self, passphrase: impl Into<String>) -> Self {
self.passphrase = Some(passphrase.into());
self.key_size = 16; self
}
#[must_use]
pub const fn with_key_size(mut self, key_size: u8) -> Self {
self.key_size = key_size;
self
}
}
#[derive(Debug, Clone)]
struct UnackedPacket {
packet: DataPacket,
sent_at: Instant,
retransmit_count: u32,
}
#[derive(Debug, Clone, Default)]
pub struct SrtStats {
pub packets_sent: u64,
pub packets_received: u64,
pub packets_retransmitted: u64,
pub packets_lost: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub send_buffer_size: usize,
pub recv_buffer_size: usize,
}
#[derive(Debug)]
pub struct SrtSocket {
socket_id: u32,
peer_socket_id: u32,
state: ConnectionState,
config: SrtConfig,
pub(crate) send_seq: u32,
recv_seq: u32,
last_ack_sent: u32,
last_ack_recv: u32,
unacked_packets: VecDeque<UnackedPacket>,
recv_buffer: VecDeque<DataPacket>,
rtt: u32,
rtt_var: u32,
last_activity: Instant,
start_time: Instant,
handshake: HandshakeInfo,
stats: SrtStats,
}
impl SrtSocket {
#[must_use]
pub fn new(config: SrtConfig) -> Self {
let now = Instant::now();
Self {
socket_id: rand_socket_id(),
peer_socket_id: 0,
state: ConnectionState::Initial,
config,
send_seq: rand_initial_seq(),
recv_seq: 0,
last_ack_sent: 0,
last_ack_recv: 0,
unacked_packets: VecDeque::new(),
recv_buffer: VecDeque::new(),
rtt: 100_000, rtt_var: 50_000,
last_activity: now,
start_time: now,
handshake: HandshakeInfo::new(),
stats: SrtStats::default(),
}
}
#[must_use]
pub const fn socket_id(&self) -> u32 {
self.socket_id
}
#[must_use]
pub const fn peer_socket_id(&self) -> u32 {
self.peer_socket_id
}
#[must_use]
pub const fn state(&self) -> ConnectionState {
self.state
}
#[must_use]
pub const fn is_connected(&self) -> bool {
self.state.is_connected()
}
#[must_use]
pub const fn config(&self) -> &SrtConfig {
&self.config
}
#[must_use]
pub const fn rtt(&self) -> u32 {
self.rtt
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
#[must_use]
pub fn current_timestamp(&self) -> u32 {
self.start_time.elapsed().as_micros() as u32
}
#[must_use]
pub fn generate_caller_handshake(&mut self) -> SrtPacket {
self.handshake = HandshakeInfo {
version: 5,
mtu: self.config.mtu,
flow_window: self.config.flow_window,
handshake_type: HandshakeInfo::TYPE_WAVEAHAND,
socket_id: self.socket_id,
initial_seq: self.send_seq,
..Default::default()
};
self.state = ConnectionState::Handshaking;
SrtPacket::Control(ControlPacket::handshake(&self.handshake, 0))
}
pub fn process_packet(&mut self, packet: SrtPacket) -> NetResult<Vec<SrtPacket>> {
self.last_activity = Instant::now();
let mut responses = Vec::new();
match packet {
SrtPacket::Data(data) => {
if !self.is_connected() {
return Err(NetError::invalid_state("Not connected"));
}
self.process_data_packet(data, &mut responses)?;
}
SrtPacket::Control(ctrl) => {
self.process_control_packet(ctrl, &mut responses)?;
}
}
Ok(responses)
}
fn process_data_packet(
&mut self,
packet: DataPacket,
responses: &mut Vec<SrtPacket>,
) -> NetResult<()> {
let seq = packet.sequence_number;
let payload_len = packet.payload.len() as u64;
self.stats.packets_received += 1;
self.stats.bytes_received += payload_len;
if seq == self.recv_seq {
self.recv_seq = seq.wrapping_add(1);
self.recv_buffer.push_back(packet);
while let Some(buffered) = self.recv_buffer.front() {
if buffered.sequence_number == self.recv_seq {
self.recv_seq = self.recv_seq.wrapping_add(1);
self.recv_buffer.pop_front();
} else {
break;
}
}
} else if seq_after(seq, self.recv_seq) {
self.recv_buffer.push_back(packet);
}
self.stats.recv_buffer_size = self.recv_buffer.len();
if self.recv_seq != self.last_ack_sent {
let ack = ControlPacket::ack(self.recv_seq, self.peer_socket_id)
.with_timestamp(self.current_timestamp());
responses.push(SrtPacket::Control(ack));
self.last_ack_sent = self.recv_seq;
}
Ok(())
}
fn process_control_packet(
&mut self,
packet: ControlPacket,
responses: &mut Vec<SrtPacket>,
) -> NetResult<()> {
match packet.control_type {
ControlType::Handshake => {
self.process_handshake(&packet, responses)?;
}
ControlType::Keepalive => {
let keepalive = ControlPacket::keepalive(self.peer_socket_id)
.with_timestamp(self.current_timestamp());
responses.push(SrtPacket::Control(keepalive));
}
ControlType::Ack => {
let ack_seq = packet.type_info;
self.last_ack_recv = ack_seq;
while let Some(front) = self.unacked_packets.front() {
if seq_after(ack_seq, front.packet.sequence_number) {
self.unacked_packets.pop_front();
} else {
break;
}
}
let ack_ack = ControlPacket::new(ControlType::AckAck)
.with_timestamp(self.current_timestamp());
responses.push(SrtPacket::Control(ack_ack));
}
ControlType::Nak => {
self.handle_nak(&packet)?;
}
ControlType::Shutdown => {
self.state = ConnectionState::Closed;
}
_ => {
}
}
Ok(())
}
fn process_handshake(
&mut self,
packet: &ControlPacket,
responses: &mut Vec<SrtPacket>,
) -> NetResult<()> {
let hs = HandshakeInfo::decode(&packet.payload)?;
match self.state {
ConnectionState::Initial => {
self.peer_socket_id = hs.socket_id;
self.recv_seq = hs.initial_seq;
let response = HandshakeInfo {
version: 5,
mtu: self.config.mtu.min(hs.mtu),
flow_window: self.config.flow_window.min(hs.flow_window),
handshake_type: HandshakeInfo::TYPE_INDUCTION,
socket_id: self.socket_id,
initial_seq: self.send_seq,
syn_cookie: generate_cookie(),
..Default::default()
};
responses.push(SrtPacket::Control(ControlPacket::handshake(
&response,
self.peer_socket_id,
)));
self.state = ConnectionState::Handshaking;
}
ConnectionState::Handshaking => {
if hs.handshake_type == HandshakeInfo::TYPE_INDUCTION
|| hs.handshake_type == HandshakeInfo::TYPE_CONCLUSION
{
self.peer_socket_id = hs.socket_id;
self.recv_seq = hs.initial_seq;
self.config.mtu = self.config.mtu.min(hs.mtu);
self.config.flow_window = self.config.flow_window.min(hs.flow_window);
if hs.handshake_type == HandshakeInfo::TYPE_INDUCTION {
let conclusion = HandshakeInfo {
version: 5,
mtu: self.config.mtu,
flow_window: self.config.flow_window,
handshake_type: HandshakeInfo::TYPE_CONCLUSION,
socket_id: self.socket_id,
initial_seq: self.send_seq,
syn_cookie: hs.syn_cookie,
..Default::default()
};
responses.push(SrtPacket::Control(ControlPacket::handshake(
&conclusion,
self.peer_socket_id,
)));
}
self.state = ConnectionState::Connected;
} else if hs.handshake_type == HandshakeInfo::TYPE_AGREEMENT {
self.state = ConnectionState::Connected;
}
}
_ => {}
}
Ok(())
}
fn handle_nak(&mut self, _packet: &ControlPacket) -> NetResult<()> {
Ok(())
}
#[must_use]
pub fn create_data_packet(&mut self, payload: Bytes) -> DataPacket {
let seq = self.send_seq;
self.send_seq = self.send_seq.wrapping_add(1);
self.stats.packets_sent += 1;
self.stats.bytes_sent += payload.len() as u64;
DataPacket::new(seq, payload)
.with_timestamp(self.current_timestamp())
.with_dst_socket(self.peer_socket_id)
}
pub fn close(&mut self) -> Option<SrtPacket> {
if self.state.is_connected() {
self.state = ConnectionState::Closing;
Some(SrtPacket::Control(ControlPacket::shutdown(
self.peer_socket_id,
)))
} else {
self.state = ConnectionState::Closed;
None
}
}
#[must_use]
pub fn check_timeout(&self) -> bool {
self.last_activity.elapsed() > self.config.peer_idle_timeout
}
#[must_use]
pub fn stats(&self) -> &SrtStats {
&self.stats
}
pub fn update_rtt(&mut self, sample: u32) {
if self.rtt == 0 {
self.rtt = sample;
self.rtt_var = sample / 2;
} else {
let diff = if sample > self.rtt {
sample - self.rtt
} else {
self.rtt - sample
};
self.rtt_var = (3 * self.rtt_var + diff) / 4;
self.rtt = (7 * self.rtt + sample) / 8;
}
}
pub fn mark_for_retransmit(&mut self, seq: u32) {
for entry in &mut self.unacked_packets {
if entry.packet.sequence_number == seq {
entry.retransmit_count += 1;
self.stats.packets_retransmitted += 1;
break;
}
}
}
}
const fn seq_after(a: u32, b: u32) -> bool {
let diff = a.wrapping_sub(b);
diff > 0 && diff < 0x8000_0000
}
fn rand_socket_id() -> u32 {
let seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u32)
.unwrap_or(12345);
seed ^ 0xDEAD_BEEF
}
fn rand_initial_seq() -> u32 {
let seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u32)
.unwrap_or(54321);
seed & 0x7FFF_FFFF
}
fn generate_cookie() -> u32 {
let seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u32)
.unwrap_or(0);
seed ^ 0xCAFE_BABE
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_state() {
assert!(ConnectionState::Connected.is_connected());
assert!(!ConnectionState::Initial.is_connected());
assert!(ConnectionState::Closed.is_finished());
assert!(ConnectionState::Broken.is_finished());
}
#[test]
fn test_srt_config() {
let config = SrtConfig::new()
.with_latency(200)
.with_mtu(1400)
.with_stream_id("mystream");
assert_eq!(config.latency_ms, 200);
assert_eq!(config.mtu, 1400);
assert_eq!(config.stream_id, Some("mystream".to_string()));
}
#[test]
fn test_srt_socket_new() {
let socket = SrtSocket::new(SrtConfig::default());
assert_eq!(socket.state(), ConnectionState::Initial);
assert!(!socket.is_connected());
}
#[test]
fn test_caller_handshake() {
let mut socket = SrtSocket::new(SrtConfig::default());
let packet = socket.generate_caller_handshake();
assert_eq!(socket.state(), ConnectionState::Handshaking);
assert!(packet.is_control());
}
#[test]
fn test_create_data_packet() {
let mut socket = SrtSocket::new(SrtConfig::default());
socket.state = ConnectionState::Connected;
socket.peer_socket_id = 100;
let packet1 = socket.create_data_packet(Bytes::from(vec![1, 2, 3]));
let packet2 = socket.create_data_packet(Bytes::from(vec![4, 5, 6]));
assert_eq!(packet2.sequence_number, packet1.sequence_number + 1);
assert_eq!(packet1.dst_socket_id, 100);
}
#[test]
fn test_seq_after() {
assert!(seq_after(10, 5));
assert!(!seq_after(5, 10));
assert!(!seq_after(5, 5));
assert!(seq_after(0, 0xFFFF_FFFF));
}
#[test]
fn test_close() {
let mut socket = SrtSocket::new(SrtConfig::default());
socket.state = ConnectionState::Connected;
socket.peer_socket_id = 42;
let packet = socket.close();
assert!(packet.is_some());
assert_eq!(socket.state(), ConnectionState::Closing);
if let Some(SrtPacket::Control(ctrl)) = packet {
assert_eq!(ctrl.control_type, ControlType::Shutdown);
}
}
}