embassy_net/
tcp.rs

1//! TCP sockets.
2//!
3//! # Listening
4//!
5//! `embassy-net` does not have a `TcpListener`. Instead, individual `TcpSocket`s can be put into
6//! listening mode by calling [`TcpSocket::accept`].
7//!
8//! Incoming connections when no socket is listening are rejected. To accept many incoming
9//! connections, create many sockets and put them all into listening mode.
10
11use core::future::{poll_fn, Future};
12use core::mem;
13use core::task::{Context, Poll};
14
15use embassy_time::Duration;
16use smoltcp::iface::{Interface, SocketHandle};
17use smoltcp::socket::tcp;
18pub use smoltcp::socket::tcp::State;
19use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
20
21use crate::time::duration_to_smoltcp;
22use crate::Stack;
23
24/// Error returned by TcpSocket read/write functions.
25#[derive(PartialEq, Eq, Clone, Copy, Debug)]
26#[cfg_attr(feature = "defmt", derive(defmt::Format))]
27pub enum Error {
28    /// The connection was reset.
29    ///
30    /// This can happen on receiving a RST packet, or on timeout.
31    ConnectionReset,
32}
33
34/// Error returned by [`TcpSocket::connect`].
35#[derive(PartialEq, Eq, Clone, Copy, Debug)]
36#[cfg_attr(feature = "defmt", derive(defmt::Format))]
37pub enum ConnectError {
38    /// The socket is already connected or listening.
39    InvalidState,
40    /// The remote host rejected the connection with a RST packet.
41    ConnectionReset,
42    /// Connect timed out.
43    TimedOut,
44    /// No route to host.
45    NoRoute,
46}
47
48/// Error returned by [`TcpSocket::accept`].
49#[derive(PartialEq, Eq, Clone, Copy, Debug)]
50#[cfg_attr(feature = "defmt", derive(defmt::Format))]
51pub enum AcceptError {
52    /// The socket is already connected or listening.
53    InvalidState,
54    /// Invalid listen port
55    InvalidPort,
56    /// The remote host rejected the connection with a RST packet.
57    ConnectionReset,
58}
59
60/// A TCP socket.
61pub struct TcpSocket<'a> {
62    io: TcpIo<'a>,
63}
64
65/// The reader half of a TCP socket.
66pub struct TcpReader<'a> {
67    io: TcpIo<'a>,
68}
69
70/// The writer half of a TCP socket.
71pub struct TcpWriter<'a> {
72    io: TcpIo<'a>,
73}
74
75impl<'a> TcpReader<'a> {
76    /// Wait until the socket becomes readable.
77    ///
78    /// A socket becomes readable when the receive half of the full-duplex connection is open
79    /// (see [`may_recv()`](TcpSocket::may_recv)), and there is some pending data in the receive buffer.
80    ///
81    /// This is the equivalent of [read](#method.read), without buffering any data.
82    pub fn wait_read_ready(&self) -> impl Future<Output = ()> + '_ {
83        poll_fn(move |cx| self.io.poll_read_ready(cx))
84    }
85
86    /// Read data from the socket.
87    ///
88    /// Returns how many bytes were read, or an error. If no data is available, it waits
89    /// until there is at least one byte available.
90    ///
91    /// # Note
92    /// A return value of Ok(0) means that we have read all data and the remote
93    /// side has closed our receive half of the socket. The remote can no longer
94    /// send bytes.
95    ///
96    /// The send half of the socket is still open. If you want to reconnect using
97    /// the socket you split this reader off the send half needs to be closed using
98    /// [`abort()`](TcpSocket::abort).
99    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
100        self.io.read(buf).await
101    }
102
103    /// Call `f` with the largest contiguous slice of octets in the receive buffer,
104    /// and dequeue the amount of elements returned by `f`.
105    ///
106    /// If no data is available, it waits until there is at least one byte available.
107    pub async fn read_with<F, R>(&mut self, f: F) -> Result<R, Error>
108    where
109        F: FnOnce(&mut [u8]) -> (usize, R),
110    {
111        self.io.read_with(f).await
112    }
113
114    /// Return the maximum number of bytes inside the transmit buffer.
115    pub fn recv_capacity(&self) -> usize {
116        self.io.recv_capacity()
117    }
118
119    /// Return the amount of octets queued in the receive buffer. This value can be larger than
120    /// the slice read by the next `recv` or `peek` call because it includes all queued octets,
121    /// and not only the octets that may be returned as a contiguous slice.
122    pub fn recv_queue(&self) -> usize {
123        self.io.recv_queue()
124    }
125}
126
127impl<'a> TcpWriter<'a> {
128    /// Wait until the socket becomes writable.
129    ///
130    /// A socket becomes writable when the transmit half of the full-duplex connection is open
131    /// (see [`may_send()`](TcpSocket::may_send)), and the transmit buffer is not full.
132    ///
133    /// This is the equivalent of [write](#method.write), without sending any data.
134    pub fn wait_write_ready(&self) -> impl Future<Output = ()> + '_ {
135        poll_fn(move |cx| self.io.poll_write_ready(cx))
136    }
137
138    /// Write data to the socket.
139    ///
140    /// Returns how many bytes were written, or an error. If the socket is not ready to
141    /// accept data, it waits until it is.
142    pub fn write<'s>(&'s mut self, buf: &'s [u8]) -> impl Future<Output = Result<usize, Error>> + 's {
143        self.io.write(buf)
144    }
145
146    /// Flushes the written data to the socket.
147    ///
148    /// This waits until all data has been sent, and ACKed by the remote host. For a connection
149    /// closed with [`abort()`](TcpSocket::abort) it will wait for the TCP RST packet to be sent.
150    pub fn flush(&mut self) -> impl Future<Output = Result<(), Error>> + '_ {
151        self.io.flush()
152    }
153
154    /// Call `f` with the largest contiguous slice of octets in the transmit buffer,
155    /// and enqueue the amount of elements returned by `f`.
156    ///
157    /// If the socket is not ready to accept data, it waits until it is.
158    pub async fn write_with<F, R>(&mut self, f: F) -> Result<R, Error>
159    where
160        F: FnOnce(&mut [u8]) -> (usize, R),
161    {
162        self.io.write_with(f).await
163    }
164
165    /// Return the maximum number of bytes inside the transmit buffer.
166    pub fn send_capacity(&self) -> usize {
167        self.io.send_capacity()
168    }
169
170    /// Return the amount of octets queued in the transmit buffer.
171    pub fn send_queue(&self) -> usize {
172        self.io.send_queue()
173    }
174}
175
176impl<'a> TcpSocket<'a> {
177    /// Create a new TCP socket on the given stack, with the given buffers.
178    pub fn new(stack: Stack<'a>, rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self {
179        let handle = stack.with_mut(|i| {
180            let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
181            let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
182            i.sockets.add(tcp::Socket::new(
183                tcp::SocketBuffer::new(rx_buffer),
184                tcp::SocketBuffer::new(tx_buffer),
185            ))
186        });
187
188        Self {
189            io: TcpIo { stack, handle },
190        }
191    }
192
193    /// Return the maximum number of bytes inside the recv buffer.
194    pub fn recv_capacity(&self) -> usize {
195        self.io.recv_capacity()
196    }
197
198    /// Return the maximum number of bytes inside the transmit buffer.
199    pub fn send_capacity(&self) -> usize {
200        self.io.send_capacity()
201    }
202
203    /// Return the amount of octets queued in the transmit buffer.
204    pub fn send_queue(&self) -> usize {
205        self.io.send_queue()
206    }
207
208    /// Return the amount of octets queued in the receive buffer. This value can be larger than
209    /// the slice read by the next `recv` or `peek` call because it includes all queued octets,
210    /// and not only the octets that may be returned as a contiguous slice.
211    pub fn recv_queue(&self) -> usize {
212        self.io.recv_queue()
213    }
214
215    /// Call `f` with the largest contiguous slice of octets in the transmit buffer,
216    /// and enqueue the amount of elements returned by `f`.
217    ///
218    /// If the socket is not ready to accept data, it waits until it is.
219    pub async fn write_with<F, R>(&mut self, f: F) -> Result<R, Error>
220    where
221        F: FnOnce(&mut [u8]) -> (usize, R),
222    {
223        self.io.write_with(f).await
224    }
225
226    /// Call `f` with the largest contiguous slice of octets in the receive buffer,
227    /// and dequeue the amount of elements returned by `f`.
228    ///
229    /// If no data is available, it waits until there is at least one byte available.
230    pub async fn read_with<F, R>(&mut self, f: F) -> Result<R, Error>
231    where
232        F: FnOnce(&mut [u8]) -> (usize, R),
233    {
234        self.io.read_with(f).await
235    }
236
237    /// Split the socket into reader and a writer halves.
238    pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) {
239        (TcpReader { io: self.io }, TcpWriter { io: self.io })
240    }
241
242    /// Connect to a remote host.
243    pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<(), ConnectError>
244    where
245        T: Into<IpEndpoint>,
246    {
247        let local_port = self.io.stack.with_mut(|i| i.get_local_port());
248
249        match {
250            self.io
251                .with_mut(|s, i| s.connect(i.context(), remote_endpoint, local_port))
252        } {
253            Ok(()) => {}
254            Err(tcp::ConnectError::InvalidState) => return Err(ConnectError::InvalidState),
255            Err(tcp::ConnectError::Unaddressable) => return Err(ConnectError::NoRoute),
256        }
257
258        poll_fn(|cx| {
259            self.io.with_mut(|s, _| match s.state() {
260                tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(ConnectError::ConnectionReset)),
261                tcp::State::Listen => unreachable!(),
262                tcp::State::SynSent | tcp::State::SynReceived => {
263                    s.register_send_waker(cx.waker());
264                    Poll::Pending
265                }
266                _ => Poll::Ready(Ok(())),
267            })
268        })
269        .await
270    }
271
272    /// Accept a connection from a remote host.
273    ///
274    /// This function puts the socket in listening mode, and waits until a connection is received.
275    pub async fn accept<T>(&mut self, local_endpoint: T) -> Result<(), AcceptError>
276    where
277        T: Into<IpListenEndpoint>,
278    {
279        match self.io.with_mut(|s, _| s.listen(local_endpoint)) {
280            Ok(()) => {}
281            Err(tcp::ListenError::InvalidState) => return Err(AcceptError::InvalidState),
282            Err(tcp::ListenError::Unaddressable) => return Err(AcceptError::InvalidPort),
283        }
284
285        poll_fn(|cx| {
286            self.io.with_mut(|s, _| match s.state() {
287                tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
288                    s.register_send_waker(cx.waker());
289                    Poll::Pending
290                }
291                _ => Poll::Ready(Ok(())),
292            })
293        })
294        .await
295    }
296
297    /// Wait until the socket becomes readable.
298    ///
299    /// A socket becomes readable when the receive half of the full-duplex connection is open
300    /// (see [may_recv](#method.may_recv)), and there is some pending data in the receive buffer.
301    ///
302    /// This is the equivalent of [read](#method.read), without buffering any data.
303    pub fn wait_read_ready(&self) -> impl Future<Output = ()> + '_ {
304        poll_fn(move |cx| self.io.poll_read_ready(cx))
305    }
306
307    /// Read data from the socket.
308    ///
309    /// Returns how many bytes were read, or an error. If no data is available, it waits
310    /// until there is at least one byte available.
311    ///
312    /// A return value of Ok(0) means that the socket was closed and is longer
313    /// able to receive any data.
314    pub fn read<'s>(&'s mut self, buf: &'s mut [u8]) -> impl Future<Output = Result<usize, Error>> + 's {
315        self.io.read(buf)
316    }
317
318    /// Wait until the socket becomes writable.
319    ///
320    /// A socket becomes writable when the transmit half of the full-duplex connection is open
321    /// (see [may_send](#method.may_send)), and the transmit buffer is not full.
322    ///
323    /// This is the equivalent of [write](#method.write), without sending any data.
324    pub fn wait_write_ready(&self) -> impl Future<Output = ()> + '_ {
325        poll_fn(move |cx| self.io.poll_write_ready(cx))
326    }
327
328    /// Write data to the socket.
329    ///
330    /// Returns how many bytes were written, or an error. If the socket is not ready to
331    /// accept data, it waits until it is.
332    pub fn write<'s>(&'s mut self, buf: &'s [u8]) -> impl Future<Output = Result<usize, Error>> + 's {
333        self.io.write(buf)
334    }
335
336    /// Flushes the written data to the socket.
337    ///
338    /// This waits until all data has been sent, and ACKed by the remote host. For a connection
339    /// closed with [`abort()`](TcpSocket::abort) it will wait for the TCP RST packet to be sent.
340    pub fn flush(&mut self) -> impl Future<Output = Result<(), Error>> + '_ {
341        self.io.flush()
342    }
343
344    /// Set the timeout for the socket.
345    ///
346    /// If the timeout is set, the socket will be closed if no data is received for the
347    /// specified duration.
348    ///
349    /// # Note:
350    /// Set a keep alive interval ([`set_keep_alive`] to prevent timeouts when
351    /// the remote could still respond.
352    pub fn set_timeout(&mut self, duration: Option<Duration>) {
353        self.io
354            .with_mut(|s, _| s.set_timeout(duration.map(duration_to_smoltcp)))
355    }
356
357    /// Set the keep-alive interval for the socket.
358    ///
359    /// If the keep-alive interval is set, the socket will send keep-alive packets after
360    /// the specified duration of inactivity.
361    ///
362    /// If not set, the socket will not send keep-alive packets.
363    ///
364    /// By setting a [`timeout`](Self::timeout) larger then the keep alive you
365    /// can detect a remote endpoint that no longer answers.
366    pub fn set_keep_alive(&mut self, interval: Option<Duration>) {
367        self.io
368            .with_mut(|s, _| s.set_keep_alive(interval.map(duration_to_smoltcp)))
369    }
370
371    /// Set the hop limit field in the IP header of sent packets.
372    pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
373        self.io.with_mut(|s, _| s.set_hop_limit(hop_limit))
374    }
375
376    /// Get the local endpoint of the socket.
377    ///
378    /// Returns `None` if the socket is not bound (listening) or not connected.
379    pub fn local_endpoint(&self) -> Option<IpEndpoint> {
380        self.io.with(|s, _| s.local_endpoint())
381    }
382
383    /// Get the remote endpoint of the socket.
384    ///
385    /// Returns `None` if the socket is not connected.
386    pub fn remote_endpoint(&self) -> Option<IpEndpoint> {
387        self.io.with(|s, _| s.remote_endpoint())
388    }
389
390    /// Get the state of the socket.
391    pub fn state(&self) -> State {
392        self.io.with(|s, _| s.state())
393    }
394
395    /// Close the write half of the socket.
396    ///
397    /// This closes only the write half of the socket. The read half side remains open, the
398    /// socket can still receive data.
399    ///
400    /// Data that has been written to the socket and not yet sent (or not yet ACKed) will still
401    /// still sent. The last segment of the pending to send data is sent with the FIN flag set.
402    pub fn close(&mut self) {
403        self.io.with_mut(|s, _| s.close())
404    }
405
406    /// Forcibly close the socket.
407    ///
408    /// This instantly closes both the read and write halves of the socket. Any pending data
409    /// that has not been sent will be lost.
410    ///
411    /// Note that the TCP RST packet is not sent immediately - if the `TcpSocket` is dropped too soon
412    /// the remote host may not know the connection has been closed.
413    /// `abort()` callers should wait for a [`flush()`](TcpSocket::flush) call to complete before
414    /// dropping or reusing the socket.
415    pub fn abort(&mut self) {
416        self.io.with_mut(|s, _| s.abort())
417    }
418
419    /// Return whether the transmit half of the full-duplex connection is open.
420    ///
421    /// This function returns true if it's possible to send data and have it arrive
422    /// to the remote endpoint. However, it does not make any guarantees about the state
423    /// of the transmit buffer, and even if it returns true, [write](#method.write) may
424    /// not be able to enqueue any octets.
425    ///
426    /// In terms of the TCP state machine, the socket must be in the `ESTABLISHED` or
427    /// `CLOSE-WAIT` state.
428    pub fn may_send(&self) -> bool {
429        self.io.with(|s, _| s.may_send())
430    }
431
432    /// Check whether the transmit half of the full-duplex connection is open
433    /// (see [may_send](#method.may_send)), and the transmit buffer is not full.
434    pub fn can_send(&self) -> bool {
435        self.io.with(|s, _| s.can_send())
436    }
437
438    /// return whether the receive half of the full-duplex connection is open.
439    /// This function returns true if it’s possible to receive data from the remote endpoint.
440    /// It will return true while there is data in the receive buffer, and if there isn’t,
441    /// as long as the remote endpoint has not closed the connection.
442    pub fn may_recv(&self) -> bool {
443        self.io.with(|s, _| s.may_recv())
444    }
445
446    /// Get whether the socket is ready to receive data, i.e. whether there is some pending data in the receive buffer.
447    pub fn can_recv(&self) -> bool {
448        self.io.with(|s, _| s.can_recv())
449    }
450}
451
452impl<'a> Drop for TcpSocket<'a> {
453    fn drop(&mut self) {
454        self.io.stack.with_mut(|i| i.sockets.remove(self.io.handle));
455    }
456}
457
458fn _assert_covariant<'a, 'b: 'a>(x: TcpSocket<'b>) -> TcpSocket<'a> {
459    x
460}
461fn _assert_covariant_reader<'a, 'b: 'a>(x: TcpReader<'b>) -> TcpReader<'a> {
462    x
463}
464fn _assert_covariant_writer<'a, 'b: 'a>(x: TcpWriter<'b>) -> TcpWriter<'a> {
465    x
466}
467
468// =======================
469
470#[derive(Copy, Clone)]
471struct TcpIo<'a> {
472    stack: Stack<'a>,
473    handle: SocketHandle,
474}
475
476impl<'d> TcpIo<'d> {
477    fn with<R>(&self, f: impl FnOnce(&tcp::Socket, &Interface) -> R) -> R {
478        self.stack.with(|i| {
479            let socket = i.sockets.get::<tcp::Socket>(self.handle);
480            f(socket, &i.iface)
481        })
482    }
483
484    fn with_mut<R>(&self, f: impl FnOnce(&mut tcp::Socket, &mut Interface) -> R) -> R {
485        self.stack.with_mut(|i| {
486            let socket = i.sockets.get_mut::<tcp::Socket>(self.handle);
487            let res = f(socket, &mut i.iface);
488            i.waker.wake();
489            res
490        })
491    }
492
493    fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
494        self.with_mut(|s, _| {
495            if s.can_recv() {
496                Poll::Ready(())
497            } else {
498                s.register_recv_waker(cx.waker());
499                Poll::Pending
500            }
501        })
502    }
503
504    fn read<'s>(&'s mut self, buf: &'s mut [u8]) -> impl Future<Output = Result<usize, Error>> + 's {
505        poll_fn(|cx| {
506            // CAUTION: smoltcp semantics around EOF are different to what you'd expect
507            // from posix-like IO, so we have to tweak things here.
508            self.with_mut(|s, _| match s.recv_slice(buf) {
509                // Reading into empty buffer
510                Ok(0) if buf.is_empty() => {
511                    // embedded_io_async::Read's contract is to not block if buf is empty. While
512                    // this function is not a direct implementor of the trait method, we still don't
513                    // want our future to never resolve.
514                    Poll::Ready(Ok(0))
515                }
516                // No data ready
517                Ok(0) => {
518                    s.register_recv_waker(cx.waker());
519                    Poll::Pending
520                }
521                // Data ready!
522                Ok(n) => Poll::Ready(Ok(n)),
523                // EOF
524                Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
525                // Connection reset. TODO: this can also be timeouts etc, investigate.
526                Err(tcp::RecvError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
527            })
528        })
529    }
530
531    fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
532        self.with_mut(|s, _| {
533            if s.can_send() {
534                Poll::Ready(())
535            } else {
536                s.register_send_waker(cx.waker());
537                Poll::Pending
538            }
539        })
540    }
541
542    fn write<'s>(&'s mut self, buf: &'s [u8]) -> impl Future<Output = Result<usize, Error>> + 's {
543        poll_fn(|cx| {
544            self.with_mut(|s, _| match s.send_slice(buf) {
545                // Not ready to send (no space in the tx buffer)
546                Ok(0) => {
547                    s.register_send_waker(cx.waker());
548                    Poll::Pending
549                }
550                // Some data sent
551                Ok(n) => Poll::Ready(Ok(n)),
552                // Connection reset. TODO: this can also be timeouts etc, investigate.
553                Err(tcp::SendError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
554            })
555        })
556    }
557
558    async fn write_with<F, R>(&mut self, f: F) -> Result<R, Error>
559    where
560        F: FnOnce(&mut [u8]) -> (usize, R),
561    {
562        let mut f = Some(f);
563        poll_fn(move |cx| {
564            self.with_mut(|s, _| {
565                if !s.can_send() {
566                    if s.may_send() {
567                        // socket buffer is full wait until it has atleast one byte free
568                        s.register_send_waker(cx.waker());
569                        Poll::Pending
570                    } else {
571                        // if we can't transmit because the transmit half of the duplex connection is closed then return an error
572                        Poll::Ready(Err(Error::ConnectionReset))
573                    }
574                } else {
575                    Poll::Ready(match s.send(unwrap!(f.take())) {
576                        // Connection reset. TODO: this can also be timeouts etc, investigate.
577                        Err(tcp::SendError::InvalidState) => Err(Error::ConnectionReset),
578                        Ok(r) => Ok(r),
579                    })
580                }
581            })
582        })
583        .await
584    }
585
586    async fn read_with<F, R>(&mut self, f: F) -> Result<R, Error>
587    where
588        F: FnOnce(&mut [u8]) -> (usize, R),
589    {
590        let mut f = Some(f);
591        poll_fn(move |cx| {
592            self.with_mut(|s, _| {
593                if !s.can_recv() {
594                    if s.may_recv() {
595                        // socket buffer is empty wait until it has atleast one byte has arrived
596                        s.register_recv_waker(cx.waker());
597                        Poll::Pending
598                    } else {
599                        // if we can't receive because the receive half of the duplex connection is closed then return an error
600                        Poll::Ready(Err(Error::ConnectionReset))
601                    }
602                } else {
603                    Poll::Ready(match s.recv(unwrap!(f.take())) {
604                        // Connection reset. TODO: this can also be timeouts etc, investigate.
605                        Err(tcp::RecvError::Finished) | Err(tcp::RecvError::InvalidState) => {
606                            Err(Error::ConnectionReset)
607                        }
608                        Ok(r) => Ok(r),
609                    })
610                }
611            })
612        })
613        .await
614    }
615
616    fn flush(&mut self) -> impl Future<Output = Result<(), Error>> + '_ {
617        poll_fn(|cx| {
618            self.with_mut(|s, _| {
619                let data_pending = (s.send_queue() > 0) && s.state() != tcp::State::Closed;
620                let fin_pending = matches!(
621                    s.state(),
622                    tcp::State::FinWait1 | tcp::State::Closing | tcp::State::LastAck
623                );
624                let rst_pending = s.state() == tcp::State::Closed && s.remote_endpoint().is_some();
625
626                // If there are outstanding send operations, register for wake up and wait
627                // smoltcp issues wake-ups when octets are dequeued from the send buffer
628                if data_pending || fin_pending || rst_pending {
629                    s.register_send_waker(cx.waker());
630                    Poll::Pending
631                // No outstanding sends, socket is flushed
632                } else {
633                    Poll::Ready(Ok(()))
634                }
635            })
636        })
637    }
638
639    fn recv_capacity(&self) -> usize {
640        self.with(|s, _| s.recv_capacity())
641    }
642
643    fn send_capacity(&self) -> usize {
644        self.with(|s, _| s.send_capacity())
645    }
646
647    fn send_queue(&self) -> usize {
648        self.with(|s, _| s.send_queue())
649    }
650
651    fn recv_queue(&self) -> usize {
652        self.with(|s, _| s.recv_queue())
653    }
654}
655
656mod embedded_io_impls {
657    use super::*;
658
659    impl embedded_io_async::Error for ConnectError {
660        fn kind(&self) -> embedded_io_async::ErrorKind {
661            match self {
662                ConnectError::ConnectionReset => embedded_io_async::ErrorKind::ConnectionReset,
663                ConnectError::TimedOut => embedded_io_async::ErrorKind::TimedOut,
664                ConnectError::NoRoute => embedded_io_async::ErrorKind::NotConnected,
665                ConnectError::InvalidState => embedded_io_async::ErrorKind::Other,
666            }
667        }
668    }
669
670    impl embedded_io_async::Error for Error {
671        fn kind(&self) -> embedded_io_async::ErrorKind {
672            match self {
673                Error::ConnectionReset => embedded_io_async::ErrorKind::ConnectionReset,
674            }
675        }
676    }
677
678    impl<'d> embedded_io_async::ErrorType for TcpSocket<'d> {
679        type Error = Error;
680    }
681
682    impl<'d> embedded_io_async::Read for TcpSocket<'d> {
683        async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
684            self.io.read(buf).await
685        }
686    }
687
688    impl<'d> embedded_io_async::ReadReady for TcpSocket<'d> {
689        fn read_ready(&mut self) -> Result<bool, Self::Error> {
690            Ok(self.io.with(|s, _| s.can_recv() || !s.may_recv()))
691        }
692    }
693
694    impl<'d> embedded_io_async::Write for TcpSocket<'d> {
695        async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
696            self.io.write(buf).await
697        }
698
699        async fn flush(&mut self) -> Result<(), Self::Error> {
700            self.io.flush().await
701        }
702    }
703
704    impl<'d> embedded_io_async::WriteReady for TcpSocket<'d> {
705        fn write_ready(&mut self) -> Result<bool, Self::Error> {
706            Ok(self.io.with(|s, _| s.can_send()))
707        }
708    }
709
710    impl<'d> embedded_io_async::ErrorType for TcpReader<'d> {
711        type Error = Error;
712    }
713
714    impl<'d> embedded_io_async::Read for TcpReader<'d> {
715        async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
716            self.io.read(buf).await
717        }
718    }
719
720    impl<'d> embedded_io_async::ReadReady for TcpReader<'d> {
721        fn read_ready(&mut self) -> Result<bool, Self::Error> {
722            Ok(self.io.with(|s, _| s.can_recv() || !s.may_recv()))
723        }
724    }
725
726    impl<'d> embedded_io_async::ErrorType for TcpWriter<'d> {
727        type Error = Error;
728    }
729
730    impl<'d> embedded_io_async::Write for TcpWriter<'d> {
731        async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
732            self.io.write(buf).await
733        }
734
735        async fn flush(&mut self) -> Result<(), Self::Error> {
736            self.io.flush().await
737        }
738    }
739
740    impl<'d> embedded_io_async::WriteReady for TcpWriter<'d> {
741        fn write_ready(&mut self) -> Result<bool, Self::Error> {
742            Ok(self.io.with(|s, _| s.can_send()))
743        }
744    }
745}
746
747/// TCP client compatible with `embedded-nal-async` traits.
748pub mod client {
749    use core::cell::{Cell, UnsafeCell};
750    use core::mem::MaybeUninit;
751    use core::net::IpAddr;
752    use core::ptr::NonNull;
753
754    use super::*;
755
756    /// TCP client connection pool compatible with `embedded-nal-async` traits.
757    ///
758    /// The pool is capable of managing up to N concurrent connections with tx and rx buffers according to TX_SZ and RX_SZ.
759    pub struct TcpClient<'d, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> {
760        stack: Stack<'d>,
761        state: &'d TcpClientState<N, TX_SZ, RX_SZ>,
762        socket_timeout: Option<Duration>,
763    }
764
765    impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, N, TX_SZ, RX_SZ> {
766        /// Create a new `TcpClient`.
767        pub fn new(stack: Stack<'d>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Self {
768            Self {
769                stack,
770                state,
771                socket_timeout: None,
772            }
773        }
774
775        /// Set the timeout for each socket created by this `TcpClient`.
776        ///
777        /// If the timeout is set, the socket will be closed if no data is received for the
778        /// specified duration.
779        pub fn set_timeout(&mut self, timeout: Option<Duration>) {
780            self.socket_timeout = timeout;
781        }
782    }
783
784    impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_nal_async::TcpConnect
785        for TcpClient<'d, N, TX_SZ, RX_SZ>
786    {
787        type Error = Error;
788        type Connection<'m>
789            = TcpConnection<'m, N, TX_SZ, RX_SZ>
790        where
791            Self: 'm;
792
793        async fn connect<'a>(&'a self, remote: core::net::SocketAddr) -> Result<Self::Connection<'a>, Self::Error> {
794            let addr: crate::IpAddress = match remote.ip() {
795                #[cfg(feature = "proto-ipv4")]
796                IpAddr::V4(addr) => crate::IpAddress::Ipv4(addr),
797                #[cfg(not(feature = "proto-ipv4"))]
798                IpAddr::V4(_) => panic!("ipv4 support not enabled"),
799                #[cfg(feature = "proto-ipv6")]
800                IpAddr::V6(addr) => crate::IpAddress::Ipv6(addr),
801                #[cfg(not(feature = "proto-ipv6"))]
802                IpAddr::V6(_) => panic!("ipv6 support not enabled"),
803            };
804            let remote_endpoint = (addr, remote.port());
805            let mut socket = TcpConnection::new(self.stack, self.state)?;
806            socket.socket.set_timeout(self.socket_timeout);
807            socket
808                .socket
809                .connect(remote_endpoint)
810                .await
811                .map_err(|_| Error::ConnectionReset)?;
812            Ok(socket)
813        }
814    }
815
816    /// Opened TCP connection in a [`TcpClient`].
817    pub struct TcpConnection<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> {
818        socket: TcpSocket<'d>,
819        state: &'d TcpClientState<N, TX_SZ, RX_SZ>,
820        bufs: NonNull<([u8; TX_SZ], [u8; RX_SZ])>,
821    }
822
823    impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpConnection<'d, N, TX_SZ, RX_SZ> {
824        fn new(stack: Stack<'d>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Result<Self, Error> {
825            let mut bufs = state.pool.alloc().ok_or(Error::ConnectionReset)?;
826            Ok(Self {
827                socket: unsafe { TcpSocket::new(stack, &mut bufs.as_mut().1, &mut bufs.as_mut().0) },
828                state,
829                bufs,
830            })
831        }
832    }
833
834    impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> Drop for TcpConnection<'d, N, TX_SZ, RX_SZ> {
835        fn drop(&mut self) {
836            unsafe {
837                self.socket.close();
838                self.state.pool.free(self.bufs);
839            }
840        }
841    }
842
843    impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io_async::ErrorType
844        for TcpConnection<'d, N, TX_SZ, RX_SZ>
845    {
846        type Error = Error;
847    }
848
849    impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io_async::Read
850        for TcpConnection<'d, N, TX_SZ, RX_SZ>
851    {
852        async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
853            self.socket.read(buf).await
854        }
855    }
856
857    impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> embedded_io_async::Write
858        for TcpConnection<'d, N, TX_SZ, RX_SZ>
859    {
860        async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
861            self.socket.write(buf).await
862        }
863
864        async fn flush(&mut self) -> Result<(), Self::Error> {
865            self.socket.flush().await
866        }
867    }
868
869    /// State for TcpClient
870    pub struct TcpClientState<const N: usize, const TX_SZ: usize, const RX_SZ: usize> {
871        pool: Pool<([u8; TX_SZ], [u8; RX_SZ]), N>,
872    }
873
874    impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClientState<N, TX_SZ, RX_SZ> {
875        /// Create a new `TcpClientState`.
876        pub const fn new() -> Self {
877            Self { pool: Pool::new() }
878        }
879    }
880
881    struct Pool<T, const N: usize> {
882        used: [Cell<bool>; N],
883        data: [UnsafeCell<MaybeUninit<T>>; N],
884    }
885
886    impl<T, const N: usize> Pool<T, N> {
887        const VALUE: Cell<bool> = Cell::new(false);
888        const UNINIT: UnsafeCell<MaybeUninit<T>> = UnsafeCell::new(MaybeUninit::uninit());
889
890        const fn new() -> Self {
891            Self {
892                used: [Self::VALUE; N],
893                data: [Self::UNINIT; N],
894            }
895        }
896    }
897
898    impl<T, const N: usize> Pool<T, N> {
899        fn alloc(&self) -> Option<NonNull<T>> {
900            for n in 0..N {
901                // this can't race because Pool is not Sync.
902                if !self.used[n].get() {
903                    self.used[n].set(true);
904                    let p = self.data[n].get() as *mut T;
905                    return Some(unsafe { NonNull::new_unchecked(p) });
906                }
907            }
908            None
909        }
910
911        /// safety: p must be a pointer obtained from self.alloc that hasn't been freed yet.
912        unsafe fn free(&self, p: NonNull<T>) {
913            let origin = self.data.as_ptr() as *mut T;
914            let n = p.as_ptr().offset_from(origin);
915            assert!(n >= 0);
916            assert!((n as usize) < N);
917            self.used[n as usize].set(false);
918        }
919    }
920}