use core::fmt::Display;
use core::fmt::Formatter;
use core::net::SocketAddr;
use embedded_io_async::ErrorKind;
#[expect(async_fn_in_trait)]
pub trait TcpSocket {
async fn connect(&mut self, address: SocketAddr) -> Result<impl TcpConnection, Error>;
async fn accept(
&mut self,
address: SocketAddr,
) -> Result<(impl TcpConnection, SocketAddr), Error>;
}
#[expect(async_fn_in_trait)]
pub trait TcpConnection:
core::fmt::Debug
+ embedded_io_async::Read
+ embedded_io_async::Write
+ embedded_io_async::ErrorType<Error = Error>
{
async fn close(self);
}
#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone, Hash)]
pub enum Error {
ConnectionReset,
InvalidState,
InvalidPort,
InvalidAddress,
TimedOut,
NoRoute,
PermissionDenied,
NetworkDown,
Other,
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
match self {
Error::ConnectionReset => {
write!(f, "The connection was reset by timeout of RST packet.")
}
Error::InvalidState => {
write!(f, "The socket is in an invalid state.")
}
Error::InvalidPort => {
write!(f, "The provided port is invalid.")
}
Error::TimedOut => {
write!(f, "The connection timed out.")
}
Error::NoRoute => {
write!(f, "No route to host.")
}
Error::Other => {
write!(
f,
"Unspecified error, please open a bug report if you encounter this error."
)
}
Error::InvalidAddress => {
write!(f, "The provided address is invalid.")
}
Error::PermissionDenied => {
write!(f, "No permission to access the resource.")
}
Error::NetworkDown => {
write!(f, "The network stack is down.")
}
}
}
}
impl core::error::Error for Error {}
impl embedded_io_async::ErrorType for Error {
type Error = Error;
}
impl embedded_io_async::Error for Error {
fn kind(&self) -> ErrorKind {
match self {
Error::ConnectionReset => ErrorKind::ConnectionReset,
Error::InvalidState => ErrorKind::InvalidInput,
Error::InvalidPort => ErrorKind::InvalidInput,
Error::InvalidAddress => ErrorKind::InvalidInput,
Error::TimedOut => ErrorKind::TimedOut,
Error::NoRoute => ErrorKind::Other,
Error::PermissionDenied => ErrorKind::PermissionDenied,
Error::NetworkDown => ErrorKind::NotConnected,
Error::Other => ErrorKind::Other,
}
}
}
#[doc(hidden)]
#[cfg(feature = "test-suites")]
#[cfg_attr(coverage_nightly, coverage(off))]
pub mod test_suite {
#![expect(missing_docs, reason = "tests")]
use crate::net::tcp::{Error, TcpConnection, TcpSocket};
use embedded_io_async::{Read, Write};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
pub async fn test_connect(
mut client: impl TcpSocket,
mut server: impl TcpSocket,
ip_address: &str,
) {
let ip_address = ip_address.parse().unwrap();
let server_addr = SocketAddr::new(ip_address, 59001);
let server_task = async {
let (connection, remote_addr) = server.accept(server_addr).await.unwrap();
assert!(remote_addr.ip().is_loopback() || remote_addr.ip() == ip_address);
assert_ne!(remote_addr.port(), server_addr.port());
assert_ne!(remote_addr.port(), 0);
connection.close().await;
};
let client_task = async {
let connection = loop {
if let Ok(connection) = client.connect(server_addr).await {
break connection;
}
};
connection.close().await;
};
futures::join!(server_task, client_task);
}
pub async fn test_send_recv(
mut client: impl TcpSocket,
mut server: impl TcpSocket,
ip_address: &str,
) {
let ip_address = ip_address.parse().unwrap();
let server_addr = SocketAddr::new(ip_address, 59003);
let server_task = async {
let (mut connection, _) = server.accept(server_addr).await.unwrap();
let mut buffer = [0u8; 256];
let read = connection.read(&mut buffer).await.unwrap();
assert_eq!(&buffer[..read], b"Test message from client");
connection.write_all(&buffer[..read]).await.unwrap();
connection.flush().await.unwrap();
let read = connection.read(&mut buffer).await.unwrap();
assert_eq!(&buffer[..read], b"Second message");
connection.write_all(b"Acknowledged").await.unwrap();
connection.flush().await.unwrap();
connection.close().await;
};
let client_task = async {
let mut connection = loop {
if let Ok(connection) = client.connect(server_addr).await {
break connection;
}
};
connection
.write_all(b"Test message from client")
.await
.unwrap();
connection.flush().await.unwrap();
let mut buffer = [0u8; 256];
let read = connection.read(&mut buffer).await.unwrap();
assert_eq!(&buffer[..read], b"Test message from client");
connection.write_all(b"Second message").await.unwrap();
connection.flush().await.unwrap();
let read = connection.read(&mut buffer).await.unwrap();
assert_eq!(&buffer[..read], b"Acknowledged");
connection.close().await;
};
futures::join!(server_task, client_task);
}
pub async fn test_connect_refused(mut client: impl TcpSocket, ip_address: &str) {
let ip_address = ip_address.parse().unwrap();
let server_addr = SocketAddr::new(ip_address, 59900);
assert_eq!(
client.connect(server_addr).await.unwrap_err(),
Error::ConnectionReset
);
}
pub async fn test_accept_with_zero_port(mut server: impl TcpSocket, ip_address: &str) {
let ip_address = ip_address.parse().unwrap();
let server_addr = SocketAddr::new(ip_address, 0);
assert_eq!(
server.accept(server_addr).await.unwrap_err(),
Error::InvalidPort
);
}
pub async fn test_accept_all_zero_ip(
mut client: impl TcpSocket,
mut server: impl TcpSocket,
ip_address: &str,
) {
let port = 59910;
let ip_address: IpAddr = ip_address.parse().unwrap();
let ip_address = SocketAddr::new(ip_address, port);
let all_zero_address = if ip_address.is_ipv4() {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
} else {
SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), port)
};
let server_task = async {
let (mut connection, _) = server.accept(all_zero_address).await.unwrap();
let mut buffer = [0u8; 256];
let read = connection.read(&mut buffer).await.unwrap();
assert_eq!(&buffer[..read], b"Test message from client");
connection.write_all(&buffer[..read]).await.unwrap();
connection.flush().await.unwrap();
};
let client_task = async {
let mut connection = loop {
if let Ok(connection) = client.connect(ip_address).await {
break connection;
}
};
connection
.write_all(b"Test message from client")
.await
.unwrap();
connection.flush().await.unwrap();
connection.close().await;
};
futures::join!(server_task, client_task);
}
pub async fn test_close_connection(
mut client: impl TcpSocket,
mut server: impl TcpSocket,
ip_address: &str,
) {
let ip_address = ip_address.parse().unwrap();
let server_addr = SocketAddr::new(ip_address, 59004);
let server_task = async {
let (mut connection, _) = server.accept(server_addr).await.unwrap();
connection.write_all(b"Hello").await.unwrap();
connection.flush().await.unwrap();
connection.close().await;
};
let client_task = async {
let mut connection = loop {
if let Ok(connection) = client.connect(server_addr).await {
break connection;
}
};
let mut buffer = [0u8; 32];
let read = connection.read(&mut buffer).await.unwrap();
assert_eq!(&buffer[..read], b"Hello");
let read = connection.read(&mut buffer).await.unwrap();
assert_eq!(
read, 0,
"Expected EOF (0 bytes) after server closed connection"
);
connection.close().await;
};
futures::join!(server_task, client_task);
}
}