1use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 pin::Pin,
5 sync::{
6 atomic::{AtomicBool, Ordering},
7 Arc,
8 },
9 task::{Context, Poll, Waker},
10};
11
12use futures::Stream;
13use smoltcp::{
14 iface::{Config as InterfaceConfig, Interface, SocketHandle, SocketSet},
15 phy::Device,
16 socket::tcp::{Socket as TcpSocket, SocketBuffer as TcpSocketBuffer, State as TcpState},
17 storage::RingBuffer,
18 time::{Duration, Instant},
19 wire::{HardwareAddress, IpAddress, IpCidr, IpProtocol, Ipv4Address, Ipv6Address, TcpPacket},
20};
21use spin::Mutex as SpinMutex;
22use tokio::{
23 io::{AsyncRead, AsyncWrite, ReadBuf},
24 sync::{
25 mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender},
26 Notify,
27 },
28};
29use tracing::{error, trace};
30
31use crate::{
32 device::VirtualDevice,
33 packet::{AnyIpPktFrame, IpPacket},
34 Runner,
35};
36
37const DEFAULT_TCP_SEND_BUFFER_SIZE: u32 = 0x3FFF * 20;
39const DEFAULT_TCP_RECV_BUFFER_SIZE: u32 = 0x3FFF * 20;
40
41#[derive(Debug, Clone, Copy, Eq, PartialEq)]
42enum TcpSocketState {
43 Normal,
44 Close,
45 Closing,
46 Closed,
47}
48
49struct TcpSocketControl {
50 send_buffer: RingBuffer<'static, u8>,
51 send_waker: Option<Waker>,
52 recv_buffer: RingBuffer<'static, u8>,
53 recv_waker: Option<Waker>,
54 recv_state: TcpSocketState,
55 send_state: TcpSocketState,
56}
57
58struct TcpSocketCreation {
59 control: SharedControl,
60 socket: TcpSocket<'static>,
61}
62
63type SharedNotify = Arc<Notify>;
64type SharedControl = Arc<SpinMutex<TcpSocketControl>>;
65
66struct TcpListenerRunner;
67
68impl TcpListenerRunner {
69 fn create(
70 device: VirtualDevice,
71 iface: Interface,
72 iface_ingress_tx: UnboundedSender<Vec<u8>>,
73 iface_ingress_tx_avail: Arc<AtomicBool>,
74 tcp_rx: Receiver<AnyIpPktFrame>,
75 stream_tx: UnboundedSender<TcpStream>,
76 sockets: HashMap<SocketHandle, SharedControl>,
77 ) -> Runner {
78 Runner::new(async move {
79 let notify = Arc::new(Notify::new());
80 let (socket_tx, socket_rx) = unbounded_channel::<TcpSocketCreation>();
81 let res = tokio::select! {
82 v = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => v,
83 v = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => v,
84 };
85 res?;
86 trace!("VirtDevice::poll thread exited");
87 Ok(())
88 })
89 }
90
91 async fn handle_packet(
92 notify: SharedNotify,
93 iface_ingress_tx: UnboundedSender<Vec<u8>>,
94 iface_ingress_tx_avail: Arc<AtomicBool>,
95 mut tcp_rx: Receiver<AnyIpPktFrame>,
96 stream_tx: UnboundedSender<TcpStream>,
97 socket_tx: UnboundedSender<TcpSocketCreation>,
98 ) -> std::io::Result<()> {
99 while let Some(frame) = tcp_rx.recv().await {
100 let packet = match IpPacket::new_checked(frame.as_slice()) {
101 Ok(p) => p,
102 Err(err) => {
103 error!("invalid TCP IP packet: {:?}", err,);
104 continue;
105 }
106 };
107
108 if matches!(packet.protocol(), IpProtocol::Icmp | IpProtocol::Icmpv6) {
110 iface_ingress_tx
111 .send(frame)
112 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
113 iface_ingress_tx_avail.store(true, Ordering::Release);
114 notify.notify_one();
115 continue;
116 }
117
118 let src_ip = packet.src_addr();
119 let dst_ip = packet.dst_addr();
120 let payload = packet.payload();
121
122 let packet = match TcpPacket::new_checked(payload) {
123 Ok(p) => p,
124 Err(err) => {
125 error!("invalid TCP err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}");
126 continue;
127 }
128 };
129 let src_port = packet.src_port();
130 let dst_port = packet.dst_port();
131
132 let src_addr = SocketAddr::new(src_ip, src_port);
133 let dst_addr = SocketAddr::new(dst_ip, dst_port);
134
135 if packet.syn() && !packet.ack() {
137 let mut socket = TcpSocket::new(
138 TcpSocketBuffer::new(vec![0u8; DEFAULT_TCP_RECV_BUFFER_SIZE as usize]),
139 TcpSocketBuffer::new(vec![0u8; DEFAULT_TCP_SEND_BUFFER_SIZE as usize]),
140 );
141 socket.set_keep_alive(Some(Duration::from_secs(28)));
142 socket.set_timeout(Some(Duration::from_secs(7200)));
144 if let Err(err) = socket.listen(dst_addr) {
148 error!("listen error: {:?}", err);
149 continue;
150 }
151
152 trace!("created TCP connection for {} <-> {}", src_addr, dst_addr);
153
154 let control = Arc::new(SpinMutex::new(TcpSocketControl {
155 send_buffer: RingBuffer::new(vec![0u8; DEFAULT_TCP_SEND_BUFFER_SIZE as usize]),
156 send_waker: None,
157 recv_buffer: RingBuffer::new(vec![0u8; DEFAULT_TCP_RECV_BUFFER_SIZE as usize]),
158 recv_waker: None,
159 recv_state: TcpSocketState::Normal,
160 send_state: TcpSocketState::Normal,
161 }));
162
163 stream_tx
164 .send(TcpStream {
165 src_addr,
166 dst_addr,
167 notify: notify.clone(),
168 control: control.clone(),
169 })
170 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
171 socket_tx
172 .send(TcpSocketCreation { control, socket })
173 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
174 }
175
176 iface_ingress_tx
178 .send(frame)
179 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
180 iface_ingress_tx_avail.store(true, Ordering::Release);
181 notify.notify_one();
182 }
183 Ok(())
184 }
185
186 async fn handle_socket(
187 notify: SharedNotify,
188 mut device: VirtualDevice,
189 mut iface: Interface,
190 iface_ingress_tx_avail: Arc<AtomicBool>,
191 mut sockets: HashMap<SocketHandle, SharedControl>,
192 mut socket_rx: UnboundedReceiver<TcpSocketCreation>,
193 ) -> std::io::Result<()> {
194 let mut socket_set = SocketSet::new(vec![]);
195 loop {
196 while let Ok(TcpSocketCreation { control, socket }) = socket_rx.try_recv() {
197 let handle = socket_set.add(socket);
198 sockets.insert(handle, control);
199 }
200
201 let before_poll = Instant::now();
202 let updated_sockets = iface.poll(before_poll, &mut device, &mut socket_set);
203 if matches!(
204 updated_sockets,
205 smoltcp::iface::PollResult::SocketStateChanged
206 ) {
207 trace!("VirtDevice::poll costed {}", Instant::now() - before_poll);
208 }
209
210 let mut sockets_to_remove = Vec::new();
212
213 for (socket_handle, control) in sockets.iter() {
214 let socket_handle = *socket_handle;
215 let socket = socket_set.get_mut::<TcpSocket>(socket_handle);
216 let mut control = control.lock();
217
218 if socket.state() == TcpState::Closed {
220 sockets_to_remove.push(socket_handle);
221
222 control.send_state = TcpSocketState::Closed;
223 control.recv_state = TcpSocketState::Closed;
224
225 if let Some(waker) = control.send_waker.take() {
226 waker.wake();
227 }
228 if let Some(waker) = control.recv_waker.take() {
229 waker.wake();
230 }
231
232 trace!("closed TCP connection");
233 continue;
234 }
235
236 if matches!(control.send_state, TcpSocketState::Close)
241 && control.send_buffer.is_empty()
242 {
243 trace!("closing TCP Write Half, {:?}", socket.state());
244
245 socket.close();
246 control.send_state = TcpSocketState::Closing;
247 }
248
249 let mut wake_receiver = false;
251 while socket.can_recv() && !control.recv_buffer.is_full() {
252 let result = socket.recv(|buffer| {
253 let n = control.recv_buffer.enqueue_slice(buffer);
254 (n, ())
255 });
256
257 match result {
258 Ok(..) => wake_receiver = true,
259 Err(err) => {
260 error!("socket recv error: {:?}, {:?}", err, socket.state());
261
262 socket.abort();
264
265 if matches!(control.recv_state, TcpSocketState::Normal) {
266 control.recv_state = TcpSocketState::Closed;
267 }
268 wake_receiver = true;
269
270 break;
272 }
273 }
274 }
275
276 let states = [
279 TcpState::Listen,
280 TcpState::SynReceived,
281 TcpState::Established,
282 TcpState::FinWait1,
283 TcpState::FinWait2,
284 ];
285 if matches!(control.recv_state, TcpSocketState::Normal)
286 && !socket.may_recv()
287 && !states.contains(&socket.state())
288 {
289 trace!("closed TCP Read Half, {:?}", socket.state());
290
291 control.recv_state = TcpSocketState::Closed;
293 wake_receiver = true;
294 }
295
296 if wake_receiver && control.recv_waker.is_some() {
297 if let Some(waker) = control.recv_waker.take() {
298 waker.wake();
299 }
300 }
301
302 let mut wake_sender = false;
304 while socket.can_send() && !control.send_buffer.is_empty() {
305 let result = socket.send(|buffer| {
306 let n = control.send_buffer.dequeue_slice(buffer);
307 (n, ())
308 });
309
310 match result {
311 Ok(..) => wake_sender = true,
312 Err(err) => {
313 error!("socket send error: {:?}, {:?}", err, socket.state());
314
315 socket.abort();
317
318 if matches!(control.send_state, TcpSocketState::Normal) {
319 control.send_state = TcpSocketState::Closed;
320 }
321 wake_sender = true;
322
323 break;
325 }
326 }
327 }
328
329 if wake_sender && control.send_waker.is_some() {
330 if let Some(waker) = control.send_waker.take() {
331 waker.wake();
332 }
333 }
334 }
335
336 for socket_handle in sockets_to_remove {
337 sockets.remove(&socket_handle);
338 socket_set.remove(socket_handle);
339 }
340
341 if !iface_ingress_tx_avail.load(Ordering::Acquire) {
342 let next_duration = iface
343 .poll_delay(before_poll, &socket_set)
344 .unwrap_or(Duration::from_millis(5));
345 if next_duration != Duration::ZERO {
346 let _ = tokio::time::timeout(
347 tokio::time::Duration::from(next_duration),
348 notify.notified(),
349 )
350 .await;
351 }
352 }
353 }
354 }
355}
356
357pub struct TcpListener {
358 stream_rx: UnboundedReceiver<TcpStream>,
359}
360
361impl TcpListener {
362 pub(super) fn new(
363 tcp_rx: Receiver<AnyIpPktFrame>,
364 stack_tx: Sender<AnyIpPktFrame>,
365 ) -> std::io::Result<(Runner, Self)> {
366 let (mut device, iface_ingress_tx, iface_ingress_tx_avail) = VirtualDevice::new(stack_tx);
367 let iface = Self::create_interface(&mut device)?;
368
369 let (stream_tx, stream_rx) = unbounded_channel();
370
371 let runner = TcpListenerRunner::create(
372 device,
373 iface,
374 iface_ingress_tx,
375 iface_ingress_tx_avail,
376 tcp_rx,
377 stream_tx,
378 HashMap::new(),
379 );
380
381 Ok((runner, Self { stream_rx }))
382 }
383
384 fn create_interface<D>(device: &mut D) -> std::io::Result<Interface>
385 where
386 D: Device + ?Sized,
387 {
388 let mut iface_config = InterfaceConfig::new(HardwareAddress::Ip);
389 iface_config.random_seed = rand::random();
390 let mut iface = Interface::new(iface_config, device, Instant::now());
391 iface.update_ip_addrs(|ip_addrs| {
392 ip_addrs
393 .push(IpCidr::new(IpAddress::v4(0, 0, 0, 1), 0))
394 .expect("iface IPv4");
395 ip_addrs
396 .push(IpCidr::new(IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 1), 0))
397 .expect("iface IPv6");
398 });
399 iface
400 .routes_mut()
401 .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1))
402 .map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
403 iface
404 .routes_mut()
405 .add_default_ipv6_route(Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1))
406 .map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
407 iface.set_any_ip(true);
408 Ok(iface)
409 }
410}
411
412impl Stream for TcpListener {
413 type Item = (TcpStream, SocketAddr, SocketAddr);
414
415 fn poll_next(
416 mut self: std::pin::Pin<&mut Self>,
417 cx: &mut std::task::Context<'_>,
418 ) -> std::task::Poll<Option<Self::Item>> {
419 self.stream_rx.poll_recv(cx).map(|stream| {
420 stream.map(|stream| {
421 let local_addr = *stream.local_addr();
422 let remote_addr: SocketAddr = *stream.remote_addr();
423 (stream, local_addr, remote_addr)
424 })
425 })
426 }
427}
428
429pub struct TcpStream {
430 src_addr: SocketAddr,
431 dst_addr: SocketAddr,
432 notify: SharedNotify,
433 control: SharedControl,
434}
435
436impl Drop for TcpStream {
437 fn drop(&mut self) {
438 let mut control = self.control.lock();
439
440 if matches!(control.recv_state, TcpSocketState::Normal) {
441 control.recv_state = TcpSocketState::Close;
442 }
443
444 if matches!(control.send_state, TcpSocketState::Normal) {
445 control.send_state = TcpSocketState::Close;
446 }
447
448 self.notify.notify_one();
449 }
450}
451
452impl TcpStream {
453 pub fn local_addr(&self) -> &SocketAddr {
454 &self.src_addr
455 }
456
457 pub fn remote_addr(&self) -> &SocketAddr {
458 &self.dst_addr
459 }
460}
461
462impl AsyncRead for TcpStream {
463 fn poll_read(
464 self: Pin<&mut Self>,
465 cx: &mut Context<'_>,
466 buf: &mut ReadBuf<'_>,
467 ) -> Poll<std::io::Result<()>> {
468 let mut control = self.control.lock();
469
470 if control.recv_buffer.is_empty() {
472 if matches!(control.recv_state, TcpSocketState::Closed) {
474 return Ok(()).into();
475 }
476
477 if let Some(old_waker) = control.recv_waker.replace(cx.waker().clone()) {
479 if !old_waker.will_wake(cx.waker()) {
480 old_waker.wake();
481 }
482 }
483
484 return Poll::Pending;
485 }
486
487 let recv_buf = buf.initialize_unfilled();
488 let n = control.recv_buffer.dequeue_slice(recv_buf);
489 buf.advance(n);
490
491 if n > 0 {
492 self.notify.notify_one();
493 }
494
495 Ok(()).into()
496 }
497}
498
499impl AsyncWrite for TcpStream {
500 fn poll_write(
501 self: Pin<&mut Self>,
502 cx: &mut Context<'_>,
503 buf: &[u8],
504 ) -> Poll<std::io::Result<usize>> {
505 let mut control = self.control.lock();
506
507 if !matches!(control.send_state, TcpSocketState::Normal) {
509 return Err(std::io::ErrorKind::BrokenPipe.into()).into();
510 }
511
512 if control.send_buffer.is_full() {
515 if let Some(old_waker) = control.send_waker.replace(cx.waker().clone()) {
516 if !old_waker.will_wake(cx.waker()) {
517 old_waker.wake();
518 }
519 }
520
521 return Poll::Pending;
522 }
523
524 let n = control.send_buffer.enqueue_slice(buf);
525
526 if n > 0 {
527 self.notify.notify_one();
528 }
529
530 Ok(n).into()
531 }
532
533 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
534 Ok(()).into()
535 }
536
537 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
538 let mut control = self.control.lock();
539
540 if matches!(control.send_state, TcpSocketState::Closed) {
541 return Ok(()).into();
542 }
543
544 if matches!(control.send_state, TcpSocketState::Normal) {
546 control.send_state = TcpSocketState::Close;
547 }
548
549 if let Some(old_waker) = control.send_waker.replace(cx.waker().clone()) {
550 if !old_waker.will_wake(cx.waker()) {
551 old_waker.wake();
552 }
553 }
554
555 self.notify.notify_one();
556
557 Poll::Pending
558 }
559}