use std::io::{Error, ErrorKind, Result};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use crate::kernel::packet::{self, Packet, TcpFlags, TcpSegment, Transport};
use crate::kernel::socket::{Addr, BindKey, Domain, Fd, Socket, Tcb, TcpState, Type};
use crate::kernel::Kernel;
pub(super) fn poll_connect(
k: &mut Kernel,
fd: Fd,
cx: &mut Context<'_>,
domain: Domain,
peer: SocketAddr,
is_bound: bool,
) -> Poll<Result<()>> {
if let Some(tcb) = &k.lookup(fd).expect("fd validated").tcb {
return match tcb.state {
TcpState::Established => Poll::Ready(Ok(())),
TcpState::SynSent | TcpState::SynReceived => {
park_connect(k, fd, cx);
Poll::Pending
}
TcpState::Closed
| TcpState::FinWait1
| TcpState::FinWait2
| TcpState::CloseWait
| TcpState::LastAck
| TcpState::Closing => {
if tcb.timed_out {
Poll::Ready(Err(Error::from(ErrorKind::TimedOut)))
} else {
Poll::Ready(Err(Error::from(ErrorKind::ConnectionRefused)))
}
}
};
}
if !is_bound {
auto_bind(k, fd, domain, peer.ip())?;
}
let src = local_endpoint(k, fd);
let isn = initial_sequence(k);
{
let st = k.lookup_mut(fd).expect("fd validated");
st.tcb = Some(Tcb {
state: TcpState::SynSent,
peer,
snd_nxt: isn.wrapping_add(1),
snd_una: isn.wrapping_add(1),
snd_wnd: DEFAULT_WINDOW,
rcv_nxt: 0,
send_buf: BytesMut::new(),
recv_buf: BytesMut::new(),
wr_closed: false,
peer_fin: false,
fin_seq: None,
reset: false,
timed_out: false,
egress_since_ack: 0,
retx_attempts: 0,
});
st.peer = Some(Addr::Inet(peer));
}
k.sockets.insert_connection(src, peer, fd);
emit(
k,
src,
peer,
TcpSegment {
src_port: src.port(),
dst_port: peer.port(),
seq: isn,
ack: 0,
flags: TcpFlags {
syn: true,
..TcpFlags::default()
},
window: DEFAULT_WINDOW,
payload: Bytes::new(),
},
);
park_connect(k, fd, cx);
Poll::Pending
}
pub(super) fn deliver(k: &mut Kernel, pkt: &Packet, s: &TcpSegment) {
let local = SocketAddr::new(pkt.dst, s.dst_port);
let remote = SocketAddr::new(pkt.src, s.src_port);
if let Some(fd) = k.sockets.find_connection(local, remote) {
handle_on_connection(k, fd, local, remote, s);
return;
}
if s.flags.syn && !s.flags.ack {
if let Some(listener) = find_listener(k, local) {
accept_syn(k, listener, local, remote, s);
return;
}
emit_rst(k, local, remote, s);
return;
}
if !s.flags.rst {
emit_rst(k, local, remote, s);
}
}
fn emit_rst(k: &mut Kernel, local: SocketAddr, remote: SocketAddr, s: &TcpSegment) {
let (seq, ack, ack_flag) = if s.flags.ack {
(s.ack, 0, false)
} else {
let seg_len = s.payload.len() as u32
+ if s.flags.syn { 1 } else { 0 }
+ if s.flags.fin { 1 } else { 0 };
(0, s.seq.wrapping_add(seg_len), true)
};
emit(
k,
local,
remote,
TcpSegment {
src_port: local.port(),
dst_port: remote.port(),
seq,
ack,
flags: TcpFlags {
rst: true,
ack: ack_flag,
..TcpFlags::default()
},
window: 0,
payload: Bytes::new(),
},
);
}
fn handle_on_connection(
k: &mut Kernel,
fd: Fd,
local: SocketAddr,
remote: SocketAddr,
s: &TcpSegment,
) {
if s.flags.rst {
abort_connection(k, fd);
return;
}
let state = k
.lookup(fd)
.expect("fd present")
.tcb
.as_ref()
.expect("tcb present")
.state;
match state {
TcpState::SynSent if s.flags.syn && s.flags.ack => {
let recv_cap = k.recv_buf_cap;
let (snd_nxt, rcv_nxt, window) = {
let tcb = k.lookup_mut(fd).unwrap().tcb.as_mut().unwrap();
tcb.state = TcpState::Established;
tcb.rcv_nxt = s.seq.wrapping_add(1);
tcb.snd_wnd = s.window;
(tcb.snd_nxt, tcb.rcv_nxt, advertised_window(recv_cap, 0))
};
wake_connect(k, fd);
emit(
k,
local,
remote,
TcpSegment {
src_port: local.port(),
dst_port: remote.port(),
seq: snd_nxt,
ack: rcv_nxt,
flags: TcpFlags {
ack: true,
..TcpFlags::default()
},
window,
payload: Bytes::new(),
},
);
}
TcpState::SynReceived if s.flags.ack && !s.flags.syn => {
let expected_ack = k.lookup(fd).unwrap().tcb.as_ref().unwrap().snd_nxt;
if s.ack != expected_ack {
return;
}
{
let tcb = k.lookup_mut(fd).unwrap().tcb.as_mut().unwrap();
tcb.state = TcpState::Established;
tcb.snd_wnd = s.window;
}
push_to_listener(k, fd, local);
}
TcpState::Established
| TcpState::FinWait1
| TcpState::FinWait2
| TcpState::CloseWait
| TcpState::Closing
| TcpState::LastAck => handle_established(k, fd, local, remote, s),
TcpState::Closed => {
}
TcpState::SynSent | TcpState::SynReceived => {
}
}
}
fn handle_established(
k: &mut Kernel,
fd: Fd,
local: SocketAddr,
remote: SocketAddr,
s: &TcpSegment,
) {
let mut wake_write = false;
let mut wake_read = false;
let mut send_ack = false;
let recv_cap = k.recv_buf_cap;
{
let st = k.lookup_mut(fd).unwrap();
let tcb = st.tcb.as_mut().unwrap();
if s.flags.ack {
let acked = s.ack.wrapping_sub(tcb.snd_una);
let in_flight = tcb.snd_nxt.wrapping_sub(tcb.snd_una);
if acked > 0 && acked <= in_flight {
let fin_acked = tcb
.fin_seq
.map(|fs| s.ack == fs.wrapping_add(1))
.unwrap_or(false);
let data_bytes = if fin_acked { acked - 1 } else { acked };
if data_bytes > 0 {
let _ = tcb.send_buf.split_to(data_bytes as usize);
}
tcb.snd_una = s.ack;
tcb.egress_since_ack = 0;
tcb.retx_attempts = 0;
if fin_acked {
tcb.state = match tcb.state {
TcpState::FinWait1 => TcpState::FinWait2,
TcpState::Closing => TcpState::Closed,
TcpState::LastAck => TcpState::Closed,
other => other,
};
}
}
tcb.snd_wnd = s.window;
wake_write = true;
}
let tcb = st.tcb.as_mut().unwrap();
if !s.payload.is_empty() && s.seq == tcb.rcv_nxt && !tcb.peer_fin {
let room = recv_cap.saturating_sub(tcb.recv_buf.len());
let n = s.payload.len().min(room);
if n > 0 {
tcb.recv_buf.extend_from_slice(&s.payload[..n]);
tcb.rcv_nxt = tcb.rcv_nxt.wrapping_add(n as u32);
wake_read = true;
send_ack = true;
}
}
let tcb = st.tcb.as_mut().unwrap();
if s.flags.fin && !tcb.peer_fin {
let fin_seq = s.seq.wrapping_add(s.payload.len() as u32);
if fin_seq == tcb.rcv_nxt {
tcb.peer_fin = true;
tcb.rcv_nxt = tcb.rcv_nxt.wrapping_add(1);
send_ack = true;
wake_read = true;
let tcb = st.tcb.as_mut().unwrap();
tcb.state = match tcb.state {
TcpState::Established => TcpState::CloseWait,
TcpState::FinWait1 => TcpState::Closing,
TcpState::FinWait2 => TcpState::Closed,
other => other,
};
}
}
if wake_write {
st.wake_write();
}
if wake_read {
st.wake_read();
}
}
if send_ack {
let (snd_nxt, rcv_nxt, window) = {
let tcb = k.lookup(fd).unwrap().tcb.as_ref().unwrap();
(
tcb.snd_nxt,
tcb.rcv_nxt,
advertised_window(recv_cap, tcb.recv_buf.len()),
)
};
emit(
k,
local,
remote,
TcpSegment {
src_port: local.port(),
dst_port: remote.port(),
seq: snd_nxt,
ack: rcv_nxt,
flags: TcpFlags {
ack: true,
..TcpFlags::default()
},
window,
payload: Bytes::new(),
},
);
}
}
fn accept_syn(
k: &mut Kernel,
listener_fd: Fd,
local: SocketAddr,
remote: SocketAddr,
s: &TcpSegment,
) {
let (backlog, domain, ty) = {
let st = k.lookup(listener_fd).expect("listener present");
let listen = st.listen.as_ref().expect("listener has ListenState");
(listen.backlog, st.domain, st.ty)
};
let in_flight = count_children(k, listener_fd, local);
let ready = k
.lookup(listener_fd)
.unwrap()
.listen
.as_ref()
.unwrap()
.ready
.len();
if in_flight + ready >= backlog {
return; }
let child = k.sockets.insert(Socket::new(domain, ty));
let bind_key = BindKey {
domain,
ty,
local_addr: local.ip(),
local_port: local.port(),
};
k.sockets.insert_binding(bind_key.clone(), child);
let isn = initial_sequence(k);
{
let st = k.sockets.get_mut(child).unwrap();
st.bound = Some(bind_key);
st.peer = Some(Addr::Inet(remote));
st.tcb = Some(Tcb {
state: TcpState::SynReceived,
peer: remote,
snd_nxt: isn.wrapping_add(1),
snd_una: isn.wrapping_add(1),
snd_wnd: s.window,
rcv_nxt: s.seq.wrapping_add(1),
send_buf: BytesMut::new(),
recv_buf: BytesMut::new(),
wr_closed: false,
peer_fin: false,
fin_seq: None,
reset: false,
timed_out: false,
egress_since_ack: 0,
retx_attempts: 0,
});
}
k.sockets.insert_connection(local, remote, child);
emit(
k,
local,
remote,
TcpSegment {
src_port: local.port(),
dst_port: remote.port(),
seq: isn,
ack: s.seq.wrapping_add(1),
flags: TcpFlags {
syn: true,
ack: true,
..TcpFlags::default()
},
window: DEFAULT_WINDOW,
payload: Bytes::new(),
},
);
}
fn push_to_listener(k: &mut Kernel, child: Fd, local: SocketAddr) {
let Some(listener_fd) = find_listener(k, local) else {
return;
};
let wakers: Vec<_> = {
let listen = k
.lookup_mut(listener_fd)
.unwrap()
.listen
.as_mut()
.expect("listener");
listen.ready.push_back(child);
listen.accept_wakers.drain(..).collect()
};
for w in wakers {
w.wake();
}
}
fn park_connect(k: &mut Kernel, fd: Fd, cx: &mut Context<'_>) {
let st = k.lookup_mut(fd).expect("fd present");
st.connect_waker = Some(cx.waker().clone());
}
fn wake_connect(k: &mut Kernel, fd: Fd) {
if let Some(w) = k.lookup_mut(fd).unwrap().connect_waker.take() {
w.wake();
}
}
pub(super) fn on_close(k: &mut Kernel, fd: Fd) -> bool {
enum Action {
Reap,
Linger,
Rst {
local: SocketAddr,
remote: SocketAddr,
},
CloseListener {
local: SocketAddr,
},
}
let action = {
let Some(st) = k.sockets.get(fd) else {
return true;
};
match (st.ty, st.tcb.as_ref(), st.listen.as_ref()) {
(Type::Stream, None, Some(_)) => {
Action::CloseListener {
local: bound_endpoint(st),
}
}
(Type::Stream, Some(tcb), _)
if !tcb.reset
&& !tcb.timed_out
&& tcb.state != TcpState::Closed
&& tcb.state != TcpState::SynSent
&& tcb.state != TcpState::SynReceived =>
{
if !tcb.recv_buf.is_empty() {
Action::Rst {
local: bound_endpoint(st),
remote: tcb.peer,
}
} else {
Action::Linger
}
}
_ => Action::Reap,
}
};
match action {
Action::Reap => true,
Action::Rst { local, remote } => {
let (seq, ack) = {
let tcb = k.lookup(fd).unwrap().tcb.as_ref().unwrap();
(tcb.snd_nxt, tcb.rcv_nxt)
};
emit(
k,
local,
remote,
TcpSegment {
src_port: local.port(),
dst_port: remote.port(),
seq,
ack,
flags: TcpFlags {
rst: true,
ack: true,
..TcpFlags::default()
},
window: 0,
payload: Bytes::new(),
},
);
true
}
Action::Linger => {
let st = k.sockets.get_mut(fd).unwrap();
st.fd_closed = true;
let tcb = st.tcb.as_mut().unwrap();
if !tcb.wr_closed {
let fin_seq = tcb.snd_una.wrapping_add(tcb.send_buf.len() as u32);
tcb.fin_seq = Some(fin_seq);
tcb.wr_closed = true;
tcb.state = match tcb.state {
TcpState::Established => TcpState::FinWait1,
TcpState::CloseWait => TcpState::LastAck,
other => other,
};
}
false
}
Action::CloseListener { local } => {
let listener_port = local.port();
let wildcard = local.ip().is_unspecified();
let mut children: Vec<Fd> = k
.sockets
.get(fd)
.and_then(|s| s.listen.as_ref())
.map(|l| l.ready.iter().copied().collect())
.unwrap_or_default();
for (child_fd, st) in k.sockets.iter() {
if child_fd == fd || children.contains(&child_fd) {
continue;
}
let (Some(tcb), Some(bind)) = (st.tcb.as_ref(), st.bound.as_ref()) else {
continue;
};
if tcb.state != TcpState::SynReceived {
continue;
}
if bind.local_port != listener_port {
continue;
}
if !wildcard && bind.local_addr != local.ip() {
continue;
}
children.push(child_fd);
}
for child in children {
let Some(st) = k.sockets.get(child) else {
continue;
};
let Some(tcb) = st.tcb.as_ref() else {
k.sockets.remove(child);
continue;
};
let child_local = bound_endpoint(st);
let (seq, ack, peer) = (tcb.snd_nxt, tcb.rcv_nxt, tcb.peer);
emit(
k,
child_local,
peer,
TcpSegment {
src_port: child_local.port(),
dst_port: peer.port(),
seq,
ack,
flags: TcpFlags {
rst: true,
ack: true,
..TcpFlags::default()
},
window: 0,
payload: Bytes::new(),
},
);
k.sockets.remove(child);
}
true
}
}
}
pub(super) fn reap_closed(k: &mut Kernel) {
let victims: Vec<Fd> = k
.sockets
.iter()
.filter(|(_, s)| {
s.fd_closed
&& s.tcb
.as_ref()
.map(|t| t.state == TcpState::Closed || t.reset)
.unwrap_or(true)
})
.map(|(fd, _)| fd)
.collect();
for fd in victims {
k.sockets.remove(fd);
}
}
fn abort_connection(k: &mut Kernel, fd: Fd) {
abort_with(k, fd, AbortReason::Reset);
}
fn abort_timed_out(k: &mut Kernel, fd: Fd) {
abort_with(k, fd, AbortReason::TimedOut);
}
enum AbortReason {
Reset,
TimedOut,
}
fn abort_error(tcb: &Tcb) -> Option<Error> {
if tcb.reset {
Some(Error::from(ErrorKind::ConnectionReset))
} else if tcb.timed_out {
Some(Error::from(ErrorKind::TimedOut))
} else {
None
}
}
fn abort_with(k: &mut Kernel, fd: Fd, reason: AbortReason) {
let st = k.lookup_mut(fd).unwrap();
if let Some(tcb) = st.tcb.as_mut() {
tcb.state = TcpState::Closed;
match reason {
AbortReason::Reset => tcb.reset = true,
AbortReason::TimedOut => tcb.timed_out = true,
}
tcb.send_buf.clear();
tcb.recv_buf.clear();
}
if let Some(w) = st.connect_waker.take() {
w.wake();
}
st.wake_read();
st.wake_write();
}
fn find_listener(k: &Kernel, local: SocketAddr) -> Option<Fd> {
let domain = match local {
SocketAddr::V4(_) => Domain::Inet,
SocketAddr::V6(_) => Domain::Inet6,
};
let exact = BindKey {
domain,
ty: Type::Stream,
local_addr: local.ip(),
local_port: local.port(),
};
let wildcard_ip = match local {
SocketAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
SocketAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
};
let wildcard = BindKey {
domain,
ty: Type::Stream,
local_addr: wildcard_ip,
local_port: local.port(),
};
for key in [&exact, &wildcard] {
for &fd in k.sockets.find_by_bind(key) {
if k.sockets.get(fd).unwrap().listen.is_some() {
return Some(fd);
}
}
}
None
}
fn count_children(k: &Kernel, listener_fd: Fd, local: SocketAddr) -> usize {
k.sockets
.connections_on(local)
.filter(|(_, fd)| {
if *fd == listener_fd {
return false;
}
k.sockets
.get(*fd)
.and_then(|s| s.tcb.as_ref())
.map(|t| t.state == TcpState::SynReceived)
.unwrap_or(false)
})
.count()
}
fn emit(k: &mut Kernel, src: SocketAddr, dst: SocketAddr, seg: TcpSegment) {
k.outbound.push_back(Packet {
src: src.ip(),
dst: dst.ip(),
ttl: 64,
payload: Transport::Tcp(seg),
});
}
fn local_endpoint(k: &Kernel, fd: Fd) -> SocketAddr {
let bind = k
.lookup(fd)
.unwrap()
.bound
.as_ref()
.expect("fd bound by now");
let ip = if bind.local_addr.is_unspecified() {
let peer = k.lookup(fd).unwrap().tcb.as_ref().map(|t| t.peer.ip());
match peer {
Some(p) if p.is_loopback() => match p {
IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::LOCALHOST),
IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::LOCALHOST),
},
Some(p) => k
.addresses
.iter()
.copied()
.find(|a| a.is_ipv4() == p.is_ipv4())
.unwrap_or(bind.local_addr),
None => bind.local_addr,
}
} else {
bind.local_addr
};
SocketAddr::new(ip, bind.local_port)
}
fn auto_bind(k: &mut Kernel, fd: Fd, domain: Domain, dst: IpAddr) -> Result<()> {
let local_ip = if dst.is_loopback() {
match dst {
IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::LOCALHOST),
IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::LOCALHOST),
}
} else {
k.addresses
.iter()
.copied()
.find(|a| a.is_ipv4() == dst.is_ipv4())
.ok_or_else(|| Error::from(ErrorKind::AddrNotAvailable))?
};
let port = k
.sockets
.allocate_port(domain, Type::Stream)
.ok_or_else(|| Error::from(ErrorKind::AddrInUse))?;
let key = BindKey {
domain,
ty: Type::Stream,
local_addr: local_ip,
local_port: port,
};
k.sockets.insert_binding(key.clone(), fd);
k.sockets.get_mut(fd).expect("fd present").bound = Some(key);
Ok(())
}
fn initial_sequence(k: &mut Kernel) -> u32 {
let v = k.tcp_isn;
k.tcp_isn = k.tcp_isn.wrapping_add(0x1_0000);
v
}
pub(super) fn poll_send(
k: &mut Kernel,
fd: Fd,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
let send_cap = k.send_buf_cap;
let st = match k.lookup_mut(fd) {
Ok(st) => st,
Err(e) => return Poll::Ready(Err(e)),
};
let Some(tcb) = st.tcb.as_ref() else {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
};
if let Some(e) = abort_error(tcb) {
return Poll::Ready(Err(e));
}
if tcb.wr_closed {
return Poll::Ready(Err(Error::from(ErrorKind::BrokenPipe)));
}
if !matches!(tcb.state, TcpState::Established | TcpState::CloseWait) {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
let space = send_cap.saturating_sub(tcb.send_buf.len());
if space == 0 {
st.register_write_waker(cx.waker());
return Poll::Pending;
}
let n = buf.len().min(space);
st.tcb
.as_mut()
.unwrap()
.send_buf
.extend_from_slice(&buf[..n]);
Poll::Ready(Ok(n))
}
pub(super) fn poll_shutdown_write(
k: &mut Kernel,
fd: Fd,
_cx: &mut Context<'_>,
) -> Poll<Result<()>> {
let st = match k.lookup_mut(fd) {
Ok(st) => st,
Err(e) => return Poll::Ready(Err(e)),
};
let Some(tcb) = st.tcb.as_mut() else {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
};
if let Some(e) = abort_error(tcb) {
return Poll::Ready(Err(e));
}
if tcb.wr_closed {
return Poll::Ready(Ok(()));
}
let fin_seq = tcb.snd_una.wrapping_add(tcb.send_buf.len() as u32);
tcb.fin_seq = Some(fin_seq);
tcb.wr_closed = true;
tcb.state = match tcb.state {
TcpState::Established => TcpState::FinWait1,
TcpState::CloseWait => TcpState::LastAck,
other => other,
};
Poll::Ready(Ok(()))
}
pub(super) fn poll_recv(
k: &mut Kernel,
fd: Fd,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
let recv_cap = k.recv_buf_cap;
let (n, should_update_window, local, remote) = {
let st = match k.lookup_mut(fd) {
Ok(st) => st,
Err(e) => return Poll::Ready(Err(e)),
};
let (empty, peer_fin, abort_err, readable_state, peer) = match st.tcb.as_ref() {
None => return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))),
Some(t) => (
t.recv_buf.is_empty(),
t.peer_fin,
abort_error(t),
matches!(
t.state,
TcpState::Established
| TcpState::FinWait1
| TcpState::FinWait2
| TcpState::CloseWait
),
t.peer,
),
};
if let Some(e) = abort_err {
return Poll::Ready(Err(e));
}
if empty {
if peer_fin {
return Poll::Ready(Ok(0));
}
if !readable_state {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
st.register_read_waker(cx.waker());
return Poll::Pending;
}
let local = bound_endpoint(st);
let tcb = st.tcb.as_mut().unwrap();
let n = tcb.recv_buf.len().min(buf.len());
let drained = tcb.recv_buf.split_to(n);
buf[..n].copy_from_slice(&drained);
let should_update = n >= recv_cap / 2;
(n, should_update, local, peer)
};
if should_update_window {
let (snd_nxt, rcv_nxt, window) = {
let tcb = k.lookup(fd).unwrap().tcb.as_ref().unwrap();
(
tcb.snd_nxt,
tcb.rcv_nxt,
advertised_window(recv_cap, tcb.recv_buf.len()),
)
};
emit(
k,
local,
remote,
TcpSegment {
src_port: local.port(),
dst_port: remote.port(),
seq: snd_nxt,
ack: rcv_nxt,
flags: TcpFlags {
ack: true,
..TcpFlags::default()
},
window,
payload: Bytes::new(),
},
);
}
Poll::Ready(Ok(n))
}
pub(super) fn poll_peek(
k: &mut Kernel,
fd: Fd,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
let st = match k.lookup_mut(fd) {
Ok(st) => st,
Err(e) => return Poll::Ready(Err(e)),
};
let Some(tcb) = st.tcb.as_ref() else {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
};
if let Some(e) = abort_error(tcb) {
return Poll::Ready(Err(e));
}
if tcb.recv_buf.is_empty() {
if tcb.peer_fin {
return Poll::Ready(Ok(0));
}
if !matches!(
tcb.state,
TcpState::Established | TcpState::FinWait1 | TcpState::FinWait2 | TcpState::CloseWait
) {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
st.register_read_waker(cx.waker());
return Poll::Pending;
}
let n = tcb.recv_buf.len().min(buf.len());
buf[..n].copy_from_slice(&tcb.recv_buf[..n]);
Poll::Ready(Ok(n))
}
pub(super) fn check_retx(k: &mut Kernel) {
let threshold = k.retx_threshold;
let max = k.retx_max;
let candidates: Vec<Fd> = k
.sockets
.iter()
.filter_map(|(fd, st)| {
let tcb = st.tcb.as_ref()?;
let handshake = matches!(tcb.state, TcpState::SynSent | TcpState::SynReceived);
let data = matches!(
tcb.state,
TcpState::Established
| TcpState::CloseWait
| TcpState::FinWait1
| TcpState::Closing
| TcpState::LastAck
) && tcb.snd_una != tcb.snd_nxt;
if handshake || data {
Some(fd)
} else {
None
}
})
.collect();
let mut abort: Vec<Fd> = Vec::new();
let mut resend_handshake: Vec<Fd> = Vec::new();
for fd in candidates {
let tcb = k.sockets.get_mut(fd).unwrap().tcb.as_mut().unwrap();
tcb.egress_since_ack += 1;
if tcb.egress_since_ack < threshold {
continue;
}
if tcb.retx_attempts >= max {
abort.push(fd);
continue;
}
tcb.retx_attempts += 1;
tcb.egress_since_ack = 0;
match tcb.state {
TcpState::SynSent | TcpState::SynReceived => resend_handshake.push(fd),
_ => {
tcb.snd_nxt = tcb.snd_una;
}
}
}
for fd in resend_handshake {
emit_handshake(k, fd);
}
for fd in abort {
abort_timed_out(k, fd);
}
}
fn emit_handshake(k: &mut Kernel, fd: Fd) {
let st = k.lookup(fd).expect("retx candidate");
let tcb = st.tcb.as_ref().expect("handshake state has tcb");
let local = bound_endpoint(st);
let remote = tcb.peer;
let seq = tcb.snd_una.wrapping_sub(1);
let (ack_flag, ack) = match tcb.state {
TcpState::SynSent => (false, 0),
TcpState::SynReceived => (true, tcb.rcv_nxt),
_ => unreachable!("emit_handshake only called for SynSent/SynReceived"),
};
emit(
k,
local,
remote,
TcpSegment {
src_port: local.port(),
dst_port: remote.port(),
seq,
ack,
flags: TcpFlags {
syn: true,
ack: ack_flag,
..TcpFlags::default()
},
window: DEFAULT_WINDOW,
payload: Bytes::new(),
},
);
}
pub(super) fn segment_all(k: &mut Kernel) {
let candidates: Vec<Fd> = k
.sockets
.iter()
.filter(|(_, s)| {
s.tcb
.as_ref()
.map(|t| {
let transmittable = matches!(
t.state,
TcpState::Established
| TcpState::CloseWait
| TcpState::FinWait1
| TcpState::Closing
| TcpState::LastAck
);
let has_data = t.send_buf.len() > (t.snd_nxt.wrapping_sub(t.snd_una)) as usize;
let fin_pending = t.fin_seq.map(|fs| t.snd_nxt == fs).unwrap_or(false);
transmittable && (has_data || fin_pending)
})
.unwrap_or(false)
})
.map(|(fd, _)| fd)
.collect();
for fd in candidates {
segment_one(k, fd);
}
}
fn segment_one(k: &mut Kernel, fd: Fd) {
let local = {
let st = k.lookup(fd).unwrap();
bound_endpoint(st)
};
let mss = mss_for(k, local.ip());
let recv_cap = k.recv_buf_cap;
loop {
let (seq, payload, is_fin) = {
let tcb = k.lookup_mut(fd).unwrap().tcb.as_mut().unwrap();
let in_flight = tcb.snd_nxt.wrapping_sub(tcb.snd_una) as usize;
let unsent = tcb.send_buf.len().saturating_sub(in_flight);
let wnd_remaining = (tcb.snd_wnd as usize).saturating_sub(in_flight);
let fin_pending = tcb.fin_seq.map(|fs| tcb.snd_nxt == fs).unwrap_or(false);
if unsent > 0 && wnd_remaining > 0 {
let n = unsent.min(mss).min(wnd_remaining);
let start = in_flight;
let end = start + n;
let payload = Bytes::copy_from_slice(&tcb.send_buf[start..end]);
let seq = tcb.snd_nxt;
tcb.snd_nxt = tcb.snd_nxt.wrapping_add(n as u32);
(seq, payload, false)
} else if fin_pending && wnd_remaining > 0 {
let seq = tcb.snd_nxt;
tcb.snd_nxt = tcb.snd_nxt.wrapping_add(1);
(seq, Bytes::new(), true)
} else {
return;
}
};
let remote = k.lookup(fd).unwrap().tcb.as_ref().unwrap().peer;
let (rcv_nxt, window) = {
let tcb = k.lookup(fd).unwrap().tcb.as_ref().unwrap();
(tcb.rcv_nxt, advertised_window(recv_cap, tcb.recv_buf.len()))
};
emit(
k,
local,
remote,
TcpSegment {
src_port: local.port(),
dst_port: remote.port(),
seq,
ack: rcv_nxt,
flags: TcpFlags {
ack: true,
psh: !is_fin && !payload.is_empty(),
fin: is_fin,
..TcpFlags::default()
},
window,
payload,
},
);
}
}
fn mss_for(k: &Kernel, src_ip: IpAddr) -> usize {
let ip_hdr = match src_ip {
IpAddr::V4(_) => packet::IPV4_HEADER_SIZE as u32,
IpAddr::V6(_) => packet::IPV6_HEADER_SIZE as u32,
};
let mtu = if src_ip.is_loopback() {
k.loopback_mtu
} else {
k.mtu
};
mtu.saturating_sub(ip_hdr)
.saturating_sub(packet::TCP_HEADER_SIZE as u32) as usize
}
fn bound_endpoint(st: &Socket) -> SocketAddr {
let bind = st.bound.as_ref().expect("bound at handshake time");
SocketAddr::new(bind.local_addr, bind.local_port)
}
fn advertised_window(recv_buf_cap: usize, recv_buf_len: usize) -> u16 {
recv_buf_cap
.saturating_sub(recv_buf_len)
.min(u16::MAX as usize) as u16
}
const DEFAULT_WINDOW: u16 = 65535;