Skip to main content

microsandbox_network/
shared.rs

1//! Shared state between the NetWorker thread, smoltcp poll thread, and tokio
2//! proxy tasks.
3//!
4//! All inter-thread communication flows through [`SharedState`], which holds
5//! lock-free frame queues and cross-platform [`WakePipe`] notifications.
6
7use crossbeam_queue::ArrayQueue;
8pub use microsandbox_utils::wake_pipe::WakePipe;
9use std::sync::{
10    Arc, Mutex,
11    atomic::{AtomicU64, Ordering},
12};
13
14//--------------------------------------------------------------------------------------------------
15// Constants
16//--------------------------------------------------------------------------------------------------
17
18/// Default frame queue capacity. Matches libkrun's virtio queue size.
19pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
20
21//--------------------------------------------------------------------------------------------------
22// Types
23//--------------------------------------------------------------------------------------------------
24
25/// All shared state between the three threads:
26///
27/// - **NetWorker** (libkrun) — pushes guest frames to `tx_ring`, pops
28///   response frames from `rx_ring`.
29/// - **smoltcp poll thread** — pops from `tx_ring`, processes through smoltcp,
30///   pushes responses to `rx_ring`.
31/// - **tokio proxy tasks** — relay data between smoltcp sockets and real
32///   network connections.
33///
34/// Queue naming follows the **guest's perspective** (matching libkrun's
35/// convention): `tx_ring` = "transmit from guest", `rx_ring` = "receive at
36/// guest".
37pub struct SharedState {
38    /// Frames from guest → smoltcp (NetWorker writes, smoltcp reads).
39    pub tx_ring: ArrayQueue<Vec<u8>>,
40
41    /// Frames from smoltcp → guest (smoltcp writes, NetWorker reads).
42    pub rx_ring: ArrayQueue<Vec<u8>>,
43
44    /// Wakes NetWorker: "rx_ring has frames for the guest."
45    /// Written by `SmoltcpDevice::transmit()`. Read end polled by NetWorker's
46    /// epoll loop.
47    pub rx_wake: WakePipe,
48
49    /// Wakes smoltcp poll thread: "tx_ring has frames from the guest."
50    /// Written by `SmoltcpBackend::write_frame()`. Read end polled by the
51    /// poll loop.
52    pub tx_wake: WakePipe,
53
54    /// Wakes smoltcp poll thread: "proxy task has data to write to a smoltcp
55    /// socket." Written by proxy tasks via channels. Read end polled by the
56    /// poll loop.
57    pub proxy_wake: WakePipe,
58
59    /// Optional host-side termination hook used for fatal policy violations.
60    termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
61
62    /// Aggregate network byte counters at the guest/runtime boundary.
63    metrics: NetworkMetrics,
64}
65
66/// Aggregate network byte counters shared with the runtime metrics sampler.
67pub struct NetworkMetrics {
68    tx_bytes: AtomicU64,
69    rx_bytes: AtomicU64,
70}
71
72//--------------------------------------------------------------------------------------------------
73// Methods
74//--------------------------------------------------------------------------------------------------
75
76impl SharedState {
77    /// Create shared state with the given queue capacity.
78    pub fn new(queue_capacity: usize) -> Self {
79        Self {
80            tx_ring: ArrayQueue::new(queue_capacity),
81            rx_ring: ArrayQueue::new(queue_capacity),
82            rx_wake: WakePipe::new(),
83            tx_wake: WakePipe::new(),
84            proxy_wake: WakePipe::new(),
85            termination_hook: Mutex::new(None),
86            metrics: NetworkMetrics::default(),
87        }
88    }
89
90    /// Install a host-side termination hook.
91    pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
92        *self.termination_hook.lock().unwrap() = Some(hook);
93    }
94
95    /// Trigger host-side termination if a hook is installed.
96    pub fn trigger_termination(&self) {
97        let hook = self.termination_hook.lock().unwrap().clone();
98        if let Some(hook) = hook {
99            hook();
100        }
101    }
102
103    /// Increment the guest -> runtime byte counter.
104    pub fn add_tx_bytes(&self, bytes: usize) {
105        self.metrics
106            .tx_bytes
107            .fetch_add(bytes as u64, Ordering::Relaxed);
108    }
109
110    /// Increment the runtime -> guest byte counter.
111    pub fn add_rx_bytes(&self, bytes: usize) {
112        self.metrics
113            .rx_bytes
114            .fetch_add(bytes as u64, Ordering::Relaxed);
115    }
116
117    /// Total bytes transmitted by the guest into the runtime.
118    pub fn tx_bytes(&self) -> u64 {
119        self.metrics.tx_bytes.load(Ordering::Relaxed)
120    }
121
122    /// Total bytes delivered by the runtime to the guest.
123    pub fn rx_bytes(&self) -> u64 {
124        self.metrics.rx_bytes.load(Ordering::Relaxed)
125    }
126}
127
128impl Default for NetworkMetrics {
129    fn default() -> Self {
130        Self {
131            tx_bytes: AtomicU64::new(0),
132            rx_bytes: AtomicU64::new(0),
133        }
134    }
135}
136
137//--------------------------------------------------------------------------------------------------
138// Tests
139//--------------------------------------------------------------------------------------------------
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn shared_state_queue_push_pop() {
147        let state = SharedState::new(4);
148
149        // Push frames to tx_ring.
150        state.tx_ring.push(vec![1, 2, 3]).unwrap();
151        state.tx_ring.push(vec![4, 5, 6]).unwrap();
152
153        // Pop in FIFO order.
154        assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
155        assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
156        assert_eq!(state.tx_ring.pop(), None);
157    }
158
159    #[test]
160    fn shared_state_queue_full() {
161        let state = SharedState::new(2);
162
163        state.rx_ring.push(vec![1]).unwrap();
164        state.rx_ring.push(vec![2]).unwrap();
165        // Queue is full — push returns the frame back.
166        assert!(state.rx_ring.push(vec![3]).is_err());
167    }
168}