Skip to main content

ax_net/smoltcp_impl/
tcp.rs

1use core::{
2    cell::UnsafeCell,
3    net::SocketAddr,
4    sync::atomic::{AtomicBool, AtomicU8, Ordering},
5};
6
7use ax_errno::{AxError, AxResult, ax_err, ax_err_type};
8use ax_io::PollState;
9use ax_sync::Mutex;
10use smoltcp::{
11    iface::SocketHandle,
12    socket::tcp::{self, ConnectError, State},
13    wire::{IpEndpoint, IpListenEndpoint},
14};
15
16use super::{ETH0, LISTEN_TABLE, SOCKET_SET, SocketSetWrapper, addr::UNSPECIFIED_ENDPOINT};
17
18// State transitions:
19// CLOSED -(connect)-> BUSY -> CONNECTING -> CONNECTED -(shutdown)-> BUSY -> CLOSED
20//       |
21//       |-(listen)-> BUSY -> LISTENING -(shutdown)-> BUSY -> CLOSED
22//       |
23//        -(bind)-> BUSY -> CLOSED
24const STATE_CLOSED: u8 = 0;
25const STATE_BUSY: u8 = 1;
26const STATE_CONNECTING: u8 = 2;
27const STATE_CONNECTED: u8 = 3;
28const STATE_LISTENING: u8 = 4;
29
30/// A TCP socket that provides POSIX-like APIs.
31///
32/// - [`connect`] is for TCP clients.
33/// - [`bind`], [`listen`], and [`accept`] are for TCP servers.
34/// - Other methods are for both TCP clients and servers.
35///
36/// [`connect`]: TcpSocket::connect
37/// [`bind`]: TcpSocket::bind
38/// [`listen`]: TcpSocket::listen
39/// [`accept`]: TcpSocket::accept
40pub struct TcpSocket {
41    state: AtomicU8,
42    handle: UnsafeCell<Option<SocketHandle>>,
43    local_addr: UnsafeCell<IpEndpoint>,
44    peer_addr: UnsafeCell<IpEndpoint>,
45    nonblock: AtomicBool,
46    reuse_addr: AtomicBool,
47}
48
49unsafe impl Sync for TcpSocket {}
50
51impl Default for TcpSocket {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl TcpSocket {
58    /// Creates a new TCP socket.
59    pub const fn new() -> Self {
60        Self {
61            state: AtomicU8::new(STATE_CLOSED),
62            handle: UnsafeCell::new(None),
63            local_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
64            peer_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
65            nonblock: AtomicBool::new(false),
66            reuse_addr: AtomicBool::new(false),
67        }
68    }
69
70    /// Creates a new TCP socket that is already connected.
71    const fn new_connected(
72        handle: SocketHandle,
73        local_addr: IpEndpoint,
74        peer_addr: IpEndpoint,
75    ) -> Self {
76        Self {
77            state: AtomicU8::new(STATE_CONNECTED),
78            handle: UnsafeCell::new(Some(handle)),
79            local_addr: UnsafeCell::new(local_addr),
80            peer_addr: UnsafeCell::new(peer_addr),
81            nonblock: AtomicBool::new(false),
82            reuse_addr: AtomicBool::new(false),
83        }
84    }
85
86    /// Returns the local address and port, or
87    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
88    pub fn local_addr(&self) -> AxResult<SocketAddr> {
89        match self.get_state() {
90            STATE_CONNECTED | STATE_LISTENING => {
91                Ok(SocketAddr::from(unsafe { self.local_addr.get().read() }))
92            }
93            _ => Err(AxError::NotConnected),
94        }
95    }
96
97    /// Returns the remote address and port, or
98    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
99    pub fn peer_addr(&self) -> AxResult<SocketAddr> {
100        match self.get_state() {
101            STATE_CONNECTED | STATE_LISTENING => {
102                Ok(SocketAddr::from(unsafe { self.peer_addr.get().read() }))
103            }
104            _ => Err(AxError::NotConnected),
105        }
106    }
107
108    /// Returns whether this socket is in nonblocking mode.
109    #[inline]
110    pub fn is_nonblocking(&self) -> bool {
111        self.nonblock.load(Ordering::Acquire)
112    }
113
114    /// Moves this TCP stream into or out of nonblocking mode.
115    ///
116    /// This will result in `read`, `write`, `recv` and `send` operations
117    /// becoming nonblocking, i.e., immediately returning from their calls.
118    /// If the IO operation is successful, `Ok` is returned and no further
119    /// action is required. If the IO operation could not be completed and needs
120    /// to be retried, an error with kind  [`Err(WouldBlock)`](AxError::WouldBlock) is
121    /// returned.
122    #[inline]
123    pub fn set_nonblocking(&self, nonblocking: bool) {
124        self.nonblock.store(nonblocking, Ordering::Release);
125    }
126
127    /// Returns whether SO_REUSEADDR behavior is enabled.
128    #[inline]
129    pub fn is_reuse_addr(&self) -> bool {
130        self.reuse_addr.load(Ordering::Acquire)
131    }
132
133    /// Enables or disables SO_REUSEADDR behavior.
134    #[inline]
135    pub fn set_reuseaddr(&self, reuse: bool) {
136        self.reuse_addr.store(reuse, Ordering::Release);
137    }
138
139    /// Connects to the given address and port.
140    ///
141    /// The local port is generated automatically.
142    pub fn connect(&self, remote_addr: SocketAddr) -> AxResult {
143        self.update_state(STATE_CLOSED, STATE_CONNECTING, || {
144            // SAFETY: no other threads can read or write these fields.
145            let handle = unsafe { self.handle.get().read() }
146                .unwrap_or_else(|| SOCKET_SET.add(SocketSetWrapper::new_tcp_socket()));
147
148            // TODO: check remote addr unreachable
149            let bound_endpoint = self.bound_endpoint()?;
150            let iface = &ETH0.iface;
151            let (local_endpoint, remote_endpoint) = SOCKET_SET
152                .with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
153                    socket
154                        .connect(iface.lock().context(), remote_addr, bound_endpoint)
155                        .or_else(|e| match e {
156                            ConnectError::InvalidState => {
157                                ax_err!(BadState, "socket connect() failed")
158                            }
159                            ConnectError::Unaddressable => {
160                                ax_err!(ConnectionRefused, "socket connect() failed")
161                            }
162                        })?;
163                    AxResult::Ok((
164                        socket.local_endpoint().unwrap(),
165                        socket.remote_endpoint().unwrap(),
166                    ))
167                })?;
168            unsafe {
169                // SAFETY: no other threads can read or write these fields as we
170                // have changed the state to `BUSY`.
171                self.local_addr.get().write(local_endpoint);
172                self.peer_addr.get().write(remote_endpoint);
173                self.handle.get().write(Some(handle));
174            }
175            Ok(())
176        })
177        .unwrap_or_else(|_| ax_err!(AlreadyExists, "socket connect() failed: already connected"))?; // EISCONN
178
179        // Here our state must be `CONNECTING`, and only one thread can run here.
180        if self.is_nonblocking() {
181            Err(AxError::WouldBlock)
182        } else {
183            self.block_on(|| {
184                let PollState { writable, .. } = self.poll_connect()?;
185                if !writable {
186                    Err(AxError::WouldBlock)
187                } else if self.get_state() == STATE_CONNECTED {
188                    Ok(())
189                } else {
190                    ax_err!(ConnectionRefused, "socket connect() failed")
191                }
192            })
193        }
194    }
195
196    /// Binds an unbound socket to the given address and port.
197    ///
198    /// If the given port is 0, it generates one automatically.
199    ///
200    /// It must be called before [`listen`](Self::listen) and
201    /// [`accept`](Self::accept).
202    pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
203        self.update_state(STATE_CLOSED, STATE_CLOSED, || {
204            // TODO: check addr is available
205            if local_addr.port() == 0 {
206                local_addr.set_port(get_ephemeral_port()?);
207            }
208            // SAFETY: no other threads can read or write `self.local_addr` as we
209            // have changed the state to `BUSY`.
210            unsafe {
211                let old = self.local_addr.get().read();
212                if old != UNSPECIFIED_ENDPOINT {
213                    return ax_err!(InvalidInput, "socket bind() failed: already bound");
214                }
215                self.local_addr.get().write(IpEndpoint::from(local_addr));
216            }
217            Ok(())
218        })
219        .unwrap_or_else(|_| ax_err!(InvalidInput, "socket bind() failed: already bound"))
220    }
221
222    /// Starts listening on the bound address and port.
223    ///
224    /// It must be called after [`bind`](Self::bind) and before
225    /// [`accept`](Self::accept).
226    pub fn listen(&self) -> AxResult {
227        self.update_state(STATE_CLOSED, STATE_LISTENING, || {
228            let bound_endpoint = self.bound_endpoint()?;
229            unsafe {
230                (*self.local_addr.get()).port = bound_endpoint.port;
231            }
232            LISTEN_TABLE.listen(bound_endpoint)?;
233            debug!("TCP socket listening on {bound_endpoint}");
234            Ok(())
235        })
236        .unwrap_or(Ok(())) // ignore simultaneous `listen`s.
237    }
238
239    /// Accepts a new connection.
240    ///
241    /// This function will block the calling thread until a new TCP connection
242    /// is established. When established, a new [`TcpSocket`] is returned.
243    ///
244    /// It must be called after [`bind`](Self::bind) and [`listen`](Self::listen).
245    pub fn accept(&self) -> AxResult<TcpSocket> {
246        if !self.is_listening() {
247            return ax_err!(InvalidInput, "socket accept() failed: not listen");
248        }
249
250        // SAFETY: `self.local_addr` should be initialized after `bind()`.
251        let local_port = unsafe { self.local_addr.get().read().port };
252        self.block_on(|| {
253            let (handle, (local_addr, peer_addr)) = LISTEN_TABLE.accept(local_port)?;
254            debug!("TCP socket accepted a new connection {peer_addr}");
255            Ok(TcpSocket::new_connected(handle, local_addr, peer_addr))
256        })
257    }
258
259    /// Close the connection.
260    pub fn shutdown(&self) -> AxResult {
261        // stream
262        self.update_state(STATE_CONNECTED, STATE_CLOSED, || {
263            // SAFETY: `self.handle` should be initialized in a connected socket, and
264            // no other threads can read or write it.
265            let handle = unsafe { self.handle.get().read().unwrap() };
266            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
267                debug!("TCP socket {handle}: shutting down");
268                socket.close();
269            });
270            unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; // clear bound address
271            SOCKET_SET.poll_interfaces();
272            Ok(())
273        })
274        .unwrap_or(Ok(()))?;
275
276        // listener
277        self.update_state(STATE_LISTENING, STATE_CLOSED, || {
278            // SAFETY: `self.local_addr` should be initialized in a listening socket,
279            // and no other threads can read or write it.
280            let local_port = unsafe { self.local_addr.get().read().port };
281            unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; // clear bound address
282            LISTEN_TABLE.unlisten(local_port);
283            SOCKET_SET.poll_interfaces();
284            Ok(())
285        })
286        .unwrap_or(Ok(()))?;
287
288        // ignore for other states
289        Ok(())
290    }
291
292    /// Receives data from the socket, stores it in the given buffer.
293    pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
294        if self.is_connecting() {
295            return Err(AxError::WouldBlock);
296        } else if !self.is_connected() {
297            return ax_err!(NotConnected, "socket recv() failed");
298        }
299
300        // SAFETY: `self.handle` should be initialized in a connected socket.
301        let handle = unsafe { self.handle.get().read().unwrap() };
302        self.block_on(|| {
303            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
304                if !socket.is_active() {
305                    // not open
306                    ax_err!(ConnectionRefused, "socket recv() failed")
307                } else if !socket.may_recv() {
308                    // connection closed
309                    Ok(0)
310                } else if socket.recv_queue() > 0 {
311                    // data available
312                    // TODO: use socket.recv(|buf| {...})
313                    let len = socket
314                        .recv_slice(buf)
315                        .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
316                    Ok(len)
317                } else {
318                    // no more data
319                    Err(AxError::WouldBlock)
320                }
321            })
322        })
323    }
324
325    /// Transmits data in the given buffer.
326    pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
327        if self.is_connecting() {
328            return Err(AxError::WouldBlock);
329        } else if !self.is_connected() {
330            return ax_err!(NotConnected, "socket send() failed");
331        }
332
333        // SAFETY: `self.handle` should be initialized in a connected socket.
334        let handle = unsafe { self.handle.get().read().unwrap() };
335        self.block_on(|| {
336            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
337                if !socket.is_active() || !socket.may_send() {
338                    // closed by remote
339                    ax_err!(ConnectionReset, "socket send() failed")
340                } else if socket.can_send() {
341                    // connected, and the tx buffer is not full
342                    // TODO: use socket.send(|buf| {...})
343                    let len = socket
344                        .send_slice(buf)
345                        .map_err(|_| ax_err_type!(BadState, "socket send() failed"))?;
346                    Ok(len)
347                } else {
348                    // tx buffer is full
349                    Err(AxError::WouldBlock)
350                }
351            })
352        })
353    }
354
355    /// Whether the socket is readable or writable.
356    pub fn poll(&self) -> AxResult<PollState> {
357        match self.get_state() {
358            STATE_CONNECTING => self.poll_connect(),
359            STATE_CONNECTED => self.poll_stream(),
360            STATE_LISTENING => self.poll_listener(),
361            _ => Ok(PollState {
362                readable: false,
363                writable: false,
364            }),
365        }
366    }
367
368    /// Checks if Nagle's algorithm is enabled for this TCP socket.
369    pub fn nodelay(&self) -> AxResult<bool> {
370        if let Some(h) = unsafe { self.handle.get().read() } {
371            Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.nagle_enabled()))
372        } else {
373            ax_err!(NotConnected, "socket is not connected")
374        }
375    }
376
377    /// Enables or disables Nagle's algorithm for this TCP socket.
378    pub fn set_nodelay(&self, enabled: bool) -> AxResult<()> {
379        if let Some(h) = unsafe { self.handle.get().read() } {
380            SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(h, |socket| {
381                socket.set_nagle_enabled(enabled);
382            });
383            Ok(())
384        } else {
385            ax_err!(NotConnected, "socket is not connected")
386        }
387    }
388
389    /// Returns the maximum capacity of the receive buffer in bytes.
390    pub fn recv_capacity(&self) -> AxResult<usize> {
391        if let Some(h) = unsafe { self.handle.get().read() } {
392            Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.recv_capacity()))
393        } else {
394            ax_err!(NotConnected, "socket is not connected")
395        }
396    }
397
398    /// Returns the maximum capacity of the send buffer in bytes.
399    pub fn send_capacity(&self) -> AxResult<usize> {
400        if let Some(h) = unsafe { self.handle.get().read() } {
401            Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.send_capacity()))
402        } else {
403            ax_err!(NotConnected, "socket is not connected")
404        }
405    }
406}
407
408/// Private methods
409impl TcpSocket {
410    #[inline]
411    fn get_state(&self) -> u8 {
412        self.state.load(Ordering::Acquire)
413    }
414
415    #[inline]
416    fn set_state(&self, state: u8) {
417        self.state.store(state, Ordering::Release);
418    }
419
420    /// Update the state of the socket atomically.
421    ///
422    /// If the current state is `expect`, it first changes the state to `STATE_BUSY`,
423    /// then calls the given function. If the function returns `Ok`, it changes the
424    /// state to `new`, otherwise it changes the state back to `expect`.
425    ///
426    /// It returns `Ok` if the current state is `expect`, otherwise it returns
427    /// the current state in `Err`.
428    fn update_state<F, T>(&self, expect: u8, new: u8, f: F) -> Result<AxResult<T>, u8>
429    where
430        F: FnOnce() -> AxResult<T>,
431    {
432        match self
433            .state
434            .compare_exchange(expect, STATE_BUSY, Ordering::Acquire, Ordering::Acquire)
435        {
436            Ok(_) => {
437                let res = f();
438                if res.is_ok() {
439                    self.set_state(new);
440                } else {
441                    self.set_state(expect);
442                }
443                Ok(res)
444            }
445            Err(old) => Err(old),
446        }
447    }
448
449    #[inline]
450    fn is_connecting(&self) -> bool {
451        self.get_state() == STATE_CONNECTING
452    }
453
454    #[inline]
455    fn is_connected(&self) -> bool {
456        self.get_state() == STATE_CONNECTED
457    }
458
459    #[inline]
460    fn is_listening(&self) -> bool {
461        self.get_state() == STATE_LISTENING
462    }
463
464    fn bound_endpoint(&self) -> AxResult<IpListenEndpoint> {
465        // SAFETY: no other threads can read or write `self.local_addr`.
466        let local_addr = unsafe { self.local_addr.get().read() };
467        let port = if local_addr.port != 0 {
468            local_addr.port
469        } else {
470            get_ephemeral_port()?
471        };
472        assert_ne!(port, 0);
473        let addr = if !local_addr.addr.is_unspecified() {
474            Some(local_addr.addr)
475        } else {
476            None
477        };
478        Ok(IpListenEndpoint { addr, port })
479    }
480
481    fn poll_connect(&self) -> AxResult<PollState> {
482        // SAFETY: `self.handle` should be initialized above.
483        let handle = unsafe { self.handle.get().read().unwrap() };
484        let writable =
485            SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| match socket.state() {
486                State::SynSent => false, // wait for connection
487                State::Established => {
488                    self.set_state(STATE_CONNECTED); // connected
489                    debug!(
490                        "TCP socket {}: connected to {}",
491                        handle,
492                        socket.remote_endpoint().unwrap(),
493                    );
494                    true
495                }
496                _ => {
497                    unsafe {
498                        self.local_addr.get().write(UNSPECIFIED_ENDPOINT);
499                        self.peer_addr.get().write(UNSPECIFIED_ENDPOINT);
500                    }
501                    self.set_state(STATE_CLOSED); // connection failed
502                    true
503                }
504            });
505        Ok(PollState {
506            readable: false,
507            writable,
508        })
509    }
510
511    fn poll_stream(&self) -> AxResult<PollState> {
512        // SAFETY: `self.handle` should be initialized in a connected socket.
513        let handle = unsafe { self.handle.get().read().unwrap() };
514        SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
515            Ok(PollState {
516                readable: !socket.may_recv() || socket.can_recv(),
517                writable: !socket.may_send() || socket.can_send(),
518            })
519        })
520    }
521
522    fn poll_listener(&self) -> AxResult<PollState> {
523        // SAFETY: `self.local_addr` should be initialized in a listening socket.
524        let local_addr = unsafe { self.local_addr.get().read() };
525        Ok(PollState {
526            readable: LISTEN_TABLE.can_accept(local_addr.port)?,
527            writable: false,
528        })
529    }
530
531    /// Block the current thread until the given function completes or fails.
532    ///
533    /// If the socket is non-blocking, it calls the function once and returns
534    /// immediately. Otherwise, it may call the function multiple times if it
535    /// returns [`Err(WouldBlock)`](AxError::WouldBlock).
536    fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
537    where
538        F: FnMut() -> AxResult<T>,
539    {
540        if self.is_nonblocking() {
541            f()
542        } else {
543            loop {
544                SOCKET_SET.poll_interfaces();
545                match f() {
546                    Ok(t) => return Ok(t),
547                    Err(AxError::WouldBlock) => ax_task::yield_now(),
548                    Err(e) => return Err(e),
549                }
550            }
551        }
552    }
553}
554
555impl Drop for TcpSocket {
556    fn drop(&mut self) {
557        self.shutdown().ok();
558        // Safe because we have mut reference to `self`.
559        if let Some(handle) = unsafe { self.handle.get().read() } {
560            SOCKET_SET.remove(handle);
561        }
562    }
563}
564
565fn get_ephemeral_port() -> AxResult<u16> {
566    const PORT_START: u16 = 0xc000;
567    const PORT_END: u16 = 0xffff;
568    static CURR: Mutex<u16> = Mutex::new(PORT_START);
569
570    let mut curr = CURR.lock();
571    let mut tries = 0;
572    // TODO: more robust
573    while tries <= PORT_END - PORT_START {
574        let port = *curr;
575        if *curr == PORT_END {
576            *curr = PORT_START;
577        } else {
578            *curr += 1;
579        }
580        if LISTEN_TABLE.can_listen(port) {
581            return Ok(port);
582        }
583        tries += 1;
584    }
585    ax_err!(AddrInUse, "no available ports!")
586}