use std::collections::VecDeque;
use std::io::{Error, ErrorKind};
use std::net::{IpAddr, SocketAddr};
use std::task::{Context, Poll};
use tokio::io::ReadBuf;
use crate::kernel::socket::{BindKey, SocketTable, DEFAULT_RECV_BUF_CAP, DEFAULT_SEND_BUF_CAP};
mod packet;
mod socket;
mod tcp;
mod udp;
mod uds;
pub use socket::{Addr, Domain, Fd, SocketOption, SocketOptionKind, Type};
pub use socket::{ListenState, Socket, Tcb, TcpState};
pub use packet::{Packet, TcpFlags, TcpSegment, Transport, UdpDatagram};
pub const DEFAULT_MTU: u32 = 1500;
pub const DEFAULT_LOOPBACK_MTU: u32 = 65536;
pub const DEFAULT_BACKLOG: usize = 1024;
pub const DEFAULT_RETX_THRESHOLD: u32 = 3;
pub const DEFAULT_RETX_MAX: u32 = 5;
#[derive(Debug, Clone)]
pub struct KernelConfig {
pub mtu: u32,
pub loopback_mtu: u32,
pub send_buf_cap: usize,
pub recv_buf_cap: usize,
pub default_backlog: usize,
pub retx_threshold: u32,
pub retx_max: u32,
}
impl Default for KernelConfig {
fn default() -> Self {
Self {
mtu: DEFAULT_MTU,
loopback_mtu: DEFAULT_LOOPBACK_MTU,
send_buf_cap: DEFAULT_SEND_BUF_CAP,
recv_buf_cap: DEFAULT_RECV_BUF_CAP,
default_backlog: DEFAULT_BACKLOG,
retx_threshold: DEFAULT_RETX_THRESHOLD,
retx_max: DEFAULT_RETX_MAX,
}
}
}
impl KernelConfig {
pub fn mtu(mut self, v: u32) -> Self {
self.mtu = v;
self
}
pub fn loopback_mtu(mut self, v: u32) -> Self {
self.loopback_mtu = v;
self
}
pub fn send_buf_cap(mut self, v: usize) -> Self {
self.send_buf_cap = v;
self
}
pub fn recv_buf_cap(mut self, v: usize) -> Self {
self.recv_buf_cap = v;
self
}
pub fn default_backlog(mut self, v: usize) -> Self {
self.default_backlog = v;
self
}
pub fn retx_threshold(mut self, v: u32) -> Self {
self.retx_threshold = v;
self
}
pub fn retx_max(mut self, v: u32) -> Self {
self.retx_max = v;
self
}
}
const EAFNOSUPPORT: i32 = 97;
pub(crate) const EMSGSIZE: i32 = 90;
#[derive(Debug)]
pub struct Kernel {
sockets: SocketTable,
addresses: Vec<IpAddr>,
pub(crate) mtu: u32,
pub(crate) loopback_mtu: u32,
pub(crate) send_buf_cap: usize,
pub(crate) recv_buf_cap: usize,
pub(crate) default_backlog: usize,
pub(crate) retx_threshold: u32,
pub(crate) retx_max: u32,
outbound: VecDeque<Packet>,
tcp_isn: u32,
}
impl Kernel {
pub fn new() -> Self {
Self::with_config(KernelConfig::default())
}
pub fn with_config(cfg: KernelConfig) -> Self {
Self {
sockets: SocketTable::new(),
addresses: Vec::new(),
mtu: cfg.mtu,
loopback_mtu: cfg.loopback_mtu,
send_buf_cap: cfg.send_buf_cap,
recv_buf_cap: cfg.recv_buf_cap,
default_backlog: cfg.default_backlog,
retx_threshold: cfg.retx_threshold,
retx_max: cfg.retx_max,
outbound: VecDeque::new(),
tcp_isn: 0x0100_0000,
}
}
fn mk_socket(&mut self, domain: Domain, ty: Type) -> Fd {
self.sockets.insert(Socket::new(domain, ty))
}
pub fn open(&mut self, domain: Domain, ty: Type) -> Fd {
self.mk_socket(domain, ty)
}
pub fn close(&mut self, fd: Fd) {
if tcp::on_close(self, fd) {
self.sockets.remove(fd);
}
}
pub fn sockets(&self) -> impl Iterator<Item = (Fd, &Socket)> {
self.sockets.iter()
}
pub(crate) fn lookup(&self, fd: Fd) -> std::io::Result<&Socket> {
self.sockets
.get(fd)
.ok_or_else(|| Error::from(ErrorKind::NotFound))
}
pub(crate) fn lookup_mut(&mut self, fd: Fd) -> std::io::Result<&mut Socket> {
self.sockets
.get_mut(fd)
.ok_or_else(|| Error::from(ErrorKind::NotFound))
}
pub fn add_address(&mut self, addr: IpAddr) {
if !self.addresses.contains(&addr) {
self.addresses.push(addr);
}
}
pub fn is_local(&self, addr: IpAddr) -> bool {
if addr.is_loopback() {
return true;
}
self.addresses.contains(&addr)
}
pub fn bind(&mut self, addr: &Addr, ty: Type) -> std::io::Result<Fd> {
let (domain, ip, port) = match addr {
Addr::Inet(sa) if sa.is_ipv4() => (Domain::Inet, sa.ip(), sa.port()),
Addr::Inet(sa) => (Domain::Inet6, sa.ip(), sa.port()),
Addr::Unix(_) => unimplemented!("AF_UNIX bind"),
};
if !ip.is_unspecified() && !self.is_local(ip) {
return Err(Error::from(ErrorKind::AddrNotAvailable));
}
let port = if port == 0 {
self.sockets
.allocate_port(domain, ty)
.ok_or_else(|| Error::from(ErrorKind::AddrInUse))?
} else {
port
};
let key = BindKey {
domain,
ty,
local_addr: ip,
local_port: port,
};
for (existing, _ids) in self
.sockets
.bindings_on_port(key.domain, key.ty, key.local_port)
{
if existing.local_addr == key.local_addr
|| existing.local_addr.is_unspecified()
|| key.local_addr.is_unspecified()
{
return Err(Error::from(ErrorKind::AddrInUse));
}
}
let fd = self.mk_socket(domain, ty);
self.sockets.insert_binding(key.clone(), fd);
self.lookup_mut(fd).expect("socket entry present").bound = Some(key);
Ok(fd)
}
pub fn set_option(&mut self, fd: Fd, opt: SocketOption) -> std::io::Result<()> {
let st = self.lookup_mut(fd)?;
match opt {
SocketOption::Broadcast(v) => st.broadcast = v,
SocketOption::IpTtl(v) => st.ttl = v,
SocketOption::TcpNoDelay(v) => st.tcp_nodelay = v,
_ => unimplemented!("set_option {:?}", opt),
}
Ok(())
}
pub fn get_option(&self, fd: Fd, kind: SocketOptionKind) -> std::io::Result<SocketOption> {
let st = self.lookup(fd)?;
Ok(match kind {
SocketOptionKind::Broadcast => SocketOption::Broadcast(st.broadcast),
SocketOptionKind::IpTtl => SocketOption::IpTtl(st.ttl),
SocketOptionKind::TcpNoDelay => SocketOption::TcpNoDelay(st.tcp_nodelay),
_ => unimplemented!("get_option {:?}", kind),
})
}
pub fn local_addr(&self, fd: Fd) -> std::io::Result<Addr> {
let key = self
.lookup(fd)?
.bound
.as_ref()
.ok_or_else(|| Error::from(ErrorKind::InvalidInput))?;
Ok(Addr::Inet(SocketAddr::new(key.local_addr, key.local_port)))
}
pub fn peer_addr(&self, fd: Fd) -> std::io::Result<Addr> {
self.lookup(fd)?
.peer
.clone()
.ok_or_else(|| Error::from(ErrorKind::NotConnected))
}
pub fn poll_send_to(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
buf: &[u8],
dst: &Addr,
) -> Poll<std::io::Result<usize>> {
let Addr::Inet(dst_sa) = dst else {
panic!("AF_UNIX not wired through poll_send_to");
};
let (ty, domain) = match self.lookup(fd) {
Ok(st) => (st.ty, st.domain),
Err(e) => return Poll::Ready(Err(e)),
};
assert_eq!(ty, Type::Dgram, "poll_send_to on non-Dgram fd");
match (domain, dst_sa) {
(Domain::Inet, SocketAddr::V4(_)) | (Domain::Inet6, SocketAddr::V6(_)) => {}
_ => return Poll::Ready(Err(Error::from_raw_os_error(EAFNOSUPPORT))),
}
udp::send_to(self, fd, cx, buf, dst_sa)
}
pub fn poll_recv_from(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<Addr>> {
let st = match self.lookup_mut(fd) {
Ok(st) => st,
Err(e) => return Poll::Ready(Err(e)),
};
assert_eq!(st.ty, Type::Dgram, "poll_recv_from on non-Dgram fd");
udp::recv_from(st, cx, buf)
}
pub fn poll_connect(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
addr: &Addr,
) -> Poll<std::io::Result<()>> {
let Addr::Inet(peer) = addr else {
panic!("AF_UNIX not wired through connect");
};
let (domain, ty, is_bound) = match self.lookup(fd) {
Ok(st) => (st.domain, st.ty, st.bound.is_some()),
Err(e) => return Poll::Ready(Err(e)),
};
match (domain, peer) {
(Domain::Inet, SocketAddr::V4(_)) | (Domain::Inet6, SocketAddr::V6(_)) => {}
_ => return Poll::Ready(Err(Error::from_raw_os_error(EAFNOSUPPORT))),
}
match ty {
Type::Dgram => {
if !is_bound {
if let Err(e) = udp::auto_bind(self, fd, domain, ty, peer.ip()) {
return Poll::Ready(Err(e));
}
}
self.lookup_mut(fd).expect("socket present").peer = Some(Addr::Inet(*peer));
Poll::Ready(Ok(()))
}
Type::Stream => tcp::poll_connect(self, fd, cx, domain, *peer, is_bound),
Type::SeqPacket => unimplemented!("SOCK_SEQPACKET connect"),
}
}
pub fn listen(&mut self, fd: Fd, backlog: usize) -> std::io::Result<()> {
let st = self.lookup_mut(fd)?;
assert_eq!(st.ty, Type::Stream, "listen on non-Stream fd");
assert!(st.bound.is_some(), "listen on unbound fd");
st.listen = Some(ListenState::new(backlog));
Ok(())
}
pub fn poll_accept(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<(Fd, SocketAddr)>> {
let st = match self.lookup_mut(fd) {
Ok(st) => st,
Err(e) => return Poll::Ready(Err(e)),
};
let listen = st.listen.as_mut().expect("poll_accept on non-listener fd");
if let Some(child) = listen.ready.pop_front() {
let peer = self
.lookup(child)
.expect("accepted fd present")
.tcb
.as_ref()
.expect("accepted fd has TCB")
.peer;
return Poll::Ready(Ok((child, peer)));
}
if !listen.accept_wakers.iter().any(|w| w.will_wake(cx.waker())) {
listen.accept_wakers.push(cx.waker().clone());
}
Poll::Pending
}
pub fn poll_send(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let (ty, peer) = match self.lookup(fd) {
Ok(st) => (st.ty, st.peer.clone()),
Err(e) => return Poll::Ready(Err(e)),
};
match ty {
Type::Dgram => {
let Some(peer) = peer else {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
};
let Addr::Inet(peer_sa) = peer else {
panic!("UDP peer stored as Addr::Unix");
};
udp::send_to(self, fd, cx, buf, &peer_sa)
}
Type::Stream => tcp::poll_send(self, fd, cx, buf),
Type::SeqPacket => unimplemented!("SOCK_SEQPACKET poll_send"),
}
}
pub fn poll_recv(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let ty = match self.lookup(fd) {
Ok(st) => st.ty,
Err(e) => return Poll::Ready(Err(e)),
};
match ty {
Type::Dgram => {
let st = self.lookup_mut(fd).expect("fd validated");
let mut rb = ReadBuf::new(buf);
match udp::recv(st, cx, &mut rb) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(rb.filled().len())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
Type::Stream => tcp::poll_recv(self, fd, cx, buf),
Type::SeqPacket => unimplemented!("SOCK_SEQPACKET poll_recv"),
}
}
pub fn poll_shutdown_write(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
let ty = match self.lookup(fd) {
Ok(st) => st.ty,
Err(e) => return Poll::Ready(Err(e)),
};
match ty {
Type::Stream => tcp::poll_shutdown_write(self, fd, cx),
Type::Dgram | Type::SeqPacket => {
unimplemented!("poll_shutdown_write on non-Stream fd")
}
}
}
pub fn poll_peek_from(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<Addr>> {
let st = match self.lookup_mut(fd) {
Ok(st) => st,
Err(e) => return Poll::Ready(Err(e)),
};
assert_eq!(st.ty, Type::Dgram, "poll_peek_from on non-Dgram fd");
udp::peek_from(st, cx, buf)
}
pub fn poll_peek(
&mut self,
fd: Fd,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let ty = match self.lookup(fd) {
Ok(st) => st.ty,
Err(e) => return Poll::Ready(Err(e)),
};
match ty {
Type::Dgram => {
let st = self.lookup_mut(fd).expect("fd validated");
let mut rb = ReadBuf::new(buf);
match udp::peek_from(st, cx, &mut rb) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(rb.filled().len())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
Type::Stream => tcp::poll_peek(self, fd, cx, buf),
Type::SeqPacket => unimplemented!("SOCK_SEQPACKET poll_peek"),
}
}
pub fn deliver(&mut self, pkt: Packet) {
match pkt.payload.clone() {
Transport::Udp(d) => udp::deliver(self, &pkt, &d),
Transport::Tcp(s) => tcp::deliver(self, &pkt, &s),
}
}
pub fn egress(&mut self, out: &mut Vec<Packet>) {
tcp::check_retx(self);
loop {
tcp::segment_all(self);
if self.outbound.is_empty() {
break;
}
let drained: Vec<_> = std::mem::take(&mut self.outbound).into_iter().collect();
for pkt in drained {
if self.is_local(pkt.dst) {
self.deliver(pkt);
} else {
out.push(pkt);
}
}
}
tcp::reap_closed(self);
}
}
impl Default for Kernel {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::io::ErrorKind;
use std::net::SocketAddr;
use super::*;
fn inet(s: &str) -> Addr {
Addr::Inet(s.parse().unwrap())
}
#[test]
fn loopback_is_implicit_local() {
let mut k = Kernel::new();
assert!(k.is_local("127.0.0.1".parse().unwrap()));
assert!(k.is_local("::1".parse().unwrap()));
assert!(!k.is_local("10.0.0.1".parse().unwrap()));
k.add_address("10.0.0.1".parse().unwrap());
assert!(k.is_local("10.0.0.1".parse().unwrap()));
}
#[test]
fn bind_records_local_addr() {
let mut k = Kernel::new();
let s = k.bind(&inet("127.0.0.1:5000"), Type::Dgram).unwrap();
assert_eq!(k.local_addr(s).unwrap(), inet("127.0.0.1:5000"));
}
#[test]
fn bind_port_zero_allocates_ephemeral() {
let mut k = Kernel::new();
let s = k.bind(&inet("127.0.0.1:0"), Type::Dgram).unwrap();
let Addr::Inet(SocketAddr::V4(v4)) = k.local_addr(s).unwrap() else {
panic!("expected v4")
};
assert!((49152..=65535).contains(&v4.port()));
}
#[test]
fn bind_conflict_is_addr_in_use() {
let mut k = Kernel::new();
k.bind(&inet("127.0.0.1:5000"), Type::Dgram).unwrap();
let err = k.bind(&inet("127.0.0.1:5000"), Type::Dgram).unwrap_err();
assert_eq!(err.kind(), ErrorKind::AddrInUse);
}
#[test]
fn bind_different_protocols_can_share_port() {
let mut k = Kernel::new();
k.bind(&inet("127.0.0.1:5000"), Type::Dgram).unwrap();
k.bind(&inet("127.0.0.1:5000"), Type::Stream).unwrap();
}
#[test]
fn bind_rejects_non_local_addr() {
let mut k = Kernel::new();
let err = k.bind(&inet("10.0.0.1:5000"), Type::Dgram).unwrap_err();
assert_eq!(err.kind(), ErrorKind::AddrNotAvailable);
}
#[test]
fn bind_wildcard_addr_is_allowed() {
let mut k = Kernel::new();
k.bind(&inet("0.0.0.0:5000"), Type::Dgram).unwrap();
}
#[test]
fn distinct_specific_ips_coexist() {
let mut k = Kernel::new();
k.add_address("10.0.0.1".parse().unwrap());
k.add_address("10.0.0.2".parse().unwrap());
k.bind(&inet("10.0.0.1:5000"), Type::Dgram).unwrap();
k.bind(&inet("10.0.0.2:5000"), Type::Dgram).unwrap();
}
fn noop_cx() -> Context<'static> {
use std::task::Waker;
Context::from_waker(Waker::noop())
}
#[test]
fn udp_broadcast_send_requires_broadcast_option() {
let mut k = Kernel::new();
k.add_address("10.0.0.1".parse().unwrap());
let s = k.bind(&inet("10.0.0.1:0"), Type::Dgram).unwrap();
let dst = Addr::Inet("255.255.255.255:9000".parse().unwrap());
let Poll::Ready(Err(e)) = k.poll_send_to(s, &mut noop_cx(), b"x", &dst) else {
panic!("expected broadcast rejection");
};
assert_eq!(e.kind(), ErrorKind::PermissionDenied);
k.set_option(s, SocketOption::Broadcast(true)).unwrap();
let Poll::Ready(Ok(_)) = k.poll_send_to(s, &mut noop_cx(), b"x", &dst) else {
panic!("broadcast send should succeed with SO_BROADCAST");
};
}
#[test]
fn bind_zero_avoids_ports_taken_on_other_ips() {
let mut k = Kernel::new();
k.add_address("10.0.0.1".parse().unwrap());
k.bind(&inet("10.0.0.1:49152"), Type::Dgram).unwrap();
let s = k.bind(&inet("127.0.0.1:0"), Type::Dgram).unwrap();
let Addr::Inet(sa) = k.local_addr(s).unwrap() else {
panic!("v4 expected")
};
assert_ne!(sa.port(), 49152);
}
#[test]
fn broadcast_option_roundtrips() {
let mut k = Kernel::new();
let s = k.bind(&inet("127.0.0.1:0"), Type::Dgram).unwrap();
assert_eq!(
k.get_option(s, SocketOptionKind::Broadcast).unwrap(),
SocketOption::Broadcast(false)
);
k.set_option(s, SocketOption::Broadcast(true)).unwrap();
assert_eq!(
k.get_option(s, SocketOptionKind::Broadcast).unwrap(),
SocketOption::Broadcast(true)
);
}
}