use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionDirection {
Inbound,
Outbound,
}
#[derive(Debug, Clone)]
pub struct DomainRequest {
target: String,
port: u16,
direction: ConnectionDirection,
pid: u32,
}
impl DomainRequest {
pub fn new(target: String, port: u16, direction: ConnectionDirection, pid: u32) -> Self {
Self {
target,
port,
direction,
pid,
}
}
pub fn target(&self) -> &str {
&self.target
}
pub fn port(&self) -> u16 {
self.port
}
pub fn direction(&self) -> ConnectionDirection {
self.direction
}
pub fn pid(&self) -> u32 {
self.pid
}
}
pub trait NetworkPolicy: Send + Sync + 'static {
fn check(&self, request: &DomainRequest) -> impl Future<Output = bool> + Send;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DenyAll;
impl NetworkPolicy for DenyAll {
async fn check(&self, _request: &DomainRequest) -> bool {
false
}
}
#[derive(Debug, Clone, Copy)]
pub struct AllowAll;
impl NetworkPolicy for AllowAll {
async fn check(&self, _request: &DomainRequest) -> bool {
true
}
}
pub struct AllowList {
allowed: HashSet<String>,
}
impl AllowList {
pub fn new(domains: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
allowed: domains.into_iter().map(Into::into).collect(),
}
}
fn matches(&self, target: &str) -> bool {
if self.allowed.contains(target) {
return true;
}
for allowed in &self.allowed {
if allowed.starts_with("*.") {
let suffix = &allowed[1..]; if target.ends_with(suffix) {
return true;
}
}
}
false
}
}
impl NetworkPolicy for AllowList {
async fn check(&self, request: &DomainRequest) -> bool {
self.matches(request.target())
}
}
pub struct CustomPolicy<F, Fut>
where
F: Fn(&DomainRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
handler: F,
_marker: PhantomData<fn() -> Fut>,
}
impl<F, Fut> CustomPolicy<F, Fut>
where
F: Fn(&DomainRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
pub fn new(handler: F) -> Self {
Self {
handler,
_marker: PhantomData,
}
}
}
impl<F, Fut> NetworkPolicy for CustomPolicy<F, Fut>
where
F: Fn(&DomainRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
async fn check(&self, request: &DomainRequest) -> bool {
(self.handler)(request).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deny_all() {
smol::block_on(async {
let policy = DenyAll;
let request = DomainRequest::new(
"example.com".to_string(),
443,
ConnectionDirection::Outbound,
1234,
);
assert!(!policy.check(&request).await);
});
}
#[test]
fn test_allow_all() {
smol::block_on(async {
let policy = AllowAll;
let request = DomainRequest::new(
"example.com".to_string(),
443,
ConnectionDirection::Outbound,
1234,
);
assert!(policy.check(&request).await);
});
}
#[test]
fn test_allow_list_exact() {
let policy = AllowList::new(["example.com", "api.test.com"]);
assert!(policy.matches("example.com"));
assert!(policy.matches("api.test.com"));
assert!(!policy.matches("other.com"));
assert!(!policy.matches("sub.example.com"));
}
#[test]
fn test_allow_list_wildcard() {
let policy = AllowList::new(["*.example.com"]);
assert!(policy.matches("api.example.com"));
assert!(policy.matches("sub.api.example.com"));
assert!(!policy.matches("example.com")); assert!(!policy.matches("other.com"));
}
}