use std::collections::VecDeque;
use indexmap::IndexMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::ops::RangeInclusive;
use std::path::PathBuf;
use std::task::Waker;
use std::time::Duration;
use bytes::{Bytes, BytesMut};
pub const DEFAULT_EPHEMERAL_PORTS: RangeInclusive<u16> = 49152..=65535;
pub const DEFAULT_SEND_BUF_CAP: usize = 64 * 1024;
pub const DEFAULT_RECV_BUF_CAP: usize = 64 * 1024;
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Domain {
Inet,
Inet6,
Unix,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Type {
Stream,
Dgram,
SeqPacket,
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Addr {
Inet(SocketAddr),
Unix(PathBuf),
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SocketOption {
Broadcast(bool),
ReuseAddr(bool),
ReusePort(bool),
Linger(Option<Duration>),
RecvBufferSize(usize),
SendBufferSize(usize),
KeepAlive(bool),
TcpNoDelay(bool),
TcpKeepIdle(Duration),
TcpKeepInterval(Duration),
TcpKeepCount(u32),
IpTtl(u8),
IpMulticastTtl(u8),
IpMulticastLoop(bool),
IpAddMembership { group: Ipv4Addr, iface: Ipv4Addr },
IpDropMembership { group: Ipv4Addr, iface: Ipv4Addr },
Ipv6Only(bool),
Ipv6MulticastHops(u8),
Ipv6MulticastLoop(bool),
Ipv6JoinGroup { group: Ipv6Addr, iface: u32 },
Ipv6LeaveGroup { group: Ipv6Addr, iface: u32 },
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SocketOptionKind {
Broadcast,
ReuseAddr,
ReusePort,
Linger,
RecvBufferSize,
SendBufferSize,
KeepAlive,
TcpNoDelay,
TcpKeepIdle,
TcpKeepInterval,
TcpKeepCount,
IpTtl,
IpMulticastTtl,
IpMulticastLoop,
Ipv6Only,
Ipv6MulticastHops,
Ipv6MulticastLoop,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Fd(u64);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TcpState {
SynSent,
SynReceived,
Established,
FinWait1,
FinWait2,
CloseWait,
LastAck,
Closing,
Closed,
}
#[derive(Debug)]
pub struct Tcb {
pub state: TcpState,
pub peer: SocketAddr,
pub snd_nxt: u32,
pub snd_una: u32,
pub snd_wnd: u16,
pub rcv_nxt: u32,
pub send_buf: BytesMut,
pub recv_buf: BytesMut,
pub wr_closed: bool,
pub peer_fin: bool,
pub fin_seq: Option<u32>,
pub reset: bool,
pub timed_out: bool,
pub egress_since_ack: u32,
pub retx_attempts: u32,
}
#[derive(Debug)]
pub struct ListenState {
pub backlog: usize,
pub ready: VecDeque<Fd>,
pub accept_wakers: Vec<Waker>,
}
impl ListenState {
pub fn new(backlog: usize) -> Self {
Self {
backlog,
ready: VecDeque::new(),
accept_wakers: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct Socket {
pub domain: Domain,
pub ty: Type,
pub bound: Option<BindKey>,
pub peer: Option<Addr>,
pub broadcast: bool,
pub ttl: u8,
pub tcp_nodelay: bool,
pub recv_queue: VecDeque<(Addr, Bytes)>,
pub read_wakers: Vec<Waker>,
pub write_wakers: Vec<Waker>,
pub tcb: Option<Tcb>,
pub listen: Option<ListenState>,
pub connect_waker: Option<Waker>,
pub fd_closed: bool,
}
impl Socket {
pub fn new(domain: Domain, ty: Type) -> Self {
Self {
domain,
ty,
bound: None,
peer: None,
broadcast: false,
ttl: 64,
tcp_nodelay: false,
recv_queue: VecDeque::new(),
read_wakers: Vec::new(),
write_wakers: Vec::new(),
tcb: None,
listen: None,
connect_waker: None,
fd_closed: false,
}
}
pub fn register_read_waker(&mut self, waker: &Waker) {
register(&mut self.read_wakers, waker);
}
pub fn register_write_waker(&mut self, waker: &Waker) {
register(&mut self.write_wakers, waker);
}
pub fn wake_read(&mut self) {
wake_all(&mut self.read_wakers);
}
pub fn wake_write(&mut self) {
wake_all(&mut self.write_wakers);
}
}
fn register(slot: &mut Vec<Waker>, waker: &Waker) {
if slot.iter().any(|w| w.will_wake(waker)) {
return;
}
slot.push(waker.clone());
}
fn wake_all(slot: &mut Vec<Waker>) {
for w in slot.drain(..) {
w.wake();
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BindKey {
pub domain: Domain,
pub ty: Type,
pub local_addr: IpAddr,
pub local_port: u16,
}
#[derive(Debug)]
pub struct SocketTable {
next_id: u64,
sockets: IndexMap<Fd, Socket>,
bindings: IndexMap<BindKey, Vec<Fd>>,
connections: IndexMap<(SocketAddr, SocketAddr), Fd>,
ports: PortAllocator,
}
impl SocketTable {
pub fn new() -> Self {
Self {
next_id: 1,
sockets: IndexMap::new(),
bindings: IndexMap::new(),
connections: IndexMap::new(),
ports: PortAllocator::new(DEFAULT_EPHEMERAL_PORTS),
}
}
pub fn insert(&mut self, socket: Socket) -> Fd {
let fd = Fd(self.next_id);
self.next_id += 1;
self.sockets.insert(fd, socket);
fd
}
pub fn get(&self, fd: Fd) -> Option<&Socket> {
self.sockets.get(&fd)
}
pub fn iter(&self) -> impl Iterator<Item = (Fd, &Socket)> {
self.sockets.iter().map(|(&fd, s)| (fd, s))
}
pub fn get_mut(&mut self, fd: Fd) -> Option<&mut Socket> {
self.sockets.get_mut(&fd)
}
pub fn remove(&mut self, fd: Fd) -> Option<Socket> {
self.bindings.retain(|_, fds| {
fds.retain(|&f| f != fd);
!fds.is_empty()
});
self.connections.retain(|_, f| *f != fd);
self.sockets.shift_remove(&fd)
}
pub fn find_by_bind(&self, key: &BindKey) -> &[Fd] {
self.bindings
.get(key)
.map(Vec::as_slice)
.unwrap_or_default()
}
pub fn insert_binding(&mut self, key: BindKey, fd: Fd) {
self.bindings.entry(key).or_default().push(fd);
}
pub fn insert_connection(&mut self, local: SocketAddr, remote: SocketAddr, fd: Fd) {
assert!(
!local.ip().is_unspecified(),
"connection index requires concrete local addr"
);
self.connections.insert((local, remote), fd);
}
pub fn find_connection(&self, local: SocketAddr, remote: SocketAddr) -> Option<Fd> {
self.connections.get(&(local, remote)).copied()
}
pub fn connections_on(&self, local: SocketAddr) -> impl Iterator<Item = (SocketAddr, Fd)> + '_ {
self.connections
.iter()
.filter(move |((l, _), _)| *l == local)
.map(|((_, r), fd)| (*r, *fd))
}
pub fn bindings_on_port(
&self,
domain: Domain,
ty: Type,
port: u16,
) -> impl Iterator<Item = (&BindKey, &[Fd])> {
self.bindings
.iter()
.filter(move |(k, _)| k.domain == domain && k.ty == ty && k.local_port == port)
.map(|(k, v)| (k, v.as_slice()))
}
pub fn allocate_port(&mut self, domain: Domain, ty: Type) -> Option<u16> {
let bindings = &self.bindings;
self.ports.allocate(|p| {
bindings
.keys()
.any(|k| k.domain == domain && k.ty == ty && k.local_port == p)
})
}
}
impl Default for SocketTable {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct PortAllocator {
range: RangeInclusive<u16>,
cursor: u16,
}
impl PortAllocator {
pub fn new(range: RangeInclusive<u16>) -> Self {
let cursor = *range.start();
Self { range, cursor }
}
pub fn allocate(&mut self, mut in_use: impl FnMut(u16) -> bool) -> Option<u16> {
let start = self.cursor;
loop {
let p = self.cursor;
self.cursor = if p == *self.range.end() {
*self.range.start()
} else {
p + 1
};
if !in_use(p) {
return Some(p);
}
if self.cursor == start {
return None;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn port_allocator_skips_in_use() {
let mut a = PortAllocator::new(10..=12);
let used = [10u16];
assert_eq!(a.allocate(|p| used.contains(&p)), Some(11));
assert_eq!(a.allocate(|p| used.contains(&p)), Some(12));
}
#[test]
fn port_allocator_exhausts() {
let mut a = PortAllocator::new(10..=11);
let used = [10u16, 11];
assert_eq!(a.allocate(|p| used.contains(&p)), None);
}
#[test]
fn fds_are_unique() {
let mut t = SocketTable::new();
let a = t.insert(Socket::new(Domain::Inet, Type::Dgram));
let b = t.insert(Socket::new(Domain::Inet, Type::Dgram));
assert_ne!(a, b);
}
#[test]
fn multiple_bindings_on_same_key() {
let mut t = SocketTable::new();
let a = t.insert(Socket::new(Domain::Inet, Type::Dgram));
let b = t.insert(Socket::new(Domain::Inet, Type::Dgram));
let key = BindKey {
domain: Domain::Inet,
ty: Type::Dgram,
local_addr: "10.0.0.1".parse().unwrap(),
local_port: 5000,
};
t.insert_binding(key.clone(), a);
t.insert_binding(key.clone(), b);
assert_eq!(t.find_by_bind(&key), &[a, b]);
}
}