use embedded_nal::nb;
use crate::TcpExactStack;
pub struct BufferedStack<ST: embedded_nal::TcpClientStack, const N: usize>(ST);
impl<ST: embedded_nal::TcpClientStack, const N: usize> BufferedStack<ST, N> {
pub fn new(wrapped: ST) -> Self {
BufferedStack(wrapped)
}
fn try_flush_sendbuffer(
&mut self,
socket: &mut <Self as embedded_nal::TcpClientStack>::TcpSocket,
) -> Result<(), embedded_nal::nb::Error<<Self as embedded_nal::TcpClientStack>::Error>> {
if !socket.sendbuf.is_empty() {
match self.0.send(&mut socket.socket, &socket.sendbuf) {
Err(e) => Err(e),
Ok(n) if n == socket.sendbuf.len() => Ok(socket.sendbuf.clear()),
Ok(n) => {
socket.sendbuf.copy_within(n.., 0);
socket.sendbuf.truncate(socket.sendbuf.len() - n);
Err(embedded_nal::nb::Error::WouldBlock)
}
}
} else {
Ok(())
}
}
}
pub struct BufferedSocket<SO, const N: usize> {
socket: SO,
recvbuf: heapless::Vec<u8, N>,
sendbuf: heapless::Vec<u8, N>,
}
impl<ST: embedded_nal::TcpFullStack, const N: usize> embedded_nal::TcpFullStack
for BufferedStack<ST, N>
{
fn bind(&mut self, socket: &mut Self::TcpSocket, port: u16) -> Result<(), Self::Error> {
self.0.bind(&mut socket.socket, port)
}
fn listen(&mut self, socket: &mut Self::TcpSocket) -> Result<(), Self::Error> {
self.0.listen(&mut socket.socket)
}
fn accept(
&mut self,
socket: &mut Self::TcpSocket,
) -> Result<(Self::TcpSocket, embedded_nal::SocketAddr), embedded_nal::nb::Error<Self::Error>>
{
self.0.accept(&mut socket.socket).map(|(socket, addr)| {
(
BufferedSocket {
socket,
recvbuf: Default::default(),
sendbuf: Default::default(),
},
addr,
)
})
}
}
impl<ST: embedded_nal::TcpClientStack, const N: usize> embedded_nal::TcpClientStack
for BufferedStack<ST, N>
{
type TcpSocket = BufferedSocket<ST::TcpSocket, N>;
type Error = ST::Error;
fn socket(&mut self) -> Result<Self::TcpSocket, Self::Error> {
Ok(BufferedSocket {
socket: self.0.socket()?,
recvbuf: Default::default(),
sendbuf: Default::default(),
})
}
fn connect(
&mut self,
socket: &mut Self::TcpSocket,
addr: embedded_nal::SocketAddr,
) -> Result<(), embedded_nal::nb::Error<Self::Error>> {
self.0.connect(&mut socket.socket, addr)
}
fn is_connected(&mut self, socket: &Self::TcpSocket) -> Result<bool, Self::Error> {
self.0.is_connected(&socket.socket)
}
fn send(
&mut self,
socket: &mut Self::TcpSocket,
buffer: &[u8],
) -> Result<usize, embedded_nal::nb::Error<Self::Error>> {
self.try_flush_sendbuffer(socket)?;
assert!(socket.sendbuf.is_empty());
self.0.send(&mut socket.socket, buffer)
}
fn receive(
&mut self,
socket: &mut Self::TcpSocket,
buffer: &mut [u8],
) -> Result<usize, embedded_nal::nb::Error<Self::Error>> {
match self.try_flush_sendbuffer(socket) {
Ok(()) => (),
Err(nb::Error::WouldBlock) => (),
Err(e) => return Err(e),
};
match socket.recvbuf.len() {
0 => self.0.receive(&mut socket.socket, buffer),
present if present >= buffer.len() => {
buffer[..present].copy_from_slice(&socket.recvbuf);
socket.recvbuf.clear();
Ok(present)
}
present => {
buffer.copy_from_slice(&socket.recvbuf[..buffer.len()]);
socket.recvbuf.copy_within(buffer.len().., 0);
socket.recvbuf.truncate(present - buffer.len());
Ok(buffer.len())
}
}
}
fn close(&mut self, mut socket: Self::TcpSocket) -> Result<(), Self::Error> {
match self.try_flush_sendbuffer(&mut socket) {
Ok(()) => (),
Err(nb::Error::WouldBlock) => (),
Err(nb::Error::Other(_)) => (),
}
self.0.close(socket.socket)
}
}
impl<ST: embedded_nal::TcpClientStack, const N: usize> TcpExactStack
for BufferedStack<ST, N>
{
const RECVBUFLEN: usize = N;
const SENDBUFLEN: usize = N;
fn receive_exact(
&mut self,
socket: &mut Self::TcpSocket,
buffer: &mut [u8],
) -> nb::Result<(), Self::Error> {
let len_start = socket.recvbuf.len();
let missing = buffer.len().checked_sub(len_start);
if let Some(missing) = missing {
if missing > 0 {
unsafe {
socket.recvbuf.set_len(buffer.len());
}
let received = self.0.receive(
&mut socket.socket,
&mut socket.recvbuf[len_start..buffer.len()],
)?;
socket.recvbuf.truncate(len_start + received);
}
}
if socket.recvbuf.len() >= buffer.len() {
use embedded_nal::TcpClientStack;
self.receive(socket, buffer).map(|_| ())
} else {
Err(nb::Error::WouldBlock)
}
}
fn send_all(
&mut self,
socket: &mut Self::TcpSocket,
buffer: &[u8],
) -> Result<(), embedded_nal::nb::Error<Self::Error>> {
use embedded_nal::TcpClientStack;
match self.send(socket, buffer) {
Err(e) => Err(e),
Ok(n) if n == buffer.len() => Ok(()),
Ok(n) => {
assert!(
socket.sendbuf.is_empty(),
"Internal post-condition of send() violated"
);
socket
.sendbuf
.extend_from_slice(&buffer[n..])
.expect("Send leftovers exceed buffer announced in SENDBUFLEN");
Err(embedded_nal::nb::Error::WouldBlock)
}
}
}
}