Skip to main content

microsandbox_network/
backend.rs

1//! `SmoltcpBackend` — libkrun [`NetBackend`] implementation that bridges the
2//! NetWorker thread to the smoltcp poll thread via lock-free queues.
3//!
4//! The NetWorker calls [`write_frame()`](NetBackend::write_frame) when the
5//! guest sends a frame and [`read_frame()`](NetBackend::read_frame) to deliver
6//! frames back to the guest. Frames flow through [`SharedState`]'s
7//! `tx_ring`/`rx_ring` queues with [`WakePipe`](crate::shared::WakePipe)
8//! notifications. Unix libkrun registers [`raw_socket_fd`](NetBackend::raw_socket_fd)
9//! in edge-triggered mode, while Windows libkrun waits on an event source. Reads
10//! must drain the wake primitive before returning.
11
12#[cfg(unix)]
13use std::os::fd::RawFd;
14use std::sync::Arc;
15
16use msb_krun::backends::net::{NetBackend, ReadError, WriteError};
17#[cfg(windows)]
18use msb_krun_utils::event::{EventSource, EventToken};
19
20use crate::shared::SharedState;
21
22//--------------------------------------------------------------------------------------------------
23// Constants
24//--------------------------------------------------------------------------------------------------
25
26/// Size of the virtio-net header (`virtio_net_hdr_v1`): 12 bytes.
27///
28/// libkrun's NetWorker prepends this header to every frame buffer. The
29/// backend must strip it on TX (guest → smoltcp) and prepend a zeroed
30/// header on RX (smoltcp → guest).
31const VIRTIO_NET_HDR_LEN: usize = 12;
32
33//--------------------------------------------------------------------------------------------------
34// Types
35//--------------------------------------------------------------------------------------------------
36
37/// Network backend that bridges libkrun's NetWorker to smoltcp via lock-free
38/// queues.
39///
40/// - **TX path** (`write_frame`): strips the virtio-net header, pushes the
41///   ethernet frame to `tx_ring`, wakes the smoltcp poll thread.
42/// - **RX path** (`read_frame`): pops a frame from `rx_ring`, prepends a
43///   zeroed virtio-net header for the guest.
44/// - **Wake source**: returns `rx_wake`'s pollable fd on Unix or waitable
45///   event handle on Windows so the NetWorker can detect new frames.
46pub struct SmoltcpBackend {
47    shared: Arc<SharedState>,
48}
49
50//--------------------------------------------------------------------------------------------------
51// Methods
52//--------------------------------------------------------------------------------------------------
53
54impl SmoltcpBackend {
55    /// Create a new backend connected to the given shared state.
56    pub fn new(shared: Arc<SharedState>) -> Self {
57        Self { shared }
58    }
59}
60
61//--------------------------------------------------------------------------------------------------
62// Trait Implementations
63//--------------------------------------------------------------------------------------------------
64
65impl NetBackend for SmoltcpBackend {
66    /// Guest is sending a frame. Strip the virtio-net header and enqueue
67    /// the raw ethernet frame for smoltcp.
68    fn write_frame(&mut self, hdr_len: usize, buf: &mut [u8]) -> Result<(), WriteError> {
69        let ethernet_frame = buf[hdr_len..].to_vec();
70        let frame_len = ethernet_frame.len();
71
72        if self.shared.tx_ring.push(ethernet_frame).is_err() {
73            // This backend exposes a wake pipe to libkrun, not a real writable
74            // socket. Returning NothingWritten would make the virtio worker
75            // undo the TX pop and wait for write readiness that cannot signal
76            // tx_ring capacity. Treat overflow like a lossy NIC queue instead:
77            // drop the frame and let upper layers retransmit if needed.
78            tracing::debug!("dropping guest network frame because tx_ring is full");
79            return Ok(());
80        }
81
82        self.shared.add_tx_bytes(frame_len);
83        self.shared.tx_wake.wake();
84        Ok(())
85    }
86
87    /// Deliver a frame from smoltcp to the guest. Prepends a zeroed
88    /// virtio-net header.
89    fn read_frame(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
90        self.shared.rx_wake.drain();
91
92        let frame = self.shared.rx_ring.pop().ok_or(ReadError::NothingRead)?;
93
94        let total_len = VIRTIO_NET_HDR_LEN + frame.len();
95        if total_len > buf.len() {
96            // Frame too large for the buffer — drop it to avoid panicking.
97            tracing::debug!(
98                frame_len = frame.len(),
99                buf_len = buf.len(),
100                "dropping oversized frame from rx_ring"
101            );
102            return Err(ReadError::NothingRead);
103        }
104
105        // Prepend zeroed virtio-net header.
106        buf[..VIRTIO_NET_HDR_LEN].fill(0);
107        buf[VIRTIO_NET_HDR_LEN..total_len].copy_from_slice(&frame);
108
109        Ok(total_len)
110    }
111
112    /// No partial writes — queue push is atomic.
113    fn has_unfinished_write(&self) -> bool {
114        false
115    }
116
117    /// No partial writes — nothing to finish.
118    fn try_finish_write(&mut self, _hdr_len: usize, _buf: &[u8]) -> Result<(), WriteError> {
119        Ok(())
120    }
121
122    /// File descriptor for NetWorker's epoll. Becomes readable when
123    /// `rx_ring` has frames for the guest (i.e. when smoltcp's
124    /// `SmoltcpDevice::transmit()` pushes a frame and wakes `rx_wake`).
125    #[cfg(unix)]
126    fn raw_socket_fd(&self) -> RawFd {
127        self.shared.rx_wake.as_raw_fd()
128    }
129
130    /// Waitable event source for NetWorker on Windows.
131    #[cfg(windows)]
132    fn event_source(&self, token: EventToken) -> EventSource {
133        EventSource::waitable_handle(self.shared.rx_wake.as_raw_handle(), token)
134    }
135}
136
137//--------------------------------------------------------------------------------------------------
138// Tests
139//--------------------------------------------------------------------------------------------------
140
141#[cfg(all(test, unix))]
142mod tests {
143    use std::sync::Arc;
144
145    use super::*;
146
147    #[test]
148    fn read_frame_drains_rx_wake_pipe() {
149        let shared = Arc::new(SharedState::new(4));
150        let mut backend = SmoltcpBackend::new(shared.clone());
151        let mut buf = [0u8; 64];
152
153        assert!(shared.push_rx_frame_and_wake(vec![0xaa, 0xbb]));
154        assert!(fd_is_readable(backend.raw_socket_fd()));
155
156        let n = backend.read_frame(&mut buf).expect("frame should be read");
157        assert_eq!(n, VIRTIO_NET_HDR_LEN + 2);
158        assert_eq!(&buf[VIRTIO_NET_HDR_LEN..n], &[0xaa, 0xbb]);
159        assert!(!fd_is_readable(backend.raw_socket_fd()));
160
161        assert!(shared.push_rx_frame_and_wake(vec![0xcc]));
162        assert!(fd_is_readable(backend.raw_socket_fd()));
163    }
164
165    #[test]
166    fn write_frame_enqueues_guest_frame_and_wakes_poll_loop() {
167        let shared = Arc::new(SharedState::new(1));
168        let mut backend = SmoltcpBackend::new(shared.clone());
169        let mut buf = vec![0u8; VIRTIO_NET_HDR_LEN + 3];
170        buf[VIRTIO_NET_HDR_LEN..].copy_from_slice(&[0xaa, 0xbb, 0xcc]);
171
172        backend
173            .write_frame(VIRTIO_NET_HDR_LEN, &mut buf)
174            .expect("accepted frame should be queued");
175
176        assert_eq!(shared.tx_bytes(), 3);
177        assert!(fd_is_readable(shared.tx_wake.as_raw_fd()));
178        assert_eq!(shared.tx_ring.pop(), Some(vec![0xaa, 0xbb, 0xcc]));
179    }
180
181    #[test]
182    fn write_frame_drops_guest_frame_when_tx_ring_is_full() {
183        let shared = Arc::new(SharedState::new(1));
184        shared.tx_ring.push(vec![0x11]).unwrap();
185        let mut backend = SmoltcpBackend::new(shared.clone());
186        let mut buf = vec![0u8; VIRTIO_NET_HDR_LEN + 2];
187        buf[VIRTIO_NET_HDR_LEN..].copy_from_slice(&[0xaa, 0xbb]);
188
189        backend
190            .write_frame(VIRTIO_NET_HDR_LEN, &mut buf)
191            .expect("overflow should not stall the virtio TX queue");
192
193        assert_eq!(shared.tx_bytes(), 0);
194        assert_eq!(shared.tx_ring.pop(), Some(vec![0x11]));
195        assert_eq!(shared.tx_ring.pop(), None);
196    }
197
198    fn fd_is_readable(fd: RawFd) -> bool {
199        let mut pfd = libc::pollfd {
200            fd,
201            events: libc::POLLIN,
202            revents: 0,
203        };
204
205        // SAFETY: `pfd` points to a valid pollfd for a live file descriptor.
206        let ret = unsafe { libc::poll(&mut pfd, 1, 0) };
207        assert!(ret >= 0, "poll failed: {}", std::io::Error::last_os_error());
208
209        ret == 1 && pfd.revents & libc::POLLIN != 0
210    }
211}