1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
use embedded_nal::nb;

use crate::TcpExactStack;

/// A wrapper around a TcpStack that provides TcpExactStack
///
/// The implementation is comparatively crude: there's just a per-socket buffer that data is copied
/// into on demand.
///
/// Using this is generally not recommended -- TCP stacks usually have the buffers in there
/// somewhere, and and "just" need to expose them.
pub struct BufferedStack<ST: embedded_nal::TcpClientStack, const N: usize>(ST);

impl<ST: embedded_nal::TcpClientStack, const N: usize> BufferedStack<ST, N> {
    pub fn new(wrapped: ST) -> Self {
        BufferedStack(wrapped)
    }

    /// Attempt sending any content of the buffer, returning Ok only if the buffer is now empty.
    fn try_flush_sendbuffer(
        &mut self,
        socket: &mut <Self as embedded_nal::TcpClientStack>::TcpSocket,
    ) -> Result<(), embedded_nal::nb::Error<<Self as embedded_nal::TcpClientStack>::Error>> {
        if !socket.sendbuf.is_empty() {
            match self.0.send(&mut socket.socket, &socket.sendbuf) {
                // Both WouldBlock and actual errors go out here. Actual errors are a bit late
                // (given the send_all already returned successfully), but then again, this could
                // just as well have happened while things are in the OS's buffer.
                Err(e) => Err(e),
                // All flushed, we can go on
                Ok(n) if n == socket.sendbuf.len() => Ok(socket.sendbuf.clear()),
                Ok(n) => {
                    socket.sendbuf.copy_within(n.., 0);
                    socket.sendbuf.truncate(socket.sendbuf.len() - n);
                    Err(embedded_nal::nb::Error::WouldBlock)
                }
            }
        } else {
            Ok(())
        }
    }
}

/// Socket wrapper for BufferedStack
// For the server socket (which accepts), the buffer is useless -- too bad the TCP socket API
// doesn't have types for different roles.
pub struct BufferedSocket<SO, const N: usize> {
    socket: SO,
    recvbuf: heapless::Vec<u8, N>,
    sendbuf: heapless::Vec<u8, N>,
}

impl<ST: embedded_nal::TcpFullStack, const N: usize> embedded_nal::TcpFullStack
    for BufferedStack<ST, N>
{
    fn bind(&mut self, socket: &mut Self::TcpSocket, port: u16) -> Result<(), Self::Error> {
        self.0.bind(&mut socket.socket, port)
    }
    fn listen(&mut self, socket: &mut Self::TcpSocket) -> Result<(), Self::Error> {
        self.0.listen(&mut socket.socket)
    }
    fn accept(
        &mut self,
        socket: &mut Self::TcpSocket,
    ) -> Result<(Self::TcpSocket, embedded_nal::SocketAddr), embedded_nal::nb::Error<Self::Error>>
    {
        self.0.accept(&mut socket.socket).map(|(socket, addr)| {
            (
                BufferedSocket {
                    socket,
                    recvbuf: Default::default(),
                    sendbuf: Default::default(),
                },
                addr,
            )
        })
    }
}

