1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
8use std::sync::{
9 Arc, Mutex, OnceLock,
10 atomic::{AtomicU64, Ordering},
11};
12use std::time::{Duration, Instant};
13
14use crossbeam_queue::ArrayQueue;
15use microsandbox_utils::ttl_reverse_index::TtlReverseIndex;
16pub use microsandbox_utils::wake_pipe::WakePipe;
17use parking_lot::RwLock;
18
19pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
25
26pub struct SharedState {
43 pub tx_ring: ArrayQueue<Vec<u8>>,
45
46 pub rx_ring: ArrayQueue<Vec<u8>>,
48
49 pub rx_wake: WakePipe,
53
54 pub tx_wake: WakePipe,
58
59 pub proxy_wake: WakePipe,
63
64 termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
66
67 resolved_hostnames: RwLock<TtlReverseIndex<ResolvedHostnameKey, IpAddr>>,
69
70 gateway_ipv4: OnceLock<Ipv4Addr>,
74
75 gateway_ipv6: OnceLock<Ipv6Addr>,
77
78 metrics: NetworkMetrics,
80}
81
82pub struct NetworkMetrics {
84 tx_bytes: AtomicU64,
85 rx_bytes: AtomicU64,
86}
87
88#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
90pub enum ResolvedHostnameFamily {
91 Ipv4,
92 Ipv6,
93}
94
95#[derive(Clone, Debug, PartialEq, Eq, Hash)]
100struct ResolvedHostnameKey {
101 hostname: String,
102 family: ResolvedHostnameFamily,
103}
104
105impl SharedState {
110 pub fn new(queue_capacity: usize) -> Self {
112 Self {
113 tx_ring: ArrayQueue::new(queue_capacity),
114 rx_ring: ArrayQueue::new(queue_capacity),
115 rx_wake: WakePipe::new(),
116 tx_wake: WakePipe::new(),
117 proxy_wake: WakePipe::new(),
118 termination_hook: Mutex::new(None),
119 resolved_hostnames: RwLock::new(TtlReverseIndex::default()),
120 gateway_ipv4: OnceLock::new(),
121 gateway_ipv6: OnceLock::new(),
122 metrics: NetworkMetrics::default(),
123 }
124 }
125
126 pub fn set_gateway_ips(&self, ipv4: Ipv4Addr, ipv6: Ipv6Addr) {
128 let _ = self.gateway_ipv4.set(ipv4);
129 let _ = self.gateway_ipv6.set(ipv6);
130 }
131
132 pub fn gateway_ipv4(&self) -> Option<Ipv4Addr> {
134 self.gateway_ipv4.get().copied()
135 }
136
137 pub fn gateway_ipv6(&self) -> Option<Ipv6Addr> {
139 self.gateway_ipv6.get().copied()
140 }
141
142 pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
144 *self.termination_hook.lock().unwrap() = Some(hook);
145 }
146
147 pub fn trigger_termination(&self) {
149 let hook = self.termination_hook.lock().unwrap().clone();
150 if let Some(hook) = hook {
151 hook();
152 }
153 }
154
155 pub fn cache_resolved_hostname(
157 &self,
158 domain: &str,
159 family: ResolvedHostnameFamily,
160 addrs: impl IntoIterator<Item = IpAddr>,
161 ttl: Duration,
162 ) {
163 let hostname = normalize_hostname(domain);
164 let key = ResolvedHostnameKey { hostname, family };
165 self.resolved_hostnames
166 .write()
167 .insert(key, addrs, ttl, Instant::now());
168 }
169
170 pub fn clear_resolved_hostname(&self, domain: &str, family: ResolvedHostnameFamily) {
172 let hostname = normalize_hostname(domain);
173 let key = ResolvedHostnameKey { hostname, family };
174 self.resolved_hostnames.write().remove(&key, Instant::now());
175 }
176
177 pub fn any_resolved_hostname(
179 &self,
180 addr: IpAddr,
181 mut predicate: impl FnMut(&str) -> bool,
182 ) -> bool {
183 self.resolved_hostnames
184 .read()
185 .member_matches(&addr, Instant::now(), |key| predicate(&key.hostname))
186 }
187
188 pub fn cleanup_resolved_hostnames(&self) {
193 if let Some(mut idx) = self.resolved_hostnames.try_write() {
194 idx.evict_expired(Instant::now());
195 }
196 }
197
198 pub fn add_tx_bytes(&self, bytes: usize) {
200 self.metrics
201 .tx_bytes
202 .fetch_add(bytes as u64, Ordering::Relaxed);
203 }
204
205 pub fn add_rx_bytes(&self, bytes: usize) {
207 self.metrics
208 .rx_bytes
209 .fetch_add(bytes as u64, Ordering::Relaxed);
210 }
211
212 pub fn tx_bytes(&self) -> u64 {
214 self.metrics.tx_bytes.load(Ordering::Relaxed)
215 }
216
217 pub fn rx_bytes(&self) -> u64 {
219 self.metrics.rx_bytes.load(Ordering::Relaxed)
220 }
221}
222
223impl Default for NetworkMetrics {
224 fn default() -> Self {
225 Self {
226 tx_bytes: AtomicU64::new(0),
227 rx_bytes: AtomicU64::new(0),
228 }
229 }
230}
231
232pub(crate) fn normalize_hostname(domain: &str) -> String {
233 domain.trim_end_matches('.').to_ascii_lowercase()
234}
235
236#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn shared_state_queue_push_pop() {
246 let state = SharedState::new(4);
247
248 state.tx_ring.push(vec![1, 2, 3]).unwrap();
250 state.tx_ring.push(vec![4, 5, 6]).unwrap();
251
252 assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
254 assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
255 assert_eq!(state.tx_ring.pop(), None);
256 }
257
258 #[test]
259 fn shared_state_queue_full() {
260 let state = SharedState::new(2);
261
262 state.rx_ring.push(vec![1]).unwrap();
263 state.rx_ring.push(vec![2]).unwrap();
264 assert!(state.rx_ring.push(vec![3]).is_err());
266 }
267
268 #[test]
269 fn resolved_hostnames_are_isolated_per_family() {
270 let state = SharedState::new(4);
271 let v4: IpAddr = "1.1.1.1".parse().unwrap();
272 let v6: IpAddr = "2606:4700:4700::1111".parse().unwrap();
273
274 state.cache_resolved_hostname(
275 "Example.com.",
276 ResolvedHostnameFamily::Ipv4,
277 [v4],
278 Duration::from_secs(30),
279 );
280 state.cache_resolved_hostname(
281 "example.com",
282 ResolvedHostnameFamily::Ipv6,
283 [v6],
284 Duration::from_secs(30),
285 );
286
287 assert!(state.any_resolved_hostname(v4, |h| h == "example.com"));
288 assert!(state.any_resolved_hostname(v6, |h| h == "example.com"));
289 assert!(!state.any_resolved_hostname(v4, |h| h == "other.example"));
290 }
291}