use core::fmt;
pub use embedded_io_async::{self, Error, ErrorKind, ErrorType, Read, ReadExactError, Write};
use crate::time::Timer;
pub trait ReadExt: Read {
async fn discard_all_data(&mut self) -> Result<(), Self::Error> {
let mut buffer = [0; 128];
while self.read(&mut buffer).await? > 0 {}
Ok(())
}
}
impl<R: Read> ReadExt for R {}
pub(crate) enum FormatBufferWriteError<T> {
FormatError,
OutOfSpace(T),
}
pub(crate) struct FormatBuffer {
pub data: heapless::Vec<u8, 128>,
pub ignore_count: usize,
pub error_state: FormatBufferWriteError<()>,
}
impl fmt::Write for FormatBuffer {
fn write_str(&mut self, s: &str) -> fmt::Result {
for &b in s.as_bytes() {
match self.ignore_count.checked_sub(1) {
Some(ignore_count) => self.ignore_count = ignore_count,
None => {
if self.data.push(b).is_err() {
self.error_state = FormatBufferWriteError::OutOfSpace(());
return Err(fmt::Error);
}
}
}
}
Ok(())
}
}
impl FormatBuffer {
pub fn new(ignore_count: usize) -> Self {
Self {
data: heapless::Vec::new(),
ignore_count,
error_state: FormatBufferWriteError::FormatError,
}
}
pub fn write(
&mut self,
value: impl fmt::Display,
) -> Result<&[u8], FormatBufferWriteError<&[u8]>> {
use fmt::Write;
write!(self, "{value}")
.map(|()| self.data.as_slice())
.map_err(|fmt::Error| match self.error_state {
FormatBufferWriteError::FormatError => FormatBufferWriteError::FormatError,
FormatBufferWriteError::OutOfSpace(()) => {
FormatBufferWriteError::OutOfSpace(self.data.as_slice())
}
})
}
}
pub trait WriteExt: Write {
async fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> Result<(), Self::Error> {
let mut ignore_count = 0;
loop {
match FormatBuffer::new(ignore_count).write(args) {
Ok(data) => return self.write_all(data).await,
Err(FormatBufferWriteError::FormatError) => {
log_warn!("Skipping writing due to Format Error");
return Ok(());
}
Err(FormatBufferWriteError::OutOfSpace(data)) => {
self.write_all(data).await?;
ignore_count += data.len();
}
}
}
}
}
impl<W: Write> WriteExt for W {}
pub trait Socket<Runtime>: Sized {
type Error: Error + 'static;
type ReadHalf<'a>: Read<Error = Self::Error>
where
Self: 'a;
type WriteHalf<'a>: Write<Error = Self::Error>
where
Self: 'a;
fn split(&mut self) -> (Self::ReadHalf<'_>, Self::WriteHalf<'_>);
async fn abort<T: Timer<Runtime>>(
self,
timeouts: &crate::Timeouts,
timer: &mut T,
) -> Result<(), super::Error<Self::Error>>;
async fn shutdown<T: Timer<Runtime>>(
self,
timeouts: &crate::Timeouts,
timer: &mut T,
) -> Result<(), super::Error<Self::Error>>;
}
#[cfg(any(feature = "tokio", test))]
pub(crate) mod tokio_support {
use embedded_io_async::{Error, ErrorKind, ErrorType, Read, Write};
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TokioIoError(pub std::io::Error);
impl Error for TokioIoError {
fn kind(&self) -> super::ErrorKind {
ErrorKind::Other
}
}
pub struct TokioIo<T>(pub T);
impl<T> ErrorType for TokioIo<T> {
type Error = TokioIoError;
}
impl<T: tokio::io::AsyncRead + Unpin> Read for TokioIo<T> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
use tokio::io::AsyncReadExt;
self.0.read(buf).await.map_err(TokioIoError)
}
}
impl<T: tokio::io::AsyncWrite + Unpin> Write for TokioIo<T> {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
use tokio::io::AsyncWriteExt;
self.0.write(buf).await.map_err(TokioIoError)
}
async fn flush(&mut self) -> Result<(), Self::Error> {
use tokio::io::AsyncWriteExt;
self.0.flush().await.map_err(TokioIoError)
}
}
impl super::Socket<crate::TokioRuntime> for tokio::net::TcpStream {
type Error = TokioIoError;
type ReadHalf<'a> = TokioIo<tokio::net::tcp::ReadHalf<'a>>;
type WriteHalf<'a> = TokioIo<tokio::net::tcp::WriteHalf<'a>>;
fn split(&mut self) -> (Self::ReadHalf<'_>, Self::WriteHalf<'_>) {
let (read_half, write_half) = tokio::net::TcpStream::split(self);
(TokioIo(read_half), TokioIo(write_half))
}
async fn abort<T: crate::Timer<crate::TokioRuntime>>(
self,
_timeouts: &crate::Timeouts,
_timer: &mut T,
) -> Result<(), crate::Error<Self::Error>> {
Ok(())
}
async fn shutdown<T: crate::Timer<crate::TokioRuntime>>(
mut self,
timeouts: &crate::Timeouts,
timer: &mut T,
) -> Result<(), crate::Error<Self::Error>> {
timer
.run_with_timeout(
timeouts.write,
tokio::io::AsyncWriteExt::shutdown(&mut self),
)
.await
.map_err(crate::Error::WriteTimeout)?
.map_err(|error| crate::Error::Write(TokioIoError(error)))?;
let mut buffer = [0; 128];
while timer
.run_with_timeout(
timeouts.read_request,
tokio::io::AsyncReadExt::read(&mut self, &mut buffer),
)
.await
.map_err(crate::Error::ReadTimeout)?
.map_err(|error| crate::Error::Read(TokioIoError(error)))?
> 0
{}
Ok(())
}
}
}
#[cfg(feature = "embassy")]
impl<'s> Socket<super::EmbassyRuntime> for embassy_net::tcp::TcpSocket<'s> {
type Error = embassy_net::tcp::Error;
type ReadHalf<'a>
= embassy_net::tcp::TcpReader<'a>
where
's: 'a;
type WriteHalf<'a>
= embassy_net::tcp::TcpWriter<'a>
where
's: 'a;
fn split(&mut self) -> (Self::ReadHalf<'_>, Self::WriteHalf<'_>) {
embassy_net::tcp::TcpSocket::split(self)
}
async fn abort<Timer: crate::Timer<super::EmbassyRuntime>>(
mut self,
timeouts: &crate::Timeouts,
timer: &mut Timer,
) -> Result<(), crate::Error<Self::Error>> {
log_info!("Abort");
Self::abort(&mut self);
timer
.run_with_timeout(timeouts.write.clone(), self.flush())
.await
.map_err(crate::Error::WriteTimeout)?
.map_err(crate::Error::Write)
}
async fn shutdown<Timer: crate::Timer<super::EmbassyRuntime>>(
mut self,
timeouts: &crate::Timeouts,
timer: &mut Timer,
) -> Result<(), crate::Error<Self::Error>> {
self.close();
let (mut rx, mut tx) = self.split();
crate::futures::select(
async {
timer
.run_with_timeout(timeouts.read_request.clone(), rx.discard_all_data())
.await
.map_err(crate::Error::ReadTimeout)?
.map_err(crate::Error::Read)
},
async {
tx.flush().await.map_err(crate::Error::Write)?;
core::future::pending().await
},
)
.await?;
timer
.run_with_timeout(timeouts.write.clone(), self.flush())
.await
.map_err(crate::Error::WriteTimeout)?
.map_err(crate::Error::Write)
}
}