use std::net::IpAddr;
use std::path::PathBuf;
use ipnetwork::{Ipv4Network, Ipv6Network};
use crate::config::{DnsConfig, InterfaceOverrides, NetworkConfig, PortProtocol, PublishedPort};
use crate::dns::Nameserver;
use crate::policy::{BuildError, NetworkPolicy};
use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, ViolationAction};
use crate::tls::TlsConfig;
#[derive(Clone)]
pub struct NetworkBuilder {
config: NetworkConfig,
errors: Vec<BuildError>,
}
pub struct DnsBuilder {
config: DnsConfig,
}
pub struct TlsBuilder {
config: TlsConfig,
}
pub struct SecretBuilder {
env_var: Option<String>,
value: Option<String>,
placeholder: Option<String>,
allowed_hosts: Vec<HostPattern>,
injection: SecretInjection,
on_violation: Option<ViolationAction>,
require_tls_identity: bool,
}
#[derive(Default)]
pub struct ViolationActionBuilder {
action: ViolationAction,
}
impl NetworkBuilder {
pub fn new() -> Self {
Self {
config: NetworkConfig::default(),
errors: Vec::new(),
}
}
pub fn from_config(config: NetworkConfig) -> Self {
Self {
config,
errors: Vec::new(),
}
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.config.enabled = enabled;
self
}
pub fn port(self, host_port: u16, guest_port: u16) -> Self {
self.port_bind(
IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
host_port,
guest_port,
)
}
pub fn port_udp(self, host_port: u16, guest_port: u16) -> Self {
self.port_udp_bind(
IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
host_port,
guest_port,
)
}
pub fn port_bind(self, host_bind: IpAddr, host_port: u16, guest_port: u16) -> Self {
self.add_port(host_bind, host_port, guest_port, PortProtocol::Tcp)
}
pub fn port_udp_bind(self, host_bind: IpAddr, host_port: u16, guest_port: u16) -> Self {
self.add_port(host_bind, host_port, guest_port, PortProtocol::Udp)
}
fn add_port(
mut self,
host_bind: IpAddr,
host_port: u16,
guest_port: u16,
protocol: PortProtocol,
) -> Self {
self.config.ports.push(PublishedPort {
host_port,
guest_port,
protocol,
host_bind,
});
self
}
pub fn policy(mut self, policy: NetworkPolicy) -> Self {
self.config.policy = policy;
self
}
pub fn dns(mut self, f: impl FnOnce(DnsBuilder) -> DnsBuilder) -> Self {
self.config.dns = f(DnsBuilder::new()).build();
self
}
pub fn tls(mut self, f: impl FnOnce(TlsBuilder) -> TlsBuilder) -> Self {
self.config.tls = f(TlsBuilder::new()).build();
self
}
pub fn secret(mut self, f: impl FnOnce(SecretBuilder) -> SecretBuilder) -> Self {
self.config
.secrets
.secrets
.push(f(SecretBuilder::new()).build());
self
}
pub fn secret_env(
mut self,
env_var: impl Into<String>,
value: impl Into<String>,
placeholder: impl Into<String>,
allowed_host: impl Into<String>,
) -> Self {
self.config.secrets.secrets.push(SecretEntry {
env_var: env_var.into(),
value: value.into(),
placeholder: placeholder.into(),
allowed_hosts: vec![HostPattern::Exact(allowed_host.into())],
injection: SecretInjection::default(),
on_violation: None,
require_tls_identity: true,
});
self
}
pub fn on_secret_violation(
mut self,
f: impl FnOnce(ViolationActionBuilder) -> ViolationActionBuilder,
) -> Self {
self.config.secrets.on_violation = f(ViolationActionBuilder::default()).build();
self
}
pub fn max_connections(mut self, max: usize) -> Self {
self.config.max_connections = Some(max);
self
}
pub fn interface(mut self, overrides: InterfaceOverrides) -> Self {
self.config.interface = overrides;
self
}
pub fn ipv4_pool(mut self, pool: Ipv4Network) -> Self {
if pool.prefix() > 30 {
self.errors.push(BuildError::InvalidIpv4Pool {
raw: pool.to_string(),
});
} else {
self.config.interface.ipv4_pool = Some(pool);
}
self
}
pub fn ipv6_pool(mut self, pool: Ipv6Network) -> Self {
if pool.prefix() > 64 {
self.errors.push(BuildError::InvalidIpv6Pool {
raw: pool.to_string(),
});
} else {
self.config.interface.ipv6_pool = Some(pool);
}
self
}
pub fn trust_host_cas(mut self, enabled: bool) -> Self {
self.config.trust_host_cas = enabled;
self
}
pub fn build(mut self) -> Result<NetworkConfig, BuildError> {
if let Some(err) = self.errors.drain(..).next() {
return Err(err);
}
Ok(self.config)
}
}
impl DnsBuilder {
pub fn new() -> Self {
Self {
config: DnsConfig::default(),
}
}
pub fn rebind_protection(mut self, enabled: bool) -> Self {
self.config.rebind_protection = enabled;
self
}
pub fn nameservers<I>(mut self, nameservers: I) -> Self
where
I: IntoIterator,
I::Item: Into<Nameserver>,
{
self.config.nameservers = nameservers.into_iter().map(Into::into).collect();
self
}
pub fn query_timeout_ms(mut self, ms: u64) -> Self {
self.config.query_timeout_ms = ms;
self
}
pub fn build(self) -> DnsConfig {
self.config
}
}
impl Default for DnsBuilder {
fn default() -> Self {
Self::new()
}
}
impl TlsBuilder {
pub fn new() -> Self {
Self {
config: TlsConfig {
enabled: true,
..TlsConfig::default()
},
}
}
pub fn bypass(mut self, pattern: impl Into<String>) -> Self {
self.config.bypass.push(pattern.into());
self
}
pub fn verify_upstream(mut self, verify: bool) -> Self {
self.config.verify_upstream = verify;
self
}
pub fn intercepted_ports(mut self, ports: Vec<u16>) -> Self {
self.config.intercepted_ports = ports;
self
}
pub fn block_quic(mut self, block: bool) -> Self {
self.config.block_quic_on_intercept = block;
self
}
pub fn upstream_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.config.upstream_ca_cert.push(path.into());
self
}
pub fn intercept_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.config.intercept_ca.cert_path = Some(path.into());
self
}
pub fn intercept_ca_key(mut self, path: impl Into<PathBuf>) -> Self {
self.config.intercept_ca.key_path = Some(path.into());
self
}
pub fn build(self) -> TlsConfig {
self.config
}
}
impl SecretBuilder {
pub fn new() -> Self {
Self {
env_var: None,
value: None,
placeholder: None,
allowed_hosts: Vec::new(),
injection: SecretInjection::default(),
on_violation: None,
require_tls_identity: true,
}
}
pub fn env(mut self, var: impl Into<String>) -> Self {
self.env_var = Some(var.into());
self
}
pub fn value(mut self, value: impl Into<String>) -> Self {
self.value = Some(value.into());
self
}
pub fn placeholder(mut self, placeholder: impl Into<String>) -> Self {
self.placeholder = Some(placeholder.into());
self
}
pub fn allow_host(mut self, host: impl Into<String>) -> Self {
self.allowed_hosts.push(HostPattern::Exact(host.into()));
self
}
pub fn allow_host_pattern(mut self, pattern: impl Into<String>) -> Self {
self.allowed_hosts
.push(HostPattern::Wildcard(pattern.into()));
self
}
pub fn allow_any_host_dangerous(mut self, i_understand_the_risk: bool) -> Self {
if i_understand_the_risk {
self.allowed_hosts.push(HostPattern::Any);
}
self
}
pub fn on_violation(
mut self,
f: impl FnOnce(ViolationActionBuilder) -> ViolationActionBuilder,
) -> Self {
self.on_violation = Some(f(ViolationActionBuilder::default()).build());
self
}
pub fn require_tls_identity(mut self, enabled: bool) -> Self {
self.require_tls_identity = enabled;
self
}
pub fn inject_headers(mut self, enabled: bool) -> Self {
self.injection.headers = enabled;
self
}
pub fn inject_basic_auth(mut self, enabled: bool) -> Self {
self.injection.basic_auth = enabled;
self
}
pub fn inject_query(mut self, enabled: bool) -> Self {
self.injection.query_params = enabled;
self
}
pub fn inject_body(mut self, enabled: bool) -> Self {
self.injection.body = enabled;
self
}
pub fn build(self) -> SecretEntry {
let env_var = self.env_var.expect("SecretBuilder: .env() is required");
let value = self.value.expect("SecretBuilder: .value() is required");
let placeholder = self
.placeholder
.unwrap_or_else(|| format!("$MSB_{env_var}"));
SecretEntry {
env_var,
value,
placeholder,
allowed_hosts: self.allowed_hosts,
injection: self.injection,
on_violation: self.on_violation,
require_tls_identity: self.require_tls_identity,
}
}
}
impl ViolationActionBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn from_action(action: ViolationAction) -> Self {
action.into()
}
pub fn block(mut self) -> Self {
self.action = ViolationAction::Block;
self
}
pub fn block_and_log(mut self) -> Self {
self.action = ViolationAction::BlockAndLog;
self
}
pub fn block_and_terminate(mut self) -> Self {
self.action = ViolationAction::BlockAndTerminate;
self
}
pub fn passthrough_host(mut self, host: impl Into<String>) -> Self {
self.push_passthrough_host(HostPattern::Exact(host.into()));
self
}
pub fn passthrough_host_pattern(mut self, pattern: impl Into<String>) -> Self {
self.push_passthrough_host(HostPattern::Wildcard(pattern.into()));
self
}
pub fn passthrough_all_hosts(mut self, i_understand_the_risk: bool) -> Self {
if i_understand_the_risk {
self.push_passthrough_host(HostPattern::Any);
}
self
}
fn push_passthrough_host(&mut self, host: HostPattern) {
match self.action {
ViolationAction::Passthrough(ref mut hosts) => hosts.push(host),
_ => self.action = ViolationAction::Passthrough(vec![host]),
}
}
pub fn build(self) -> ViolationAction {
self.action
}
}
impl Default for NetworkBuilder {
fn default() -> Self {
Self::new()
}
}
impl Default for TlsBuilder {
fn default() -> Self {
Self::new()
}
}
impl Default for SecretBuilder {
fn default() -> Self {
Self::new()
}
}
impl From<ViolationAction> for ViolationActionBuilder {
fn from(action: ViolationAction) -> Self {
Self { action }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn network_builder_happy_path_returns_config() {
let cfg = NetworkBuilder::new()
.dns(|d| d.rebind_protection(false))
.build()
.unwrap();
assert!(!cfg.dns.rebind_protection);
}
#[test]
fn port_bind_sets_host_bind() {
let bind = "0.0.0.0".parse().unwrap();
let cfg = NetworkBuilder::new()
.port_bind(bind, 8080, 80)
.port_udp_bind(bind, 5353, 53)
.build()
.unwrap();
assert_eq!(cfg.ports[0].host_bind, bind);
assert_eq!(cfg.ports[0].host_port, 8080);
assert_eq!(cfg.ports[0].guest_port, 80);
assert_eq!(cfg.ports[0].protocol, PortProtocol::Tcp);
assert_eq!(cfg.ports[1].host_bind, bind);
assert_eq!(cfg.ports[1].protocol, PortProtocol::Udp);
}
#[test]
fn network_builder_sets_global_passthrough_action() {
let cfg = NetworkBuilder::new()
.on_secret_violation(|v| {
v.passthrough_host("api.anthropic.com")
.passthrough_host_pattern("*.anthropic.com")
})
.build()
.unwrap();
assert_eq!(
cfg.secrets.on_violation,
ViolationAction::Passthrough(vec![
HostPattern::Exact("api.anthropic.com".into()),
HostPattern::Wildcard("*.anthropic.com".into()),
])
);
}
#[test]
fn secret_builder_sets_violation_action() {
let secret = SecretBuilder::new()
.env("TOKEN")
.value("secret-value")
.allow_host("api.github.com")
.on_violation(|v| {
v.passthrough_host("api.anthropic.com")
.passthrough_host_pattern("*.anthropic.com")
})
.build();
assert_eq!(
secret.on_violation,
Some(ViolationAction::Passthrough(vec![
HostPattern::Exact("api.anthropic.com".into()),
HostPattern::Wildcard("*.anthropic.com".into()),
])),
);
}
#[test]
fn violation_action_builder_blocking_call_replaces_passthrough_policy() {
let action = ViolationActionBuilder::default()
.passthrough_host("google.com")
.block_and_terminate()
.passthrough_host("facebook.com")
.build();
assert_eq!(
action,
ViolationAction::Passthrough(vec![HostPattern::Exact("facebook.com".into())])
);
}
#[test]
fn violation_action_builder_accumulates_passthrough_hosts() {
let action = ViolationActionBuilder::default()
.block()
.passthrough_host("google.com")
.passthrough_host("facebook.com")
.build();
assert_eq!(
action,
ViolationAction::Passthrough(vec![
HostPattern::Exact("google.com".into()),
HostPattern::Exact("facebook.com".into()),
]),
);
}
}