1use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 pin::Pin,
5 sync::{
6 atomic::{AtomicBool, Ordering},
7 Arc,
8 },
9 task::{Context, Poll, Waker},
10};
11
12use futures::Stream;
13use smoltcp::{
14 iface::{Config as InterfaceConfig, Interface, SocketHandle, SocketSet},
15 phy::Device,
16 socket::tcp::{Socket as TcpSocket, SocketBuffer as TcpSocketBuffer, State as TcpState},
17 storage::RingBuffer,
18 time::{Duration, Instant},
19 wire::{HardwareAddress, IpAddress, IpCidr, IpProtocol, Ipv4Address, Ipv6Address, TcpPacket},
20};
21use spin::Mutex as SpinMutex;
22use tokio::{
23 io::{AsyncRead, AsyncWrite, ReadBuf},
24 sync::{
25 mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender},
26 Notify,
27 },
28};
29use tracing::{error, trace};
30
31use crate::{
32 device::VirtualDevice,
33 packet::{AnyIpPktFrame, IpPacket},
34 Runner,
35};
36
37const DEFAULT_TCP_SEND_BUFFER_SIZE: u32 = 0x3FFF * 20;
39const DEFAULT_TCP_RECV_BUFFER_SIZE: u32 = 0x3FFF * 20;
40
41#[derive(Debug, Clone, Copy, Eq, PartialEq)]
42enum TcpSocketState {
43 Normal,
44 Close,
45 Closing,
46 Closed,
47}
48
49struct TcpSocketControl {
50 send_buffer: RingBuffer<'static, u8>,
51 send_waker: Option<Waker>,
52 recv_buffer: RingBuffer<'static, u8>,
53 recv_waker: Option<Waker>,
54 recv_state: TcpSocketState,
55 send_state: TcpSocketState,
56}
57
58struct TcpSocketCreation {
59 control: SharedControl,
60 socket: TcpSocket<'static>,
61}
62
63type SharedNotify = Arc<Notify>;
64type SharedControl = Arc<SpinMutex<TcpSocketControl>>;
65
66struct TcpListenerRunner;
67
68impl TcpListenerRunner {
69 fn create(
70 device: VirtualDevice,
71 iface: Interface,
72 iface_ingress_tx: UnboundedSender<Vec<u8>>,
73 iface_ingress_tx_avail: Arc<AtomicBool>,
74 tcp_rx: Receiver<AnyIpPktFrame>,
75 stream_tx: UnboundedSender<TcpStream>,
76 sockets: HashMap<SocketHandle, SharedControl>,
77 ) -> Runner {
78 Runner::new(async move {
79 let notify = Arc::new(Notify::new());
80 let (socket_tx, socket_rx) = unbounded_channel::<TcpSocketCreation>();
81 let res = tokio::select! {
82 v = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => v,
83 v = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => v,
84 };
85 res?;
86 trace!("VirtDevice::poll thread exited");
87 Ok(())
88 })
89 }
90
91 async fn handle_packet(
92 notify: SharedNotify,
93 iface_ingress_tx: UnboundedSender<Vec<u8>>,
94 iface_ingress_tx_avail: Arc<AtomicBool>,
95 mut tcp_rx: Receiver<AnyIpPktFrame>,
96 stream_tx: UnboundedSender<TcpStream>,
97 socket_tx: UnboundedSender<TcpSocketCreation>,
98 ) -> std::io::Result<()> {
99 while let Some(frame) = tcp_rx.recv().await {
100 let packet = match IpPacket::new_checked(frame.as_slice()) {
101 Ok(p) => p,
102 Err(err) => {
103 error!("invalid TCP IP packet: {:?}", err,);
104 continue;
105 }
106 };
107
108 if matches!(packet.protocol(), IpProtocol::Icmp | IpProtocol::Icmpv6) {
110 iface_ingress_tx
111 .send(frame)
112 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
113 iface_ingress_tx_avail.store(true, Ordering::Release);
114 notify.notify_one();
115 continue;
116 }
117
118 let src_ip = packet.src_addr();
119 let dst_ip = packet.dst_addr();
120 let payload = packet.payload();
121
122 let packet = match TcpPacket::new_checked(payload) {
123 Ok(p) => p,
124 Err(err) => {
125 error!("invalid TCP err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}");
126 continue;
127 }
128 };
129 let src_port = packet.src_port();
130 let dst_port = packet.dst_port();
131
132 let src_addr = SocketAddr::new(src_ip, src_port);
133 let dst_addr = SocketAddr::new(dst_ip, dst_port);
134
135 if packet.syn() && !packet.ack() {
137 let mut socket = TcpSocket::new(
138 TcpSocketBuffer::new(vec![0u8; DEFAULT_TCP_RECV_BUFFER_SIZE as usize]),
139 TcpSocketBuffer::new(vec![0u8; DEFAULT_TCP_SEND_BUFFER_SIZE as usize]),
140 );
141 socket.set_keep_alive(Some(Duration::from_secs(28)));
142 socket.set_timeout(Some(Duration::from_secs(7200)));
144 if let Err(err) = socket.listen(dst_addr) {
148 error!("listen error: {:?}", err);
149 continue;
150 }
151
152 trace!("created TCP connection for {} <-> {}", src_addr, dst_addr);
153
154 let control = Arc::new(SpinMutex::new(TcpSocketControl {
155 send_buffer: RingBuffer::new(vec![0u8; DEFAULT_TCP_SEND_BUFFER_SIZE as usize]),
156 send_waker: None,
157 recv_buffer: RingBuffer::new(vec![0u8; DEFAULT_TCP_RECV_BUFFER_SIZE as usize]),
158 recv_waker: None,
159 recv_state: TcpSocketState::Normal,
160 send_state: TcpSocketState::Normal,
161 }));
162
163 stream_tx
164 .send(TcpStream {
165 src_addr,
166 dst_addr,
167 notify: notify.clone(),
168 control: control.clone(),
169 })
170 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
171 socket_tx
172 .send(TcpSocketCreation { control, socket })
173 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
174 }
175
176 iface_ingress_tx
178 .send(frame)
179 .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
180 iface_ingress_tx_avail.store(true, Ordering::Release);
181 notify.notify_one();
182 }
183 Ok(())
184 }
185
186 async fn handle_socket(
187 notify: SharedNotify,
188 mut device: VirtualDevice,
189 mut iface: Interface,
190 iface_ingress_tx_avail: Arc<AtomicBool>,
191 mut sockets: HashMap<SocketHandle, SharedControl>,
192 mut socket_rx: UnboundedReceiver<TcpSocketCreation>,
193 ) -> std::io::Result<()> {
194 let mut socket_set = SocketSet::new(vec![]);
195 loop {
196 while let Ok(TcpSocketCreation { control, socket }) = socket_rx.try_recv() {
197 let handle = socket_set.add(socket);
198 sockets.insert(handle, control);
199 }
200
201 let before_poll = Instant::now();
202 let updated_sockets = iface.poll(before_poll, &mut device, &mut socket_set);
203 if matches!(
204 updated_sockets,
205 smoltcp::iface::PollResult::SocketStateChanged
206 ) {
207 trace!("VirtDevice::poll costed {}", Instant::now() - before_poll);
208 }
209
210 let mut sockets_to_remove = Vec::new();
212
213 for (socket_handle, control) in sockets.iter() {
214 let socket_handle = *socket_handle;
215 let socket = socket_set.get_mut::<TcpSocket>(socket_handle);
216 let mut control = control.lock();
217
218 if socket.state() == TcpState::Closed {
220 sockets_to_remove.push(socket_handle);
221
222 control.send_state = TcpSocketState::Closed;
223 control.recv_state = TcpSocketState::Closed;
224
225 if let Some(waker) = control.send_waker.take() {
226 waker.wake();
227 }
228 if let Some(waker) = control.recv_waker.take() {
229 waker.wake();
230 }
231
232 trace!("closed TCP connection");
233 continue;
234 }
235
236 if matches!(control.send_state, TcpSocketState::Close) {
238 trace!("closing TCP Write Half, {:?}", socket.state());
239
240 socket.close();
242 control.send_state = TcpSocketState::Closing;
243
244 }
246
247 let mut wake_receiver = false;
249 while socket.can_recv() && !control.recv_buffer.is_full() {
250 let result = socket.recv(|buffer| {
251 let n = control.recv_buffer.enqueue_slice(buffer);
252 (n, ())
253 });
254
255 match result {
256 Ok(..) => wake_receiver = true,
257 Err(err) => {
258 error!("socket recv error: {:?}, {:?}", err, socket.state());
259
260 socket.abort();
262
263 if matches!(control.recv_state, TcpSocketState::Normal) {
264 control.recv_state = TcpSocketState::Closed;
265 }
266 wake_receiver = true;
267
268 break;
270 }
271 }
272 }
273
274 let states = [
277 TcpState::Listen,
278 TcpState::SynReceived,
279 TcpState::Established,
280 TcpState::FinWait1,
281 TcpState::FinWait2,
282 ];
283 if matches!(control.recv_state, TcpSocketState::Normal)
284 && !socket.may_recv()
285 && !states.contains(&socket.state())
286 {
287 trace!("closed TCP Read Half, {:?}", socket.state());
288
289 control.recv_state = TcpSocketState::Closed;
291 wake_receiver = true;
292 }
293
294 if wake_receiver && control.recv_waker.is_some() {
295 if let Some(waker) = control.recv_waker.take() {
296 waker.wake();
297 }
298 }
299
300 let mut wake_sender = false;
302 while socket.can_send() && !control.send_buffer.is_empty() {
303 let result = socket.send(|buffer| {
304 let n = control.send_buffer.dequeue_slice(buffer);
305 (n, ())
306 });
307
308 match result {
309 Ok(..) => wake_sender = true,
310 Err(err) => {
311 error!("socket send error: {:?}, {:?}", err, socket.state());
312
313 socket.abort();
315
316 if matches!(control.send_state, TcpSocketState::Normal) {
317 control.send_state = TcpSocketState::Closed;
318 }
319 wake_sender = true;
320
321 break;
323 }
324 }
325 }
326
327 if wake_sender && control.send_waker.is_some() {
328 if let Some(waker) = control.send_waker.take() {
329 waker.wake();
330 }
331 }
332 }
333
334 for socket_handle in sockets_to_remove {
335 sockets.remove(&socket_handle);
336 socket_set.remove(socket_handle);
337 }
338
339 if !iface_ingress_tx_avail.load(Ordering::Acquire) {
340 let next_duration = iface
341 .poll_delay(before_poll, &socket_set)
342 .unwrap_or(Duration::from_millis(5));
343 if next_duration != Duration::ZERO {
344 let _ = tokio::time::timeout(
345 tokio::time::Duration::from(next_duration),
346 notify.notified(),
347 )
348 .await;
349 }
350 }
351 }
352 }
353}
354
355pub struct TcpListener {
356 stream_rx: UnboundedReceiver<TcpStream>,
357}
358
359impl TcpListener {
360 pub(super) fn new(
361 tcp_rx: Receiver<AnyIpPktFrame>,
362 stack_tx: Sender<AnyIpPktFrame>,
363 ) -> std::io::Result<(Runner, Self)> {
364 let (mut device, iface_ingress_tx, iface_ingress_tx_avail) = VirtualDevice::new(stack_tx);
365 let iface = Self::create_interface(&mut device)?;
366
367 let (stream_tx, stream_rx) = unbounded_channel();
368
369 let runner = TcpListenerRunner::create(
370 device,
371 iface,
372 iface_ingress_tx,
373 iface_ingress_tx_avail,
374 tcp_rx,
375 stream_tx,
376 HashMap::new(),
377 );
378
379 Ok((runner, Self { stream_rx }))
380 }
381
382 fn create_interface<D>(device: &mut D) -> std::io::Result<Interface>
383 where
384 D: Device + ?Sized,
385 {
386 let mut iface_config = InterfaceConfig::new(HardwareAddress::Ip);
387 iface_config.random_seed = rand::random();
388 let mut iface = Interface::new(iface_config, device, Instant::now());
389 iface.update_ip_addrs(|ip_addrs| {
390 ip_addrs
391 .push(IpCidr::new(IpAddress::v4(0, 0, 0, 1), 0))
392 .expect("iface IPv4");
393 ip_addrs
394 .push(IpCidr::new(IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 1), 0))
395 .expect("iface IPv6");
396 });
397 iface
398 .routes_mut()
399 .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1))
400 .map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
401 iface
402 .routes_mut()
403 .add_default_ipv6_route(Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1))
404 .map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
405 iface.set_any_ip(true);
406 Ok(iface)
407 }
408}
409
410impl Stream for TcpListener {
411 type Item = (TcpStream, SocketAddr, SocketAddr);
412
413 fn poll_next(
414 mut self: std::pin::Pin<&mut Self>,
415 cx: &mut std::task::Context<'_>,
416 ) -> std::task::Poll<Option<Self::Item>> {
417 self.stream_rx.poll_recv(cx).map(|stream| {
418 stream.map(|stream| {
419 let local_addr = *stream.local_addr();
420 let remote_addr: SocketAddr = *stream.remote_addr();
421 (stream, local_addr, remote_addr)
422 })
423 })
424 }
425}
426
427pub struct TcpStream {
428 src_addr: SocketAddr,
429 dst_addr: SocketAddr,
430 notify: SharedNotify,
431 control: SharedControl,
432}
433
434impl Drop for TcpStream {
435 fn drop(&mut self) {
436 let mut control = self.control.lock();
437
438 if matches!(control.recv_state, TcpSocketState::Normal) {
439 control.recv_state = TcpSocketState::Close;
440 }
441
442 if matches!(control.send_state, TcpSocketState::Normal) {
443 control.send_state = TcpSocketState::Close;
444 }
445
446 self.notify.notify_one();
447 }
448}
449
450impl TcpStream {
451 pub fn local_addr(&self) -> &SocketAddr {
452 &self.src_addr
453 }
454
455 pub fn remote_addr(&self) -> &SocketAddr {
456 &self.dst_addr
457 }
458}
459
460impl AsyncRead for TcpStream {
461 fn poll_read(
462 self: Pin<&mut Self>,
463 cx: &mut Context<'_>,
464 buf: &mut ReadBuf<'_>,
465 ) -> Poll<std::io::Result<()>> {
466 let mut control = self.control.lock();
467
468 if control.recv_buffer.is_empty() {
470 if matches!(control.recv_state, TcpSocketState::Closed) {
472 return Ok(()).into();
473 }
474
475 if let Some(old_waker) = control.recv_waker.replace(cx.waker().clone()) {
477 if !old_waker.will_wake(cx.waker()) {
478 old_waker.wake();
479 }
480 }
481
482 return Poll::Pending;
483 }
484
485 let recv_buf = unsafe {
486 std::mem::transmute::<&mut [std::mem::MaybeUninit<u8>], &mut [u8]>(buf.unfilled_mut())
487 };
488 let n = control.recv_buffer.dequeue_slice(recv_buf);
489 buf.advance(n);
490
491 if n > 0 {
492 self.notify.notify_one();
493 }
494
495 Ok(()).into()
496 }
497}
498
499impl AsyncWrite for TcpStream {
500 fn poll_write(
501 self: Pin<&mut Self>,
502 cx: &mut Context<'_>,
503 buf: &[u8],
504 ) -> Poll<std::io::Result<usize>> {
505 let mut control = self.control.lock();
506
507 if !matches!(control.send_state, TcpSocketState::Normal) {
509 return Err(std::io::ErrorKind::BrokenPipe.into()).into();
510 }
511
512 if control.send_buffer.is_full() {
515 if let Some(old_waker) = control.send_waker.replace(cx.waker().clone()) {
516 if !old_waker.will_wake(cx.waker()) {
517 old_waker.wake();
518 }
519 }
520
521 return Poll::Pending;
522 }
523
524 let n = control.send_buffer.enqueue_slice(buf);
525
526 if n > 0 {
527 self.notify.notify_one();
528 }
529
530 Ok(n).into()
531 }
532
533 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
534 Ok(()).into()
535 }
536
537 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
538 let mut control = self.control.lock();
539
540 if matches!(control.send_state, TcpSocketState::Closed) {
541 return Ok(()).into();
542 }
543
544 if matches!(control.send_state, TcpSocketState::Normal) {
546 control.send_state = TcpSocketState::Close;
547 }
548
549 if let Some(old_waker) = control.send_waker.replace(cx.waker().clone()) {
550 if !old_waker.will_wake(cx.waker()) {
551 old_waker.wake();
552 }
553 }
554
555 self.notify.notify_one();
556
557 Poll::Pending
558 }
559}