1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
8use std::sync::{
9 Arc, Mutex, OnceLock,
10 atomic::{AtomicU64, Ordering},
11};
12use std::time::{Duration, Instant};
13
14use crossbeam_queue::ArrayQueue;
15use microsandbox_utils::ttl_reverse_index::TtlReverseIndex;
16pub use microsandbox_utils::wake_pipe::WakePipe;
17use parking_lot::RwLock;
18
19pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
25
26pub struct SharedState {
43 pub tx_ring: ArrayQueue<Vec<u8>>,
45
46 pub rx_ring: ArrayQueue<Vec<u8>>,
48
49 pub rx_wake: WakePipe,
53
54 pub tx_wake: WakePipe,
58
59 pub proxy_wake: WakePipe,
63
64 termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
66
67 resolved_hostnames: RwLock<TtlReverseIndex<ResolvedHostnameKey, IpAddr>>,
69
70 gateway_ipv4: OnceLock<Ipv4Addr>,
74
75 gateway_ipv6: OnceLock<Ipv6Addr>,
77
78 metrics: NetworkMetrics,
80}
81
82pub struct NetworkMetrics {
84 tx_bytes: AtomicU64,
85 rx_bytes: AtomicU64,
86}
87
88#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
90pub enum ResolvedHostnameFamily {
91 Ipv4,
92 Ipv6,
93}
94
95#[derive(Clone, Debug, PartialEq, Eq, Hash)]
100struct ResolvedHostnameKey {
101 hostname: String,
102 family: ResolvedHostnameFamily,
103}
104
105impl SharedState {
110 pub fn new(queue_capacity: usize) -> Self {
112 Self {
113 tx_ring: ArrayQueue::new(queue_capacity),
114 rx_ring: ArrayQueue::new(queue_capacity),
115 rx_wake: WakePipe::new(),
116 tx_wake: WakePipe::new(),
117 proxy_wake: WakePipe::new(),
118 termination_hook: Mutex::new(None),
119 resolved_hostnames: RwLock::new(TtlReverseIndex::default()),
120 gateway_ipv4: OnceLock::new(),
121 gateway_ipv6: OnceLock::new(),
122 metrics: NetworkMetrics::default(),
123 }
124 }
125
126 pub fn set_gateway_ips(&self, ipv4: Option<Ipv4Addr>, ipv6: Option<Ipv6Addr>) {
129 if let Some(ipv4) = ipv4 {
130 let _ = self.gateway_ipv4.set(ipv4);
131 }
132 if let Some(ipv6) = ipv6 {
133 let _ = self.gateway_ipv6.set(ipv6);
134 }
135 }
136
137 pub fn gateway_ipv4(&self) -> Option<Ipv4Addr> {
139 self.gateway_ipv4.get().copied()
140 }
141
142 pub fn gateway_ipv6(&self) -> Option<Ipv6Addr> {
144 self.gateway_ipv6.get().copied()
145 }
146
147 pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
149 *self.termination_hook.lock().unwrap() = Some(hook);
150 }
151
152 pub fn trigger_termination(&self) {
154 let hook = self.termination_hook.lock().unwrap().clone();
155 if let Some(hook) = hook {
156 hook();
157 }
158 }
159
160 pub fn cache_resolved_hostname(
162 &self,
163 domain: &str,
164 family: ResolvedHostnameFamily,
165 addrs: impl IntoIterator<Item = IpAddr>,
166 ttl: Duration,
167 ) {
168 let hostname = normalize_hostname(domain);
169 let key = ResolvedHostnameKey { hostname, family };
170 self.resolved_hostnames
171 .write()
172 .insert(key, addrs, ttl, Instant::now());
173 }
174
175 pub fn clear_resolved_hostname(&self, domain: &str, family: ResolvedHostnameFamily) {
177 let hostname = normalize_hostname(domain);
178 let key = ResolvedHostnameKey { hostname, family };
179 self.resolved_hostnames.write().remove(&key, Instant::now());
180 }
181
182 pub fn any_resolved_hostname(
184 &self,
185 addr: IpAddr,
186 mut predicate: impl FnMut(&str) -> bool,
187 ) -> bool {
188 self.resolved_hostnames
189 .read()
190 .member_matches(&addr, Instant::now(), |key| predicate(&key.hostname))
191 }
192
193 pub fn cleanup_resolved_hostnames(&self) {
198 if let Some(mut idx) = self.resolved_hostnames.try_write() {
199 idx.evict_expired(Instant::now());
200 }
201 }
202
203 pub fn add_tx_bytes(&self, bytes: usize) {
205 self.metrics
206 .tx_bytes
207 .fetch_add(bytes as u64, Ordering::Relaxed);
208 }
209
210 pub fn add_rx_bytes(&self, bytes: usize) {
212 self.metrics
213 .rx_bytes
214 .fetch_add(bytes as u64, Ordering::Relaxed);
215 }
216
217 pub fn tx_bytes(&self) -> u64 {
219 self.metrics.tx_bytes.load(Ordering::Relaxed)
220 }
221
222 pub fn rx_bytes(&self) -> u64 {
224 self.metrics.rx_bytes.load(Ordering::Relaxed)
225 }
226}
227
228impl Default for NetworkMetrics {
229 fn default() -> Self {
230 Self {
231 tx_bytes: AtomicU64::new(0),
232 rx_bytes: AtomicU64::new(0),
233 }
234 }
235}
236
237pub(crate) fn normalize_hostname(domain: &str) -> String {
238 domain.trim_end_matches('.').to_ascii_lowercase()
239}
240
241#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn shared_state_queue_push_pop() {
251 let state = SharedState::new(4);
252
253 state.tx_ring.push(vec![1, 2, 3]).unwrap();
255 state.tx_ring.push(vec![4, 5, 6]).unwrap();
256
257 assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
259 assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
260 assert_eq!(state.tx_ring.pop(), None);
261 }
262
263 #[test]
264 fn shared_state_queue_full() {
265 let state = SharedState::new(2);
266
267 state.rx_ring.push(vec![1]).unwrap();
268 state.rx_ring.push(vec![2]).unwrap();
269 assert!(state.rx_ring.push(vec![3]).is_err());
271 }
272
273 #[test]
274 fn resolved_hostnames_are_isolated_per_family() {
275 let state = SharedState::new(4);
276 let v4: IpAddr = "1.1.1.1".parse().unwrap();
277 let v6: IpAddr = "2606:4700:4700::1111".parse().unwrap();
278
279 state.cache_resolved_hostname(
280 "Example.com.",
281 ResolvedHostnameFamily::Ipv4,
282 [v4],
283 Duration::from_secs(30),
284 );
285 state.cache_resolved_hostname(
286 "example.com",
287 ResolvedHostnameFamily::Ipv6,
288 [v6],
289 Duration::from_secs(30),
290 );
291
292 assert!(state.any_resolved_hostname(v4, |h| h == "example.com"));
293 assert!(state.any_resolved_hostname(v6, |h| h == "example.com"));
294 assert!(!state.any_resolved_hostname(v4, |h| h == "other.example"));
295 }
296}