use crate::envelope::{hex, Datagram, Protocol, Segment, Syn};
#[cfg(feature = "unstable-fs")]
use crate::fs::{Fs, FsConfig};
use crate::net::tcp::stream::BidiFlowControl;
use crate::net::{SocketPair, TcpListener, UdpSocket};
use crate::{Envelope, TRACING_TARGET};
use bytes::Bytes;
use indexmap::IndexMap;
use std::collections::VecDeque;
use std::fmt::Display;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::ops::RangeInclusive;
use std::sync::Arc;
#[cfg(feature = "unstable-fs")]
use std::sync::Mutex;
use tokio::sync::{mpsc, Notify};
use tokio::time::{Duration, Instant};
const DEFAULT_BROADCAST: bool = false;
const DEFAULT_MULTICAST_LOOP: bool = true;
pub(crate) struct Host {
pub(crate) nodename: String,
pub(crate) addr: IpAddr,
pub(crate) timer: HostTimer,
pub(crate) udp: Udp,
pub(crate) tcp: Tcp,
#[cfg(feature = "unstable-fs")]
pub(crate) fs: Arc<Mutex<Fs>>,
next_ephemeral_port: u16,
ephemeral_ports: RangeInclusive<u16>,
}
impl Host {
#[cfg(feature = "unstable-fs")]
pub(crate) fn new(
nodename: impl Into<String>,
addr: IpAddr,
timer: HostTimer,
ephemeral_ports: RangeInclusive<u16>,
tcp_capacity: usize,
udp_capacity: usize,
fs_config: FsConfig,
) -> Host {
Host {
nodename: nodename.into(),
addr,
udp: Udp::new(udp_capacity),
tcp: Tcp::new(tcp_capacity),
fs: Arc::new(Mutex::new(Fs::new(fs_config))),
timer,
next_ephemeral_port: *ephemeral_ports.start(),
ephemeral_ports,
}
}
#[cfg(not(feature = "unstable-fs"))]
pub(crate) fn new(
nodename: impl Into<String>,
addr: IpAddr,
timer: HostTimer,
ephemeral_ports: RangeInclusive<u16>,
tcp_capacity: usize,
udp_capacity: usize,
) -> Host {
Host {
nodename: nodename.into(),
addr,
udp: Udp::new(udp_capacity),
tcp: Tcp::new(tcp_capacity),
timer,
next_ephemeral_port: *ephemeral_ports.start(),
ephemeral_ports,
}
}
pub(crate) fn assign_ephemeral_port(&mut self) -> u16 {
for _ in self.ephemeral_ports.clone() {
let ret = self.next_ephemeral_port;
if self.next_ephemeral_port == *self.ephemeral_ports.end() {
self.next_ephemeral_port = *self.ephemeral_ports.start();
} else {
self.next_ephemeral_port += 1;
}
if self.udp.is_port_assigned(ret) || self.tcp.is_port_assigned(ret) {
continue;
}
return ret;
}
panic!("Host: '{}' ports exhausted", self.nodename)
}
pub(crate) fn receive_from_network(&mut self, envelope: Envelope) -> Result<(), Protocol> {
let Envelope { src, dst, message } = envelope;
tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %message, "Delivered");
match message {
Protocol::Tcp(segment) => self.tcp.receive_from_network(src, dst, segment),
Protocol::Udp(datagram) => {
self.udp.receive_from_network(src, dst, datagram);
Ok(())
}
}
}
}
pub(crate) struct HostTimer {
elapsed: Duration,
now: Option<Instant>,
start_offset: Duration,
since_epoch: Duration,
}
impl HostTimer {
pub(crate) fn new(start_offset: Duration, since_epoch: Duration) -> Self {
Self {
elapsed: Duration::ZERO,
now: None,
start_offset,
since_epoch,
}
}
pub(crate) fn tick(&mut self, duration: Duration) {
self.elapsed += duration
}
pub(crate) fn now(&mut self, now: Instant) {
self.now.replace(now);
}
pub(crate) fn elapsed(&self) -> Duration {
let run_duration = self.now.expect("host instant not set").elapsed();
self.elapsed + run_duration
}
pub(crate) fn sim_elapsed(&self) -> Duration {
self.start_offset + self.elapsed()
}
pub(crate) fn since_epoch(&self) -> Duration {
self.since_epoch + self.sim_elapsed()
}
}
pub(crate) struct Udp {
binds: IndexMap<u16, UdpBind>,
capacity: usize,
}
struct UdpBind {
bind_addr: SocketAddr,
target_addr: Option<SocketAddr>,
broadcast: bool,
multicast_loop: bool,
queue: mpsc::Sender<(Datagram, SocketAddr)>,
}
impl Udp {
fn new(capacity: usize) -> Self {
Self {
binds: IndexMap::new(),
capacity,
}
}
pub(crate) fn is_port_assigned(&self, port: u16) -> bool {
self.binds.keys().any(|p| *p == port)
}
pub(crate) fn is_broadcast_enabled(&self, port: u16) -> bool {
self.binds
.get(&port)
.map(|bind| bind.broadcast)
.unwrap_or(DEFAULT_BROADCAST)
}
pub(crate) fn is_multicast_loop_enabled(&self, port: u16) -> bool {
self.binds
.get(&port)
.map(|bind| bind.multicast_loop)
.unwrap_or(DEFAULT_MULTICAST_LOOP)
}
pub(crate) fn set_broadcast(&mut self, port: u16, on: bool) {
self.binds
.entry(port)
.and_modify(|bind| bind.broadcast = on);
}
pub(crate) fn set_multicast_loop(&mut self, port: u16, on: bool) {
self.binds
.entry(port)
.and_modify(|bind| bind.multicast_loop = on);
}
pub(crate) fn bind(&mut self, addr: SocketAddr) -> io::Result<UdpSocket> {
let (tx, rx) = mpsc::channel(self.capacity);
let bind = UdpBind {
bind_addr: addr,
target_addr: None,
broadcast: DEFAULT_BROADCAST,
multicast_loop: DEFAULT_MULTICAST_LOOP,
queue: tx,
};
match self.binds.entry(addr.port()) {
indexmap::map::Entry::Occupied(_) => {
return Err(io::Error::new(io::ErrorKind::AddrInUse, addr.to_string()));
}
indexmap::map::Entry::Vacant(entry) => entry.insert(bind),
};
tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"UDP", "Bind");
Ok(UdpSocket::new(addr, rx))
}
pub(crate) fn connect(&mut self, src: SocketAddr, dst: SocketAddr) {
let Some(bind) = self.binds.get_mut(&src.port()) else {
panic!("Connect failed (no matching bind) for {src}");
};
bind.target_addr = Some(dst);
}
fn receive_from_network(&mut self, src: SocketAddr, dst: SocketAddr, datagram: Datagram) {
if let Some(bind) = self.binds.get_mut(&dst.port()) {
if let Some(target) = bind.target_addr {
if !matches(target, src) {
tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Connect Addr not matching)");
return;
}
}
if !matches(bind.bind_addr, dst) {
tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Addr not bound)");
return;
}
if let Err(err) = bind.queue.try_send((datagram, src)) {
match err {
mpsc::error::TrySendError::Full((datagram, _)) => {
tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Full buffer)");
}
mpsc::error::TrySendError::Closed((datagram, _)) => {
tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Receiver closed)");
}
}
}
}
}
pub(crate) fn unbind(&mut self, addr: SocketAddr) {
let exists = self.binds.swap_remove(&addr.port());
assert!(exists.is_some(), "unknown bind {addr}");
tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"UDP", "Unbind");
}
}
pub(crate) struct Tcp {
binds: IndexMap<u16, ServerSocket>,
server_socket_capacity: usize,
sockets: IndexMap<SocketPair, StreamSocket>,
socket_capacity: usize,
}
struct ServerSocket {
bind_addr: SocketAddr,
notify: Arc<Notify>,
deque: VecDeque<(Syn, SocketAddr)>,
}
struct StreamSocket {
buf: IndexMap<u64, SequencedSegment>,
next_send_seq: u64,
recv_seq: u64,
sender: mpsc::Sender<SequencedSegment>,
flow_control: BidiFlowControl,
ref_ct: usize,
}
#[derive(Debug)]
pub(crate) enum SequencedSegment {
Data(Bytes),
Fin,
}
impl Display for SequencedSegment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SequencedSegment::Data(data) => hex("TCP", data, f),
SequencedSegment::Fin => write!(f, "TCP FIN"),
}
}
}
impl StreamSocket {
fn new(capacity: usize) -> (Self, mpsc::Receiver<SequencedSegment>, BidiFlowControl) {
let (tx, rx) = mpsc::channel(capacity);
let flow_control = BidiFlowControl::new(capacity);
let sock = Self {
buf: IndexMap::new(),
next_send_seq: 1,
recv_seq: 0,
sender: tx,
flow_control: flow_control.clone(),
ref_ct: 2,
};
(sock, rx, flow_control)
}
fn assign_seq(&mut self) -> u64 {
let seq = self.next_send_seq;
self.next_send_seq += 1;
seq
}
fn buffer(&mut self, seq: u64, segment: SequencedSegment) -> Result<(), Protocol> {
use mpsc::error::TrySendError::*;
let exists = self.buf.insert(seq, segment);
assert!(exists.is_none(), "duplicate segment {seq}");
while self.buf.contains_key(&(self.recv_seq + 1)) {
self.recv_seq += 1;
match self.sender.try_reserve() {
Ok(permit) => {
let segment = self.buf.swap_remove(&self.recv_seq).unwrap();
permit.send(segment)
}
Err(Closed(())) => return Err(Protocol::Tcp(Segment::Rst)),
Err(Full(())) => {
self.recv_seq -= 1;
break;
}
}
}
Ok(())
}
}
impl Tcp {
fn new(capacity: usize) -> Self {
Self {
binds: IndexMap::new(),
sockets: IndexMap::new(),
server_socket_capacity: capacity,
socket_capacity: capacity,
}
}
fn is_port_assigned(&self, port: u16) -> bool {
self.binds.keys().any(|p| *p == port) || self.sockets.keys().any(|a| a.local.port() == port)
}
pub(crate) fn bind(&mut self, addr: SocketAddr) -> io::Result<TcpListener> {
let notify = Arc::new(Notify::new());
let sock = ServerSocket {
bind_addr: addr,
notify: notify.clone(),
deque: VecDeque::new(),
};
match self.binds.entry(addr.port()) {
indexmap::map::Entry::Occupied(_) => {
return Err(io::Error::new(io::ErrorKind::AddrInUse, addr.to_string()));
}
indexmap::map::Entry::Vacant(entry) => entry.insert(sock),
};
tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"TCP", "Bind");
Ok(TcpListener::new(addr, notify))
}
pub(crate) fn new_stream(
&mut self,
pair: SocketPair,
) -> (mpsc::Receiver<SequencedSegment>, BidiFlowControl) {
let (sock, rx, bidi) = StreamSocket::new(self.socket_capacity);
let exists = self.sockets.insert(pair, sock);
assert!(exists.is_none(), "{pair:?} is already connected");
(rx, bidi)
}
pub(crate) fn flow_control(&self, pair: SocketPair) -> BidiFlowControl {
self.sockets
.get(&pair)
.expect("missing stream socket")
.flow_control
.clone()
}
pub(crate) fn stream_count(&self) -> usize {
self.sockets.len()
}
pub(crate) fn accept(&mut self, addr: SocketAddr) -> Option<(Syn, SocketAddr)> {
self.binds[&addr.port()].deque.pop_front()
}
pub(crate) fn assign_send_seq(&mut self, pair: SocketPair) -> Option<u64> {
let sock = self.sockets.get_mut(&pair)?;
Some(sock.assign_seq())
}
fn receive_from_network(
&mut self,
src: SocketAddr,
dst: SocketAddr,
segment: Segment,
) -> Result<(), Protocol> {
match segment {
Segment::Syn(syn) => {
if let Some(b) = self.binds.get_mut(&dst.port()) {
if b.deque.len() == self.server_socket_capacity {
panic!("{dst} server socket buffer full");
}
if matches(b.bind_addr, dst) {
b.deque.push_back((syn, src));
b.notify.notify_one();
}
}
}
Segment::Data(seq, data) => match self.sockets.get_mut(&SocketPair::new(dst, src)) {
Some(sock) => sock.buffer(seq, SequencedSegment::Data(data))?,
None => return Err(Protocol::Tcp(Segment::Rst)),
},
Segment::Fin(seq) => match self.sockets.get_mut(&SocketPair::new(dst, src)) {
Some(sock) => sock.buffer(seq, SequencedSegment::Fin)?,
None => return Err(Protocol::Tcp(Segment::Rst)),
},
Segment::Rst => {
if self.sockets.get(&SocketPair::new(dst, src)).is_some() {
self.sockets
.swap_remove(&SocketPair::new(dst, src))
.unwrap();
}
}
};
Ok(())
}
pub(crate) fn has_buffered_data(&self, pair: SocketPair) -> bool {
self.sockets
.get(&pair)
.map(|s| {
s.buf
.values()
.any(|seg| matches!(seg, SequencedSegment::Data(_)))
})
.unwrap_or(false)
}
pub(crate) fn reset_stream(&mut self, pair: SocketPair) {
self.sockets.swap_remove(&pair);
}
pub(crate) fn close_stream_half(&mut self, pair: SocketPair) {
if let Some(sock) = self.sockets.get_mut(&pair) {
sock.ref_ct -= 1;
if sock.ref_ct == 0 {
self.sockets.swap_remove(&pair).unwrap();
}
}
}
pub(crate) fn unbind(&mut self, addr: SocketAddr) {
let exists = self.binds.swap_remove(&addr.port());
assert!(exists.is_some(), "unknown bind {addr}");
tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"TCP", "Unbind");
}
}
pub fn matches(bind: SocketAddr, dst: SocketAddr) -> bool {
if bind.ip().is_unspecified() && bind.port() == dst.port() {
return true;
}
bind == dst
}
pub(crate) fn is_same(src: SocketAddr, dst: SocketAddr) -> bool {
dst.ip().is_loopback() || src.ip() == dst.ip()
}
#[cfg(test)]
mod test {
use std::time::Duration;
#[cfg(feature = "unstable-fs")]
use crate::fs::FsConfig;
use crate::{host::HostTimer, Host, Result};
#[test]
fn recycle_ports() -> Result {
#[cfg(feature = "unstable-fs")]
let mut host = Host::new(
"host",
std::net::Ipv4Addr::UNSPECIFIED.into(),
HostTimer::new(Duration::ZERO, Duration::ZERO),
49152..=49162,
1,
1,
FsConfig::default(),
);
#[cfg(not(feature = "unstable-fs"))]
let mut host = Host::new(
"host",
std::net::Ipv4Addr::UNSPECIFIED.into(),
HostTimer::new(Duration::ZERO, Duration::ZERO),
49152..=49162,
1,
1,
);
host.udp.bind((host.addr, 49161).into())?;
host.udp.bind((host.addr, 49162).into())?;
for _ in 49152..49161 {
host.assign_ephemeral_port();
}
assert_eq!(49152, host.assign_ephemeral_port());
Ok(())
}
}