Skip to main content

microsandbox_network/
shared.rs

1//! Shared state between the NetWorker thread, smoltcp poll thread, and tokio
2//! proxy tasks.
3//!
4//! All inter-thread communication flows through [`SharedState`], which holds
5//! lock-free frame queues and cross-platform [`WakePipe`] notifications.
6
7use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
8use std::sync::{
9    Arc, Mutex, OnceLock,
10    atomic::{AtomicU64, Ordering},
11};
12use std::time::{Duration, Instant};
13
14use crossbeam_queue::ArrayQueue;
15use microsandbox_utils::ttl_reverse_index::TtlReverseIndex;
16pub use microsandbox_utils::wake_pipe::WakePipe;
17use parking_lot::RwLock;
18
19use crate::addr::normalize_ip_addr;
20
21//--------------------------------------------------------------------------------------------------
22// Constants
23//--------------------------------------------------------------------------------------------------
24
25/// Default frame queue capacity. Matches libkrun's virtio queue size.
26pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
27
28//--------------------------------------------------------------------------------------------------
29// Types
30//--------------------------------------------------------------------------------------------------
31
32/// All shared state between the three threads:
33///
34/// - **NetWorker** (libkrun) — pushes guest frames to `tx_ring`, pops
35///   response frames from `rx_ring`.
36/// - **smoltcp poll thread** — pops from `tx_ring`, processes through smoltcp,
37///   pushes responses to `rx_ring`.
38/// - **tokio proxy tasks** — relay data between smoltcp sockets and real
39///   network connections.
40///
41/// Queue naming follows the **guest's perspective** (matching libkrun's
42/// convention): `tx_ring` = "transmit from guest", `rx_ring` = "receive at
43/// guest".
44pub struct SharedState {
45    /// Frames from guest → smoltcp (NetWorker writes, smoltcp reads).
46    pub tx_ring: ArrayQueue<Vec<u8>>,
47
48    /// Frames from smoltcp → guest (smoltcp writes, NetWorker reads).
49    pub rx_ring: ArrayQueue<Vec<u8>>,
50
51    /// Wakes NetWorker: "rx_ring has frames for the guest."
52    /// Written by `SmoltcpDevice::transmit()`. Read end polled by NetWorker's
53    /// epoll loop.
54    pub rx_wake: WakePipe,
55
56    /// Wakes smoltcp poll thread: "tx_ring has frames from the guest."
57    /// Written by `SmoltcpBackend::write_frame()`. Read end polled by the
58    /// poll loop.
59    pub tx_wake: WakePipe,
60
61    /// Wakes smoltcp poll thread: "proxy task has data to write to a smoltcp
62    /// socket." Written by proxy tasks via channels. Read end polled by the
63    /// poll loop.
64    pub proxy_wake: WakePipe,
65
66    /// Optional host-side termination hook used for fatal policy violations.
67    termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
68
69    /// Resolved hostname index used to map destination IPs back to queried hostnames.
70    resolved_hostnames: RwLock<TtlReverseIndex<ResolvedHostnameKey, IpAddr>>,
71
72    /// Per-sandbox gateway IPv4. Set once at boot; used by
73    /// `DestinationGroup::Host` rule matching and `host.microsandbox.internal`
74    /// DNS synthesis. `None` in isolated unit tests.
75    gateway_ipv4: OnceLock<Ipv4Addr>,
76
77    /// Per-sandbox gateway IPv6. Set once at boot. See `gateway_ipv4`.
78    gateway_ipv6: OnceLock<Ipv6Addr>,
79
80    /// Aggregate network byte counters at the guest/runtime boundary.
81    metrics: NetworkMetrics,
82}
83
84/// Aggregate network byte counters shared with the runtime metrics sampler.
85pub struct NetworkMetrics {
86    tx_bytes: AtomicU64,
87    rx_bytes: AtomicU64,
88}
89
90/// Address family for resolved hostname entries.
91#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
92pub enum ResolvedHostnameFamily {
93    Ipv4,
94    Ipv6,
95}
96
97/// Composite cache key for a single DNS resolution.
98///
99/// `family` partitions entries so that `A` and `AAAA` responses for the
100/// same hostname refresh independently instead of overwriting each other.
101#[derive(Clone, Debug, PartialEq, Eq, Hash)]
102struct ResolvedHostnameKey {
103    hostname: String,
104    family: ResolvedHostnameFamily,
105}
106
107//--------------------------------------------------------------------------------------------------
108// Methods
109//--------------------------------------------------------------------------------------------------
110
111impl SharedState {
112    /// Create shared state with the given queue capacity.
113    pub fn new(queue_capacity: usize) -> Self {
114        Self {
115            tx_ring: ArrayQueue::new(queue_capacity),
116            rx_ring: ArrayQueue::new(queue_capacity),
117            rx_wake: WakePipe::new(),
118            tx_wake: WakePipe::new(),
119            proxy_wake: WakePipe::new(),
120            termination_hook: Mutex::new(None),
121            resolved_hostnames: RwLock::new(TtlReverseIndex::default()),
122            gateway_ipv4: OnceLock::new(),
123            gateway_ipv6: OnceLock::new(),
124            metrics: NetworkMetrics::default(),
125        }
126    }
127
128    /// Set the per-sandbox gateway IPs. Called once at boot. Each family is
129    /// only published when active for this sandbox.
130    pub fn set_gateway_ips(&self, ipv4: Option<Ipv4Addr>, ipv6: Option<Ipv6Addr>) {
131        if let Some(ipv4) = ipv4 {
132            let _ = self.gateway_ipv4.set(ipv4);
133        }
134        if let Some(ipv6) = ipv6 {
135            let _ = self.gateway_ipv6.set(ipv6);
136        }
137    }
138
139    /// Gateway IPv4 address, if set.
140    pub fn gateway_ipv4(&self) -> Option<Ipv4Addr> {
141        self.gateway_ipv4.get().copied()
142    }
143
144    /// Gateway IPv6 address, if set.
145    pub fn gateway_ipv6(&self) -> Option<Ipv6Addr> {
146        self.gateway_ipv6.get().copied()
147    }
148
149    /// Install a host-side termination hook.
150    pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
151        *self.termination_hook.lock().unwrap() = Some(hook);
152    }
153
154    /// Trigger host-side termination if a hook is installed.
155    pub fn trigger_termination(&self) {
156        let hook = self.termination_hook.lock().unwrap().clone();
157        if let Some(hook) = hook {
158            hook();
159        }
160    }
161
162    /// Replace the resolved addresses for a hostname within the given address family.
163    pub fn cache_resolved_hostname(
164        &self,
165        domain: &str,
166        family: ResolvedHostnameFamily,
167        addrs: impl IntoIterator<Item = IpAddr>,
168        ttl: Duration,
169    ) {
170        let hostname = normalize_hostname(domain);
171        let key = ResolvedHostnameKey { hostname, family };
172        let addrs = addrs.into_iter().map(normalize_ip_addr);
173        self.resolved_hostnames
174            .write()
175            .insert(key, addrs, ttl, Instant::now());
176    }
177
178    /// Clear the resolved addresses for a hostname within the given address family.
179    pub fn clear_resolved_hostname(&self, domain: &str, family: ResolvedHostnameFamily) {
180        let hostname = normalize_hostname(domain);
181        let key = ResolvedHostnameKey { hostname, family };
182        self.resolved_hostnames.write().remove(&key, Instant::now());
183    }
184
185    /// Returns `true` when any resolved hostname for `addr` satisfies `predicate`.
186    pub fn any_resolved_hostname(
187        &self,
188        addr: IpAddr,
189        mut predicate: impl FnMut(&str) -> bool,
190    ) -> bool {
191        let addr = normalize_ip_addr(addr);
192
193        self.resolved_hostnames
194            .read()
195            .member_matches(&addr, Instant::now(), |key| predicate(&key.hostname))
196    }
197
198    /// Best-effort expiry maintenance for resolved hostnames.
199    ///
200    /// This runs outside the hot egress read path. If the index is currently
201    /// busy, cleanup is skipped and retried on the next maintenance pass.
202    pub fn cleanup_resolved_hostnames(&self) {
203        if let Some(mut idx) = self.resolved_hostnames.try_write() {
204            idx.evict_expired(Instant::now());
205        }
206    }
207
208    /// Increment the guest -> runtime byte counter.
209    pub fn add_tx_bytes(&self, bytes: usize) {
210        self.metrics
211            .tx_bytes
212            .fetch_add(bytes as u64, Ordering::Relaxed);
213    }
214
215    /// Increment the runtime -> guest byte counter.
216    pub fn add_rx_bytes(&self, bytes: usize) {
217        self.metrics
218            .rx_bytes
219            .fetch_add(bytes as u64, Ordering::Relaxed);
220    }
221
222    /// Push a runtime -> guest ethernet frame and update RX metrics on success.
223    pub(crate) fn push_rx_frame(&self, frame: Vec<u8>) -> bool {
224        let frame_len = frame.len();
225        if self.rx_ring.push(frame).is_err() {
226            return false;
227        }
228
229        self.add_rx_bytes(frame_len);
230        true
231    }
232
233    /// Push a runtime -> guest ethernet frame, update RX metrics, and wake libkrun.
234    pub(crate) fn push_rx_frame_and_wake(&self, frame: Vec<u8>) -> bool {
235        if !self.push_rx_frame(frame) {
236            return false;
237        }
238
239        self.rx_wake.wake();
240        true
241    }
242
243    /// Total bytes transmitted by the guest into the runtime.
244    pub fn tx_bytes(&self) -> u64 {
245        self.metrics.tx_bytes.load(Ordering::Relaxed)
246    }
247
248    /// Total bytes delivered by the runtime to the guest.
249    pub fn rx_bytes(&self) -> u64 {
250        self.metrics.rx_bytes.load(Ordering::Relaxed)
251    }
252}
253
254impl Default for NetworkMetrics {
255    fn default() -> Self {
256        Self {
257            tx_bytes: AtomicU64::new(0),
258            rx_bytes: AtomicU64::new(0),
259        }
260    }
261}
262
263pub(crate) fn normalize_hostname(domain: &str) -> String {
264    domain.trim_end_matches('.').to_ascii_lowercase()
265}
266
267//--------------------------------------------------------------------------------------------------
268// Tests
269//--------------------------------------------------------------------------------------------------
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn shared_state_queue_push_pop() {
277        let state = SharedState::new(4);
278
279        // Push frames to tx_ring.
280        state.tx_ring.push(vec![1, 2, 3]).unwrap();
281        state.tx_ring.push(vec![4, 5, 6]).unwrap();
282
283        // Pop in FIFO order.
284        assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
285        assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
286        assert_eq!(state.tx_ring.pop(), None);
287    }
288
289    #[test]
290    fn shared_state_queue_full() {
291        let state = SharedState::new(2);
292
293        state.rx_ring.push(vec![1]).unwrap();
294        state.rx_ring.push(vec![2]).unwrap();
295        // Queue is full — push returns the frame back.
296        assert!(state.rx_ring.push(vec![3]).is_err());
297    }
298
299    #[test]
300    fn push_rx_frame_counts_only_successful_pushes() {
301        let state = SharedState::new(1);
302
303        assert!(state.push_rx_frame(vec![1, 2, 3]));
304        assert_eq!(state.rx_bytes(), 3);
305
306        assert!(!state.push_rx_frame(vec![4, 5]));
307        assert_eq!(state.rx_bytes(), 3);
308    }
309
310    #[test]
311    fn resolved_hostnames_are_isolated_per_family() {
312        let state = SharedState::new(4);
313        let v4: IpAddr = "1.1.1.1".parse().unwrap();
314        let v6: IpAddr = "2606:4700:4700::1111".parse().unwrap();
315
316        state.cache_resolved_hostname(
317            "Example.com.",
318            ResolvedHostnameFamily::Ipv4,
319            [v4],
320            Duration::from_secs(30),
321        );
322        state.cache_resolved_hostname(
323            "example.com",
324            ResolvedHostnameFamily::Ipv6,
325            [v6],
326            Duration::from_secs(30),
327        );
328
329        assert!(state.any_resolved_hostname(v4, |h| h == "example.com"));
330        assert!(state.any_resolved_hostname(v6, |h| h == "example.com"));
331        assert!(!state.any_resolved_hostname(v4, |h| h == "other.example"));
332    }
333
334    #[test]
335    fn resolved_hostnames_normalize_ipv4_mapped_ipv6() {
336        let state = SharedState::new(4);
337        let mapped: IpAddr = "::ffff:169.254.169.254".parse().unwrap();
338        let embedded: IpAddr = "169.254.169.254".parse().unwrap();
339
340        state.cache_resolved_hostname(
341            "metadata.example",
342            ResolvedHostnameFamily::Ipv6,
343            [mapped],
344            Duration::from_secs(30),
345        );
346
347        assert!(state.any_resolved_hostname(embedded, |h| h == "metadata.example"));
348        assert!(state.any_resolved_hostname(mapped, |h| h == "metadata.example"));
349    }
350}