Skip to main content

assay_core/providers/
network.rs

1use std::sync::{Mutex, OnceLock};
2
3#[derive(Clone, Debug, PartialEq, Eq)]
4pub enum NetworkPolicy {
5    Allow,
6    Deny(String),
7}
8
9#[derive(Debug)]
10struct NetworkState {
11    policy: NetworkPolicy,
12}
13
14fn state() -> &'static Mutex<NetworkState> {
15    static STATE: OnceLock<Mutex<NetworkState>> = OnceLock::new();
16    STATE.get_or_init(|| {
17        Mutex::new(NetworkState {
18            policy: NetworkPolicy::Allow,
19        })
20    })
21}
22
23fn lock_state() -> std::sync::MutexGuard<'static, NetworkState> {
24    state()
25        .lock()
26        .unwrap_or_else(|poisoned| poisoned.into_inner())
27}
28
29pub struct NetworkPolicyGuard {
30    previous: NetworkPolicy,
31}
32
33impl NetworkPolicyGuard {
34    pub fn set(policy: NetworkPolicy) -> Self {
35        let mut s = lock_state();
36        let previous = s.policy.clone();
37        s.policy = policy;
38        Self { previous }
39    }
40
41    pub fn deny(reason: impl Into<String>) -> Self {
42        Self::set(NetworkPolicy::Deny(reason.into()))
43    }
44}
45
46impl Drop for NetworkPolicyGuard {
47    fn drop(&mut self) {
48        let mut s = lock_state();
49        s.policy = self.previous.clone();
50    }
51}
52
53pub fn check_outbound(target: &str) -> anyhow::Result<()> {
54    let policy = effective_policy();
55    match policy {
56        NetworkPolicy::Allow => Ok(()),
57        NetworkPolicy::Deny(reason) => anyhow::bail!(
58            "config error: outbound network blocked by policy (target={}): {}",
59            target,
60            reason
61        ),
62    }
63}
64
65fn effective_policy() -> NetworkPolicy {
66    if let Ok(raw) = std::env::var("ASSAY_NETWORK_POLICY") {
67        let mode = raw.trim().to_ascii_lowercase();
68        if mode == "deny" {
69            return NetworkPolicy::Deny("ASSAY_NETWORK_POLICY=deny".to_string());
70        }
71    }
72    let s = lock_state();
73    s.policy.clone()
74}
75
76#[cfg(test)]
77fn test_serial_lock() -> &'static tokio::sync::Mutex<()> {
78    static LOCK: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
79    LOCK.get_or_init(|| tokio::sync::Mutex::new(()))
80}
81
82#[cfg(test)]
83pub(crate) fn lock_test_serial() -> tokio::sync::MutexGuard<'static, ()> {
84    test_serial_lock().blocking_lock()
85}
86
87#[cfg(test)]
88pub(crate) async fn lock_test_serial_async() -> tokio::sync::MutexGuard<'static, ()> {
89    test_serial_lock().lock().await
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn scoped_deny_blocks_and_restores() {
98        let _serial = lock_test_serial();
99        let previous = std::env::var("ASSAY_NETWORK_POLICY").ok();
100        std::env::remove_var("ASSAY_NETWORK_POLICY");
101        {
102            let _guard = NetworkPolicyGuard::deny("test deny");
103            let err = check_outbound("test-target").unwrap_err().to_string();
104            assert!(err.contains("outbound network blocked by policy"));
105            assert!(err.contains("test-target"));
106        }
107        check_outbound("test-target").unwrap();
108        if let Some(v) = previous {
109            std::env::set_var("ASSAY_NETWORK_POLICY", v);
110        } else {
111            std::env::remove_var("ASSAY_NETWORK_POLICY");
112        }
113    }
114
115    #[test]
116    fn env_deny_overrides_scoped_allow() {
117        let _serial = lock_test_serial();
118        let previous = std::env::var("ASSAY_NETWORK_POLICY").ok();
119        let _guard = NetworkPolicyGuard::set(NetworkPolicy::Allow);
120        std::env::set_var("ASSAY_NETWORK_POLICY", "deny");
121        let err = check_outbound("env-target").unwrap_err().to_string();
122        assert!(err.contains("ASSAY_NETWORK_POLICY=deny"));
123        if let Some(v) = previous {
124            std::env::set_var("ASSAY_NETWORK_POLICY", v);
125        } else {
126            std::env::remove_var("ASSAY_NETWORK_POLICY");
127        }
128    }
129}