impl<ST: embedded_nal::TcpClientStack, const N: usize> embedded_nal::TcpClientStack
    for BufferedStack<ST, N>
{
    type TcpSocket = BufferedSocket<ST::TcpSocket, N>;
    type Error = ST::Error;

    fn socket(&mut self) -> Result<Self::TcpSocket, Self::Error> {
        Ok(BufferedSocket {
            socket: self.0.socket()?,
            recvbuf: Default::default(),
            sendbuf: Default::default(),
        })
    }
    fn connect(
        &mut self,
        socket: &mut Self::TcpSocket,
        addr: embedded_nal::SocketAddr,
    ) -> Result<(), embedded_nal::nb::Error<Self::Error>> {
        self.0.connect(&mut socket.socket, addr)
    }
    fn is_connected(&mut self, socket: &Self::TcpSocket) -> Result<bool, Self::Error> {
        self.0.is_connected(&socket.socket)
    }
    fn send(
        &mut self,
        socket: &mut Self::TcpSocket,
        buffer: &[u8],
    ) -> Result<usize, embedded_nal::nb::Error<Self::Error>> {
        // First, send out anything that is enqueued
        self.try_flush_sendbuffer(socket)?;

        assert!(socket.sendbuf.is_empty());
        self.0.send(&mut socket.socket, buffer)
    }
    fn receive(
        &mut self,
        socket: &mut Self::TcpSocket,
        buffer: &mut [u8],
    ) -> Result<usize, embedded_nal::nb::Error<Self::Error>> {
        // There is no task that'd flush the buffer out, so we depend on something to make
        // progress. The read is definitely the best candidate to make that.
        match self.try_flush_sendbuffer(socket) {
            Ok(()) => (),
            // Maybe we made progress, maybe not -- but anyway we tried, and that's all that
            // matters in the receive path. No need to stop receiving just because we have a full
            // send buffer.
            Err(nb::Error::WouldBlock) => (),
            // Ensure the error isn't lost. This may not be 100% precise in half-open connections,
            // but I doubt embedded-nal aims to support them. (If we'd want to, we'd need to mark
            // the send buffer as having erred).
            Err(e) => return Err(e),
        };

        match socket.recvbuf.len() {
            // The common case
            0 => self.0.receive(&mut socket.socket, buffer),
            // The easy case (sure we could try to receive more, but it's TCP and prepared to get
            // data piecemeal, so just eat it as it is)
            present if present >= buffer.len() => {
                buffer[..present].copy_from_slice(&socket.recvbuf);
                socket.recvbuf.clear();
                Ok(present)
            }
            // The tricky case: Even when reading this there's still data left over. This only
            // happens if a long and incomplete read_exactly is followed by a short read. Still
            // needs to be implemented...
            present => {
                buffer.copy_from_slice(&socket.recvbuf[..buffer.len()]);
                socket.recvbuf.copy_within(buffer.len().., 0);
                socket.recvbuf.truncate(present - buffer.len());
                Ok(buffer.len())
            }
        }
    }
    fn close(&mut self, mut socket: Self::TcpSocket) -> Result<(), Self::Error> {
        match self.try_flush_sendbuffer(&mut socket) {
            Ok(()) => (),
            // As close can't WouldBlock, it would appear that not having sent some data is
            // considered acceptable in embedded-nal
            Err(nb::Error::WouldBlock) => (),
            // ... and then it's just logical that errors from there are discarded too.
            Err(nb::Error::Other(_)) => (),
        }
        self.0.close(socket.socket)
    }
}

impl<ST: embedded_nal::TcpClientStack, const N: usize> TcpExactStack
    for BufferedStack<ST, N>
{
    const RECVBUFLEN: usize = N;

    const SENDBUFLEN: usize = N;

    fn receive_exact(
        &mut self,
        socket: &mut Self::TcpSocket,
        buffer: &mut [u8],
    ) -> nb::Result<(), Self::Error> {
        let len_start = socket.recvbuf.len();
        let missing = buffer.len().checked_sub(len_start);

        if let Some(missing) = missing {
            if missing > 0 {
                // unsafe: All u8 values are valid.
                //
                // The safe alternative would be `socket.recvbuf.resize_default(buffer.len());`,
                // which needlessly zeroes out text.
                //
                // There are proposals out there on how to do these things more elegantly, but
                // AFAICT they're not done yet (and I can't look it up right now).
                unsafe {
                    socket.recvbuf.set_len(buffer.len());
                }
                // Note: This panics at the bounds check when too much is asked.
                let received = self.0.receive(
                    &mut socket.socket,
                    &mut socket.recvbuf[len_start..buffer.len()],
                )?;
                socket.recvbuf.truncate(len_start + received);
            }
        }

        if socket.recvbuf.len() >= buffer.len() {
            // It *can* be greater than, if receive_exact was incompletely called earlier; receive
            // already handles the back-rotation of any leftovers, and is guaranteed to succeed in
            // this case.
            use embedded_nal::TcpClientStack;
            self.receive(socket, buffer).map(|_| ())
        } else {
            Err(nb::Error::WouldBlock)
        }
    }

    fn send_all(
        &mut self,
        socket: &mut Self::TcpSocket,
        buffer: &[u8],
    ) -> Result<(), embedded_nal::nb::Error<Self::Error>> {
        use embedded_nal::TcpClientStack;

        match self.send(socket, buffer) {
            Err(e) => Err(e),
            Ok(n) if n == buffer.len() => Ok(()),
            Ok(n) => {
                assert!(
                    socket.sendbuf.is_empty(),
                    "Internal post-condition of send() violated"
                );
                socket
                    .sendbuf
                    .extend_from_slice(&buffer[n..])
                    .expect("Send leftovers exceed buffer announced in SENDBUFLEN");

                Err(embedded_nal::nb::Error::WouldBlock)
            }
        }
    }
}