1use core::{
2 net::SocketAddr,
3 sync::atomic::{AtomicBool, Ordering},
4};
5
6use ax_errno::{AxError, AxResult, ax_err, ax_err_type};
7use ax_io::PollState;
8use ax_sync::Mutex;
9use smoltcp::{
10 iface::SocketHandle,
11 socket::udp::{self, BindError, SendError},
12 wire::{IpEndpoint, IpListenEndpoint},
13};
14use spin::RwLock;
15
16use super::{SOCKET_SET, SocketSetWrapper, addr::UNSPECIFIED_ENDPOINT};
17
18pub struct UdpSocket {
20 handle: SocketHandle,
21 local_addr: RwLock<Option<IpEndpoint>>,
22 peer_addr: RwLock<Option<IpEndpoint>>,
23 nonblock: AtomicBool,
24}
25
26impl UdpSocket {
27 #[allow(clippy::new_without_default)]
29 pub fn new() -> Self {
30 let socket = SocketSetWrapper::new_udp_socket();
31 let handle = SOCKET_SET.add(socket);
32 Self {
33 handle,
34 local_addr: RwLock::new(None),
35 peer_addr: RwLock::new(None),
36 nonblock: AtomicBool::new(false),
37 }
38 }
39
40 pub fn local_addr(&self) -> AxResult<SocketAddr> {
43 match self.local_addr.try_read() {
44 Some(addr) => addr.map(Into::into).ok_or(AxError::NotConnected),
45 None => Err(AxError::NotConnected),
46 }
47 }
48
49 pub fn peer_addr(&self) -> AxResult<SocketAddr> {
52 self.remote_endpoint().map(Into::into)
53 }
54
55 #[inline]
57 pub fn is_nonblocking(&self) -> bool {
58 self.nonblock.load(Ordering::Acquire)
59 }
60
61 #[inline]
70 pub fn set_nonblocking(&self, nonblocking: bool) {
71 self.nonblock.store(nonblocking, Ordering::Release);
72 }
73
74 pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
79 let mut self_local_addr = self.local_addr.write();
80
81 if local_addr.port() == 0 {
82 local_addr.set_port(get_ephemeral_port()?);
83 }
84 if self_local_addr.is_some() {
85 return ax_err!(InvalidInput, "socket bind() failed: already bound");
86 }
87
88 let local_endpoint = IpEndpoint::from(local_addr);
89 let endpoint = IpListenEndpoint {
90 addr: (!local_endpoint.addr.is_unspecified()).then_some(local_endpoint.addr),
91 port: local_endpoint.port,
92 };
93 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
94 socket.bind(endpoint).or_else(|e| match e {
95 BindError::InvalidState => ax_err!(AlreadyExists, "socket bind() failed"),
96 BindError::Unaddressable => ax_err!(InvalidInput, "socket bind() failed"),
97 })
98 })?;
99
100 *self_local_addr = Some(local_endpoint);
101 debug!("UDP socket {}: bound on {}", self.handle, endpoint);
102 Ok(())
103 }
104
105 pub fn send_to(&self, buf: &[u8], remote_addr: SocketAddr) -> AxResult<usize> {
108 if remote_addr.port() == 0 || remote_addr.ip().is_unspecified() {
109 return ax_err!(InvalidInput, "socket send_to() failed: invalid address");
110 }
111 self.send_impl(buf, IpEndpoint::from(remote_addr))
112 }
113
114 pub fn recv_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
117 self.recv_impl(|socket| match socket.recv_slice(buf) {
118 Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))),
119 Err(_) => ax_err!(BadState, "socket recv_from() failed"),
120 })
121 }
122
123 pub fn peek_from(&self, buf: &mut [u8]) -> AxResult<(usize, SocketAddr)> {
126 self.recv_impl(|socket| match socket.peek_slice(buf) {
127 Ok((len, meta)) => Ok((len, SocketAddr::from(meta.endpoint))),
128 Err(_) => ax_err!(BadState, "socket recv_from() failed"),
129 })
130 }
131
132 pub fn connect(&self, addr: SocketAddr) -> AxResult {
140 let mut self_peer_addr = self.peer_addr.write();
141
142 if self.local_addr.read().is_none() {
143 self.bind(SocketAddr::from(UNSPECIFIED_ENDPOINT))?;
144 }
145
146 *self_peer_addr = Some(IpEndpoint::from(addr));
147 debug!("UDP socket {}: connected to {}", self.handle, addr);
148 Ok(())
149 }
150
151 pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
153 let remote_endpoint = self.remote_endpoint()?;
154 self.send_impl(buf, remote_endpoint)
155 }
156
157 pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
160 let remote_endpoint = self.remote_endpoint()?;
161 self.recv_impl(|socket| {
162 let (len, meta) = socket
163 .recv_slice(buf)
164 .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
165 if !remote_endpoint.addr.is_unspecified() && remote_endpoint.addr != meta.endpoint.addr
166 {
167 return Err(AxError::WouldBlock);
168 }
169 if remote_endpoint.port != 0 && remote_endpoint.port != meta.endpoint.port {
170 return Err(AxError::WouldBlock);
171 }
172 Ok(len)
173 })
174 }
175
176 pub fn shutdown(&self) -> AxResult {
178 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
179 debug!("UDP socket {}: shutting down", self.handle);
180 socket.close();
181 });
182 SOCKET_SET.poll_interfaces();
183 Ok(())
184 }
185
186 pub fn poll(&self) -> AxResult<PollState> {
188 if self.local_addr.read().is_none() {
189 return Ok(PollState {
190 readable: false,
191 writable: false,
192 });
193 }
194 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
195 Ok(PollState {
196 readable: socket.can_recv(),
197 writable: socket.can_send(),
198 })
199 })
200 }
201}
202
203impl UdpSocket {
205 fn remote_endpoint(&self) -> AxResult<IpEndpoint> {
206 match self.peer_addr.try_read() {
207 Some(addr) => addr.ok_or(AxError::NotConnected),
208 None => Err(AxError::NotConnected),
209 }
210 }
211
212 fn send_impl(&self, buf: &[u8], remote_endpoint: IpEndpoint) -> AxResult<usize> {
213 if self.local_addr.read().is_none() {
214 return ax_err!(NotConnected, "socket send() failed");
215 }
216
217 self.block_on(|| {
218 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
219 if socket.can_send() {
220 socket
221 .send_slice(buf, remote_endpoint)
222 .map_err(|e| match e {
223 SendError::BufferFull => AxError::WouldBlock,
224 SendError::Unaddressable => {
225 ax_err_type!(ConnectionRefused, "socket send() failed")
226 }
227 })?;
228 Ok(buf.len())
229 } else {
230 Err(AxError::WouldBlock)
232 }
233 })
234 })
235 }
236
237 fn recv_impl<F, T>(&self, mut op: F) -> AxResult<T>
238 where
239 F: FnMut(&mut udp::Socket) -> AxResult<T>,
240 {
241 if self.local_addr.read().is_none() {
242 return ax_err!(NotConnected, "socket send() failed");
243 }
244
245 self.block_on(|| {
246 SOCKET_SET.with_socket_mut::<udp::Socket, _, _>(self.handle, |socket| {
247 if socket.can_recv() {
248 op(socket)
250 } else {
251 Err(AxError::WouldBlock)
253 }
254 })
255 })
256 }
257
258 fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
259 where
260 F: FnMut() -> AxResult<T>,
261 {
262 if self.is_nonblocking() {
263 f()
264 } else {
265 loop {
266 SOCKET_SET.poll_interfaces();
267 match f() {
268 Ok(t) => return Ok(t),
269 Err(AxError::WouldBlock) => ax_task::yield_now(),
270 Err(e) => return Err(e),
271 }
272 }
273 }
274 }
275}
276
277impl Drop for UdpSocket {
278 fn drop(&mut self) {
279 self.shutdown().ok();
280 SOCKET_SET.remove(self.handle);
281 }
282}
283
284fn get_ephemeral_port() -> AxResult<u16> {
285 const PORT_START: u16 = 0xc000;
286 const PORT_END: u16 = 0xffff;
287 static CURR: Mutex<u16> = Mutex::new(PORT_START);
288 let mut curr = CURR.lock();
289
290 let port = *curr;
291 if *curr == PORT_END {
292 *curr = PORT_START;
293 } else {
294 *curr += 1;
295 }
296 Ok(port)
297}