use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::path::Path;
use std::sync::Arc;
use hudsucker::certificate_authority::RcgenAuthority;
use hudsucker::hyper::{Request, Response, StatusCode};
use hudsucker::rcgen::{CertificateParams, KeyPair};
use hudsucker::{Body, HttpContext, HttpHandler, Proxy, RequestOrResponse};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use crate::policy::{http_acl_check, HttpRule};
pub type OrigDestMap = Arc<std::sync::RwLock<HashMap<SocketAddr, IpAddr>>>;
struct DnsCacheEntry {
ips: Vec<IpAddr>,
expires: std::time::Instant,
}
#[derive(Clone)]
struct AclHandler {
allow_rules: Arc<Vec<HttpRule>>,
deny_rules: Arc<Vec<HttpRule>>,
orig_dest: OrigDestMap,
dns_cache: Arc<tokio::sync::Mutex<HashMap<String, DnsCacheEntry>>>,
}
const DNS_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(30);
impl AclHandler {
async fn resolve_cached(&self, host: &str) -> Option<Vec<IpAddr>> {
{
let cache = self.dns_cache.lock().await;
if let Some(entry) = cache.get(host) {
if entry.expires > std::time::Instant::now() {
return Some(entry.ips.clone());
}
}
}
let lookup = format!("{}:0", host);
let resolved = tokio::net::lookup_host(&lookup).await.ok()?;
let ips: Vec<IpAddr> = resolved.map(|sa| sa.ip()).collect();
let mut cache = self.dns_cache.lock().await;
cache.insert(
host.to_string(),
DnsCacheEntry {
ips: ips.clone(),
expires: std::time::Instant::now() + DNS_CACHE_TTL,
},
);
Some(ips)
}
async fn verify_host(&self, client_addr: &SocketAddr, claimed_host: &str) -> bool {
let orig_ip = {
let map = self.orig_dest.read().unwrap_or_else(|e| e.into_inner());
map.get(client_addr).copied()
};
let orig_ip = match orig_ip {
Some(ip) => ip,
None => return true,
};
if let Ok(ip) = claimed_host.parse::<IpAddr>() {
return ip == orig_ip;
}
match self.resolve_cached(claimed_host).await {
Some(ips) => ips.iter().any(|ip| *ip == orig_ip),
None => false,
}
}
}
impl HttpHandler for AclHandler {
async fn handle_request(
&mut self,
ctx: &HttpContext,
req: Request<Body>,
) -> RequestOrResponse {
let method = req.method().as_str().to_string();
let host = req
.uri()
.host()
.map(|h| h.to_string())
.or_else(|| {
req.headers()
.get("host")
.and_then(|v| v.to_str().ok())
.map(|h| {
h.split(':').next().unwrap_or(h).to_string()
})
})
.unwrap_or_default();
let path = req.uri().path().to_string();
if !self.verify_host(&ctx.client_addr, &host).await {
if let Ok(mut map) = self.orig_dest.write() {
map.remove(&ctx.client_addr);
}
return Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("Blocked by sandlock: Host header does not match connection destination"))
.expect("failed to build 403 response")
.into();
}
if let Ok(mut map) = self.orig_dest.write() {
map.remove(&ctx.client_addr);
}
if http_acl_check(&self.allow_rules, &self.deny_rules, &method, &host, &path) {
let mut req = req;
if req.uri().authority().is_none() {
let host_port = req
.headers()
.get("host")
.and_then(|v| v.to_str().ok())
.unwrap_or_default()
.to_string();
if !host_port.is_empty() {
if let Ok(uri) = format!("http://{}{}", host_port, req.uri().path_and_query().map(|pq| pq.as_str()).unwrap_or("/")).parse() {
*req.uri_mut() = uri;
}
}
}
req.into()
} else {
Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("Blocked by sandlock HTTP ACL policy"))
.expect("failed to build 403 response")
.into()
}
}
}
pub struct HttpAclProxyHandle {
pub addr: SocketAddr,
pub orig_dest: OrigDestMap,
shutdown_tx: Option<oneshot::Sender<()>>,
}
impl Drop for HttpAclProxyHandle {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
fn dummy_ca() -> std::io::Result<(KeyPair, hudsucker::rcgen::Certificate)> {
use hudsucker::rcgen::{BasicConstraints, IsCa};
let kp = KeyPair::generate().map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("keygen failed: {e}"))
})?;
let mut params = CertificateParams::default();
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
let cert = params.self_signed(&kp).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("self-sign failed: {e}"))
})?;
Ok((kp, cert))
}
static DUMMY_CA: std::sync::LazyLock<std::io::Result<(Vec<u8>, Vec<u8>)>> =
std::sync::LazyLock::new(|| {
let (kp, cert) = dummy_ca()?;
Ok((kp.serialize_pem().into_bytes(), cert.pem().into_bytes()))
});
pub async fn spawn_http_acl_proxy(
allow: Vec<HttpRule>,
deny: Vec<HttpRule>,
ca_cert: Option<&Path>,
ca_key: Option<&Path>,
) -> std::io::Result<HttpAclProxyHandle> {
let (key_pair, cert) = if let (Some(cert_path), Some(key_path)) = (ca_cert, ca_key) {
let key_pem = std::fs::read_to_string(key_path).map_err(|e| {
std::io::Error::new(e.kind(), format!("failed to read --https-key {:?}: {e}", key_path))
})?;
let cert_pem = std::fs::read_to_string(cert_path).map_err(|e| {
std::io::Error::new(e.kind(), format!("failed to read --https-ca {:?}: {e}", cert_path))
})?;
let kp = KeyPair::from_pem(&key_pem).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid CA key: {e}"))
})?;
let params = CertificateParams::from_ca_cert_pem(&cert_pem).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("invalid CA cert: {e}"))
})?;
let cert = params.self_signed(&kp).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("CA cert error: {e}"))
})?;
(kp, cert)
} else {
let (key_pem, cert_pem) = DUMMY_CA.as_ref().map_err(|e| {
std::io::Error::new(e.kind(), format!("dummy CA init failed: {e}"))
})?;
let kp = KeyPair::from_pem(std::str::from_utf8(key_pem).unwrap()).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("dummy CA key: {e}"))
})?;
let params = CertificateParams::from_ca_cert_pem(std::str::from_utf8(cert_pem).unwrap())
.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("dummy CA cert: {e}"))
})?;
let cert = params.self_signed(&kp).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("dummy CA sign: {e}"))
})?;
(kp, cert)
};
let ca = RcgenAuthority::new(key_pair, cert, 1_000);
let orig_dest: OrigDestMap = Arc::new(std::sync::RwLock::new(HashMap::new()));
let handler = AclHandler {
allow_rules: Arc::new(allow),
deny_rules: Arc::new(deny),
orig_dest: Arc::clone(&orig_dest),
dns_cache: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
};
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let proxy = Proxy::builder()
.with_listener(listener)
.with_rustls_client()
.with_ca(ca)
.with_http_handler(handler)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
})
.build();
tokio::spawn(async move {
if let Err(e) = proxy.start().await {
eprintln!("sandlock HTTP ACL proxy error: {e}");
}
});
Ok(HttpAclProxyHandle {
addr,
orig_dest,
shutdown_tx: Some(shutdown_tx),
})
}