microsandbox_network/backend.rs
1//! `SmoltcpBackend` — libkrun [`NetBackend`] implementation that bridges the
2//! NetWorker thread to the smoltcp poll thread via lock-free queues.
3//!
4//! The NetWorker calls [`write_frame()`](NetBackend::write_frame) when the
5//! guest sends a frame and [`read_frame()`](NetBackend::read_frame) to deliver
6//! frames back to the guest. Frames flow through [`SharedState`]'s
7//! `tx_ring`/`rx_ring` queues with [`WakePipe`](crate::shared::WakePipe)
8//! notifications. Unix libkrun registers [`raw_socket_fd`](NetBackend::raw_socket_fd)
9//! in edge-triggered mode, while Windows libkrun waits on an event source. Reads
10//! must drain the wake primitive before returning.
11
12#[cfg(unix)]
13use std::os::fd::RawFd;
14use std::sync::Arc;
15
16use msb_krun::backends::net::{NetBackend, ReadError, WriteError};
17#[cfg(windows)]
18use msb_krun_utils::event::{EventSource, EventToken};
19
20use crate::shared::SharedState;
21
22//--------------------------------------------------------------------------------------------------
23// Constants
24//--------------------------------------------------------------------------------------------------
25
26/// Size of the virtio-net header (`virtio_net_hdr_v1`): 12 bytes.
27///
28/// libkrun's NetWorker prepends this header to every frame buffer. The
29/// backend must strip it on TX (guest → smoltcp) and prepend a zeroed
30/// header on RX (smoltcp → guest).
31const VIRTIO_NET_HDR_LEN: usize = 12;
32
33//--------------------------------------------------------------------------------------------------
34// Types
35//--------------------------------------------------------------------------------------------------
36
37/// Network backend that bridges libkrun's NetWorker to smoltcp via lock-free
38/// queues.
39///
40/// - **TX path** (`write_frame`): strips the virtio-net header, pushes the
41/// ethernet frame to `tx_ring`, wakes the smoltcp poll thread.
42/// - **RX path** (`read_frame`): pops a frame from `rx_ring`, prepends a
43/// zeroed virtio-net header for the guest.
44/// - **Wake source**: returns `rx_wake`'s pollable fd on Unix or waitable
45/// event handle on Windows so the NetWorker can detect new frames.
46pub struct SmoltcpBackend {
47 shared: Arc<SharedState>,
48}
49
50//--------------------------------------------------------------------------------------------------
51// Methods
52//--------------------------------------------------------------------------------------------------
53
54impl SmoltcpBackend {
55 /// Create a new backend connected to the given shared state.
56 pub fn new(shared: Arc<SharedState>) -> Self {
57 Self { shared }
58 }
59}
60
61//--------------------------------------------------------------------------------------------------
62// Trait Implementations
63//--------------------------------------------------------------------------------------------------
64
65impl NetBackend for SmoltcpBackend {
66 /// Guest is sending a frame. Strip the virtio-net header and enqueue
67 /// the raw ethernet frame for smoltcp.
68 fn write_frame(&mut self, hdr_len: usize, buf: &mut [u8]) -> Result<(), WriteError> {
69 let ethernet_frame = buf[hdr_len..].to_vec();
70 self.shared.add_tx_bytes(ethernet_frame.len());
71 self.shared
72 .tx_ring
73 .push(ethernet_frame)
74 .map_err(|_| WriteError::NothingWritten)?;
75 self.shared.tx_wake.wake();
76 Ok(())
77 }
78
79 /// Deliver a frame from smoltcp to the guest. Prepends a zeroed
80 /// virtio-net header.
81 fn read_frame(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
82 self.shared.rx_wake.drain();
83
84 let frame = self.shared.rx_ring.pop().ok_or(ReadError::NothingRead)?;
85
86 let total_len = VIRTIO_NET_HDR_LEN + frame.len();
87 if total_len > buf.len() {
88 // Frame too large for the buffer — drop it to avoid panicking.
89 tracing::debug!(
90 frame_len = frame.len(),
91 buf_len = buf.len(),
92 "dropping oversized frame from rx_ring"
93 );
94 return Err(ReadError::NothingRead);
95 }
96
97 // Prepend zeroed virtio-net header.
98 buf[..VIRTIO_NET_HDR_LEN].fill(0);
99 buf[VIRTIO_NET_HDR_LEN..total_len].copy_from_slice(&frame);
100
101 Ok(total_len)
102 }
103
104 /// No partial writes — queue push is atomic.
105 fn has_unfinished_write(&self) -> bool {
106 false
107 }
108
109 /// No partial writes — nothing to finish.
110 fn try_finish_write(&mut self, _hdr_len: usize, _buf: &[u8]) -> Result<(), WriteError> {
111 Ok(())
112 }
113
114 /// File descriptor for NetWorker's epoll. Becomes readable when
115 /// `rx_ring` has frames for the guest (i.e. when smoltcp's
116 /// `SmoltcpDevice::transmit()` pushes a frame and wakes `rx_wake`).
117 #[cfg(unix)]
118 fn raw_socket_fd(&self) -> RawFd {
119 self.shared.rx_wake.as_raw_fd()
120 }
121
122 /// Waitable event source for NetWorker on Windows.
123 #[cfg(windows)]
124 fn event_source(&self, token: EventToken) -> EventSource {
125 EventSource::waitable_handle(self.shared.rx_wake.as_raw_handle(), token)
126 }
127}
128
129//--------------------------------------------------------------------------------------------------
130// Tests
131//--------------------------------------------------------------------------------------------------
132
133#[cfg(all(test, unix))]
134mod tests {
135 use std::sync::Arc;
136
137 use super::*;
138
139 #[test]
140 fn read_frame_drains_rx_wake_pipe() {
141 let shared = Arc::new(SharedState::new(4));
142 let mut backend = SmoltcpBackend::new(shared.clone());
143 let mut buf = [0u8; 64];
144
145 assert!(shared.push_rx_frame_and_wake(vec![0xaa, 0xbb]));
146 assert!(fd_is_readable(backend.raw_socket_fd()));
147
148 let n = backend.read_frame(&mut buf).expect("frame should be read");
149 assert_eq!(n, VIRTIO_NET_HDR_LEN + 2);
150 assert_eq!(&buf[VIRTIO_NET_HDR_LEN..n], &[0xaa, 0xbb]);
151 assert!(!fd_is_readable(backend.raw_socket_fd()));
152
153 assert!(shared.push_rx_frame_and_wake(vec![0xcc]));
154 assert!(fd_is_readable(backend.raw_socket_fd()));
155 }
156
157 fn fd_is_readable(fd: RawFd) -> bool {
158 let mut pfd = libc::pollfd {
159 fd,
160 events: libc::POLLIN,
161 revents: 0,
162 };
163
164 // SAFETY: `pfd` points to a valid pollfd for a live file descriptor.
165 let ret = unsafe { libc::poll(&mut pfd, 1, 0) };
166 assert!(ret >= 0, "poll failed: {}", std::io::Error::last_os_error());
167
168 ret == 1 && pfd.revents & libc::POLLIN != 0
169 }
170}