1use core::{
2 cell::UnsafeCell,
3 net::SocketAddr,
4 sync::atomic::{AtomicBool, AtomicU8, Ordering},
5};
6
7use ax_errno::{AxError, AxResult, ax_err, ax_err_type};
8use ax_io::PollState;
9use ax_sync::Mutex;
10use smoltcp::{
11 iface::SocketHandle,
12 socket::tcp::{self, ConnectError, State},
13 wire::{IpEndpoint, IpListenEndpoint},
14};
15
16use super::{ETH0, LISTEN_TABLE, SOCKET_SET, SocketSetWrapper, addr::UNSPECIFIED_ENDPOINT};
17
18const STATE_CLOSED: u8 = 0;
25const STATE_BUSY: u8 = 1;
26const STATE_CONNECTING: u8 = 2;
27const STATE_CONNECTED: u8 = 3;
28const STATE_LISTENING: u8 = 4;
29
30pub struct TcpSocket {
41 state: AtomicU8,
42 handle: UnsafeCell<Option<SocketHandle>>,
43 local_addr: UnsafeCell<IpEndpoint>,
44 peer_addr: UnsafeCell<IpEndpoint>,
45 nonblock: AtomicBool,
46 reuse_addr: AtomicBool,
47}
48
49unsafe impl Sync for TcpSocket {}
50
51impl Default for TcpSocket {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl TcpSocket {
58 pub const fn new() -> Self {
60 Self {
61 state: AtomicU8::new(STATE_CLOSED),
62 handle: UnsafeCell::new(None),
63 local_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
64 peer_addr: UnsafeCell::new(UNSPECIFIED_ENDPOINT),
65 nonblock: AtomicBool::new(false),
66 reuse_addr: AtomicBool::new(false),
67 }
68 }
69
70 const fn new_connected(
72 handle: SocketHandle,
73 local_addr: IpEndpoint,
74 peer_addr: IpEndpoint,
75 ) -> Self {
76 Self {
77 state: AtomicU8::new(STATE_CONNECTED),
78 handle: UnsafeCell::new(Some(handle)),
79 local_addr: UnsafeCell::new(local_addr),
80 peer_addr: UnsafeCell::new(peer_addr),
81 nonblock: AtomicBool::new(false),
82 reuse_addr: AtomicBool::new(false),
83 }
84 }
85
86 pub fn local_addr(&self) -> AxResult<SocketAddr> {
89 match self.get_state() {
90 STATE_CONNECTED | STATE_LISTENING => {
91 Ok(SocketAddr::from(unsafe { self.local_addr.get().read() }))
92 }
93 _ => Err(AxError::NotConnected),
94 }
95 }
96
97 pub fn peer_addr(&self) -> AxResult<SocketAddr> {
100 match self.get_state() {
101 STATE_CONNECTED | STATE_LISTENING => {
102 Ok(SocketAddr::from(unsafe { self.peer_addr.get().read() }))
103 }
104 _ => Err(AxError::NotConnected),
105 }
106 }
107
108 #[inline]
110 pub fn is_nonblocking(&self) -> bool {
111 self.nonblock.load(Ordering::Acquire)
112 }
113
114 #[inline]
123 pub fn set_nonblocking(&self, nonblocking: bool) {
124 self.nonblock.store(nonblocking, Ordering::Release);
125 }
126
127 #[inline]
129 pub fn is_reuse_addr(&self) -> bool {
130 self.reuse_addr.load(Ordering::Acquire)
131 }
132
133 #[inline]
135 pub fn set_reuseaddr(&self, reuse: bool) {
136 self.reuse_addr.store(reuse, Ordering::Release);
137 }
138
139 pub fn connect(&self, remote_addr: SocketAddr) -> AxResult {
143 self.update_state(STATE_CLOSED, STATE_CONNECTING, || {
144 let handle = unsafe { self.handle.get().read() }
146 .unwrap_or_else(|| SOCKET_SET.add(SocketSetWrapper::new_tcp_socket()));
147
148 let bound_endpoint = self.bound_endpoint()?;
150 let iface = Ð0.iface;
151 let (local_endpoint, remote_endpoint) = SOCKET_SET
152 .with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
153 socket
154 .connect(iface.lock().context(), remote_addr, bound_endpoint)
155 .or_else(|e| match e {
156 ConnectError::InvalidState => {
157 ax_err!(BadState, "socket connect() failed")
158 }
159 ConnectError::Unaddressable => {
160 ax_err!(ConnectionRefused, "socket connect() failed")
161 }
162 })?;
163 AxResult::Ok((
164 socket.local_endpoint().unwrap(),
165 socket.remote_endpoint().unwrap(),
166 ))
167 })?;
168 unsafe {
169 self.local_addr.get().write(local_endpoint);
172 self.peer_addr.get().write(remote_endpoint);
173 self.handle.get().write(Some(handle));
174 }
175 Ok(())
176 })
177 .unwrap_or_else(|_| ax_err!(AlreadyExists, "socket connect() failed: already connected"))?; if self.is_nonblocking() {
181 Err(AxError::WouldBlock)
182 } else {
183 self.block_on(|| {
184 let PollState { writable, .. } = self.poll_connect()?;
185 if !writable {
186 Err(AxError::WouldBlock)
187 } else if self.get_state() == STATE_CONNECTED {
188 Ok(())
189 } else {
190 ax_err!(ConnectionRefused, "socket connect() failed")
191 }
192 })
193 }
194 }
195
196 pub fn bind(&self, mut local_addr: SocketAddr) -> AxResult {
203 self.update_state(STATE_CLOSED, STATE_CLOSED, || {
204 if local_addr.port() == 0 {
206 local_addr.set_port(get_ephemeral_port()?);
207 }
208 unsafe {
211 let old = self.local_addr.get().read();
212 if old != UNSPECIFIED_ENDPOINT {
213 return ax_err!(InvalidInput, "socket bind() failed: already bound");
214 }
215 self.local_addr.get().write(IpEndpoint::from(local_addr));
216 }
217 Ok(())
218 })
219 .unwrap_or_else(|_| ax_err!(InvalidInput, "socket bind() failed: already bound"))
220 }
221
222 pub fn listen(&self) -> AxResult {
227 self.update_state(STATE_CLOSED, STATE_LISTENING, || {
228 let bound_endpoint = self.bound_endpoint()?;
229 unsafe {
230 (*self.local_addr.get()).port = bound_endpoint.port;
231 }
232 LISTEN_TABLE.listen(bound_endpoint)?;
233 debug!("TCP socket listening on {bound_endpoint}");
234 Ok(())
235 })
236 .unwrap_or(Ok(())) }
238
239 pub fn accept(&self) -> AxResult<TcpSocket> {
246 if !self.is_listening() {
247 return ax_err!(InvalidInput, "socket accept() failed: not listen");
248 }
249
250 let local_port = unsafe { self.local_addr.get().read().port };
252 self.block_on(|| {
253 let (handle, (local_addr, peer_addr)) = LISTEN_TABLE.accept(local_port)?;
254 debug!("TCP socket accepted a new connection {peer_addr}");
255 Ok(TcpSocket::new_connected(handle, local_addr, peer_addr))
256 })
257 }
258
259 pub fn shutdown(&self) -> AxResult {
261 self.update_state(STATE_CONNECTED, STATE_CLOSED, || {
263 let handle = unsafe { self.handle.get().read().unwrap() };
266 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
267 debug!("TCP socket {handle}: shutting down");
268 socket.close();
269 });
270 unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; SOCKET_SET.poll_interfaces();
272 Ok(())
273 })
274 .unwrap_or(Ok(()))?;
275
276 self.update_state(STATE_LISTENING, STATE_CLOSED, || {
278 let local_port = unsafe { self.local_addr.get().read().port };
281 unsafe { self.local_addr.get().write(UNSPECIFIED_ENDPOINT) }; LISTEN_TABLE.unlisten(local_port);
283 SOCKET_SET.poll_interfaces();
284 Ok(())
285 })
286 .unwrap_or(Ok(()))?;
287
288 Ok(())
290 }
291
292 pub fn recv(&self, buf: &mut [u8]) -> AxResult<usize> {
294 if self.is_connecting() {
295 return Err(AxError::WouldBlock);
296 } else if !self.is_connected() {
297 return ax_err!(NotConnected, "socket recv() failed");
298 }
299
300 let handle = unsafe { self.handle.get().read().unwrap() };
302 self.block_on(|| {
303 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
304 if !socket.is_active() {
305 ax_err!(ConnectionRefused, "socket recv() failed")
307 } else if !socket.may_recv() {
308 Ok(0)
310 } else if socket.recv_queue() > 0 {
311 let len = socket
314 .recv_slice(buf)
315 .map_err(|_| ax_err_type!(BadState, "socket recv() failed"))?;
316 Ok(len)
317 } else {
318 Err(AxError::WouldBlock)
320 }
321 })
322 })
323 }
324
325 pub fn send(&self, buf: &[u8]) -> AxResult<usize> {
327 if self.is_connecting() {
328 return Err(AxError::WouldBlock);
329 } else if !self.is_connected() {
330 return ax_err!(NotConnected, "socket send() failed");
331 }
332
333 let handle = unsafe { self.handle.get().read().unwrap() };
335 self.block_on(|| {
336 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(handle, |socket| {
337 if !socket.is_active() || !socket.may_send() {
338 ax_err!(ConnectionReset, "socket send() failed")
340 } else if socket.can_send() {
341 let len = socket
344 .send_slice(buf)
345 .map_err(|_| ax_err_type!(BadState, "socket send() failed"))?;
346 Ok(len)
347 } else {
348 Err(AxError::WouldBlock)
350 }
351 })
352 })
353 }
354
355 pub fn poll(&self) -> AxResult<PollState> {
357 match self.get_state() {
358 STATE_CONNECTING => self.poll_connect(),
359 STATE_CONNECTED => self.poll_stream(),
360 STATE_LISTENING => self.poll_listener(),
361 _ => Ok(PollState {
362 readable: false,
363 writable: false,
364 }),
365 }
366 }
367
368 pub fn nodelay(&self) -> AxResult<bool> {
370 if let Some(h) = unsafe { self.handle.get().read() } {
371 Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.nagle_enabled()))
372 } else {
373 ax_err!(NotConnected, "socket is not connected")
374 }
375 }
376
377 pub fn set_nodelay(&self, enabled: bool) -> AxResult<()> {
379 if let Some(h) = unsafe { self.handle.get().read() } {
380 SOCKET_SET.with_socket_mut::<tcp::Socket, _, _>(h, |socket| {
381 socket.set_nagle_enabled(enabled);
382 });
383 Ok(())
384 } else {
385 ax_err!(NotConnected, "socket is not connected")
386 }
387 }
388
389 pub fn recv_capacity(&self) -> AxResult<usize> {
391 if let Some(h) = unsafe { self.handle.get().read() } {
392 Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.recv_capacity()))
393 } else {
394 ax_err!(NotConnected, "socket is not connected")
395 }
396 }
397
398 pub fn send_capacity(&self) -> AxResult<usize> {
400 if let Some(h) = unsafe { self.handle.get().read() } {
401 Ok(SOCKET_SET.with_socket::<tcp::Socket, _, _>(h, |socket| socket.send_capacity()))
402 } else {
403 ax_err!(NotConnected, "socket is not connected")
404 }
405 }
406}
407
408impl TcpSocket {
410 #[inline]
411 fn get_state(&self) -> u8 {
412 self.state.load(Ordering::Acquire)
413 }
414
415 #[inline]
416 fn set_state(&self, state: u8) {
417 self.state.store(state, Ordering::Release);
418 }
419
420 fn update_state<F, T>(&self, expect: u8, new: u8, f: F) -> Result<AxResult<T>, u8>
429 where
430 F: FnOnce() -> AxResult<T>,
431 {
432 match self
433 .state
434 .compare_exchange(expect, STATE_BUSY, Ordering::Acquire, Ordering::Acquire)
435 {
436 Ok(_) => {
437 let res = f();
438 if res.is_ok() {
439 self.set_state(new);
440 } else {
441 self.set_state(expect);
442 }
443 Ok(res)
444 }
445 Err(old) => Err(old),
446 }
447 }
448
449 #[inline]
450 fn is_connecting(&self) -> bool {
451 self.get_state() == STATE_CONNECTING
452 }
453
454 #[inline]
455 fn is_connected(&self) -> bool {
456 self.get_state() == STATE_CONNECTED
457 }
458
459 #[inline]
460 fn is_listening(&self) -> bool {
461 self.get_state() == STATE_LISTENING
462 }
463
464 fn bound_endpoint(&self) -> AxResult<IpListenEndpoint> {
465 let local_addr = unsafe { self.local_addr.get().read() };
467 let port = if local_addr.port != 0 {
468 local_addr.port
469 } else {
470 get_ephemeral_port()?
471 };
472 assert_ne!(port, 0);
473 let addr = if !local_addr.addr.is_unspecified() {
474 Some(local_addr.addr)
475 } else {
476 None
477 };
478 Ok(IpListenEndpoint { addr, port })
479 }
480
481 fn poll_connect(&self) -> AxResult<PollState> {
482 let handle = unsafe { self.handle.get().read().unwrap() };
484 let writable =
485 SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| match socket.state() {
486 State::SynSent => false, State::Established => {
488 self.set_state(STATE_CONNECTED); debug!(
490 "TCP socket {}: connected to {}",
491 handle,
492 socket.remote_endpoint().unwrap(),
493 );
494 true
495 }
496 _ => {
497 unsafe {
498 self.local_addr.get().write(UNSPECIFIED_ENDPOINT);
499 self.peer_addr.get().write(UNSPECIFIED_ENDPOINT);
500 }
501 self.set_state(STATE_CLOSED); true
503 }
504 });
505 Ok(PollState {
506 readable: false,
507 writable,
508 })
509 }
510
511 fn poll_stream(&self) -> AxResult<PollState> {
512 let handle = unsafe { self.handle.get().read().unwrap() };
514 SOCKET_SET.with_socket::<tcp::Socket, _, _>(handle, |socket| {
515 Ok(PollState {
516 readable: !socket.may_recv() || socket.can_recv(),
517 writable: !socket.may_send() || socket.can_send(),
518 })
519 })
520 }
521
522 fn poll_listener(&self) -> AxResult<PollState> {
523 let local_addr = unsafe { self.local_addr.get().read() };
525 Ok(PollState {
526 readable: LISTEN_TABLE.can_accept(local_addr.port)?,
527 writable: false,
528 })
529 }
530
531 fn block_on<F, T>(&self, mut f: F) -> AxResult<T>
537 where
538 F: FnMut() -> AxResult<T>,
539 {
540 if self.is_nonblocking() {
541 f()
542 } else {
543 loop {
544 SOCKET_SET.poll_interfaces();
545 match f() {
546 Ok(t) => return Ok(t),
547 Err(AxError::WouldBlock) => ax_task::yield_now(),
548 Err(e) => return Err(e),
549 }
550 }
551 }
552 }
553}
554
555impl Drop for TcpSocket {
556 fn drop(&mut self) {
557 self.shutdown().ok();
558 if let Some(handle) = unsafe { self.handle.get().read() } {
560 SOCKET_SET.remove(handle);
561 }
562 }
563}
564
565fn get_ephemeral_port() -> AxResult<u16> {
566 const PORT_START: u16 = 0xc000;
567 const PORT_END: u16 = 0xffff;
568 static CURR: Mutex<u16> = Mutex::new(PORT_START);
569
570 let mut curr = CURR.lock();
571 let mut tries = 0;
572 while tries <= PORT_END - PORT_START {
574 let port = *curr;
575 if *curr == PORT_END {
576 *curr = PORT_START;
577 } else {
578 *curr += 1;
579 }
580 if LISTEN_TABLE.can_listen(port) {
581 return Ok(port);
582 }
583 tries += 1;
584 }
585 ax_err!(AddrInUse, "no available ports!")
586}