use crate::rtmp::poller::{Interest, Poller, RawHandle};
use crate::rtmp::rtmp_scheduler::{RtmpScheduler, ServerResult};
use crate::rtmp::write_queue::{BackpressureLevel, FlushResult, WriteQueue};
use bytes::Bytes;
use log::{debug, error, info};
use rml_rtmp::chunk_io::ChunkSerializer;
use rml_rtmp::handshake::{Handshake, HandshakeProcessResult, PeerType};
use rml_rtmp::messages::RtmpMessage;
use rml_rtmp::rml_amf0::Amf0Value;
use rml_rtmp::time::RtmpTimestamp;
use std::collections::{HashMap, HashSet};
use std::io::{self, Read};
use std::net::{Shutdown, TcpStream};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
const READ_BUFFER_SIZE: usize = 8192;
const POLL_TIMEOUT_MS: u64 = 100;
const CONNECTION_TIMEOUT_SECS: u64 = 60; const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5; const MAX_READ_PER_POLL: usize = 512 * 1024; const DEFAULT_MAX_CONNECTIONS: usize = 10000; #[cfg(windows)]
const DEFAULT_MAX_CONNECTIONS_WINDOWS: usize = 8000; pub const CHANNEL_HEADROOM: usize = 256;
fn get_fd_limit() -> Option<usize> {
#[cfg(unix)]
{
use std::mem::MaybeUninit;
let mut rlim = MaybeUninit::<libc::rlimit>::uninit();
if unsafe { libc::getrlimit(libc::RLIMIT_NOFILE, rlim.as_mut_ptr()) } == 0 {
let rlim = unsafe { rlim.assume_init() };
return Some(rlim.rlim_cur as usize);
}
None
}
#[cfg(windows)]
{
Some(DEFAULT_MAX_CONNECTIONS_WINDOWS)
}
#[cfg(not(any(unix, windows)))]
{
None
}
}
pub fn effective_max_connections(config_max: Option<usize>) -> usize {
let config_value = config_max.unwrap_or(DEFAULT_MAX_CONNECTIONS);
let result = if let Some(fd_limit) = get_fd_limit() {
let fd_based_limit = (fd_limit as f64 * 0.8) as usize;
config_value.min(fd_based_limit)
} else {
config_value
};
result.max(1)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConnectionToken {
pub id: usize,
pub generation: u32,
}
impl ConnectionToken {
fn new(id: usize, generation: u32) -> Self {
Self { id, generation }
}
#[cfg(target_pointer_width = "64")]
fn to_poller_token(&self) -> usize {
((self.generation as usize) << 32) | (self.id & 0xFFFFFFFF)
}
#[cfg(target_pointer_width = "64")]
fn from_poller_token(token: usize) -> Self {
let id = token & 0xFFFFFFFF;
let generation = (token >> 32) as u32;
Self { id, generation }
}
#[cfg(target_pointer_width = "32")]
fn to_poller_token(&self) -> usize {
self.id
}
#[cfg(target_pointer_width = "32")]
fn from_poller_token(token: usize) -> Self {
Self {
id: token,
generation: 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Handshaking,
Active,
SlowClient,
Closing,
Closed,
}
impl ConnectionState {
#[cfg(test)]
pub fn is_active(&self) -> bool {
matches!(self, ConnectionState::Active | ConnectionState::SlowClient)
}
pub fn can_read(&self) -> bool {
matches!(
self,
ConnectionState::Handshaking | ConnectionState::Active | ConnectionState::SlowClient
)
}
pub fn can_write(&self) -> bool {
matches!(
self,
ConnectionState::Handshaking
| ConnectionState::Active
| ConnectionState::SlowClient
| ConnectionState::Closing
)
}
}
pub struct ReactorConnection {
token: ConnectionToken,
socket: TcpStream,
raw_handle: RawHandle,
state: ConnectionState,
write_queue: WriteQueue,
read_buffer: Vec<u8>,
handshake: Option<Handshake>,
last_read_activity: Instant,
last_write_activity: Instant,
current_interest: Interest,
}
impl ReactorConnection {
pub fn new(token: ConnectionToken, socket: TcpStream) -> io::Result<Self> {
socket.set_nonblocking(true)?;
#[cfg(unix)]
let raw_handle = {
use std::os::unix::io::AsRawFd;
socket.as_raw_fd()
};
#[cfg(windows)]
let raw_handle = {
use std::os::windows::io::AsRawSocket;
socket.as_raw_socket()
};
let now = Instant::now();
Ok(Self {
token,
socket,
raw_handle,
state: ConnectionState::Handshaking,
write_queue: WriteQueue::new(),
read_buffer: vec![0u8; READ_BUFFER_SIZE],
handshake: Some(Handshake::new(PeerType::Server)),
last_read_activity: now,
last_write_activity: now,
current_interest: Interest::READABLE,
})
}
pub fn raw_handle(&self) -> RawHandle {
self.raw_handle
}
pub fn last_activity(&self) -> Instant {
self.last_read_activity.max(self.last_write_activity)
}
pub fn is_timed_out(&self, timeout: Duration) -> bool {
self.last_activity().elapsed() > timeout
}
pub fn enqueue_data(
&mut self,
data: Bytes,
is_keyframe: bool,
is_sequence_header: bool,
is_video: bool,
) -> bool {
let result = self
.write_queue
.enqueue(data, is_keyframe, is_sequence_header, is_video);
match self.write_queue.backpressure_level() {
BackpressureLevel::Critical => {
self.state = ConnectionState::Closing;
return false;
}
BackpressureLevel::High | BackpressureLevel::Warning => {
if self.state == ConnectionState::Active {
self.state = ConnectionState::SlowClient;
}
}
BackpressureLevel::Normal => {
if self.state == ConnectionState::SlowClient {
self.state = ConnectionState::Active;
}
}
}
result
}
pub fn enqueue_raw(&mut self, data: Vec<u8>) -> bool {
if !self
.write_queue
.enqueue(Bytes::from(data), false, false, false)
{
self.state = ConnectionState::Closing;
return false;
}
true
}
pub fn try_flush(&mut self) -> io::Result<bool> {
if self.write_queue.is_empty() {
return Ok(false);
}
match self.write_queue.try_flush(&mut self.socket) {
Ok(FlushResult::Complete { bytes_written }) => {
if bytes_written > 0 {
self.last_write_activity = Instant::now();
}
Ok(false)
}
Ok(FlushResult::WouldBlock { bytes_written }) => {
if bytes_written > 0 {
self.last_write_activity = Instant::now();
}
Ok(false)
}
Ok(FlushResult::Closed) => Ok(true),
Err(e) => {
debug!(
"Connection {} write error: {:?}",
self.token.id, e
);
Err(e)
}
}
}
pub fn try_read(&mut self) -> io::Result<(Vec<u8>, bool)> {
let mut all_data = Vec::new();
loop {
if all_data.len() >= MAX_READ_PER_POLL {
return Ok((all_data, false)); }
match self.socket.read(&mut self.read_buffer) {
Ok(0) => {
return Ok((all_data, true));
}
Ok(n) => {
self.last_read_activity = Instant::now();
all_data.extend_from_slice(&self.read_buffer[..n]);
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok((all_data, false));
}
Err(e) => {
debug!("Connection {} read error: {:?}", self.token.id, e);
return Err(e);
}
}
}
}
pub fn process_handshake(&mut self, data: &[u8]) -> (Option<Vec<u8>>, Option<Vec<u8>>, bool, bool) {
let handshake = match self.handshake.as_mut() {
Some(h) => h,
None => return (Some(data.to_vec()), None, true, false), };
match handshake.process_bytes(data) {
Ok(HandshakeProcessResult::InProgress { response_bytes }) => {
let response = if response_bytes.is_empty() {
None
} else {
Some(response_bytes)
};
(None, response, false, false)
}
Ok(HandshakeProcessResult::Completed {
response_bytes,
remaining_bytes,
}) => {
let response = if response_bytes.is_empty() {
None
} else {
Some(response_bytes)
};
let remaining = if remaining_bytes.is_empty() {
None
} else {
Some(remaining_bytes)
};
self.handshake = None;
self.state = ConnectionState::Active;
(remaining, response, true, false)
}
Err(e) => {
debug!("Connection {} handshake error: {:?}", self.token.id, e);
(None, None, false, true)
}
}
}
pub fn has_pending_writes(&self) -> bool {
!self.write_queue.is_empty()
}
pub fn desired_interest(&self) -> Interest {
let mut interest = if self.state.can_read() {
Interest::READABLE
} else {
Interest {
readable: false,
writable: false,
}
};
if self.has_pending_writes() {
interest = interest.add_writable();
}
interest
}
pub fn mark_closing(&mut self) {
self.state = ConnectionState::Closing;
}
pub fn mark_closed(&mut self) {
self.state = ConnectionState::Closed;
}
pub fn shutdown(&mut self) {
if let Err(e) = self.socket.shutdown(Shutdown::Both) {
debug!("Socket shutdown error (expected if already closed): {:?}", e);
}
self.mark_closed();
}
}
pub struct PublisherState {
pub stream_key: String,
pub receiver: crossbeam_channel::Receiver<Vec<u8>>,
}
pub enum HandleResult {
Disconnect(usize),
}
pub struct Reactor {
poller: Poller,
connections: slab::Slab<ReactorConnection>,
generations: HashMap<usize, u32>,
scheduler: RtmpScheduler,
publishers: slab::Slab<PublisherState>,
stream_keys: dashmap::DashSet<String>,
status: Arc<AtomicUsize>,
max_connections: usize,
pending_flush: HashSet<usize>,
interest_dirty: HashSet<usize>,
#[allow(dead_code)]
conn_ids_buffer: Vec<usize>,
packets_buffer: Vec<(usize, Vec<u8>, bool, bool, bool)>,
ids_to_close_buffer: Vec<usize>,
results_buffer: Vec<HandleResult>,
}
const STATUS_RUN: usize = 1;
const STATUS_END: usize = 2;
impl Reactor {
pub fn new(
gop_limit: usize,
max_connections: Option<usize>,
stream_keys: dashmap::DashSet<String>,
status: Arc<AtomicUsize>,
) -> io::Result<Self> {
let poller = Poller::new()?;
let effective_max = effective_max_connections(max_connections);
Ok(Self {
poller,
connections: slab::Slab::with_capacity(1024),
generations: HashMap::new(),
scheduler: RtmpScheduler::new(gop_limit),
publishers: slab::Slab::with_capacity(64),
stream_keys,
status,
max_connections: effective_max,
pending_flush: HashSet::with_capacity(256),
interest_dirty: HashSet::with_capacity(256),
conn_ids_buffer: Vec::with_capacity(1024),
packets_buffer: Vec::with_capacity(64),
ids_to_close_buffer: Vec::with_capacity(16),
results_buffer: Vec::with_capacity(16),
})
}
pub fn add_connection(&mut self, socket: TcpStream) -> io::Result<ConnectionToken> {
if self.connections.len() >= self.max_connections {
return Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
format!(
"max connections limit reached ({}/{})",
self.connections.len(),
self.max_connections
),
));
}
let entry = self.connections.vacant_entry();
let id = entry.key();
let generation = self.generations.entry(id).or_insert(0);
*generation = generation.wrapping_add(1);
let token = ConnectionToken::new(id, *generation);
let conn = ReactorConnection::new(token, socket)?;
let poller_token = token.to_poller_token();
self.poller
.register(conn.raw_handle(), poller_token, Interest::READABLE)?;
entry.insert(conn);
debug!("Connection {} added (generation {})", id, token.generation);
Ok(token)
}
pub fn remove_connection(&mut self, id: usize) {
if let Some(conn) = self.connections.try_remove(id) {
if let Err(e) = self.poller.deregister(conn.raw_handle()) {
debug!("Failed to deregister connection {} from poller: {:?}", id, e);
}
self.scheduler.notify_connection_closed(id);
debug!(
"Connection {} removed (generation {})",
id, conn.token.generation
);
}
}
pub fn add_publisher(
&mut self,
stream_key: String,
receiver: crossbeam_channel::Receiver<Vec<u8>>,
) -> Option<usize> {
let entry = self.publishers.vacant_entry();
let id = entry.key();
if self.scheduler.new_channel(stream_key.clone(), id) {
self.stream_keys.insert(stream_key.clone());
entry.insert(PublisherState {
stream_key,
receiver,
});
debug!("Publisher {} added", id);
Some(id)
} else {
None
}
}
pub fn remove_publisher(&mut self, id: usize) {
if let Some(pub_state) = self.publishers.try_remove(id) {
self.scheduler.notify_publisher_closed(id);
self.stream_keys.remove(&pub_state.stream_key);
debug!("Publisher {} removed", id);
}
}
fn update_interest(&mut self, id: usize) -> io::Result<()> {
if let Some(conn) = self.connections.get_mut(id) {
let desired = conn.desired_interest();
if desired != conn.current_interest {
self.poller
.modify(conn.raw_handle(), conn.token.to_poller_token(), desired)?;
conn.current_interest = desired;
}
}
Ok(())
}
fn validate_connection(&self, poller_token: usize) -> Option<usize> {
let token = ConnectionToken::from_poller_token(poller_token);
if let Some(conn) = self.connections.get(token.id) {
if conn.token.generation == token.generation {
return Some(token.id);
}
debug!(
"Stale event for connection {}: expected gen {}, got {}",
token.id, conn.token.generation, token.generation
);
}
None
}
fn handle_readable(&mut self, id: usize) -> Vec<HandleResult> {
self.results_buffer.clear();
self.packets_buffer.clear();
self.ids_to_close_buffer.clear();
let (data, should_close) = match self.read_connection_data(id) {
Some(result) => result,
None => return std::mem::take(&mut self.results_buffer),
};
self.process_connection_data(id, &data);
self.write_pending_packets();
for close_id in self.ids_to_close_buffer.drain(..) {
self.results_buffer.push(HandleResult::Disconnect(close_id));
}
if should_close {
self.results_buffer.push(HandleResult::Disconnect(id));
}
std::mem::take(&mut self.results_buffer)
}
fn read_connection_data(
&mut self,
id: usize,
) -> Option<(Vec<u8>, bool)> {
let conn = match self.connections.get_mut(id) {
Some(c) if c.state.can_read() => c,
_ => return None,
};
match conn.try_read() {
Ok((data, close)) => {
if data.is_empty() {
if close {
self.results_buffer.push(HandleResult::Disconnect(id));
}
return None;
}
Some((data, close))
}
Err(_) => {
self.results_buffer.push(HandleResult::Disconnect(id));
None
}
}
}
fn process_connection_data(
&mut self,
id: usize,
data: &[u8],
) {
let conn = match self.connections.get_mut(id) {
Some(c) => c,
None => return,
};
let state = conn.state;
if state == ConnectionState::Handshaking {
self.process_handshake_data(id, data);
} else {
self.process_normal_data(id, data);
}
}
fn process_handshake_data(
&mut self,
id: usize,
data: &[u8],
) {
let conn = match self.connections.get_mut(id) {
Some(c) => c,
None => return,
};
let (remaining, response, completed, error) = conn.process_handshake(data);
if error {
self.results_buffer.push(HandleResult::Disconnect(id));
return;
}
if let Some(resp) = response {
if !conn.enqueue_raw(resp) {
self.results_buffer.push(HandleResult::Disconnect(id));
return;
}
self.pending_flush.insert(id);
self.interest_dirty.insert(id);
}
if completed {
debug!("Connection {} handshake completed", id);
}
if let Some(remaining_data) = remaining {
if !remaining_data.is_empty() {
self.process_scheduler_results(id, &remaining_data);
}
}
}
fn process_normal_data(
&mut self,
id: usize,
data: &[u8],
) {
self.process_scheduler_results(id, data);
}
fn process_scheduler_results(
&mut self,
id: usize,
data: &[u8],
) {
match self.scheduler.bytes_received(id, data) {
Ok(server_results) => {
for result in server_results {
match result {
ServerResult::OutboundPacket {
target_connection_id,
packet,
is_keyframe,
is_sequence_header,
is_video,
} => {
self.packets_buffer.push((
target_connection_id,
packet.bytes,
is_keyframe,
is_sequence_header,
is_video,
));
}
ServerResult::DisconnectConnection {
connection_id: close_id,
} => {
self.ids_to_close_buffer.push(close_id);
}
}
}
}
Err(e) => {
debug!("Connection {} scheduler error: {}", id, e);
self.results_buffer.push(HandleResult::Disconnect(id));
}
}
}
fn write_pending_packets(&mut self) {
let mut enqueued_ids = Vec::new();
for (target_id, data, is_keyframe, is_sequence_header, is_video) in self.packets_buffer.drain(..) {
if let Some(target_conn) = self.connections.get_mut(target_id) {
let enqueued =
target_conn.enqueue_data(Bytes::from(data), is_keyframe, is_sequence_header, is_video);
if enqueued {
enqueued_ids.push(target_id);
} else {
self.ids_to_close_buffer.push(target_id);
}
}
}
for id in enqueued_ids {
self.pending_flush.insert(id);
self.interest_dirty.insert(id);
}
}
fn handle_writable(&mut self, id: usize) -> Option<HandleResult> {
let conn = match self.connections.get_mut(id) {
Some(c) if c.state.can_write() => c,
_ => return None,
};
match conn.try_flush() {
Ok(true) => Some(HandleResult::Disconnect(id)),
Ok(false) => {
if !conn.has_pending_writes() {
self.interest_dirty.insert(id);
}
None
}
Err(_) => Some(HandleResult::Disconnect(id)),
}
}
fn process_publishers(&mut self) -> Vec<usize> {
let mut publisher_ids_to_remove = Vec::new();
let mut packets_to_write = Vec::new();
let mut ids_to_close = Vec::new();
let publisher_ids: Vec<usize> = self.publishers.iter().map(|(id, _)| id).collect();
for pub_id in publisher_ids {
let receiver = {
let pub_state = match self.publishers.get(pub_id) {
Some(p) => p,
None => continue,
};
pub_state.receiver.clone()
};
loop {
match receiver.try_recv() {
Ok(bytes) => {
match self.scheduler.publish_bytes_received(pub_id, bytes) {
Ok(server_results) => {
for result in server_results {
match result {
ServerResult::OutboundPacket {
target_connection_id,
packet,
is_keyframe,
is_sequence_header,
is_video,
} => {
packets_to_write.push((
target_connection_id,
packet.bytes,
is_keyframe,
is_sequence_header,
is_video,
));
}
ServerResult::DisconnectConnection {
connection_id: close_id,
} => {
ids_to_close.push(close_id);
}
}
}
}
Err(e) => {
debug!("Publisher {} scheduler error: {}", pub_id, e);
publisher_ids_to_remove.push(pub_id);
break;
}
}
}
Err(crossbeam_channel::TryRecvError::Empty) => break,
Err(crossbeam_channel::TryRecvError::Disconnected) => {
debug!("Publisher {} disconnected", pub_id);
self.send_delete_stream(pub_id, &mut packets_to_write, &mut ids_to_close);
publisher_ids_to_remove.push(pub_id);
break;
}
}
}
}
let mut enqueued_ids = Vec::new();
for (target_id, data, is_keyframe, is_sequence_header, is_video) in packets_to_write {
if let Some(target_conn) = self.connections.get_mut(target_id) {
let enqueued =
target_conn.enqueue_data(Bytes::from(data), is_keyframe, is_sequence_header, is_video);
if enqueued {
enqueued_ids.push(target_id);
} else {
ids_to_close.push(target_id);
}
}
}
for id in enqueued_ids {
self.pending_flush.insert(id);
self.interest_dirty.insert(id);
}
for close_id in ids_to_close {
self.remove_connection(close_id);
}
publisher_ids_to_remove
}
fn send_delete_stream(
&mut self,
pub_id: usize,
packets: &mut Vec<(usize, Vec<u8>, bool, bool, bool)>,
ids_to_close: &mut Vec<usize>,
) {
let mut arguments = Vec::new();
arguments.push(Amf0Value::Number(1.0));
let delete_stream_cmd = RtmpMessage::Amf0Command {
command_name: "deleteStream".to_string(),
transaction_id: 4.0,
command_object: Amf0Value::Null,
additional_arguments: arguments,
}
.into_message_payload(RtmpTimestamp { value: 0 }, 1);
if let Ok(payload) = delete_stream_cmd {
let mut serializer = ChunkSerializer::new();
if let Ok(packet) = serializer.serialize(&payload, false, true) {
match self.scheduler.publish_bytes_received(pub_id, packet.bytes) {
Ok(server_results) => {
for result in server_results {
match result {
ServerResult::OutboundPacket {
target_connection_id,
packet,
is_keyframe,
is_sequence_header,
is_video,
} => {
packets.push((
target_connection_id,
packet.bytes,
is_keyframe,
is_sequence_header,
is_video,
));
}
ServerResult::DisconnectConnection {
connection_id: close_id,
} => {
ids_to_close.push(close_id);
}
}
}
}
Err(e) => {
log::warn!(
"Failed to process deleteStream command for publisher {}: {:?}",
pub_id, e
);
}
}
}
}
}
fn flush_pending(&mut self) -> Vec<usize> {
let mut ids_to_close = Vec::new();
let pending_ids: Vec<usize> = self.pending_flush.drain().collect();
for id in pending_ids {
if let Some(conn) = self.connections.get_mut(id) {
if conn.has_pending_writes() {
match conn.try_flush() {
Ok(true) | Err(_) => {
ids_to_close.push(id);
}
Ok(false) => {
if conn.has_pending_writes() {
self.pending_flush.insert(id);
} else {
self.interest_dirty.insert(id);
}
}
}
} else {
self.interest_dirty.insert(id);
}
}
}
ids_to_close
}
fn check_timeouts(&mut self) -> Vec<usize> {
let timeout = Duration::from_secs(CONNECTION_TIMEOUT_SECS);
let mut timed_out = Vec::new();
for (id, conn) in self.connections.iter() {
if conn.is_timed_out(timeout) {
debug!("Connection {} timed out", id);
timed_out.push(id);
}
}
timed_out
}
fn update_dirty_interests(&mut self) {
let dirty_ids: Vec<usize> = self.interest_dirty.drain().collect();
for id in dirty_ids {
if let Err(e) = self.update_interest(id) {
log::warn!("Failed to update interest for connection {}: {:?}", id, e);
}
}
}
pub fn run(
&mut self,
connection_receiver: crossbeam_channel::Receiver<TcpStream>,
publisher_receiver: crossbeam_channel::Receiver<(
String,
crossbeam_channel::Receiver<Vec<u8>>,
)>,
) {
info!("Reactor started");
let poll_timeout = Duration::from_millis(POLL_TIMEOUT_MS);
loop {
if self.status.load(Ordering::Acquire) == STATUS_END {
info!("Reactor received stop signal");
break;
}
while let Ok(socket) = connection_receiver.try_recv() {
match self.add_connection(socket) {
Ok(token) => {
debug!("New connection added: {:?}", token);
}
Err(e) => {
error!("Failed to add connection: {:?}", e);
}
}
}
while let Ok((stream_key, receiver)) = publisher_receiver.try_recv() {
if self.add_publisher(stream_key.clone(), receiver).is_some() {
debug!("New publisher added for stream: {}", stream_key);
}
}
let events = match self.poller.poll(Some(poll_timeout)) {
Ok(events) => events,
Err(e) => {
error!("Poller error: {:?}", e);
continue;
}
};
let mut ids_to_close = Vec::new();
for event in events {
let poller_token = event.token;
let Some(id) = self.validate_connection(poller_token) else {
continue;
};
if event.is_error() || event.is_hangup() {
ids_to_close.push(id);
continue;
}
if event.is_readable() {
let results = self.handle_readable(id);
for result in results {
let HandleResult::Disconnect(close_id) = result;
ids_to_close.push(close_id);
}
}
if event.is_writable() {
if let Some(HandleResult::Disconnect(close_id)) = self.handle_writable(id) {
ids_to_close.push(close_id);
}
}
}
let publisher_ids_to_remove = self.process_publishers();
for pub_id in publisher_ids_to_remove {
self.remove_publisher(pub_id);
}
let flush_closes = self.flush_pending();
ids_to_close.extend(flush_closes);
self.update_dirty_interests();
let timed_out = self.check_timeouts();
ids_to_close.extend(timed_out);
ids_to_close.sort_unstable();
ids_to_close.dedup();
for id in ids_to_close {
self.remove_connection(id);
}
}
self.graceful_shutdown();
info!("Reactor stopped");
}
fn graceful_shutdown(&mut self) {
info!("Starting graceful shutdown...");
let deadline = Instant::now() + Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
for (_, conn) in self.connections.iter_mut() {
conn.mark_closing();
}
while Instant::now() < deadline {
let mut all_flushed = true;
for (_, conn) in self.connections.iter_mut() {
if conn.has_pending_writes() {
all_flushed = false;
if let Err(e) = conn.try_flush() {
debug!("Failed to flush connection during shutdown: {:?}", e);
}
}
}
if all_flushed {
break;
}
std::thread::sleep(Duration::from_millis(10));
}
for (_, conn) in self.connections.iter_mut() {
conn.shutdown();
}
info!("Graceful shutdown complete");
}
#[cfg(test)]
pub fn is_interest_dirty(&self, id: usize) -> bool {
self.interest_dirty.contains(&id)
}
#[cfg(test)]
pub fn drain_interest_dirty(&mut self) -> Vec<usize> {
self.interest_dirty.drain().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_state_transitions() {
assert!(ConnectionState::Handshaking.can_read());
assert!(ConnectionState::Handshaking.can_write());
assert!(!ConnectionState::Handshaking.is_active());
assert!(ConnectionState::Active.can_read());
assert!(ConnectionState::Active.can_write());
assert!(ConnectionState::Active.is_active());
assert!(ConnectionState::SlowClient.is_active());
assert!(!ConnectionState::Closing.can_read());
assert!(ConnectionState::Closing.can_write());
assert!(!ConnectionState::Closed.can_read());
assert!(!ConnectionState::Closed.can_write());
}
#[test]
fn test_interest_desired() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let client = TcpStream::connect(addr).expect("Failed to connect");
let token = ConnectionToken::new(0, 1);
let conn = ReactorConnection::new(token, client).expect("Failed to create connection");
assert_eq!(conn.desired_interest(), Interest::READABLE);
}
#[test]
fn test_graceful_shutdown_flushes_data() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let client = TcpStream::connect(addr).expect("Failed to connect");
let (server_socket, _) = listener.accept().expect("Failed to accept");
let token = ConnectionToken::new(0, 1);
let mut conn = ReactorConnection::new(token, server_socket).expect("Failed to create connection");
conn.state = ConnectionState::Active;
let test_data = b"Hello, World!";
conn.enqueue_data(Bytes::from_static(test_data), false, false, false);
assert!(conn.has_pending_writes());
let _ = conn.try_flush();
client.set_nonblocking(false).expect("Failed to set blocking");
let mut buf = vec![0u8; 100];
use std::time::Duration;
client.set_read_timeout(Some(Duration::from_millis(100))).expect("Failed to set timeout");
match client.peek(&mut buf) {
Ok(n) if n > 0 => {
assert!(n >= test_data.len());
}
_ => {
}
}
}
#[test]
fn test_connection_timeout_detection() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let client = TcpStream::connect(addr).expect("Failed to connect");
let token = ConnectionToken::new(0, 1);
let conn = ReactorConnection::new(token, client).expect("Failed to create connection");
assert!(!conn.is_timed_out(Duration::from_secs(60)));
assert!(conn.is_timed_out(Duration::from_nanos(1)));
}
#[test]
fn test_reactor_creation() {
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let reactor = Reactor::new(3, None, stream_keys, status);
assert!(reactor.is_ok());
}
#[test]
fn test_connection_generation_increments() {
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let client1 = TcpStream::connect(addr).expect("Failed to connect");
let (server1, _) = listener.accept().expect("Failed to accept");
let token1 = reactor.add_connection(server1).expect("Failed to add connection");
reactor.remove_connection(token1.id);
let client2 = TcpStream::connect(addr).expect("Failed to connect");
let (server2, _) = listener.accept().expect("Failed to accept");
let token2 = reactor.add_connection(server2).expect("Failed to add connection");
assert_eq!(token1.id, token2.id);
assert_eq!(token2.generation, token1.generation + 1);
drop(client1);
drop(client2);
}
#[test]
fn test_token_validation() {
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let client = TcpStream::connect(addr).expect("Failed to connect");
let (server, _) = listener.accept().expect("Failed to accept");
let token = reactor.add_connection(server).expect("Failed to add connection");
assert!(reactor.validate_connection(token.to_poller_token()).is_some());
reactor.remove_connection(token.id);
assert!(reactor.validate_connection(token.to_poller_token()).is_none());
drop(client);
}
#[test]
#[cfg(target_pointer_width = "64")]
fn test_generation_prevents_aba_problem() {
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let client_a = TcpStream::connect(addr).expect("Failed to connect A");
let (server_a, _) = listener.accept().expect("Failed to accept A");
let token_a = reactor.add_connection(server_a).expect("Failed to add connection A");
let stale_poller_token = token_a.to_poller_token();
reactor.remove_connection(token_a.id);
drop(client_a);
let client_b = TcpStream::connect(addr).expect("Failed to connect B");
let (server_b, _) = listener.accept().expect("Failed to accept B");
let token_b = reactor.add_connection(server_b).expect("Failed to add connection B");
assert!(reactor.validate_connection(token_b.to_poller_token()).is_some());
assert!(reactor.validate_connection(stale_poller_token).is_none());
assert_eq!(token_a.id, token_b.id); assert_ne!(token_a.generation, token_b.generation);
reactor.remove_connection(token_b.id);
drop(client_b);
}
#[test]
fn test_many_connections_creation() {
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let num_connections = 100;
let mut clients = Vec::new();
let mut tokens = Vec::new();
for i in 0..num_connections {
let client = TcpStream::connect(addr).expect(&format!("Failed to connect {}", i));
let (server, _) = listener.accept().expect(&format!("Failed to accept {}", i));
let token = reactor.add_connection(server).expect(&format!("Failed to add connection {}", i));
clients.push(client);
tokens.push(token);
}
assert_eq!(reactor.connections.len(), num_connections);
for token in &tokens {
reactor.remove_connection(token.id);
}
assert_eq!(reactor.connections.len(), 0);
}
#[test]
#[ignore] fn perf_connection_scaling() {
use std::time::Instant;
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let max_fd = effective_max_connections(None);
let num_connections = (max_fd / 3).min(1000);
let mut clients = Vec::with_capacity(num_connections);
let mut tokens = Vec::with_capacity(num_connections);
let start = Instant::now();
for i in 0..num_connections {
let client = TcpStream::connect(addr).unwrap_or_else(|_| panic!("Failed to connect {}", i));
let (server, _) = listener.accept().unwrap_or_else(|_| panic!("Failed to accept {}", i));
let token = reactor.add_connection(server).unwrap_or_else(|_| panic!("Failed to add {}", i));
clients.push(client);
tokens.push(token);
}
let connect_time = start.elapsed();
assert_eq!(reactor.connections.len(), num_connections);
let cleanup_start = Instant::now();
for token in &tokens {
reactor.remove_connection(token.id);
}
let cleanup_time = cleanup_start.elapsed();
println!();
println!("╔══════════════════════════════════════════════════════════╗");
println!("║ RTMP Performance Test: Connection Scaling ║");
println!("╠══════════════════════════════════════════════════════════╣");
println!("║ Platform: {:>40} ║", std::env::consts::OS);
println!("║ Arch: {:>40} ║", std::env::consts::ARCH);
println!("║ Connections: {:>40} ║", num_connections);
println!("╠══════════════════════════════════════════════════════════╣");
println!("║ Connect time: {:>37?} ║", connect_time);
println!("║ Per connection: {:>37?} ║", connect_time / num_connections as u32);
println!("║ Cleanup time: {:>37?} ║", cleanup_time);
println!("║ Per cleanup: {:>37?} ║", cleanup_time / num_connections as u32);
println!("╚══════════════════════════════════════════════════════════╝");
println!();
}
#[test]
#[ignore] fn perf_read_throughput() {
use std::time::Instant;
use std::io::Write;
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let mut client = TcpStream::connect(addr).expect("Failed to connect");
let (server, _) = listener.accept().expect("Failed to accept");
client.set_nodelay(true).ok();
let token = reactor.add_connection(server).expect("Failed to add connection");
let test_sizes = [128, 1024, 4096, 8192, 16384, 65536];
let iterations = 100;
println!();
println!("╔══════════════════════════════════════════════════════════╗");
println!("║ RTMP Performance Test: Read Throughput ║");
println!("╠══════════════════════════════════════════════════════════╣");
println!("║ Platform: {:>40} ║", std::env::consts::OS);
println!("║ Arch: {:>40} ║", std::env::consts::ARCH);
println!("║ Iterations: {:>40} ║", iterations);
println!("╠══════════════════════════════════════════════════════════╣");
for &size in &test_sizes {
let data = vec![0xABu8; size];
let mut total_bytes = 0usize;
let start = Instant::now();
for _ in 0..iterations {
client.write_all(&data).expect("Failed to write");
client.flush().expect("Failed to flush");
total_bytes += size;
std::thread::sleep(std::time::Duration::from_micros(100));
if let Some(conn) = reactor.connections.get_mut(token.id) {
let _ = conn.try_read();
}
}
let elapsed = start.elapsed();
let throughput_mbps = (total_bytes as f64 / 1_000_000.0) / elapsed.as_secs_f64();
println!("║ Chunk {:>6} B: {:>8.2} MB/s ({:>6} B x {:>3}) ║",
size, throughput_mbps, size, iterations);
}
println!("╚══════════════════════════════════════════════════════════╝");
println!();
reactor.remove_connection(token.id);
}
#[test]
fn test_handle_writable_marks_interest_dirty_on_queue_drain() {
use std::io::Read;
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let mut client = TcpStream::connect(addr).expect("Failed to connect");
let (server, _) = listener.accept().expect("Failed to accept");
client.set_nonblocking(true).ok();
let token = reactor.add_connection(server).expect("Failed to add connection");
if let Some(conn) = reactor.connections.get_mut(token.id) {
conn.state = ConnectionState::Active;
}
let test_data = b"Hello";
if let Some(conn) = reactor.connections.get_mut(token.id) {
conn.enqueue_data(Bytes::from_static(test_data), false, false, false);
assert!(conn.has_pending_writes());
}
reactor.drain_interest_dirty();
let result = reactor.handle_writable(token.id);
assert!(result.is_none(), "Connection should not be closed");
if let Some(conn) = reactor.connections.get(token.id) {
assert!(!conn.has_pending_writes(), "Queue should be drained");
}
assert!(reactor.is_interest_dirty(token.id),
"interest_dirty should contain connection ID after queue drain");
let mut buf = vec![0u8; 100];
client.set_nonblocking(false).ok();
client.set_read_timeout(Some(std::time::Duration::from_millis(100))).ok();
let _ = client.read(&mut buf);
reactor.remove_connection(token.id);
}
#[test]
fn test_flush_pending_marks_interest_dirty_on_queue_drain() {
use std::io::Read;
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let mut client = TcpStream::connect(addr).expect("Failed to connect");
let (server, _) = listener.accept().expect("Failed to accept");
client.set_nonblocking(true).ok();
let token = reactor.add_connection(server).expect("Failed to add connection");
if let Some(conn) = reactor.connections.get_mut(token.id) {
conn.state = ConnectionState::Active;
}
let test_data = b"World";
if let Some(conn) = reactor.connections.get_mut(token.id) {
conn.enqueue_data(Bytes::from_static(test_data), false, false, false);
}
reactor.pending_flush.insert(token.id);
reactor.drain_interest_dirty();
let ids_to_close = reactor.flush_pending();
assert!(ids_to_close.is_empty(), "No connections should need closing");
assert!(reactor.is_interest_dirty(token.id),
"interest_dirty should contain connection ID after flush_pending drains queue");
let mut buf = vec![0u8; 100];
client.set_nonblocking(false).ok();
client.set_read_timeout(Some(std::time::Duration::from_millis(100))).ok();
let _ = client.read(&mut buf);
reactor.remove_connection(token.id);
}
#[test]
fn test_flush_pending_marks_interest_dirty_when_no_pending_writes() {
let stream_keys = dashmap::DashSet::new();
let status = Arc::new(AtomicUsize::new(STATUS_RUN));
let mut reactor = Reactor::new(3, None, stream_keys, status).expect("Failed to create reactor");
let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind");
let addr = listener.local_addr().expect("Failed to get address");
let _client = TcpStream::connect(addr).expect("Failed to connect");
let (server, _) = listener.accept().expect("Failed to accept");
let token = reactor.add_connection(server).expect("Failed to add connection");
if let Some(conn) = reactor.connections.get_mut(token.id) {
conn.state = ConnectionState::Active;
assert!(!conn.has_pending_writes());
}
reactor.pending_flush.insert(token.id);
reactor.drain_interest_dirty();
let ids_to_close = reactor.flush_pending();
assert!(ids_to_close.is_empty(), "No connections should need closing");
assert!(reactor.is_interest_dirty(token.id),
"interest_dirty should be marked even when no pending writes (to clear WRITABLE interest)");
reactor.remove_connection(token.id);
}
}