assay-core 2.17.0

High-performance evaluation framework for LLM agents (Core)
Documentation
use std::sync::{Mutex, OnceLock};

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum NetworkPolicy {
    Allow,
    Deny(String),
}

#[derive(Debug)]
struct NetworkState {
    policy: NetworkPolicy,
}

fn state() -> &'static Mutex<NetworkState> {
    static STATE: OnceLock<Mutex<NetworkState>> = OnceLock::new();
    STATE.get_or_init(|| {
        Mutex::new(NetworkState {
            policy: NetworkPolicy::Allow,
        })
    })
}

fn lock_state() -> std::sync::MutexGuard<'static, NetworkState> {
    state()
        .lock()
        .unwrap_or_else(|poisoned| poisoned.into_inner())
}

pub struct NetworkPolicyGuard {
    previous: NetworkPolicy,
}

impl NetworkPolicyGuard {
    pub fn set(policy: NetworkPolicy) -> Self {
        let mut s = lock_state();
        let previous = s.policy.clone();
        s.policy = policy;
        Self { previous }
    }

    pub fn deny(reason: impl Into<String>) -> Self {
        Self::set(NetworkPolicy::Deny(reason.into()))
    }
}

impl Drop for NetworkPolicyGuard {
    fn drop(&mut self) {
        let mut s = lock_state();
        s.policy = self.previous.clone();
    }
}

pub fn check_outbound(target: &str) -> anyhow::Result<()> {
    let policy = effective_policy();
    match policy {
        NetworkPolicy::Allow => Ok(()),
        NetworkPolicy::Deny(reason) => anyhow::bail!(
            "config error: outbound network blocked by policy (target={}): {}",
            target,
            reason
        ),
    }
}

fn effective_policy() -> NetworkPolicy {
    if let Ok(raw) = std::env::var("ASSAY_NETWORK_POLICY") {
        let mode = raw.trim().to_ascii_lowercase();
        if mode == "deny" {
            return NetworkPolicy::Deny("ASSAY_NETWORK_POLICY=deny".to_string());
        }
    }
    let s = lock_state();
    s.policy.clone()
}

#[cfg(test)]
fn test_serial_lock() -> &'static tokio::sync::Mutex<()> {
    static LOCK: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
    LOCK.get_or_init(|| tokio::sync::Mutex::new(()))
}

#[cfg(test)]
pub(crate) fn lock_test_serial() -> tokio::sync::MutexGuard<'static, ()> {
    test_serial_lock().blocking_lock()
}

#[cfg(test)]
pub(crate) async fn lock_test_serial_async() -> tokio::sync::MutexGuard<'static, ()> {
    test_serial_lock().lock().await
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn scoped_deny_blocks_and_restores() {
        let _serial = lock_test_serial();
        let previous = std::env::var("ASSAY_NETWORK_POLICY").ok();
        std::env::remove_var("ASSAY_NETWORK_POLICY");
        {
            let _guard = NetworkPolicyGuard::deny("test deny");
            let err = check_outbound("test-target").unwrap_err().to_string();
            assert!(err.contains("outbound network blocked by policy"));
            assert!(err.contains("test-target"));
        }
        check_outbound("test-target").unwrap();
        if let Some(v) = previous {
            std::env::set_var("ASSAY_NETWORK_POLICY", v);
        } else {
            std::env::remove_var("ASSAY_NETWORK_POLICY");
        }
    }

    #[test]
    fn env_deny_overrides_scoped_allow() {
        let _serial = lock_test_serial();
        let previous = std::env::var("ASSAY_NETWORK_POLICY").ok();
        let _guard = NetworkPolicyGuard::set(NetworkPolicy::Allow);
        std::env::set_var("ASSAY_NETWORK_POLICY", "deny");
        let err = check_outbound("env-target").unwrap_err().to_string();
        assert!(err.contains("ASSAY_NETWORK_POLICY=deny"));
        if let Some(v) = previous {
            std::env::set_var("ASSAY_NETWORK_POLICY", v);
        } else {
            std::env::remove_var("ASSAY_NETWORK_POLICY");
        }
    }
}