Skip to main content

microsandbox_network/
shared.rs

1//! Shared state between the NetWorker thread, smoltcp poll thread, and tokio
2//! proxy tasks.
3//!
4//! All inter-thread communication flows through [`SharedState`], which holds
5//! lock-free frame queues and cross-platform [`WakePipe`] notifications.
6
7use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
8use std::sync::{
9    Arc, Mutex, OnceLock,
10    atomic::{AtomicU64, Ordering},
11};
12use std::time::{Duration, Instant};
13
14use crossbeam_queue::ArrayQueue;
15use microsandbox_utils::ttl_reverse_index::TtlReverseIndex;
16pub use microsandbox_utils::wake_pipe::WakePipe;
17use parking_lot::RwLock;
18
19//--------------------------------------------------------------------------------------------------
20// Constants
21//--------------------------------------------------------------------------------------------------
22
23/// Default frame queue capacity. Matches libkrun's virtio queue size.
24pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
25
26//--------------------------------------------------------------------------------------------------
27// Types
28//--------------------------------------------------------------------------------------------------
29
30/// All shared state between the three threads:
31///
32/// - **NetWorker** (libkrun) — pushes guest frames to `tx_ring`, pops
33///   response frames from `rx_ring`.
34/// - **smoltcp poll thread** — pops from `tx_ring`, processes through smoltcp,
35///   pushes responses to `rx_ring`.
36/// - **tokio proxy tasks** — relay data between smoltcp sockets and real
37///   network connections.
38///
39/// Queue naming follows the **guest's perspective** (matching libkrun's
40/// convention): `tx_ring` = "transmit from guest", `rx_ring` = "receive at
41/// guest".
42pub struct SharedState {
43    /// Frames from guest → smoltcp (NetWorker writes, smoltcp reads).
44    pub tx_ring: ArrayQueue<Vec<u8>>,
45
46    /// Frames from smoltcp → guest (smoltcp writes, NetWorker reads).
47    pub rx_ring: ArrayQueue<Vec<u8>>,
48
49    /// Wakes NetWorker: "rx_ring has frames for the guest."
50    /// Written by `SmoltcpDevice::transmit()`. Read end polled by NetWorker's
51    /// epoll loop.
52    pub rx_wake: WakePipe,
53
54    /// Wakes smoltcp poll thread: "tx_ring has frames from the guest."
55    /// Written by `SmoltcpBackend::write_frame()`. Read end polled by the
56    /// poll loop.
57    pub tx_wake: WakePipe,
58
59    /// Wakes smoltcp poll thread: "proxy task has data to write to a smoltcp
60    /// socket." Written by proxy tasks via channels. Read end polled by the
61    /// poll loop.
62    pub proxy_wake: WakePipe,
63
64    /// Optional host-side termination hook used for fatal policy violations.
65    termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
66
67    /// Resolved hostname index used to map destination IPs back to queried hostnames.
68    resolved_hostnames: RwLock<TtlReverseIndex<ResolvedHostnameKey, IpAddr>>,
69
70    /// Per-sandbox gateway IPv4. Set once at boot; used by
71    /// `DestinationGroup::Host` rule matching and `host.microsandbox.internal`
72    /// DNS synthesis. `None` in isolated unit tests.
73    gateway_ipv4: OnceLock<Ipv4Addr>,
74
75    /// Per-sandbox gateway IPv6. Set once at boot. See `gateway_ipv4`.
76    gateway_ipv6: OnceLock<Ipv6Addr>,
77
78    /// Aggregate network byte counters at the guest/runtime boundary.
79    metrics: NetworkMetrics,
80}
81
82/// Aggregate network byte counters shared with the runtime metrics sampler.
83pub struct NetworkMetrics {
84    tx_bytes: AtomicU64,
85    rx_bytes: AtomicU64,
86}
87
88/// Address family for resolved hostname entries.
89#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
90pub enum ResolvedHostnameFamily {
91    Ipv4,
92    Ipv6,
93}
94
95/// Composite cache key for a single DNS resolution.
96///
97/// `family` partitions entries so that `A` and `AAAA` responses for the
98/// same hostname refresh independently instead of overwriting each other.
99#[derive(Clone, Debug, PartialEq, Eq, Hash)]
100struct ResolvedHostnameKey {
101    hostname: String,
102    family: ResolvedHostnameFamily,
103}
104
105//--------------------------------------------------------------------------------------------------
106// Methods
107//--------------------------------------------------------------------------------------------------
108
109impl SharedState {
110    /// Create shared state with the given queue capacity.
111    pub fn new(queue_capacity: usize) -> Self {
112        Self {
113            tx_ring: ArrayQueue::new(queue_capacity),
114            rx_ring: ArrayQueue::new(queue_capacity),
115            rx_wake: WakePipe::new(),
116            tx_wake: WakePipe::new(),
117            proxy_wake: WakePipe::new(),
118            termination_hook: Mutex::new(None),
119            resolved_hostnames: RwLock::new(TtlReverseIndex::default()),
120            gateway_ipv4: OnceLock::new(),
121            gateway_ipv6: OnceLock::new(),
122            metrics: NetworkMetrics::default(),
123        }
124    }
125
126    /// Set the per-sandbox gateway IPs. Called once at boot. Each family is
127    /// only published when active for this sandbox.
128    pub fn set_gateway_ips(&self, ipv4: Option<Ipv4Addr>, ipv6: Option<Ipv6Addr>) {
129        if let Some(ipv4) = ipv4 {
130            let _ = self.gateway_ipv4.set(ipv4);
131        }
132        if let Some(ipv6) = ipv6 {
133            let _ = self.gateway_ipv6.set(ipv6);
134        }
135    }
136
137    /// Gateway IPv4 address, if set.
138    pub fn gateway_ipv4(&self) -> Option<Ipv4Addr> {
139        self.gateway_ipv4.get().copied()
140    }
141
142    /// Gateway IPv6 address, if set.
143    pub fn gateway_ipv6(&self) -> Option<Ipv6Addr> {
144        self.gateway_ipv6.get().copied()
145    }
146
147    /// Install a host-side termination hook.
148    pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
149        *self.termination_hook.lock().unwrap() = Some(hook);
150    }
151
152    /// Trigger host-side termination if a hook is installed.
153    pub fn trigger_termination(&self) {
154        let hook = self.termination_hook.lock().unwrap().clone();
155        if let Some(hook) = hook {
156            hook();
157        }
158    }
159
160    /// Replace the resolved addresses for a hostname within the given address family.
161    pub fn cache_resolved_hostname(
162        &self,
163        domain: &str,
164        family: ResolvedHostnameFamily,
165        addrs: impl IntoIterator<Item = IpAddr>,
166        ttl: Duration,
167    ) {
168        let hostname = normalize_hostname(domain);
169        let key = ResolvedHostnameKey { hostname, family };
170        self.resolved_hostnames
171            .write()
172            .insert(key, addrs, ttl, Instant::now());
173    }
174
175    /// Clear the resolved addresses for a hostname within the given address family.
176    pub fn clear_resolved_hostname(&self, domain: &str, family: ResolvedHostnameFamily) {
177        let hostname = normalize_hostname(domain);
178        let key = ResolvedHostnameKey { hostname, family };
179        self.resolved_hostnames.write().remove(&key, Instant::now());
180    }
181
182    /// Returns `true` when any resolved hostname for `addr` satisfies `predicate`.
183    pub fn any_resolved_hostname(
184        &self,
185        addr: IpAddr,
186        mut predicate: impl FnMut(&str) -> bool,
187    ) -> bool {
188        self.resolved_hostnames
189            .read()
190            .member_matches(&addr, Instant::now(), |key| predicate(&key.hostname))
191    }
192
193    /// Best-effort expiry maintenance for resolved hostnames.
194    ///
195    /// This runs outside the hot egress read path. If the index is currently
196    /// busy, cleanup is skipped and retried on the next maintenance pass.
197    pub fn cleanup_resolved_hostnames(&self) {
198        if let Some(mut idx) = self.resolved_hostnames.try_write() {
199            idx.evict_expired(Instant::now());
200        }
201    }
202
203    /// Increment the guest -> runtime byte counter.
204    pub fn add_tx_bytes(&self, bytes: usize) {
205        self.metrics
206            .tx_bytes
207            .fetch_add(bytes as u64, Ordering::Relaxed);
208    }
209
210    /// Increment the runtime -> guest byte counter.
211    pub fn add_rx_bytes(&self, bytes: usize) {
212        self.metrics
213            .rx_bytes
214            .fetch_add(bytes as u64, Ordering::Relaxed);
215    }
216
217    /// Push a runtime -> guest ethernet frame and update RX metrics on success.
218    pub(crate) fn push_rx_frame(&self, frame: Vec<u8>) -> bool {
219        let frame_len = frame.len();
220        if self.rx_ring.push(frame).is_err() {
221            return false;
222        }
223
224        self.add_rx_bytes(frame_len);
225        true
226    }
227
228    /// Push a runtime -> guest ethernet frame, update RX metrics, and wake libkrun.
229    pub(crate) fn push_rx_frame_and_wake(&self, frame: Vec<u8>) -> bool {
230        if !self.push_rx_frame(frame) {
231            return false;
232        }
233
234        self.rx_wake.wake();
235        true
236    }
237
238    /// Total bytes transmitted by the guest into the runtime.
239    pub fn tx_bytes(&self) -> u64 {
240        self.metrics.tx_bytes.load(Ordering::Relaxed)
241    }
242
243    /// Total bytes delivered by the runtime to the guest.
244    pub fn rx_bytes(&self) -> u64 {
245        self.metrics.rx_bytes.load(Ordering::Relaxed)
246    }
247}
248
249impl Default for NetworkMetrics {
250    fn default() -> Self {
251        Self {
252            tx_bytes: AtomicU64::new(0),
253            rx_bytes: AtomicU64::new(0),
254        }
255    }
256}
257
258pub(crate) fn normalize_hostname(domain: &str) -> String {
259    domain.trim_end_matches('.').to_ascii_lowercase()
260}
261
262//--------------------------------------------------------------------------------------------------
263// Tests
264//--------------------------------------------------------------------------------------------------
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn shared_state_queue_push_pop() {
272        let state = SharedState::new(4);
273
274        // Push frames to tx_ring.
275        state.tx_ring.push(vec![1, 2, 3]).unwrap();
276        state.tx_ring.push(vec![4, 5, 6]).unwrap();
277
278        // Pop in FIFO order.
279        assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
280        assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
281        assert_eq!(state.tx_ring.pop(), None);
282    }
283
284    #[test]
285    fn shared_state_queue_full() {
286        let state = SharedState::new(2);
287
288        state.rx_ring.push(vec![1]).unwrap();
289        state.rx_ring.push(vec![2]).unwrap();
290        // Queue is full — push returns the frame back.
291        assert!(state.rx_ring.push(vec![3]).is_err());
292    }
293
294    #[test]
295    fn push_rx_frame_counts_only_successful_pushes() {
296        let state = SharedState::new(1);
297
298        assert!(state.push_rx_frame(vec![1, 2, 3]));
299        assert_eq!(state.rx_bytes(), 3);
300
301        assert!(!state.push_rx_frame(vec![4, 5]));
302        assert_eq!(state.rx_bytes(), 3);
303    }
304
305    #[test]
306    fn resolved_hostnames_are_isolated_per_family() {
307        let state = SharedState::new(4);
308        let v4: IpAddr = "1.1.1.1".parse().unwrap();
309        let v6: IpAddr = "2606:4700:4700::1111".parse().unwrap();
310
311        state.cache_resolved_hostname(
312            "Example.com.",
313            ResolvedHostnameFamily::Ipv4,
314            [v4],
315            Duration::from_secs(30),
316        );
317        state.cache_resolved_hostname(
318            "example.com",
319            ResolvedHostnameFamily::Ipv6,
320            [v6],
321            Duration::from_secs(30),
322        );
323
324        assert!(state.any_resolved_hostname(v4, |h| h == "example.com"));
325        assert!(state.any_resolved_hostname(v6, |h| h == "example.com"));
326        assert!(!state.any_resolved_hostname(v4, |h| h == "other.example"));
327    }
328}