use std::collections::VecDeque;
use std::io;
use std::net::SocketAddr;
use std::os::fd::RawFd;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use crate::accumulator::AccumulatorTable;
use crate::buffer::send_copy::SendCopyPool;
use crate::config::Config;
use crate::connection::{ConnectionTable, RecvMode};
use crate::disk_io_pool::DiskIoPool;
use crate::handler::{ConnSendState, DriverCtx};
use mio::Interest;
pub(crate) const WAKE_TOKEN: mio::Token = mio::Token(0);
pub(crate) type PendingSend = (Vec<u8>, usize);
pub(crate) struct Driver {
pub(crate) connections: ConnectionTable,
pub(crate) accumulators: AccumulatorTable,
pub(crate) send_copy_pool: SendCopyPool,
pub(crate) send_queues: Vec<ConnSendState>,
pub(crate) accept_rx: Option<crossbeam_channel::Receiver<(RawFd, SocketAddr)>>,
pub(crate) wake_handle: crate::wakeup::WakeHandle,
pub(crate) shutdown_flag: Arc<AtomicBool>,
pub(crate) shutdown_local: bool,
pub(crate) tls_table: Option<crate::tls::TlsTable>,
pub(crate) connect_addrs: Vec<libc::sockaddr_storage>,
pub(crate) poll: mio::Poll,
pub(crate) events: mio::Events,
pub(crate) resolve_rx: Option<crossbeam_channel::Receiver<crate::resolver::ResolveResponse>>,
pub(crate) resolve_tx: Option<crossbeam_channel::Sender<crate::resolver::ResolveResponse>>,
pub(crate) resolver: Option<Arc<crate::resolver::ResolverPool>>,
pub(crate) spawn_rx: Option<crossbeam_channel::Receiver<crate::spawner::SpawnResponse>>,
pub(crate) spawn_tx: Option<crossbeam_channel::Sender<crate::spawner::SpawnResponse>>,
pub(crate) spawner: Option<Arc<crate::spawner::SpawnerPool>>,
pub(crate) blocking_rx: Option<crossbeam_channel::Receiver<crate::blocking::BlockingResponse>>,
pub(crate) blocking_tx: Option<crossbeam_channel::Sender<crate::blocking::BlockingResponse>>,
pub(crate) blocking_pool: Option<Arc<crate::blocking::BlockingPool>>,
pub(crate) tcp_streams: Vec<Option<mio::net::TcpStream>>,
pub(crate) pending_sends: Vec<VecDeque<PendingSend>>,
pub(crate) writable: Vec<bool>,
pub(crate) connect_deadlines: Vec<Option<std::time::Instant>>,
pub(crate) tls_scratch: Vec<u8>,
pub(crate) wake_pipe_fd: RawFd,
pub(crate) tcp_nodelay: bool,
pub(crate) send_completions: Vec<VecDeque<u32>>,
pub(crate) udp_sockets: Vec<mio::net::UdpSocket>,
pub(crate) udp_token_base: usize,
pub(crate) disk_io_rx: Option<crossbeam_channel::Receiver<crate::disk_io_pool::DiskIoResponse>>,
pub(crate) disk_io_tx: Option<crossbeam_channel::Sender<crate::disk_io_pool::DiskIoResponse>>,
pub(crate) disk_io_pool: Option<Arc<DiskIoPool>>,
pub(crate) next_disk_io_seq: u32,
pub(crate) direct_io_files: Option<crate::direct_io::DirectIoFileTable>,
pub(crate) direct_io_fds: Vec<Option<RawFd>>,
pub(crate) fs_files: Option<crate::fs::FsFileTable>,
pub(crate) fs_fds: Vec<Option<RawFd>>,
pub(crate) pending_fs_opens: std::collections::HashMap<u32, u16>,
}
impl Driver {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
config: &Config,
accept_rx: Option<crossbeam_channel::Receiver<(RawFd, SocketAddr)>>,
eventfd: RawFd,
shutdown_flag: Arc<AtomicBool>,
resolve_rx: Option<crossbeam_channel::Receiver<crate::resolver::ResolveResponse>>,
resolve_tx: Option<crossbeam_channel::Sender<crate::resolver::ResolveResponse>>,
resolver: Option<Arc<crate::resolver::ResolverPool>>,
spawn_rx: Option<crossbeam_channel::Receiver<crate::spawner::SpawnResponse>>,
spawn_tx: Option<crossbeam_channel::Sender<crate::spawner::SpawnResponse>>,
spawner: Option<Arc<crate::spawner::SpawnerPool>>,
blocking_rx: Option<crossbeam_channel::Receiver<crate::blocking::BlockingResponse>>,
blocking_tx: Option<crossbeam_channel::Sender<crate::blocking::BlockingResponse>>,
blocking_pool: Option<Arc<crate::blocking::BlockingPool>>,
disk_io_rx: Option<crossbeam_channel::Receiver<crate::disk_io_pool::DiskIoResponse>>,
disk_io_tx: Option<crossbeam_channel::Sender<crate::disk_io_pool::DiskIoResponse>>,
disk_io_pool: Option<Arc<DiskIoPool>>,
) -> io::Result<Self> {
let max_conn = config.max_connections as usize;
let poll = mio::Poll::new()?;
let events = mio::Events::with_capacity(1024);
let tls_table = {
let server_config = config.tls.as_ref().map(|t| t.server_config.clone());
let client_config = config.tls_client.as_ref().map(|t| t.client_config.clone());
if server_config.is_some() || client_config.is_some() {
Some(crate::tls::TlsTable::new(
config.max_connections,
server_config,
client_config,
))
} else {
None
}
};
let udp_token_base = max_conn + 2;
let mut udp_sockets = Vec::with_capacity(config.udp_bind.len());
for (i, addr) in config.udp_bind.iter().enumerate() {
let std_socket = std::net::UdpSocket::bind(addr)
.map_err(|e| io::Error::new(e.kind(), format!("UDP bind {addr}: {e}")))?;
std_socket.set_nonblocking(true)?;
let mut mio_socket = mio::net::UdpSocket::from_std(std_socket);
poll.registry().register(
&mut mio_socket,
mio::Token(udp_token_base + i),
Interest::READABLE,
)?;
udp_sockets.push(mio_socket);
}
Ok(Driver {
connections: ConnectionTable::new(config.max_connections),
accumulators: AccumulatorTable::new(
config.max_connections,
config.recv_buffer.buffer_size as usize,
),
send_copy_pool: SendCopyPool::new(config.send_copy_count, config.send_copy_slot_size),
send_queues: (0..max_conn).map(|_| ConnSendState::new()).collect(),
accept_rx,
wake_handle: crate::wakeup::WakeHandle::from_raw_fd(eventfd),
shutdown_flag,
shutdown_local: false,
tls_table,
connect_addrs: vec![unsafe { std::mem::zeroed() }; max_conn],
poll,
events,
resolve_rx,
resolve_tx,
resolver,
spawn_rx,
spawn_tx,
spawner,
blocking_rx,
blocking_tx,
blocking_pool,
tcp_streams: (0..max_conn).map(|_| None).collect(),
pending_sends: (0..max_conn).map(|_| VecDeque::new()).collect(),
writable: vec![false; max_conn],
connect_deadlines: vec![None; max_conn],
tls_scratch: vec![0u8; 16384],
wake_pipe_fd: eventfd,
tcp_nodelay: config.tcp_nodelay,
send_completions: (0..max_conn).map(|_| VecDeque::new()).collect(),
udp_sockets,
udp_token_base,
disk_io_rx,
disk_io_tx,
disk_io_pool,
next_disk_io_seq: 0,
direct_io_files: config
.direct_io
.as_ref()
.map(|dio| crate::direct_io::DirectIoFileTable::new(dio.max_files)),
direct_io_fds: config
.direct_io
.as_ref()
.map(|dio| vec![None; dio.max_files as usize])
.unwrap_or_default(),
fs_files: config
.fs
.as_ref()
.map(|fs| crate::fs::FsFileTable::new(fs.max_files)),
fs_fds: config
.fs
.as_ref()
.map(|fs| vec![None; fs.max_files as usize])
.unwrap_or_default(),
pending_fs_opens: std::collections::HashMap::new(),
})
}
pub(crate) fn make_ctx(&mut self) -> DriverCtx<'_> {
let tls_ptr = self
.tls_table
.as_mut()
.map(|t| t as *mut _)
.unwrap_or(std::ptr::null_mut());
DriverCtx {
connections: &mut self.connections,
send_copy_pool: &mut self.send_copy_pool,
tls_table: tls_ptr,
shutdown_requested: &mut self.shutdown_local,
connect_addrs: &mut self.connect_addrs,
tcp_nodelay: self.tcp_nodelay,
#[cfg(feature = "timestamps")]
timestamps: false,
#[cfg(feature = "timestamps")]
recvmsg_msghdr: std::ptr::null(),
send_queues: &mut self.send_queues,
pending_sends: &mut self.pending_sends,
tcp_streams: &mut self.tcp_streams,
poll: &mut self.poll,
writable: &mut self.writable,
send_completions: &mut self.send_completions,
connect_deadlines: &mut self.connect_deadlines,
disk_io_pool: &self.disk_io_pool,
disk_io_tx: &self.disk_io_tx,
wake_handle: self.wake_handle,
next_disk_io_seq: &mut self.next_disk_io_seq,
direct_io_files: &mut self.direct_io_files,
direct_io_fds: &mut self.direct_io_fds,
fs_files: &mut self.fs_files,
fs_fds: &mut self.fs_fds,
pending_fs_opens: &mut self.pending_fs_opens,
}
}
pub(crate) fn close_connection(&mut self, conn_index: u32) {
let idx = conn_index as usize;
if let Some(conn) = self.connections.get_mut(conn_index) {
if matches!(conn.recv_mode, RecvMode::Closed) {
return; }
conn.recv_mode = RecvMode::Closed;
} else {
return;
}
if let Some(ref mut stream) = self.tcp_streams[idx] {
use std::io::Write;
use std::os::fd::AsRawFd;
let fd = stream.as_raw_fd();
unsafe {
let flags = libc::fcntl(fd, libc::F_GETFL);
libc::fcntl(fd, libc::F_SETFL, flags & !libc::O_NONBLOCK);
}
for (data, offset) in self.pending_sends[idx].drain(..) {
let _ = stream.write_all(&data[offset..]);
}
let _ = stream.flush();
if let Some(ref mut tls_table) = self.tls_table
&& tls_table.has(conn_index)
{
if let Some(tls_conn) = tls_table.get_mut(conn_index) {
tls_conn.conn.send_close_notify();
}
crate::tls::flush_tls_output_mio(tls_table, stream, conn_index);
tls_table.remove(conn_index);
}
}
if let Some(mut stream) = self.tcp_streams[idx].take() {
let _ = self.poll.registry().deregister(&mut stream);
}
self.pending_sends[idx].clear();
self.writable[idx] = false;
self.connect_deadlines[idx] = None;
self.send_completions[idx].clear();
self.send_queues[idx].queue.clear();
self.send_queues[idx].in_flight = false;
if self.connections.get(conn_index).is_some() {
self.connections.release(conn_index);
}
crate::metrics::CONNECTIONS_CLOSED.increment();
crate::metrics::CONNECTIONS_ACTIVE.decrement();
}
pub(crate) fn flush_sends(&mut self, conn_index: u32) -> (bool, u32) {
use std::os::fd::AsRawFd;
let idx = conn_index as usize;
let stream = match self.tcp_streams[idx].as_mut() {
Some(s) => s,
None => return (true, 0),
};
let mut total_written: u32 = 0;
while !self.pending_sends[idx].is_empty() {
let mut iovecs: Vec<libc::iovec> =
Vec::with_capacity(self.pending_sends[idx].len().min(1024));
for (data, offset) in self.pending_sends[idx].iter() {
if iovecs.len() >= 1024 {
break;
}
let remaining = &data[*offset..];
if !remaining.is_empty() {
iovecs.push(libc::iovec {
iov_base: remaining.as_ptr() as *mut libc::c_void,
iov_len: remaining.len(),
});
}
}
if iovecs.is_empty() {
self.pending_sends[idx].clear();
break;
}
let fd = stream.as_raw_fd();
let result = unsafe { libc::writev(fd, iovecs.as_ptr(), iovecs.len() as i32) };
if result < 0 {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::WouldBlock {
self.writable[idx] = false;
return (false, total_written);
}
return (true, total_written);
}
if result == 0 {
return (true, total_written);
}
let mut remaining = result as usize;
total_written += result as u32;
while remaining > 0 {
if let Some((data, offset)) = self.pending_sends[idx].front_mut() {
let avail = data.len() - *offset;
if remaining >= avail {
remaining -= avail;
self.pending_sends[idx].pop_front();
} else {
*offset += remaining;
remaining = 0;
}
} else {
break;
}
}
}
if let Some(stream) = self.tcp_streams[idx].as_mut() {
let _ = self.poll.registry().reregister(
stream,
mio::Token(idx + 1),
mio::Interest::READABLE,
);
}
(true, total_written)
}
pub(crate) fn register_writable(&mut self, conn_index: u32) {
let idx = conn_index as usize;
if let Some(stream) = self.tcp_streams[idx].as_mut() {
let _ = self.poll.registry().reregister(
stream,
mio::Token(idx + 1),
mio::Interest::READABLE | mio::Interest::WRITABLE,
);
}
}
}