1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
8use std::sync::{
9 Arc, Mutex, OnceLock,
10 atomic::{AtomicU64, Ordering},
11};
12use std::time::{Duration, Instant};
13
14use crossbeam_queue::ArrayQueue;
15use microsandbox_utils::ttl_reverse_index::TtlReverseIndex;
16pub use microsandbox_utils::wake_pipe::WakePipe;
17use parking_lot::RwLock;
18
19pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
25
26pub struct SharedState {
43 pub tx_ring: ArrayQueue<Vec<u8>>,
45
46 pub rx_ring: ArrayQueue<Vec<u8>>,
48
49 pub rx_wake: WakePipe,
53
54 pub tx_wake: WakePipe,
58
59 pub proxy_wake: WakePipe,
63
64 termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
66
67 resolved_hostnames: RwLock<TtlReverseIndex<ResolvedHostnameKey, IpAddr>>,
69
70 gateway_ipv4: OnceLock<Ipv4Addr>,
74
75 gateway_ipv6: OnceLock<Ipv6Addr>,
77
78 metrics: NetworkMetrics,
80}
81
82pub struct NetworkMetrics {
84 tx_bytes: AtomicU64,
85 rx_bytes: AtomicU64,
86}
87
88#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
90pub enum ResolvedHostnameFamily {
91 Ipv4,
92 Ipv6,
93}
94
95#[derive(Clone, Debug, PartialEq, Eq, Hash)]
100struct ResolvedHostnameKey {
101 hostname: String,
102 family: ResolvedHostnameFamily,
103}
104
105impl SharedState {
110 pub fn new(queue_capacity: usize) -> Self {
112 Self {
113 tx_ring: ArrayQueue::new(queue_capacity),
114 rx_ring: ArrayQueue::new(queue_capacity),
115 rx_wake: WakePipe::new(),
116 tx_wake: WakePipe::new(),
117 proxy_wake: WakePipe::new(),
118 termination_hook: Mutex::new(None),
119 resolved_hostnames: RwLock::new(TtlReverseIndex::default()),
120 gateway_ipv4: OnceLock::new(),
121 gateway_ipv6: OnceLock::new(),
122 metrics: NetworkMetrics::default(),
123 }
124 }
125
126 pub fn set_gateway_ips(&self, ipv4: Option<Ipv4Addr>, ipv6: Option<Ipv6Addr>) {
129 if let Some(ipv4) = ipv4 {
130 let _ = self.gateway_ipv4.set(ipv4);
131 }
132 if let Some(ipv6) = ipv6 {
133 let _ = self.gateway_ipv6.set(ipv6);
134 }
135 }
136
137 pub fn gateway_ipv4(&self) -> Option<Ipv4Addr> {
139 self.gateway_ipv4.get().copied()
140 }
141
142 pub fn gateway_ipv6(&self) -> Option<Ipv6Addr> {
144 self.gateway_ipv6.get().copied()
145 }
146
147 pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
149 *self.termination_hook.lock().unwrap() = Some(hook);
150 }
151
152 pub fn trigger_termination(&self) {
154 let hook = self.termination_hook.lock().unwrap().clone();
155 if let Some(hook) = hook {
156 hook();
157 }
158 }
159
160 pub fn cache_resolved_hostname(
162 &self,
163 domain: &str,
164 family: ResolvedHostnameFamily,
165 addrs: impl IntoIterator<Item = IpAddr>,
166 ttl: Duration,
167 ) {
168 let hostname = normalize_hostname(domain);
169 let key = ResolvedHostnameKey { hostname, family };
170 self.resolved_hostnames
171 .write()
172 .insert(key, addrs, ttl, Instant::now());
173 }
174
175 pub fn clear_resolved_hostname(&self, domain: &str, family: ResolvedHostnameFamily) {
177 let hostname = normalize_hostname(domain);
178 let key = ResolvedHostnameKey { hostname, family };
179 self.resolved_hostnames.write().remove(&key, Instant::now());
180 }
181
182 pub fn any_resolved_hostname(
184 &self,
185 addr: IpAddr,
186 mut predicate: impl FnMut(&str) -> bool,
187 ) -> bool {
188 self.resolved_hostnames
189 .read()
190 .member_matches(&addr, Instant::now(), |key| predicate(&key.hostname))
191 }
192
193 pub fn cleanup_resolved_hostnames(&self) {
198 if let Some(mut idx) = self.resolved_hostnames.try_write() {
199 idx.evict_expired(Instant::now());
200 }
201 }
202
203 pub fn add_tx_bytes(&self, bytes: usize) {
205 self.metrics
206 .tx_bytes
207 .fetch_add(bytes as u64, Ordering::Relaxed);
208 }
209
210 pub fn add_rx_bytes(&self, bytes: usize) {
212 self.metrics
213 .rx_bytes
214 .fetch_add(bytes as u64, Ordering::Relaxed);
215 }
216
217 pub(crate) fn push_rx_frame(&self, frame: Vec<u8>) -> bool {
219 let frame_len = frame.len();
220 if self.rx_ring.push(frame).is_err() {
221 return false;
222 }
223
224 self.add_rx_bytes(frame_len);
225 true
226 }
227
228 pub(crate) fn push_rx_frame_and_wake(&self, frame: Vec<u8>) -> bool {
230 if !self.push_rx_frame(frame) {
231 return false;
232 }
233
234 self.rx_wake.wake();
235 true
236 }
237
238 pub fn tx_bytes(&self) -> u64 {
240 self.metrics.tx_bytes.load(Ordering::Relaxed)
241 }
242
243 pub fn rx_bytes(&self) -> u64 {
245 self.metrics.rx_bytes.load(Ordering::Relaxed)
246 }
247}
248
249impl Default for NetworkMetrics {
250 fn default() -> Self {
251 Self {
252 tx_bytes: AtomicU64::new(0),
253 rx_bytes: AtomicU64::new(0),
254 }
255 }
256}
257
258pub(crate) fn normalize_hostname(domain: &str) -> String {
259 domain.trim_end_matches('.').to_ascii_lowercase()
260}
261
262#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn shared_state_queue_push_pop() {
272 let state = SharedState::new(4);
273
274 state.tx_ring.push(vec![1, 2, 3]).unwrap();
276 state.tx_ring.push(vec![4, 5, 6]).unwrap();
277
278 assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
280 assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
281 assert_eq!(state.tx_ring.pop(), None);
282 }
283
284 #[test]
285 fn shared_state_queue_full() {
286 let state = SharedState::new(2);
287
288 state.rx_ring.push(vec![1]).unwrap();
289 state.rx_ring.push(vec![2]).unwrap();
290 assert!(state.rx_ring.push(vec![3]).is_err());
292 }
293
294 #[test]
295 fn push_rx_frame_counts_only_successful_pushes() {
296 let state = SharedState::new(1);
297
298 assert!(state.push_rx_frame(vec![1, 2, 3]));
299 assert_eq!(state.rx_bytes(), 3);
300
301 assert!(!state.push_rx_frame(vec![4, 5]));
302 assert_eq!(state.rx_bytes(), 3);
303 }
304
305 #[test]
306 fn resolved_hostnames_are_isolated_per_family() {
307 let state = SharedState::new(4);
308 let v4: IpAddr = "1.1.1.1".parse().unwrap();
309 let v6: IpAddr = "2606:4700:4700::1111".parse().unwrap();
310
311 state.cache_resolved_hostname(
312 "Example.com.",
313 ResolvedHostnameFamily::Ipv4,
314 [v4],
315 Duration::from_secs(30),
316 );
317 state.cache_resolved_hostname(
318 "example.com",
319 ResolvedHostnameFamily::Ipv6,
320 [v6],
321 Duration::from_secs(30),
322 );
323
324 assert!(state.any_resolved_hostname(v4, |h| h == "example.com"));
325 assert!(state.any_resolved_hostname(v6, |h| h == "example.com"));
326 assert!(!state.any_resolved_hostname(v4, |h| h == "other.example"));
327 }
328}