Skip to main content

ax_net/smoltcp_impl/
udp.rs

1use core::{
2    net::SocketAddr,
3    sync::atomic::{AtomicBool, Ordering},
4};
5
6use ax_errno::{AxError, AxResult, ax_err, ax_err_type};
7use ax_io::PollState;
8use ax_sync::Mutex;
9use smoltcp::{
10    iface::SocketHandle,
11    socket::udp::{self, BindError, SendError},
12    wire::{IpEndpoint, IpListenEndpoint},
13};
14use spin::RwLock;
15
16use super::{SOCKET_SET, SocketSetWrapper, addr::UNSPECIFIED_ENDPOINT};
17
18/// A UDP socket that provides POSIX-like APIs.
19pub struct UdpSocket {
20    handle: SocketHandle,
21    local_addr: RwLock<Option<IpEndpoint>>,
22    peer_addr: RwLock<Option<IpEndpoint>>,
23    nonblock: AtomicBool,
24}
25
26impl UdpSocket {
27    /// Creates a new UDP socket.
28    #[allow(clippy::new_without_default)]
29    pub fn new() -> Self {
30        let socket = SocketSetWrapper::new_udp_socket();
31        let handle = SOCKET_SET.add(socket);
32        Self {
33            handle,
34            local_addr: RwLock::new(None),
35            peer_addr: RwLock::new(None),
36            nonblock: AtomicBool::new(false),
37        }
38    }
39
40    /// Returns the local address and port, or
41    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
42    pub fn local_addr(&self) -> AxResult<SocketAddr> {
43        match self.local_addr.try_read() {
44            Some(addr) => addr.map(Into::into).ok_or(AxError::NotConnected),
45            None => Err(AxError::NotConnected),
46        }
47    }
48
49    /// Returns the remote address and port, or
50    /// [`Err(NotConnected)`](AxError::NotConnected) if not connected.
51    pub fn peer_addr(&self) -> AxResult<SocketAddr> {
52        self.remote_endpoint().map(Into::into)
53    }
54
55    /// Returns whether this socket is in nonblocking mode.
56    #[inline]
57    pub fn is_nonblocking(&self) -> bool {
58        self.nonblock.load(Ordering::Acquire)
59    }
60
61    /// Moves this UDP socket into or out of nonblocking mode.
62    ///
63    /// This will result in `recv`, `recv_from`, `send`, and `send_to`
64    /// operations becoming nonblocking, i.e., immediately returning from their
65    /// calls. If the IO operation is successful, `Ok` is returned and no
66    /// further action is required. If the IO operation could not be completed
67    /// and needs to be retried, an error with kind
68    /// [`Err(WouldBlock)`](AxError::WouldBlock) is returned.
69    #[inline]
70    pub fn set_nonblocking(&self, nonblocking: bool) {
71        self.nonblock.store(nonblocking, Ordering::Release);
72    }
73
74    /// Binds an unbound socket to the given address and port.
75    ///
76    /// It must be called before [`send_to`](Self::send_to) and
77    /// [`recv_from`](Self::recv_from).
78    pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
79        let mut self_local_addr = self.local_addr.write();
80
81        if local_addr.port() == 0 {
82            local_addr.set_port(get_ephemeral_port()?);
83        }
84        if self_local_addr.is_some() {
85            return ax_err!(InvalidInput, "socket bind() failed: already bound");
86        }
87
88        let local_endpoint = IpEndpoint::from(local_addr);
89        let endpoint = IpListenEndpoint {
90            addr: (!local_endpoint.addr.is_unspecified()).then_some(local_endpoint.addr),
91            port: local_endpoint.port,
92        };
93        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
94            socket.bind(endpoint).or_else(|e| match e {
95                BindError::InvalidState => ax_err!(AlreadyExists, "socket bind() failed"),
96                BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"),
97            })
98        })?;
99
100        *self_local_addr = Some(local_endpoint);
101        debug!("UDP socket {}: bound on {}", self.handle, endpoint);
102        Ok(())
103    }
104
105    /// Sends data on the socket to the given address. On success, returns the
106    /// number of bytes written.
107    pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
108        if remote_addr.port() == 0 || remote_addr.ip().is_unspecified() {
109            return ax_err!(InvalidInput, "socket send_to() failed: invalid address");
110        }
111        self.send_impl(buf, IpEndpoint::from(remote_addr))
112    }
113
114    /// Receives a single datagram message on the socket. On success, returns
115    /// the number of bytes read and the origin.
116    pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
117        self.recv_impl(|socket| match socket.recv_slice(buf) {
118            Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))),
119            Err(_) => ax_err!(BadState, "socket recv_from() failed"),
120        })
121    }
122
123    /// Receives a single datagram message on the socket, without removing it from
124    /// the queue. On success, returns the number of bytes read and the origin.
125    pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
126        self.recv_impl(|socket| match socket.peek_slice(buf) {
127            Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))),
128            Err(_) => ax_err!(BadState, "socket recv_from() failed"),
129        })
130    }
131
132    /// Connects this UDP socket to a remote address, allowing the `send` and
133    /// `recv` to be used to send data and also applies filters to only receive
134    /// data from the specified address.
135    ///
136    /// The local port will be generated automatically if the socket is not bound.
137    /// It must be called before [`send`](Self::send) and
138    /// [`recv`](Self::recv).
139    pub fn connect(&self, addr: SocketAddr) -> AxResult {
140        let mut self_peer_addr = self.peer_addr.write();
141
142        if self.local_addr.read().is_none() {
143            self.bind(SocketAddr::from(UNSPECIFIED_ENDPOINT))?;
144        }
145
146        *self_peer_addr = Some(IpEndpoint::from(addr));
147        debug!("UDP socket {}: connected to {}", self.handle, addr);
148        Ok(())
149    }
150
151    /// Sends data on the socket to the remote address to which it is connected.
152    pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
153        let remote_endpoint = self.remote_endpoint()?;
154        self.send_impl(buf, remote_endpoint)
155    }
156
157    /// Receives a single datagram message on the socket from the remote address
158    /// to which it is connected. On success, returns the number of bytes read.
159    pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
160        let remote_endpoint = self.remote_endpoint()?;
161        self.recv_impl(|socket| {
162            let (len, meta) = socket
163                .recv_slice(buf)
164                .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
165            if !remote_endpoint.addr.is_unspecified() && remote_endpoint.addr != meta.endpoint.addr
166            {
167                return Err(AxError::WouldBlock);
168            }
169            if remote_endpoint.port != 0 && remote_endpoint.port != meta.endpoint.port {
170                return Err(AxError::WouldBlock);
171            }
172            Ok(len)
173        })
174    }
175
176    /// Close the socket.
177    pub fn shutdown(&self) -> AxResult {
178        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
179            debug!("UDP socket {}: shutting down", self.handle);
180            socket.close();
181        });
182        SOCKET_SET.poll_interfaces();
183        Ok(())
184    }
185
186    /// Whether the socket is readable or writable.
187    pub fn poll(&self) -> AxResult<PollState> {
188        if self.local_addr.read().is_none() {
189            return Ok(PollState {
190                readable: false,
191                writable: false,
192            });
193        }
194        SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
195            Ok(PollState {
196                readable: socket.can_recv(),
197                writable: socket.can_send(),
198            })
199        })
200    }
201}
202
203/// Private methods
204impl UdpSocket {
205    fn remote_endpoint(&self) -> AxResult<IpEndpoint> {
206        match self.peer_addr.try_read() {
207            Some(addr) => addr.ok_or(AxError::NotConnected),
208            None => Err(AxError::NotConnected),
209        }
210    }
211
212    fn send_impl(&self, buf: &[u8], remote_endpoint: IpEndpoint) -> AxResult<usize> {
213        if self.local_addr.read().is_none() {
214            return ax_err!(NotConnected, "socket send() failed");
215        }
216
217        self.block_on(|| {
218            SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
219                if socket.can_send() {
220                    socket
221                        .send_slice(buf, remote_endpoint)
222                        .map_err(|e| match e {
223                            SendError::BufferFull => AxError::WouldBlock,
224                            SendError::Unaddressable => {
225                                ax_err_type!(ConnectionRefused, "socket send() failed")
226                            }
227                        })?;
228                    Ok(buf.len())
229                } else {
230                    // tx buffer is full
231                    Err(AxError::WouldBlock)
232                }
233            })
234        })
235    }
236
237    fn recv_impl<F, T>(&self, mut op: F) -> AxResult<T>
238    where
239        F: FnMut(&mut udp::Socket) -> AxResult<T>,
240    {
241        if self.local_addr.read().is_none() {
242            return ax_err!(NotConnected, "socket send() failed");
243        }
244
245        self.block_on(|| {
246            SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
247                if socket.can_recv() {
248                    // data available
249                    op(socket)
250                } else {
251                    // no more data
252                    Err(AxError::WouldBlock)
253                }
254            })
255        })
256    }
257
258    fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
259    where
260        F: FnMut() -> AxResult<T>,
261    {
262        if self.is_nonblocking() {
263            f()
264        } else {
265            loop {
266                SOCKET_SET.poll_interfaces();
267                match f() {
268                    Ok(t) => return Ok(t),
269                    Err(AxError::WouldBlock) => ax_task::yield_now(),
270                    Err(e) => return Err(e),
271                }
272            }
273        }
274    }
275}
276
277impl Drop for UdpSocket {
278    fn drop(&mut self) {
279        self.shutdown().ok();
280        SOCKET_SET.remove(self.handle);
281    }
282}
283
284fn get_ephemeral_port() -> AxResult<u16> {
285    const PORT_START: u16 = 0xc000;
286    const PORT_END: u16 = 0xffff;
287    static CURR: Mutex<u16> = Mutex::new(PORT_START);
288    let mut curr = CURR.lock();
289
290    let port = *curr;
291    if *curr == PORT_END {
292        *curr = PORT_START;
293    } else {
294        *curr += 1;
295    }
296    Ok(port)
297}