Skip to main content

microsandbox_network/
backend.rs

1//! `SmoltcpBackend` — libkrun [`NetBackend`] implementation that bridges the
2//! NetWorker thread to the smoltcp poll thread via lock-free queues.
3//!
4//! The NetWorker calls [`write_frame()`](NetBackend::write_frame) when the
5//! guest sends a frame and [`read_frame()`](NetBackend::read_frame) to deliver
6//! frames back to the guest. Frames flow through [`SharedState`]'s
7//! `tx_ring`/`rx_ring` queues with [`WakePipe`](crate::shared::WakePipe)
8//! notifications. Unix libkrun registers [`raw_socket_fd`](NetBackend::raw_socket_fd)
9//! in edge-triggered mode, while Windows libkrun waits on an event source. Reads
10//! must drain the wake primitive before returning.
11
12#[cfg(unix)]
13use std::os::fd::RawFd;
14use std::sync::Arc;
15
16use msb_krun::backends::net::{NetBackend, ReadError, WriteError};
17#[cfg(windows)]
18use msb_krun_utils::event::{EventSource, EventToken};
19
20use crate::shared::SharedState;
21
22//--------------------------------------------------------------------------------------------------
23// Constants
24//--------------------------------------------------------------------------------------------------
25
26/// Size of the virtio-net header (`virtio_net_hdr_v1`): 12 bytes.
27///
28/// libkrun's NetWorker prepends this header to every frame buffer. The
29/// backend must strip it on TX (guest → smoltcp) and prepend a zeroed
30/// header on RX (smoltcp → guest).
31const VIRTIO_NET_HDR_LEN: usize = 12;
32
33//--------------------------------------------------------------------------------------------------
34// Types
35//--------------------------------------------------------------------------------------------------
36
37/// Network backend that bridges libkrun's NetWorker to smoltcp via lock-free
38/// queues.
39///
40/// - **TX path** (`write_frame`): strips the virtio-net header, pushes the
41///   ethernet frame to `tx_ring`, wakes the smoltcp poll thread.
42/// - **RX path** (`read_frame`): pops a frame from `rx_ring`, prepends a
43///   zeroed virtio-net header for the guest.
44/// - **Wake source**: returns `rx_wake`'s pollable fd on Unix or waitable
45///   event handle on Windows so the NetWorker can detect new frames.
46pub struct SmoltcpBackend {
47    shared: Arc<SharedState>,
48}
49
50//--------------------------------------------------------------------------------------------------
51// Methods
52//--------------------------------------------------------------------------------------------------
53
54impl SmoltcpBackend {
55    /// Create a new backend connected to the given shared state.
56    pub fn new(shared: Arc<SharedState>) -> Self {
57        Self { shared }
58    }
59}
60
61//--------------------------------------------------------------------------------------------------
62// Trait Implementations
63//--------------------------------------------------------------------------------------------------
64
65impl NetBackend for SmoltcpBackend {
66    /// Guest is sending a frame. Strip the virtio-net header and enqueue
67    /// the raw ethernet frame for smoltcp.
68    fn write_frame(&mut self, hdr_len: usize, buf: &mut [u8]) -> Result<(), WriteError> {
69        let ethernet_frame = buf[hdr_len..].to_vec();
70        self.shared.add_tx_bytes(ethernet_frame.len());
71        self.shared
72            .tx_ring
73            .push(ethernet_frame)
74            .map_err(|_| WriteError::NothingWritten)?;
75        self.shared.tx_wake.wake();
76        Ok(())
77    }
78
79    /// Deliver a frame from smoltcp to the guest. Prepends a zeroed
80    /// virtio-net header.
81    fn read_frame(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
82        self.shared.rx_wake.drain();
83
84        let frame = self.shared.rx_ring.pop().ok_or(ReadError::NothingRead)?;
85
86        let total_len = VIRTIO_NET_HDR_LEN + frame.len();
87        if total_len > buf.len() {
88            // Frame too large for the buffer — drop it to avoid panicking.
89            tracing::debug!(
90                frame_len = frame.len(),
91                buf_len = buf.len(),
92                "dropping oversized frame from rx_ring"
93            );
94            return Err(ReadError::NothingRead);
95        }
96
97        // Prepend zeroed virtio-net header.
98        buf[..VIRTIO_NET_HDR_LEN].fill(0);
99        buf[VIRTIO_NET_HDR_LEN..total_len].copy_from_slice(&frame);
100
101        Ok(total_len)
102    }
103
104    /// No partial writes — queue push is atomic.
105    fn has_unfinished_write(&self) -> bool {
106        false
107    }
108
109    /// No partial writes — nothing to finish.
110    fn try_finish_write(&mut self, _hdr_len: usize, _buf: &[u8]) -> Result<(), WriteError> {
111        Ok(())
112    }
113
114    /// File descriptor for NetWorker's epoll. Becomes readable when
115    /// `rx_ring` has frames for the guest (i.e. when smoltcp's
116    /// `SmoltcpDevice::transmit()` pushes a frame and wakes `rx_wake`).
117    #[cfg(unix)]
118    fn raw_socket_fd(&self) -> RawFd {
119        self.shared.rx_wake.as_raw_fd()
120    }
121
122    /// Waitable event source for NetWorker on Windows.
123    #[cfg(windows)]
124    fn event_source(&self, token: EventToken) -> EventSource {
125        EventSource::waitable_handle(self.shared.rx_wake.as_raw_handle(), token)
126    }
127}
128
129//--------------------------------------------------------------------------------------------------
130// Tests
131//--------------------------------------------------------------------------------------------------
132
133#[cfg(all(test, unix))]
134mod tests {
135    use std::sync::Arc;
136
137    use super::*;
138
139    #[test]
140    fn read_frame_drains_rx_wake_pipe() {
141        let shared = Arc::new(SharedState::new(4));
142        let mut backend = SmoltcpBackend::new(shared.clone());
143        let mut buf = [0u8; 64];
144
145        assert!(shared.push_rx_frame_and_wake(vec![0xaa, 0xbb]));
146        assert!(fd_is_readable(backend.raw_socket_fd()));
147
148        let n = backend.read_frame(&mut buf).expect("frame should be read");
149        assert_eq!(n, VIRTIO_NET_HDR_LEN + 2);
150        assert_eq!(&buf[VIRTIO_NET_HDR_LEN..n], &[0xaa, 0xbb]);
151        assert!(!fd_is_readable(backend.raw_socket_fd()));
152
153        assert!(shared.push_rx_frame_and_wake(vec![0xcc]));
154        assert!(fd_is_readable(backend.raw_socket_fd()));
155    }
156
157    fn fd_is_readable(fd: RawFd) -> bool {
158        let mut pfd = libc::pollfd {
159            fd,
160            events: libc::POLLIN,
161            revents: 0,
162        };
163
164        // SAFETY: `pfd` points to a valid pollfd for a live file descriptor.
165        let ret = unsafe { libc::poll(&mut pfd, 1, 0) };
166        assert!(ret >= 0, "poll failed: {}", std::io::Error::last_os_error());
167
168        ret == 1 && pfd.revents & libc::POLLIN != 0
169    }
170}