Skip to main content

microsandbox_network/
shared.rs

1//! Shared state between the NetWorker thread, smoltcp poll thread, and tokio
2//! proxy tasks.
3//!
4//! All inter-thread communication flows through [`SharedState`], which holds
5//! lock-free frame queues and cross-platform [`WakePipe`] notifications.
6
7use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
8use std::sync::{
9    Arc, Mutex, OnceLock,
10    atomic::{AtomicU64, Ordering},
11};
12use std::time::{Duration, Instant};
13
14use crossbeam_queue::ArrayQueue;
15use microsandbox_utils::ttl_reverse_index::TtlReverseIndex;
16pub use microsandbox_utils::wake_pipe::WakePipe;
17use parking_lot::RwLock;
18
19//--------------------------------------------------------------------------------------------------
20// Constants
21//--------------------------------------------------------------------------------------------------
22
23/// Default frame queue capacity. Matches libkrun's virtio queue size.
24pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
25
26//--------------------------------------------------------------------------------------------------
27// Types
28//--------------------------------------------------------------------------------------------------
29
30/// All shared state between the three threads:
31///
32/// - **NetWorker** (libkrun) — pushes guest frames to `tx_ring`, pops
33///   response frames from `rx_ring`.
34/// - **smoltcp poll thread** — pops from `tx_ring`, processes through smoltcp,
35///   pushes responses to `rx_ring`.
36/// - **tokio proxy tasks** — relay data between smoltcp sockets and real
37///   network connections.
38///
39/// Queue naming follows the **guest's perspective** (matching libkrun's
40/// convention): `tx_ring` = "transmit from guest", `rx_ring` = "receive at
41/// guest".
42pub struct SharedState {
43    /// Frames from guest → smoltcp (NetWorker writes, smoltcp reads).
44    pub tx_ring: ArrayQueue<Vec<u8>>,
45
46    /// Frames from smoltcp → guest (smoltcp writes, NetWorker reads).
47    pub rx_ring: ArrayQueue<Vec<u8>>,
48
49    /// Wakes NetWorker: "rx_ring has frames for the guest."
50    /// Written by `SmoltcpDevice::transmit()`. Read end polled by NetWorker's
51    /// epoll loop.
52    pub rx_wake: WakePipe,
53
54    /// Wakes smoltcp poll thread: "tx_ring has frames from the guest."
55    /// Written by `SmoltcpBackend::write_frame()`. Read end polled by the
56    /// poll loop.
57    pub tx_wake: WakePipe,
58
59    /// Wakes smoltcp poll thread: "proxy task has data to write to a smoltcp
60    /// socket." Written by proxy tasks via channels. Read end polled by the
61    /// poll loop.
62    pub proxy_wake: WakePipe,
63
64    /// Optional host-side termination hook used for fatal policy violations.
65    termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
66
67    /// Resolved hostname index used to map destination IPs back to queried hostnames.
68    resolved_hostnames: RwLock<TtlReverseIndex<ResolvedHostnameKey, IpAddr>>,
69
70    /// Per-sandbox gateway IPv4. Set once at boot; used by
71    /// `DestinationGroup::Host` rule matching and `host.microsandbox.internal`
72    /// DNS synthesis. `None` in isolated unit tests.
73    gateway_ipv4: OnceLock<Ipv4Addr>,
74
75    /// Per-sandbox gateway IPv6. Set once at boot. See `gateway_ipv4`.
76    gateway_ipv6: OnceLock<Ipv6Addr>,
77
78    /// Aggregate network byte counters at the guest/runtime boundary.
79    metrics: NetworkMetrics,
80}
81
82/// Aggregate network byte counters shared with the runtime metrics sampler.
83pub struct NetworkMetrics {
84    tx_bytes: AtomicU64,
85    rx_bytes: AtomicU64,
86}
87
88/// Address family for resolved hostname entries.
89#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
90pub enum ResolvedHostnameFamily {
91    Ipv4,
92    Ipv6,
93}
94
95/// Composite cache key for a single DNS resolution.
96///
97/// `family` partitions entries so that `A` and `AAAA` responses for the
98/// same hostname refresh independently instead of overwriting each other.
99#[derive(Clone, Debug, PartialEq, Eq, Hash)]
100struct ResolvedHostnameKey {
101    hostname: String,
102    family: ResolvedHostnameFamily,
103}
104
105//--------------------------------------------------------------------------------------------------
106// Methods
107//--------------------------------------------------------------------------------------------------
108
109impl SharedState {
110    /// Create shared state with the given queue capacity.
111    pub fn new(queue_capacity: usize) -> Self {
112        Self {
113            tx_ring: ArrayQueue::new(queue_capacity),
114            rx_ring: ArrayQueue::new(queue_capacity),
115            rx_wake: WakePipe::new(),
116            tx_wake: WakePipe::new(),
117            proxy_wake: WakePipe::new(),
118            termination_hook: Mutex::new(None),
119            resolved_hostnames: RwLock::new(TtlReverseIndex::default()),
120            gateway_ipv4: OnceLock::new(),
121            gateway_ipv6: OnceLock::new(),
122            metrics: NetworkMetrics::default(),
123        }
124    }
125
126    /// Set the per-sandbox gateway IPs. Called once at boot.
127    pub fn set_gateway_ips(&self, ipv4: Ipv4Addr, ipv6: Ipv6Addr) {
128        let _ = self.gateway_ipv4.set(ipv4);
129        let _ = self.gateway_ipv6.set(ipv6);
130    }
131
132    /// Gateway IPv4 address, if set.
133    pub fn gateway_ipv4(&self) -> Option<Ipv4Addr> {
134        self.gateway_ipv4.get().copied()
135    }
136
137    /// Gateway IPv6 address, if set.
138    pub fn gateway_ipv6(&self) -> Option<Ipv6Addr> {
139        self.gateway_ipv6.get().copied()
140    }
141
142    /// Install a host-side termination hook.
143    pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
144        *self.termination_hook.lock().unwrap() = Some(hook);
145    }
146
147    /// Trigger host-side termination if a hook is installed.
148    pub fn trigger_termination(&self) {
149        let hook = self.termination_hook.lock().unwrap().clone();
150        if let Some(hook) = hook {
151            hook();
152        }
153    }
154
155    /// Replace the resolved addresses for a hostname within the given address family.
156    pub fn cache_resolved_hostname(
157        &self,
158        domain: &str,
159        family: ResolvedHostnameFamily,
160        addrs: impl IntoIterator<Item = IpAddr>,
161        ttl: Duration,
162    ) {
163        let hostname = normalize_hostname(domain);
164        let key = ResolvedHostnameKey { hostname, family };
165        self.resolved_hostnames
166            .write()
167            .insert(key, addrs, ttl, Instant::now());
168    }
169
170    /// Clear the resolved addresses for a hostname within the given address family.
171    pub fn clear_resolved_hostname(&self, domain: &str, family: ResolvedHostnameFamily) {
172        let hostname = normalize_hostname(domain);
173        let key = ResolvedHostnameKey { hostname, family };
174        self.resolved_hostnames.write().remove(&key, Instant::now());
175    }
176
177    /// Returns `true` when any resolved hostname for `addr` satisfies `predicate`.
178    pub fn any_resolved_hostname(
179        &self,
180        addr: IpAddr,
181        mut predicate: impl FnMut(&str) -> bool,
182    ) -> bool {
183        self.resolved_hostnames
184            .read()
185            .member_matches(&addr, Instant::now(), |key| predicate(&key.hostname))
186    }
187
188    /// Best-effort expiry maintenance for resolved hostnames.
189    ///
190    /// This runs outside the hot egress read path. If the index is currently
191    /// busy, cleanup is skipped and retried on the next maintenance pass.
192    pub fn cleanup_resolved_hostnames(&self) {
193        if let Some(mut idx) = self.resolved_hostnames.try_write() {
194            idx.evict_expired(Instant::now());
195        }
196    }
197
198    /// Increment the guest -> runtime byte counter.
199    pub fn add_tx_bytes(&self, bytes: usize) {
200        self.metrics
201            .tx_bytes
202            .fetch_add(bytes as u64, Ordering::Relaxed);
203    }
204
205    /// Increment the runtime -> guest byte counter.
206    pub fn add_rx_bytes(&self, bytes: usize) {
207        self.metrics
208            .rx_bytes
209            .fetch_add(bytes as u64, Ordering::Relaxed);
210    }
211
212    /// Total bytes transmitted by the guest into the runtime.
213    pub fn tx_bytes(&self) -> u64 {
214        self.metrics.tx_bytes.load(Ordering::Relaxed)
215    }
216
217    /// Total bytes delivered by the runtime to the guest.
218    pub fn rx_bytes(&self) -> u64 {
219        self.metrics.rx_bytes.load(Ordering::Relaxed)
220    }
221}
222
223impl Default for NetworkMetrics {
224    fn default() -> Self {
225        Self {
226            tx_bytes: AtomicU64::new(0),
227            rx_bytes: AtomicU64::new(0),
228        }
229    }
230}
231
232pub(crate) fn normalize_hostname(domain: &str) -> String {
233    domain.trim_end_matches('.').to_ascii_lowercase()
234}
235
236//--------------------------------------------------------------------------------------------------
237// Tests
238//--------------------------------------------------------------------------------------------------
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn shared_state_queue_push_pop() {
246        let state = SharedState::new(4);
247
248        // Push frames to tx_ring.
249        state.tx_ring.push(vec![1, 2, 3]).unwrap();
250        state.tx_ring.push(vec![4, 5, 6]).unwrap();
251
252        // Pop in FIFO order.
253        assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
254        assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
255        assert_eq!(state.tx_ring.pop(), None);
256    }
257
258    #[test]
259    fn shared_state_queue_full() {
260        let state = SharedState::new(2);
261
262        state.rx_ring.push(vec![1]).unwrap();
263        state.rx_ring.push(vec![2]).unwrap();
264        // Queue is full — push returns the frame back.
265        assert!(state.rx_ring.push(vec![3]).is_err());
266    }
267
268    #[test]
269    fn resolved_hostnames_are_isolated_per_family() {
270        let state = SharedState::new(4);
271        let v4: IpAddr = "1.1.1.1".parse().unwrap();
272        let v6: IpAddr = "2606:4700:4700::1111".parse().unwrap();
273
274        state.cache_resolved_hostname(
275            "Example.com.",
276            ResolvedHostnameFamily::Ipv4,
277            [v4],
278            Duration::from_secs(30),
279        );
280        state.cache_resolved_hostname(
281            "example.com",
282            ResolvedHostnameFamily::Ipv6,
283            [v6],
284            Duration::from_secs(30),
285        );
286
287        assert!(state.any_resolved_hostname(v4, |h| h == "example.com"));
288        assert!(state.any_resolved_hostname(v6, |h| h == "example.com"));
289        assert!(!state.any_resolved_hostname(v4, |h| h == "other.example"));
290    }
291}