assay_core/providers/
network.rs1use std::sync::{Mutex, OnceLock};
2
3#[derive(Clone, Debug, PartialEq, Eq)]
4pub enum NetworkPolicy {
5 Allow,
6 Deny(String),
7}
8
9#[derive(Debug)]
10struct NetworkState {
11 policy: NetworkPolicy,
12}
13
14fn state() -> &'static Mutex<NetworkState> {
15 static STATE: OnceLock<Mutex<NetworkState>> = OnceLock::new();
16 STATE.get_or_init(|| {
17 Mutex::new(NetworkState {
18 policy: NetworkPolicy::Allow,
19 })
20 })
21}
22
23fn lock_state() -> std::sync::MutexGuard<'static, NetworkState> {
24 state()
25 .lock()
26 .unwrap_or_else(|poisoned| poisoned.into_inner())
27}
28
29pub struct NetworkPolicyGuard {
30 previous: NetworkPolicy,
31}
32
33impl NetworkPolicyGuard {
34 pub fn set(policy: NetworkPolicy) -> Self {
35 let mut s = lock_state();
36 let previous = s.policy.clone();
37 s.policy = policy;
38 Self { previous }
39 }
40
41 pub fn deny(reason: impl Into<String>) -> Self {
42 Self::set(NetworkPolicy::Deny(reason.into()))
43 }
44}
45
46impl Drop for NetworkPolicyGuard {
47 fn drop(&mut self) {
48 let mut s = lock_state();
49 s.policy = self.previous.clone();
50 }
51}
52
53pub fn check_outbound(target: &str) -> anyhow::Result<()> {
54 let policy = effective_policy();
55 match policy {
56 NetworkPolicy::Allow => Ok(()),
57 NetworkPolicy::Deny(reason) => anyhow::bail!(
58 "config error: outbound network blocked by policy (target={}): {}",
59 target,
60 reason
61 ),
62 }
63}
64
65fn effective_policy() -> NetworkPolicy {
66 if let Ok(raw) = std::env::var("ASSAY_NETWORK_POLICY") {
67 let mode = raw.trim().to_ascii_lowercase();
68 if mode == "deny" {
69 return NetworkPolicy::Deny("ASSAY_NETWORK_POLICY=deny".to_string());
70 }
71 }
72 let s = lock_state();
73 s.policy.clone()
74}
75
76#[cfg(test)]
77fn test_serial_lock() -> &'static tokio::sync::Mutex<()> {
78 static LOCK: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
79 LOCK.get_or_init(|| tokio::sync::Mutex::new(()))
80}
81
82#[cfg(test)]
83pub(crate) fn lock_test_serial() -> tokio::sync::MutexGuard<'static, ()> {
84 test_serial_lock().blocking_lock()
85}
86
87#[cfg(test)]
88pub(crate) async fn lock_test_serial_async() -> tokio::sync::MutexGuard<'static, ()> {
89 test_serial_lock().lock().await
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn scoped_deny_blocks_and_restores() {
98 let _serial = lock_test_serial();
99 let previous = std::env::var("ASSAY_NETWORK_POLICY").ok();
100 std::env::remove_var("ASSAY_NETWORK_POLICY");
101 {
102 let _guard = NetworkPolicyGuard::deny("test deny");
103 let err = check_outbound("test-target").unwrap_err().to_string();
104 assert!(err.contains("outbound network blocked by policy"));
105 assert!(err.contains("test-target"));
106 }
107 check_outbound("test-target").unwrap();
108 if let Some(v) = previous {
109 std::env::set_var("ASSAY_NETWORK_POLICY", v);
110 } else {
111 std::env::remove_var("ASSAY_NETWORK_POLICY");
112 }
113 }
114
115 #[test]
116 fn env_deny_overrides_scoped_allow() {
117 let _serial = lock_test_serial();
118 let previous = std::env::var("ASSAY_NETWORK_POLICY").ok();
119 let _guard = NetworkPolicyGuard::set(NetworkPolicy::Allow);
120 std::env::set_var("ASSAY_NETWORK_POLICY", "deny");
121 let err = check_outbound("env-target").unwrap_err().to_string();
122 assert!(err.contains("ASSAY_NETWORK_POLICY=deny"));
123 if let Some(v) = previous {
124 std::env::set_var("ASSAY_NETWORK_POLICY", v);
125 } else {
126 std::env::remove_var("ASSAY_NETWORK_POLICY");
127 }
128 }
129}