#![cfg(all(feature = "std", feature = "async-io"))]
use core::future::poll_fn;
use core::pin::pin;
use core::task::Poll;
use alloc::vec::Vec;
use std::io::{Read, Write};
use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream};
use async_io::Async;
use embassy_futures::select::{select, Either};
use embassy_time::{Duration, Timer};
use crate::error::{Error, ErrorCode};
use crate::transport::network::Address;
use crate::utils::cell::RefCell;
use crate::utils::sync::blocking::Mutex;
use crate::utils::sync::{IfMutex, Notification};
use super::{NetworkReceive, NetworkSend, MAX_RX_LARGE_PACKET_SIZE};
extern crate alloc;
const SEND_TIMEOUT: Duration = Duration::from_secs(5);
pub const DEFAULT_MAX_TCP_CONNECTIONS: usize = 8;
const FRAME_HDR_LEN: usize = 4;
const MAX_RX_BUF_SIZE: usize = MAX_RX_LARGE_PACKET_SIZE + FRAME_HDR_LEN;
pub struct TcpNetwork<const N: usize = DEFAULT_MAX_TCP_CONNECTIONS> {
listener: Async<TcpListener>,
inner: Mutex<RefCell<TcpInner>>,
pool_changed: Notification,
send_mutex: IfMutex<()>,
}
impl<const N: usize> TcpNetwork<N> {
pub fn new(listener: Async<TcpListener>) -> Self {
Self {
listener,
inner: Mutex::new(RefCell::new(TcpInner::new())),
pool_changed: Notification::new(),
send_mutex: IfMutex::new(()),
}
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.get_ref().local_addr()
}
async fn wait_available(&self) -> Result<(), Error> {
let ready = self
.inner
.lock(|inner| inner.borrow().find_ready_connection().is_some());
if ready {
return Ok(());
}
let ready = self
.inner
.lock(|inner| inner.borrow_mut().try_read_all().is_some());
if ready {
return Ok(());
}
loop {
let event = {
let io_poll = poll_fn(|cx| {
if self.listener.poll_readable(cx).is_ready() {
return Poll::Ready(IoEvent::ListenerReady);
}
self.inner.lock(|inner| {
let inner = inner.borrow();
for (i, conn) in inner.connections.iter().enumerate() {
match conn.stream.poll_readable(cx) {
Poll::Ready(Ok(())) => {
return Poll::Ready(IoEvent::ConnectionReadable(i));
}
Poll::Ready(Err(_)) => {
return Poll::Ready(IoEvent::ConnectionError(i));
}
Poll::Pending => {}
}
}
Poll::Pending
})
});
let mut io_poll = pin!(io_poll);
let mut pool_notif = pin!(self.pool_changed.wait());
match select(&mut io_poll, &mut pool_notif).await {
Either::First(event) => event,
Either::Second(()) => IoEvent::PoolChanged,
}
};
match event {
IoEvent::ListenerReady => {
match self.listener.accept().await {
Ok((stream, addr)) => {
self.inner.lock(|inner| {
inner.borrow_mut().add_connection::<N>(stream, addr);
});
}
Err(e) => {
return Err(e.into());
}
}
}
IoEvent::ConnectionReadable(idx) => {
self.inner.lock(|inner| {
let mut inner = inner.borrow_mut();
if idx < inner.connections.len()
&& inner.connections[idx].try_read_nonblocking().is_err()
{
inner.remove_connection(idx);
}
});
}
IoEvent::ConnectionError(idx) => {
self.inner.lock(|inner| {
let mut inner = inner.borrow_mut();
if idx < inner.connections.len() {
inner.remove_connection(idx);
}
});
}
IoEvent::PoolChanged => {
}
}
let ready = self
.inner
.lock(|inner| inner.borrow().find_ready_connection().is_some());
if ready {
return Ok(());
}
let ready = self
.inner
.lock(|inner| inner.borrow_mut().try_read_all().is_some());
if ready {
return Ok(());
}
}
}
async fn recv_from(&self, buffer: &mut [u8]) -> Result<(usize, Address), Error> {
loop {
self.wait_available().await?;
let result = self.inner.lock(|inner| {
let mut inner = inner.borrow_mut();
let Some(idx) = inner.find_ready_connection() else {
return Ok(None);
};
let conn = &mut inner.connections[idx];
let Some(msg_len) = conn.has_complete_message() else {
return Ok(None);
};
if msg_len > buffer.len() {
conn.rx_buf.drain(..FRAME_HDR_LEN + msg_len);
Err(ErrorCode::BufferTooSmall)?;
}
let payload = conn.extract_message(msg_len);
buffer[..msg_len].copy_from_slice(&payload);
let remote = conn.remote;
inner.poll_index = (idx + 1) % inner.connections.len().max(1);
Ok::<_, Error>(Some((msg_len, Address::Tcp(remote))))
})?;
if let Some((len, addr)) = result {
break Ok((len, addr));
}
}
}
async fn send_to(&self, data: &[u8], addr: Address) -> Result<(), Error> {
let sock_addr = addr.tcp().ok_or(ErrorCode::NoNetworkInterface)?;
let send_result = select(
pin!(async {
let conn_id = self.ensure_connected(sock_addr).await?;
let _send_guard = self.send_mutex.lock().await;
let result = self.send_framed_to(sock_addr, conn_id, data).await;
if result.is_err() {
self.inner.lock(|inner| {
let mut inner = inner.borrow_mut();
if let Some(idx) = inner.find_connection_exact(&sock_addr, conn_id) {
inner.remove_connection(idx);
}
});
let new_conn_id = self.ensure_connected(sock_addr).await?;
self.send_framed_to(sock_addr, new_conn_id, data).await?;
}
Ok::<(), Error>(())
}),
pin!(Timer::after(SEND_TIMEOUT)),
)
.await;
match send_result {
Either::First(result) => result,
Either::Second(()) => Err(ErrorCode::TxTimeout.into()),
}
}
async fn ensure_connected(&self, addr: SocketAddr) -> Result<u64, Error> {
let existing_id = self
.inner
.lock(|inner| inner.borrow().find_connection_id(&addr));
if let Some(id) = existing_id {
return Ok(id);
}
let stream = Async::<TcpStream>::connect(addr).await?;
let conn_id = self.inner.lock(|inner| {
let mut inner = inner.borrow_mut();
if let Some(id) = inner.find_connection_id(&addr) {
let _ = stream.get_ref().shutdown(Shutdown::Both);
return id;
}
inner.add_connection::<N>(stream, addr)
});
self.pool_changed.notify();
Ok(conn_id)
}
async fn send_framed_to(
&self,
addr: SocketAddr,
conn_id: u64,
data: &[u8],
) -> Result<(), Error> {
let len = data.len() as u32;
let mut frame = Vec::with_capacity(FRAME_HDR_LEN + data.len());
frame.extend_from_slice(&len.to_le_bytes());
frame.extend_from_slice(data);
self.write_all_to(addr, conn_id, &frame).await
}
async fn write_all_to(
&self,
addr: SocketAddr,
conn_id: u64,
mut data: &[u8],
) -> Result<(), Error> {
while !data.is_empty() {
let written: Result<usize, Error> = self.inner.lock(|inner| {
let inner = inner.borrow();
if let Some(idx) = inner.find_connection_exact(&addr, conn_id) {
match inner.connections[idx].stream.get_ref().write(data) {
Ok(0) => Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write to TCP stream",
)
.into()),
Ok(n) => Ok(n),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
Err(e) => Err(e.into()),
}
} else {
Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"TCP connection removed from pool",
)
.into())
}
});
match written? {
0 => {
let mut registered = false;
poll_fn(|cx| {
if !registered {
registered = true;
self.inner.lock(|inner| {
let inner = inner.borrow();
if let Some(idx) = inner.find_connection_exact(&addr, conn_id) {
let _ = inner.connections[idx].stream.poll_writable(cx);
Poll::Pending
} else {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"TCP connection removed from pool",
)))
}
})
} else {
Poll::Ready(Ok(()))
}
})
.await?;
}
n => data = &data[n..],
}
}
Ok(())
}
}
impl<const N: usize> NetworkSend for TcpNetwork<N> {
async fn send_to(&mut self, data: &[u8], addr: Address) -> Result<(), Error> {
TcpNetwork::<N>::send_to(self, data, addr).await
}
}
impl<const N: usize> NetworkReceive for TcpNetwork<N> {
async fn wait_available(&mut self) -> Result<(), Error> {
TcpNetwork::<N>::wait_available(self).await
}
async fn recv_from(&mut self, buffer: &mut [u8]) -> Result<(usize, Address), Error> {
TcpNetwork::<N>::recv_from(self, buffer).await
}
}
impl<const N: usize> NetworkSend for &TcpNetwork<N> {
async fn send_to(&mut self, data: &[u8], addr: Address) -> Result<(), Error> {
TcpNetwork::<N>::send_to(*self, data, addr).await
}
}
impl<const N: usize> NetworkReceive for &TcpNetwork<N> {
async fn wait_available(&mut self) -> Result<(), Error> {
TcpNetwork::<N>::wait_available(self).await
}
async fn recv_from(&mut self, buffer: &mut [u8]) -> Result<(usize, Address), Error> {
TcpNetwork::<N>::recv_from(*self, buffer).await
}
}
struct TcpInner {
connections: Vec<TcpConnection>,
poll_index: usize,
next_conn_id: u64,
}
impl TcpInner {
const fn new() -> Self {
Self {
connections: Vec::new(),
poll_index: 0,
next_conn_id: 0,
}
}
fn find_connection_id(&self, addr: &SocketAddr) -> Option<u64> {
self.connections
.iter()
.find(|c| c.remote == *addr)
.map(|c| c.conn_id)
}
fn find_connection_exact(&self, addr: &SocketAddr, conn_id: u64) -> Option<usize> {
self.connections
.iter()
.position(|c| c.remote == *addr && c.conn_id == conn_id)
}
fn add_connection<const N: usize>(
&mut self,
stream: Async<TcpStream>,
addr: SocketAddr,
) -> u64 {
stream.get_ref().set_nodelay(true).ok();
if self.connections.len() >= N {
let evicted = self.connections.remove(0);
let _ = evicted.stream.get_ref().shutdown(Shutdown::Both);
}
let conn_id = self.next_conn_id;
self.next_conn_id += 1;
self.connections
.push(TcpConnection::new(conn_id, addr, stream));
conn_id
}
fn remove_connection(&mut self, index: usize) {
if index < self.connections.len() {
let evicted = self.connections.remove(index);
let _ = evicted.stream.get_ref().shutdown(Shutdown::Both);
}
}
fn find_ready_connection(&self) -> Option<usize> {
let len = self.connections.len();
if len == 0 {
return None;
}
for i in 0..len {
let idx = (self.poll_index + i) % len;
if self.connections[idx].has_complete_message().is_some() {
return Some(idx);
}
}
None
}
fn try_read_all(&mut self) -> Option<usize> {
let mut i = 0;
while i < self.connections.len() {
match self.connections[i].try_read_nonblocking() {
Ok(_) => {
if self.connections[i].has_complete_message().is_some() {
return Some(i);
}
i += 1;
}
Err(_) => {
self.remove_connection(i);
}
}
}
None
}
}
struct TcpConnection {
conn_id: u64,
remote: SocketAddr,
stream: Async<TcpStream>,
rx_buf: Vec<u8>,
}
impl TcpConnection {
const fn new(conn_id: u64, remote: SocketAddr, stream: Async<TcpStream>) -> Self {
Self {
conn_id,
remote,
stream,
rx_buf: Vec::new(),
}
}
fn has_complete_message(&self) -> Option<usize> {
if self.rx_buf.len() < FRAME_HDR_LEN {
return None;
}
let hdr: [u8; 4] = self.rx_buf[..4].try_into().unwrap();
let msg_len = u32::from_le_bytes(hdr) as usize;
if self.rx_buf.len() >= FRAME_HDR_LEN + msg_len {
Some(msg_len)
} else {
None
}
}
fn extract_message(&mut self, msg_len: usize) -> Vec<u8> {
let payload = self.rx_buf[FRAME_HDR_LEN..FRAME_HDR_LEN + msg_len].to_vec();
self.rx_buf.drain(..FRAME_HDR_LEN + msg_len);
payload
}
fn try_read_nonblocking(&mut self) -> std::io::Result<usize> {
if self.rx_buf.len() >= MAX_RX_BUF_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::OutOfMemory,
"TCP receive buffer exceeded maximum size",
));
}
let mut tmp = [0u8; 4096];
match self.stream.get_ref().read(&mut tmp) {
Ok(0) => Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"TCP connection closed by peer",
)),
Ok(n) => {
self.rx_buf.extend_from_slice(&tmp[..n]);
if self.rx_buf.len() > MAX_RX_BUF_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::OutOfMemory,
"TCP receive buffer exceeded maximum size",
));
}
Ok(n)
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
Err(e) => Err(e),
}
}
}
enum IoEvent {
ListenerReady,
ConnectionReadable(usize),
ConnectionError(usize),
PoolChanged,
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use async_io::Async;
use futures_lite::future::block_on;
fn ephemeral_listener() -> Async<std::net::TcpListener> {
Async::<std::net::TcpListener>::bind(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::LOCALHOST,
0,
)))
.unwrap()
}
fn local_addr(listener: &Async<std::net::TcpListener>) -> SocketAddr {
listener.get_ref().local_addr().unwrap()
}
fn send_framed_raw(stream: &mut std::net::TcpStream, payload: &[u8]) {
let len = payload.len() as u32;
stream.write_all(&len.to_le_bytes()).unwrap();
stream.write_all(payload).unwrap();
stream.flush().unwrap();
}
#[test]
fn has_complete_message_empty_buffer() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let conn = TcpConnection::new(0, laddr, Async::new(raw).unwrap());
assert!(conn.has_complete_message().is_none());
}
#[test]
fn has_complete_message_partial_header() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let mut conn = TcpConnection::new(0, laddr, Async::new(raw).unwrap());
conn.rx_buf.push(0x05);
assert!(conn.has_complete_message().is_none());
conn.rx_buf.extend_from_slice(&[0x00, 0x00]);
assert!(conn.has_complete_message().is_none());
}
#[test]
fn has_complete_message_header_only_no_payload() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let mut conn = TcpConnection::new(0, laddr, Async::new(raw).unwrap());
conn.rx_buf.extend_from_slice(&5u32.to_le_bytes());
assert!(conn.has_complete_message().is_none());
}
#[test]
fn has_complete_message_partial_payload() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let mut conn = TcpConnection::new(0, laddr, Async::new(raw).unwrap());
conn.rx_buf.extend_from_slice(&5u32.to_le_bytes());
conn.rx_buf.extend_from_slice(&[1, 2, 3]);
assert!(conn.has_complete_message().is_none());
}
#[test]
fn has_complete_message_exact_payload() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let mut conn = TcpConnection::new(0, laddr, Async::new(raw).unwrap());
conn.rx_buf.extend_from_slice(&5u32.to_le_bytes());
conn.rx_buf.extend_from_slice(&[1, 2, 3, 4, 5]);
assert_eq!(conn.has_complete_message(), Some(5));
}
#[test]
fn has_complete_message_zero_length() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let mut conn = TcpConnection::new(0, laddr, Async::new(raw).unwrap());
conn.rx_buf.extend_from_slice(&0u32.to_le_bytes());
assert_eq!(conn.has_complete_message(), Some(0));
}
#[test]
fn extract_message_removes_frame() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let mut conn = TcpConnection::new(0, laddr, Async::new(raw).unwrap());
conn.rx_buf.extend_from_slice(&3u32.to_le_bytes());
conn.rx_buf.extend_from_slice(b"abc");
conn.rx_buf.extend_from_slice(&2u32.to_le_bytes());
conn.rx_buf.extend_from_slice(b"xy");
assert_eq!(conn.has_complete_message(), Some(3));
let msg = conn.extract_message(3);
assert_eq!(msg, b"abc");
assert_eq!(conn.has_complete_message(), Some(2));
let msg = conn.extract_message(2);
assert_eq!(msg, b"xy");
assert!(conn.rx_buf.is_empty());
assert!(conn.has_complete_message().is_none());
}
fn make_inner_connection(
inner: &mut TcpInner,
addr: SocketAddr,
) -> (u64, std::net::TcpListener) {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let stream = Async::new(raw).unwrap();
let id = inner.add_connection::<8>(stream, addr);
(id, listener)
}
#[test]
fn add_connection_assigns_unique_ids() {
let mut inner = TcpInner::new();
let (id1, _l1) = make_inner_connection(&mut inner, "1.2.3.4:100".parse().unwrap());
let (id2, _l2) = make_inner_connection(&mut inner, "1.2.3.4:101".parse().unwrap());
let (id3, _l3) = make_inner_connection(&mut inner, "1.2.3.4:102".parse().unwrap());
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert_eq!(inner.connections.len(), 3);
}
#[test]
fn add_connection_evicts_oldest_at_capacity() {
let mut inner = TcpInner::new();
let mut _listeners = Vec::new();
let listener1 = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let l1addr = listener1.local_addr().unwrap();
let raw1 = std::net::TcpStream::connect(l1addr).unwrap();
raw1.set_nonblocking(true).unwrap();
let id1 =
inner.add_connection::<2>(Async::new(raw1).unwrap(), "1.2.3.4:100".parse().unwrap());
_listeners.push(listener1);
let listener2 = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let l2addr = listener2.local_addr().unwrap();
let raw2 = std::net::TcpStream::connect(l2addr).unwrap();
raw2.set_nonblocking(true).unwrap();
let _id2 =
inner.add_connection::<2>(Async::new(raw2).unwrap(), "1.2.3.4:101".parse().unwrap());
_listeners.push(listener2);
assert_eq!(inner.connections.len(), 2);
let listener3 = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let l3addr = listener3.local_addr().unwrap();
let raw3 = std::net::TcpStream::connect(l3addr).unwrap();
raw3.set_nonblocking(true).unwrap();
let _id3 =
inner.add_connection::<2>(Async::new(raw3).unwrap(), "1.2.3.4:102".parse().unwrap());
_listeners.push(listener3);
assert_eq!(inner.connections.len(), 2);
assert!(inner
.find_connection_exact(&"1.2.3.4:100".parse().unwrap(), id1)
.is_none());
}
#[test]
fn find_connection_exact_matches_both_addr_and_id() {
let mut inner = TcpInner::new();
let addr: SocketAddr = "1.2.3.4:100".parse().unwrap();
let (id1, _l1) = make_inner_connection(&mut inner, addr);
assert!(inner.find_connection_exact(&addr, id1).is_some());
assert!(inner.find_connection_exact(&addr, id1 + 999).is_none());
assert!(inner
.find_connection_exact(&"5.6.7.8:200".parse().unwrap(), id1)
.is_none());
}
#[test]
fn remove_connection_drops_partial_buffer() {
let mut inner = TcpInner::new();
let addr: SocketAddr = "1.2.3.4:100".parse().unwrap();
let (id, _l) = make_inner_connection(&mut inner, addr);
let idx = inner.find_connection_exact(&addr, id).unwrap();
inner.connections[idx]
.rx_buf
.extend_from_slice(&[0x05, 0x00, 0x01]);
inner.remove_connection(idx);
assert!(inner.connections.is_empty());
}
#[test]
fn find_ready_connection_round_robin() {
let mut inner = TcpInner::new();
let addr_a: SocketAddr = "1.2.3.4:100".parse().unwrap();
let addr_b: SocketAddr = "1.2.3.4:101".parse().unwrap();
let (_id_a, _la) = make_inner_connection(&mut inner, addr_a);
let (_id_b, _lb) = make_inner_connection(&mut inner, addr_b);
inner.connections[0]
.rx_buf
.extend_from_slice(&3u32.to_le_bytes());
inner.connections[0].rx_buf.extend_from_slice(b"aaa");
inner.connections[1]
.rx_buf
.extend_from_slice(&3u32.to_le_bytes());
inner.connections[1].rx_buf.extend_from_slice(b"bbb");
inner.poll_index = 0;
assert_eq!(inner.find_ready_connection(), Some(0));
inner.poll_index = 1;
assert_eq!(inner.find_ready_connection(), Some(1));
}
#[test]
fn send_and_receive_single_message() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<8>::new(listener);
let client = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
send_framed_raw(&mut stream, b"hello matter");
stream
});
let mut buf = [0u8; 256];
let net: &TcpNetwork<8> = &tcp;
let (len, from) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len], b"hello matter");
assert!(from.is_tcp());
let _stream = client.join().unwrap();
});
}
#[test]
fn send_and_receive_multiple_messages_same_connection() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<8>::new(listener);
let client = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
send_framed_raw(&mut stream, b"msg1");
send_framed_raw(&mut stream, b"msg2");
send_framed_raw(&mut stream, b"msg3");
stream
});
let net: &TcpNetwork<8> = &tcp;
let mut buf = [0u8; 256];
let (len, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len], b"msg1");
let (len, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len], b"msg2");
let (len, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len], b"msg3");
let _stream = client.join().unwrap();
});
}
#[test]
fn send_to_creates_connection_and_delivers() {
block_on(async {
let server_listener = ephemeral_listener();
let server_addr = local_addr(&server_listener);
let our_listener = ephemeral_listener();
let tcp = TcpNetwork::<8>::new(our_listener);
let net: &TcpNetwork<8> = &tcp;
NetworkSend::send_to(&mut { net }, b"outgoing", Address::Tcp(server_addr))
.await
.unwrap();
let (stream, _) = server_listener.accept().await.unwrap();
let raw = stream.into_inner().unwrap();
raw.set_nonblocking(false).unwrap();
let mut reader = std::io::BufReader::new(raw);
let mut hdr = [0u8; 4];
std::io::Read::read_exact(&mut reader, &mut hdr).unwrap();
let msg_len = u32::from_le_bytes(hdr) as usize;
let mut payload = vec![0u8; msg_len];
std::io::Read::read_exact(&mut reader, &mut payload).unwrap();
assert_eq!(payload, b"outgoing");
});
}
#[test]
fn recv_from_returns_tcp_address() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<8>::new(listener);
let client = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
send_framed_raw(&mut stream, b"x");
stream
});
let net: &TcpNetwork<8> = &tcp;
let mut buf = [0u8; 256];
let (_, from) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
match from {
Address::Tcp(sa) => {
assert!(sa.ip().is_loopback());
}
other => panic!("Expected Address::Tcp, got {:?}", other),
}
let _stream = client.join().unwrap();
});
}
#[test]
fn multiple_clients_multiplexed() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<8>::new(listener);
let c1 = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
send_framed_raw(&mut stream, b"client1");
stream
});
let c2 = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
send_framed_raw(&mut stream, b"client2");
stream
});
let net: &TcpNetwork<8> = &tcp;
let mut buf = [0u8; 256];
let mut messages = Vec::new();
for _ in 0..2 {
let (len, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
messages.push(buf[..len].to_vec());
}
messages.sort();
assert_eq!(messages, vec![b"client1".to_vec(), b"client2".to_vec()]);
let _s1 = c1.join().unwrap();
let _s2 = c2.join().unwrap();
});
}
#[test]
fn message_too_large_for_buffer_returns_no_space() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<8>::new(listener);
let client = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
let payload = vec![0xAB; 100];
send_framed_raw(&mut stream, &payload);
stream
});
let net: &TcpNetwork<8> = &tcp;
let mut buf = [0u8; 10];
let result = NetworkReceive::recv_from(&mut { net }, &mut buf).await;
assert!(result.is_err());
let _stream = client.join().unwrap();
});
}
#[test]
fn conn_pool_eviction_at_capacity() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<2>::new(listener);
let mut streams = Vec::new();
for _ in 0..3 {
let c = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
send_framed_raw(&mut stream, b"hi");
stream
});
streams.push(c);
}
let net: &TcpNetwork<2> = &tcp;
let mut buf = [0u8; 256];
let mut received = 0;
for _ in 0..3 {
if NetworkReceive::wait_available(&mut { net }).await.is_ok()
&& NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.is_ok()
{
received += 1;
}
}
assert!(
received >= 2,
"Expected at least 2 messages, got {received}"
);
for c in streams {
let _s = c.join().unwrap();
}
});
}
#[test]
fn framing_4_byte_le_prefix() {
block_on(async {
let server_listener = ephemeral_listener();
let server_addr = local_addr(&server_listener);
let our_listener = ephemeral_listener();
let tcp = TcpNetwork::<8>::new(our_listener);
let net: &TcpNetwork<8> = &tcp;
let payload = vec![0x42; 300];
NetworkSend::send_to(&mut { net }, &payload, Address::Tcp(server_addr))
.await
.unwrap();
let (stream, _) = server_listener.accept().await.unwrap();
let raw = stream.into_inner().unwrap();
raw.set_nonblocking(false).unwrap();
let mut hdr = [0u8; 4];
std::io::Read::read_exact(&mut &raw, &mut hdr).unwrap();
let wire_len = u32::from_le_bytes(hdr);
assert_eq!(wire_len, 300);
let mut body = vec![0u8; 300];
std::io::Read::read_exact(&mut &raw, &mut body).unwrap();
assert_eq!(body, payload);
});
}
#[test]
fn send_to_wrong_address_type_fails() {
block_on(async {
let listener = ephemeral_listener();
let tcp = TcpNetwork::<8>::new(listener);
let net: &TcpNetwork<8> = &tcp;
let result = NetworkSend::send_to(
&mut { net },
b"data",
Address::Udp(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9999))),
)
.await;
assert!(result.is_err());
});
}
#[test]
fn back_to_back_frames_in_single_tcp_write() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<8>::new(listener);
let client = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
let mut wire = Vec::new();
wire.extend_from_slice(&4u32.to_le_bytes());
wire.extend_from_slice(b"aaaa");
wire.extend_from_slice(&3u32.to_le_bytes());
wire.extend_from_slice(b"bbb");
stream.write_all(&wire).unwrap();
stream.flush().unwrap();
stream
});
let net: &TcpNetwork<8> = &tcp;
let mut buf = [0u8; 256];
let (len1, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len1], b"aaaa");
let (len2, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len2], b"bbb");
let _stream = client.join().unwrap();
});
}
#[test]
fn zero_length_message() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<8>::new(listener);
let client = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
send_framed_raw(&mut stream, b"");
send_framed_raw(&mut stream, b"after-empty");
stream
});
let net: &TcpNetwork<8> = &tcp;
let mut buf = [0u8; 256];
let (len, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(len, 0);
let (len, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len], b"after-empty");
let _stream = client.join().unwrap();
});
}
#[test]
fn connection_closed_by_peer_is_handled() {
block_on(async {
let listener = ephemeral_listener();
let addr = local_addr(&listener);
let tcp = TcpNetwork::<8>::new(listener);
let net: &TcpNetwork<8> = &tcp;
let client = std::thread::spawn(move || {
let stream = std::net::TcpStream::connect(addr).unwrap();
std::thread::sleep(std::time::Duration::from_millis(50));
drop(stream);
});
client.join().unwrap();
let client2 = std::thread::spawn(move || {
let mut stream = std::net::TcpStream::connect(addr).unwrap();
std::thread::sleep(std::time::Duration::from_millis(50));
send_framed_raw(&mut stream, b"alive");
stream
});
let mut buf = [0u8; 256];
let (len, _) = NetworkReceive::recv_from(&mut { net }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len], b"alive");
let _stream = client2.join().unwrap();
});
}
#[test]
fn ensure_connected_reuses_existing_connection() {
block_on(async {
let server_listener = ephemeral_listener();
let server_addr = local_addr(&server_listener);
let our_listener = ephemeral_listener();
let tcp = TcpNetwork::<8>::new(our_listener);
let id1 = tcp.ensure_connected(server_addr).await.unwrap();
let id2 = tcp.ensure_connected(server_addr).await.unwrap();
assert_eq!(id1, id2);
tcp.inner.lock(|inner| {
assert_eq!(inner.borrow().connections.len(), 1);
});
});
}
#[test]
fn send_framed_encodes_correctly() {
block_on(async {
let server_listener = ephemeral_listener();
let server_addr = local_addr(&server_listener);
let our_listener = ephemeral_listener();
let tcp = TcpNetwork::<8>::new(our_listener);
let conn_id = tcp.ensure_connected(server_addr).await.unwrap();
tcp.send_framed_to(server_addr, conn_id, b"test")
.await
.unwrap();
let (stream, _) = server_listener.accept().await.unwrap();
let raw = stream.into_inner().unwrap();
raw.set_nonblocking(false).unwrap();
let mut hdr = [0u8; 4];
std::io::Read::read_exact(&mut &raw, &mut hdr).unwrap();
assert_eq!(u32::from_le_bytes(hdr), 4);
let mut body = [0u8; 4];
std::io::Read::read_exact(&mut &raw, &mut body).unwrap();
assert_eq!(&body, b"test");
});
}
#[test]
fn bidirectional_communication() {
block_on(async {
let listener_a = ephemeral_listener();
let addr_a = local_addr(&listener_a);
let tcp_a = TcpNetwork::<8>::new(listener_a);
let listener_b = ephemeral_listener();
let addr_b = local_addr(&listener_b);
let tcp_b = TcpNetwork::<8>::new(listener_b);
let net_a: &TcpNetwork<8> = &tcp_a;
let net_b: &TcpNetwork<8> = &tcp_b;
NetworkSend::send_to(&mut { net_a }, b"a-to-b", Address::Tcp(addr_b))
.await
.unwrap();
let mut buf = [0u8; 256];
let (len, _) = NetworkReceive::recv_from(&mut { net_b }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len], b"a-to-b");
NetworkSend::send_to(&mut { net_b }, b"b-to-a", Address::Tcp(addr_a))
.await
.unwrap();
let (len, _) = NetworkReceive::recv_from(&mut { net_a }, &mut buf)
.await
.unwrap();
assert_eq!(&buf[..len], b"b-to-a");
});
}
#[test]
fn try_read_all_removes_dead_connections() {
let mut inner = TcpInner::new();
let addr: SocketAddr = "1.2.3.4:100".parse().unwrap();
let peer_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let peer_addr = peer_listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(peer_addr).unwrap();
raw.set_nonblocking(true).unwrap();
let (accepted, _) = peer_listener.accept().unwrap();
accepted.shutdown(std::net::Shutdown::Both).unwrap();
drop(accepted);
drop(peer_listener);
std::thread::sleep(std::time::Duration::from_millis(50));
inner.add_connection::<8>(Async::new(raw).unwrap(), addr);
assert_eq!(inner.connections.len(), 1);
let _ = inner.try_read_all();
assert_eq!(inner.connections.len(), 0);
}
#[test]
fn try_read_nonblocking_rejects_oversized_buffer() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let laddr = listener.local_addr().unwrap();
let raw = std::net::TcpStream::connect(laddr).unwrap();
raw.set_nonblocking(true).unwrap();
let mut conn = TcpConnection::new(0, laddr, Async::new(raw).unwrap());
conn.rx_buf.resize(MAX_RX_BUF_SIZE, 0);
let result = conn.try_read_nonblocking();
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::OutOfMemory);
}
}