use core::{
fmt::{self, Display},
future::Future,
net::SocketAddr,
};
use embassy_time::Duration;
use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
use crate::{Readable, TcpAccept, TcpConnect, TcpShutdown};
#[derive(Debug)]
pub enum WithTimeoutError<E> {
Error(E),
Timeout,
}
impl<E> From<E> for WithTimeoutError<E> {
fn from(e: E) -> Self {
Self::Error(e)
}
}
impl<E> fmt::Display for WithTimeoutError<E>
where
E: Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Error(e) => write!(f, "{}", e),
Self::Timeout => write!(f, "Operation timed out"),
}
}
}
impl<E> embedded_io_async::Error for WithTimeoutError<E>
where
E: embedded_io_async::Error,
{
fn kind(&self) -> ErrorKind {
match self {
Self::Error(e) => e.kind(),
Self::Timeout => ErrorKind::TimedOut,
}
}
}
pub async fn with_timeout<F, T, E>(timeout_ms: u32, fut: F) -> Result<T, WithTimeoutError<E>>
where
F: Future<Output = Result<T, E>>,
{
map_result(embassy_time::with_timeout(Duration::from_millis(timeout_ms as _), fut).await)
}
pub struct WithTimeout<T>(T, u32);
impl<T> WithTimeout<T> {
pub const fn new(timeout_ms: u32, io: T) -> Self {
Self(io, timeout_ms)
}
pub fn io(&self) -> &T {
&self.0
}
pub fn io_mut(&mut self) -> &mut T {
&mut self.0
}
pub fn timeout_ms(&self) -> u32 {
self.1
}
pub fn into_io(self) -> T {
self.0
}
}
impl<T> ErrorType for WithTimeout<T>
where
T: ErrorType,
{
type Error = WithTimeoutError<T::Error>;
}
impl<T> Read for WithTimeout<T>
where
T: Read,
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
with_timeout(self.1, self.0.read(buf)).await
}
}
impl<T> Write for WithTimeout<T>
where
T: Write,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
with_timeout(self.1, self.0.write(buf)).await
}
async fn flush(&mut self) -> Result<(), Self::Error> {
with_timeout(self.1, self.0.flush()).await
}
}
impl<T> TcpConnect for WithTimeout<T>
where
T: TcpConnect,
{
type Error = WithTimeoutError<T::Error>;
type Socket<'a>
= WithTimeout<T::Socket<'a>>
where
Self: 'a;
async fn connect(&self, remote: SocketAddr) -> Result<Self::Socket<'_>, Self::Error> {
with_timeout(self.1, self.0.connect(remote))
.await
.map(|s| WithTimeout::new(self.1, s))
}
}
impl<T> Readable for WithTimeout<T>
where
T: Readable,
{
async fn readable(&mut self) -> Result<(), Self::Error> {
with_timeout(self.1, self.0.readable()).await
}
}
impl<T> TcpShutdown for WithTimeout<T>
where
T: TcpShutdown,
{
async fn close(&mut self, what: crate::Close) -> Result<(), Self::Error> {
with_timeout(self.1, self.0.close(what)).await
}
async fn abort(&mut self) -> Result<(), Self::Error> {
with_timeout(self.1, self.0.abort()).await
}
}
impl<T> TcpAccept for WithTimeout<T>
where
T: TcpAccept,
{
type Error = WithTimeoutError<T::Error>;
type Socket<'a>
= WithTimeout<T::Socket<'a>>
where
Self: 'a;
async fn accept(&self) -> Result<(SocketAddr, Self::Socket<'_>), Self::Error> {
let (addr, socket) = self.0.accept().await?;
Ok((addr, WithTimeout::new(self.1, socket)))
}
}
fn map_result<T, E>(
result: Result<Result<T, E>, embassy_time::TimeoutError>,
) -> Result<T, WithTimeoutError<E>> {
match result {
Ok(Ok(t)) => Ok(t),
Ok(Err(e)) => Err(WithTimeoutError::Error(e)),
Err(_) => Err(WithTimeoutError::Timeout),
}
}