Skip to main content

microsandbox_network/
shared.rs

1//! Shared state between the NetWorker thread, smoltcp poll thread, and tokio
2//! proxy tasks.
3//!
4//! All inter-thread communication flows through [`SharedState`], which holds
5//! lock-free frame queues and cross-platform [`WakePipe`] notifications.
6
7use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
8use std::sync::{
9    Arc, Mutex, OnceLock,
10    atomic::{AtomicU64, Ordering},
11};
12use std::time::{Duration, Instant};
13
14use crossbeam_queue::ArrayQueue;
15use microsandbox_utils::ttl_reverse_index::TtlReverseIndex;
16pub use microsandbox_utils::wake_pipe::WakePipe;
17use parking_lot::RwLock;
18
19//--------------------------------------------------------------------------------------------------
20// Constants
21//--------------------------------------------------------------------------------------------------
22
23/// Default frame queue capacity. Matches libkrun's virtio queue size.
24pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
25
26//--------------------------------------------------------------------------------------------------
27// Types
28//--------------------------------------------------------------------------------------------------
29
30/// All shared state between the three threads:
31///
32/// - **NetWorker** (libkrun) — pushes guest frames to `tx_ring`, pops
33///   response frames from `rx_ring`.
34/// - **smoltcp poll thread** — pops from `tx_ring`, processes through smoltcp,
35///   pushes responses to `rx_ring`.
36/// - **tokio proxy tasks** — relay data between smoltcp sockets and real
37///   network connections.
38///
39/// Queue naming follows the **guest's perspective** (matching libkrun's
40/// convention): `tx_ring` = "transmit from guest", `rx_ring` = "receive at
41/// guest".
42pub struct SharedState {
43    /// Frames from guest → smoltcp (NetWorker writes, smoltcp reads).
44    pub tx_ring: ArrayQueue<Vec<u8>>,
45
46    /// Frames from smoltcp → guest (smoltcp writes, NetWorker reads).
47    pub rx_ring: ArrayQueue<Vec<u8>>,
48
49    /// Wakes NetWorker: "rx_ring has frames for the guest."
50    /// Written by `SmoltcpDevice::transmit()`. Read end polled by NetWorker's
51    /// epoll loop.
52    pub rx_wake: WakePipe,
53
54    /// Wakes smoltcp poll thread: "tx_ring has frames from the guest."
55    /// Written by `SmoltcpBackend::write_frame()`. Read end polled by the
56    /// poll loop.
57    pub tx_wake: WakePipe,
58
59    /// Wakes smoltcp poll thread: "proxy task has data to write to a smoltcp
60    /// socket." Written by proxy tasks via channels. Read end polled by the
61    /// poll loop.
62    pub proxy_wake: WakePipe,
63
64    /// Optional host-side termination hook used for fatal policy violations.
65    termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
66
67    /// Resolved hostname index used to map destination IPs back to queried hostnames.
68    resolved_hostnames: RwLock<TtlReverseIndex<ResolvedHostnameKey, IpAddr>>,
69
70    /// Per-sandbox gateway IPv4. Set once at boot; used by
71    /// `DestinationGroup::Host` rule matching and `host.microsandbox.internal`
72    /// DNS synthesis. `None` in isolated unit tests.
73    gateway_ipv4: OnceLock<Ipv4Addr>,
74
75    /// Per-sandbox gateway IPv6. Set once at boot. See `gateway_ipv4`.
76    gateway_ipv6: OnceLock<Ipv6Addr>,
77
78    /// Aggregate network byte counters at the guest/runtime boundary.
79    metrics: NetworkMetrics,
80}
81
82/// Aggregate network byte counters shared with the runtime metrics sampler.
83pub struct NetworkMetrics {
84    tx_bytes: AtomicU64,
85    rx_bytes: AtomicU64,
86}
87
88/// Address family for resolved hostname entries.
89#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
90pub enum ResolvedHostnameFamily {
91    Ipv4,
92    Ipv6,
93}
94
95/// Composite cache key for a single DNS resolution.
96///
97/// `family` partitions entries so that `A` and `AAAA` responses for the
98/// same hostname refresh independently instead of overwriting each other.
99#[derive(Clone, Debug, PartialEq, Eq, Hash)]
100struct ResolvedHostnameKey {
101    hostname: String,
102    family: ResolvedHostnameFamily,
103}
104
105//--------------------------------------------------------------------------------------------------
106// Methods
107//--------------------------------------------------------------------------------------------------
108
109impl SharedState {
110    /// Create shared state with the given queue capacity.
111    pub fn new(queue_capacity: usize) -> Self {
112        Self {
113            tx_ring: ArrayQueue::new(queue_capacity),
114            rx_ring: ArrayQueue::new(queue_capacity),
115            rx_wake: WakePipe::new(),
116            tx_wake: WakePipe::new(),
117            proxy_wake: WakePipe::new(),
118            termination_hook: Mutex::new(None),
119            resolved_hostnames: RwLock::new(TtlReverseIndex::default()),
120            gateway_ipv4: OnceLock::new(),
121            gateway_ipv6: OnceLock::new(),
122            metrics: NetworkMetrics::default(),
123        }
124    }
125
126    /// Set the per-sandbox gateway IPs. Called once at boot. Each family is
127    /// only published when active for this sandbox.
128    pub fn set_gateway_ips(&self, ipv4: Option<Ipv4Addr>, ipv6: Option<Ipv6Addr>) {
129        if let Some(ipv4) = ipv4 {
130            let _ = self.gateway_ipv4.set(ipv4);
131        }
132        if let Some(ipv6) = ipv6 {
133            let _ = self.gateway_ipv6.set(ipv6);
134        }
135    }
136
137    /// Gateway IPv4 address, if set.
138    pub fn gateway_ipv4(&self) -> Option<Ipv4Addr> {
139        self.gateway_ipv4.get().copied()
140    }
141
142    /// Gateway IPv6 address, if set.
143    pub fn gateway_ipv6(&self) -> Option<Ipv6Addr> {
144        self.gateway_ipv6.get().copied()
145    }
146
147    /// Install a host-side termination hook.
148    pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
149        *self.termination_hook.lock().unwrap() = Some(hook);
150    }
151
152    /// Trigger host-side termination if a hook is installed.
153    pub fn trigger_termination(&self) {
154        let hook = self.termination_hook.lock().unwrap().clone();
155        if let Some(hook) = hook {
156            hook();
157        }
158    }
159
160    /// Replace the resolved addresses for a hostname within the given address family.
161    pub fn cache_resolved_hostname(
162        &self,
163        domain: &str,
164        family: ResolvedHostnameFamily,
165        addrs: impl IntoIterator<Item = IpAddr>,
166        ttl: Duration,
167    ) {
168        let hostname = normalize_hostname(domain);
169        let key = ResolvedHostnameKey { hostname, family };
170        self.resolved_hostnames
171            .write()
172            .insert(key, addrs, ttl, Instant::now());
173    }
174
175    /// Clear the resolved addresses for a hostname within the given address family.
176    pub fn clear_resolved_hostname(&self, domain: &str, family: ResolvedHostnameFamily) {
177        let hostname = normalize_hostname(domain);
178        let key = ResolvedHostnameKey { hostname, family };
179        self.resolved_hostnames.write().remove(&key, Instant::now());
180    }
181
182    /// Returns `true` when any resolved hostname for `addr` satisfies `predicate`.
183    pub fn any_resolved_hostname(
184        &self,
185        addr: IpAddr,
186        mut predicate: impl FnMut(&str) -> bool,
187    ) -> bool {
188        self.resolved_hostnames
189            .read()
190            .member_matches(&addr, Instant::now(), |key| predicate(&key.hostname))
191    }
192
193    /// Best-effort expiry maintenance for resolved hostnames.
194    ///
195    /// This runs outside the hot egress read path. If the index is currently
196    /// busy, cleanup is skipped and retried on the next maintenance pass.
197    pub fn cleanup_resolved_hostnames(&self) {
198        if let Some(mut idx) = self.resolved_hostnames.try_write() {
199            idx.evict_expired(Instant::now());
200        }
201    }
202
203    /// Increment the guest -> runtime byte counter.
204    pub fn add_tx_bytes(&self, bytes: usize) {
205        self.metrics
206            .tx_bytes
207            .fetch_add(bytes as u64, Ordering::Relaxed);
208    }
209
210    /// Increment the runtime -> guest byte counter.
211    pub fn add_rx_bytes(&self, bytes: usize) {
212        self.metrics
213            .rx_bytes
214            .fetch_add(bytes as u64, Ordering::Relaxed);
215    }
216
217    /// Total bytes transmitted by the guest into the runtime.
218    pub fn tx_bytes(&self) -> u64 {
219        self.metrics.tx_bytes.load(Ordering::Relaxed)
220    }
221
222    /// Total bytes delivered by the runtime to the guest.
223    pub fn rx_bytes(&self) -> u64 {
224        self.metrics.rx_bytes.load(Ordering::Relaxed)
225    }
226}
227
228impl Default for NetworkMetrics {
229    fn default() -> Self {
230        Self {
231            tx_bytes: AtomicU64::new(0),
232            rx_bytes: AtomicU64::new(0),
233        }
234    }
235}
236
237pub(crate) fn normalize_hostname(domain: &str) -> String {
238    domain.trim_end_matches('.').to_ascii_lowercase()
239}
240
241//--------------------------------------------------------------------------------------------------
242// Tests
243//--------------------------------------------------------------------------------------------------
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn shared_state_queue_push_pop() {
251        let state = SharedState::new(4);
252
253        // Push frames to tx_ring.
254        state.tx_ring.push(vec![1, 2, 3]).unwrap();
255        state.tx_ring.push(vec![4, 5, 6]).unwrap();
256
257        // Pop in FIFO order.
258        assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
259        assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
260        assert_eq!(state.tx_ring.pop(), None);
261    }
262
263    #[test]
264    fn shared_state_queue_full() {
265        let state = SharedState::new(2);
266
267        state.rx_ring.push(vec![1]).unwrap();
268        state.rx_ring.push(vec![2]).unwrap();
269        // Queue is full — push returns the frame back.
270        assert!(state.rx_ring.push(vec![3]).is_err());
271    }
272
273    #[test]
274    fn resolved_hostnames_are_isolated_per_family() {
275        let state = SharedState::new(4);
276        let v4: IpAddr = "1.1.1.1".parse().unwrap();
277        let v6: IpAddr = "2606:4700:4700::1111".parse().unwrap();
278
279        state.cache_resolved_hostname(
280            "Example.com.",
281            ResolvedHostnameFamily::Ipv4,
282            [v4],
283            Duration::from_secs(30),
284        );
285        state.cache_resolved_hostname(
286            "example.com",
287            ResolvedHostnameFamily::Ipv6,
288            [v6],
289            Duration::from_secs(30),
290        );
291
292        assert!(state.any_resolved_hostname(v4, |h| h == "example.com"));
293        assert!(state.any_resolved_hostname(v6, |h| h == "example.com"));
294        assert!(!state.any_resolved_hostname(v4, |h| h == "other.example"));
295    }
296}