use core::net::SocketAddr;
use embedded_io_async::{ErrorType, Read, Write};
use embedded_nal_async::TcpConnect;
const READ_BUFFER_SIZE: usize = 4 * 1024;
pub struct BufferedTcpConnect<T> {
inner: T,
}
impl<T> BufferedTcpConnect<T> {
pub const fn new(inner: T) -> Self {
Self { inner }
}
}
impl<T: TcpConnect> TcpConnect for BufferedTcpConnect<T> {
type Error = T::Error;
type Connection<'a>
= BufferedConnection<T::Connection<'a>>
where
Self: 'a;
async fn connect<'a>(
&'a self,
remote: SocketAddr,
) -> Result<Self::Connection<'a>, Self::Error> {
Ok(BufferedConnection::new(self.inner.connect(remote).await?))
}
}
pub struct BufferedConnection<C> {
inner: C,
buf: [u8; READ_BUFFER_SIZE],
start: usize,
end: usize,
}
impl<C> BufferedConnection<C> {
fn new(inner: C) -> Self {
Self {
inner,
buf: [0u8; READ_BUFFER_SIZE],
start: 0,
end: 0,
}
}
}
impl<C: ErrorType> ErrorType for BufferedConnection<C> {
type Error = C::Error;
}
impl<C: Read> Read for BufferedConnection<C> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if buf.is_empty() {
return Ok(0);
}
if self.start == self.end {
if buf.len() >= self.buf.len() {
return self.inner.read(buf).await;
}
let n = self.inner.read(&mut self.buf).await?;
if n == 0 {
return Ok(0);
}
self.start = 0;
self.end = n;
}
let available = self.end - self.start;
let n = available.min(buf.len());
buf[..n].copy_from_slice(&self.buf[self.start..self.start + n]);
self.start += n;
Ok(n)
}
}
impl<C: Write> Write for BufferedConnection<C> {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
self.inner.write(buf).await
}
async fn flush(&mut self) -> Result<(), Self::Error> {
self.inner.flush().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::convert::Infallible;
use core::future::Future;
use core::task::{Context, Poll, Waker};
fn block_on<F: Future>(fut: F) -> F::Output {
let mut fut = core::pin::pin!(fut);
let mut cx = Context::from_waker(Waker::noop());
loop {
if let Poll::Ready(v) = fut.as_mut().poll(&mut cx) {
return v;
}
}
}
struct CountingConn {
data: Vec<u8>,
pos: usize,
reads: usize,
chunk: usize,
}
impl CountingConn {
fn new(data: Vec<u8>, chunk: usize) -> Self {
Self {
data,
pos: 0,
reads: 0,
chunk,
}
}
}
impl ErrorType for CountingConn {
type Error = Infallible;
}
impl Read for CountingConn {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Infallible> {
self.reads += 1;
if buf.is_empty() {
return Ok(0);
}
let remaining = self.data.len() - self.pos;
let n = remaining.min(buf.len()).min(self.chunk);
buf[..n].copy_from_slice(&self.data[self.pos..self.pos + n]);
self.pos += n;
Ok(n)
}
}
impl Write for CountingConn {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Infallible> {
Ok(buf.len())
}
async fn flush(&mut self) -> Result<(), Infallible> {
Ok(())
}
}
fn drain_byte_by_byte(conn: &mut BufferedConnection<CountingConn>) -> Vec<u8> {
let mut out = Vec::new();
loop {
let mut b = [0u8; 1];
let n = block_on(conn.read(&mut b)).unwrap();
if n == 0 {
break;
}
out.push(b[0]);
}
out
}
#[test]
fn coalesces_byte_reads() {
let data: Vec<u8> = (0..200u8).collect();
let mut conn = BufferedConnection::new(CountingConn::new(data.clone(), 4096));
let out = drain_byte_by_byte(&mut conn);
assert_eq!(out, data);
assert_eq!(conn.inner.reads, 2);
}
#[test]
fn large_read_bypasses_buffer() {
let data: Vec<u8> = (0..2000u32).map(|n| n as u8).collect();
let mut conn = BufferedConnection::new(CountingConn::new(data.clone(), 4096));
let mut buf = [0u8; 4096];
let n = block_on(conn.read(&mut buf)).unwrap();
assert_eq!(n, data.len());
assert_eq!(&buf[..n], &data[..]);
assert_eq!(conn.inner.reads, 1);
assert_eq!(conn.start, 0);
assert_eq!(conn.end, 0);
}
#[test]
fn empty_caller_buf_returns_zero_without_inner_read() {
let mut conn = BufferedConnection::new(CountingConn::new(vec![1, 2, 3], 4096));
let n = block_on(conn.read(&mut [])).unwrap();
assert_eq!(n, 0);
assert_eq!(conn.inner.reads, 0);
}
#[test]
fn serves_partial_then_refills() {
let data: Vec<u8> = (0..50u8).collect();
let mut conn = BufferedConnection::new(CountingConn::new(data.clone(), 8));
let out = drain_byte_by_byte(&mut conn);
assert_eq!(out, data);
assert!(conn.inner.reads <= 8, "reads = {}", conn.inner.reads);
}
}