use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::{
Arc, Mutex, OnceLock,
atomic::{AtomicU64, Ordering},
};
use std::time::{Duration, Instant};
use crossbeam_queue::ArrayQueue;
use microsandbox_utils::ttl_reverse_index::TtlReverseIndex;
pub use microsandbox_utils::wake_pipe::WakePipe;
use parking_lot::RwLock;
pub const DEFAULT_QUEUE_CAPACITY: usize = 1024;
pub struct SharedState {
pub tx_ring: ArrayQueue<Vec<u8>>,
pub rx_ring: ArrayQueue<Vec<u8>>,
pub rx_wake: WakePipe,
pub tx_wake: WakePipe,
pub proxy_wake: WakePipe,
termination_hook: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
resolved_hostnames: RwLock<TtlReverseIndex<ResolvedHostnameKey, IpAddr>>,
gateway_ipv4: OnceLock<Ipv4Addr>,
gateway_ipv6: OnceLock<Ipv6Addr>,
metrics: NetworkMetrics,
}
pub struct NetworkMetrics {
tx_bytes: AtomicU64,
rx_bytes: AtomicU64,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ResolvedHostnameFamily {
Ipv4,
Ipv6,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct ResolvedHostnameKey {
hostname: String,
family: ResolvedHostnameFamily,
}
impl SharedState {
pub fn new(queue_capacity: usize) -> Self {
Self {
tx_ring: ArrayQueue::new(queue_capacity),
rx_ring: ArrayQueue::new(queue_capacity),
rx_wake: WakePipe::new(),
tx_wake: WakePipe::new(),
proxy_wake: WakePipe::new(),
termination_hook: Mutex::new(None),
resolved_hostnames: RwLock::new(TtlReverseIndex::default()),
gateway_ipv4: OnceLock::new(),
gateway_ipv6: OnceLock::new(),
metrics: NetworkMetrics::default(),
}
}
pub fn set_gateway_ips(&self, ipv4: Option<Ipv4Addr>, ipv6: Option<Ipv6Addr>) {
if let Some(ipv4) = ipv4 {
let _ = self.gateway_ipv4.set(ipv4);
}
if let Some(ipv6) = ipv6 {
let _ = self.gateway_ipv6.set(ipv6);
}
}
pub fn gateway_ipv4(&self) -> Option<Ipv4Addr> {
self.gateway_ipv4.get().copied()
}
pub fn gateway_ipv6(&self) -> Option<Ipv6Addr> {
self.gateway_ipv6.get().copied()
}
pub fn set_termination_hook(&self, hook: Arc<dyn Fn() + Send + Sync>) {
*self.termination_hook.lock().unwrap() = Some(hook);
}
pub fn trigger_termination(&self) {
let hook = self.termination_hook.lock().unwrap().clone();
if let Some(hook) = hook {
hook();
}
}
pub fn cache_resolved_hostname(
&self,
domain: &str,
family: ResolvedHostnameFamily,
addrs: impl IntoIterator<Item = IpAddr>,
ttl: Duration,
) {
let hostname = normalize_hostname(domain);
let key = ResolvedHostnameKey { hostname, family };
self.resolved_hostnames
.write()
.insert(key, addrs, ttl, Instant::now());
}
pub fn clear_resolved_hostname(&self, domain: &str, family: ResolvedHostnameFamily) {
let hostname = normalize_hostname(domain);
let key = ResolvedHostnameKey { hostname, family };
self.resolved_hostnames.write().remove(&key, Instant::now());
}
pub fn any_resolved_hostname(
&self,
addr: IpAddr,
mut predicate: impl FnMut(&str) -> bool,
) -> bool {
self.resolved_hostnames
.read()
.member_matches(&addr, Instant::now(), |key| predicate(&key.hostname))
}
pub fn cleanup_resolved_hostnames(&self) {
if let Some(mut idx) = self.resolved_hostnames.try_write() {
idx.evict_expired(Instant::now());
}
}
pub fn add_tx_bytes(&self, bytes: usize) {
self.metrics
.tx_bytes
.fetch_add(bytes as u64, Ordering::Relaxed);
}
pub fn add_rx_bytes(&self, bytes: usize) {
self.metrics
.rx_bytes
.fetch_add(bytes as u64, Ordering::Relaxed);
}
pub fn tx_bytes(&self) -> u64 {
self.metrics.tx_bytes.load(Ordering::Relaxed)
}
pub fn rx_bytes(&self) -> u64 {
self.metrics.rx_bytes.load(Ordering::Relaxed)
}
}
impl Default for NetworkMetrics {
fn default() -> Self {
Self {
tx_bytes: AtomicU64::new(0),
rx_bytes: AtomicU64::new(0),
}
}
}
pub(crate) fn normalize_hostname(domain: &str) -> String {
domain.trim_end_matches('.').to_ascii_lowercase()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shared_state_queue_push_pop() {
let state = SharedState::new(4);
state.tx_ring.push(vec![1, 2, 3]).unwrap();
state.tx_ring.push(vec![4, 5, 6]).unwrap();
assert_eq!(state.tx_ring.pop(), Some(vec![1, 2, 3]));
assert_eq!(state.tx_ring.pop(), Some(vec![4, 5, 6]));
assert_eq!(state.tx_ring.pop(), None);
}
#[test]
fn shared_state_queue_full() {
let state = SharedState::new(2);
state.rx_ring.push(vec![1]).unwrap();
state.rx_ring.push(vec![2]).unwrap();
assert!(state.rx_ring.push(vec![3]).is_err());
}
#[test]
fn resolved_hostnames_are_isolated_per_family() {
let state = SharedState::new(4);
let v4: IpAddr = "1.1.1.1".parse().unwrap();
let v6: IpAddr = "2606:4700:4700::1111".parse().unwrap();
state.cache_resolved_hostname(
"Example.com.",
ResolvedHostnameFamily::Ipv4,
[v4],
Duration::from_secs(30),
);
state.cache_resolved_hostname(
"example.com",
ResolvedHostnameFamily::Ipv6,
[v6],
Duration::from_secs(30),
);
assert!(state.any_resolved_hostname(v4, |h| h == "example.com"));
assert!(state.any_resolved_hostname(v6, |h| h == "example.com"));
assert!(!state.any_resolved_hostname(v4, |h| h == "other.example"));
}
}