use std::net::IpAddr;
use std::path::PathBuf;
use crate::config::{DnsConfig, InterfaceOverrides, NetworkConfig, PortProtocol, PublishedPort};
use crate::dns::Nameserver;
use crate::policy::{BuildError, DomainName, NetworkPolicy};
use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, ViolationAction};
use crate::tls::TlsConfig;
pub struct NetworkBuilder {
config: NetworkConfig,
errors: Vec<BuildError>,
}
pub struct DnsBuilder {
config: DnsConfig,
errors: Vec<BuildError>,
}
pub struct TlsBuilder {
config: TlsConfig,
}
pub struct SecretBuilder {
env_var: Option<String>,
value: Option<String>,
placeholder: Option<String>,
allowed_hosts: Vec<HostPattern>,
injection: SecretInjection,
require_tls_identity: bool,
}
impl NetworkBuilder {
pub fn new() -> Self {
Self {
config: NetworkConfig::default(),
errors: Vec::new(),
}
}
pub fn from_config(config: NetworkConfig) -> Self {
Self {
config,
errors: Vec::new(),
}
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.config.enabled = enabled;
self
}
pub fn port(self, host_port: u16, guest_port: u16) -> Self {
self.add_port(host_port, guest_port, PortProtocol::Tcp)
}
pub fn port_udp(self, host_port: u16, guest_port: u16) -> Self {
self.add_port(host_port, guest_port, PortProtocol::Udp)
}
fn add_port(mut self, host_port: u16, guest_port: u16, protocol: PortProtocol) -> Self {
self.config.ports.push(PublishedPort {
host_port,
guest_port,
protocol,
host_bind: IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
});
self
}
pub fn policy(mut self, policy: NetworkPolicy) -> Self {
self.config.policy = policy;
self
}
pub fn dns(mut self, f: impl FnOnce(DnsBuilder) -> DnsBuilder) -> Self {
let dns_builder = f(DnsBuilder::new());
match dns_builder.build() {
Ok(config) => self.config.dns = config,
Err(err) => self.errors.push(err),
}
self
}
pub fn tls(mut self, f: impl FnOnce(TlsBuilder) -> TlsBuilder) -> Self {
self.config.tls = f(TlsBuilder::new()).build();
self
}
pub fn secret(mut self, f: impl FnOnce(SecretBuilder) -> SecretBuilder) -> Self {
self.config
.secrets
.secrets
.push(f(SecretBuilder::new()).build());
self
}
pub fn secret_env(
mut self,
env_var: impl Into<String>,
value: impl Into<String>,
placeholder: impl Into<String>,
allowed_host: impl Into<String>,
) -> Self {
self.config.secrets.secrets.push(SecretEntry {
env_var: env_var.into(),
value: value.into(),
placeholder: placeholder.into(),
allowed_hosts: vec![HostPattern::Exact(allowed_host.into())],
injection: SecretInjection::default(),
require_tls_identity: true,
});
self
}
pub fn on_secret_violation(mut self, action: ViolationAction) -> Self {
self.config.secrets.on_violation = action;
self
}
pub fn max_connections(mut self, max: usize) -> Self {
self.config.max_connections = Some(max);
self
}
pub fn interface(mut self, overrides: InterfaceOverrides) -> Self {
self.config.interface = overrides;
self
}
pub fn trust_host_cas(mut self, enabled: bool) -> Self {
self.config.trust_host_cas = enabled;
self
}
pub fn build(mut self) -> Result<NetworkConfig, BuildError> {
if let Some(err) = self.errors.drain(..).next() {
return Err(err);
}
Ok(self.config)
}
}
impl DnsBuilder {
pub fn new() -> Self {
Self {
config: DnsConfig::default(),
errors: Vec::new(),
}
}
pub fn block_domain(mut self, domain: impl Into<String>) -> Self {
let raw: String = domain.into();
match raw.parse::<DomainName>() {
Ok(name) => self.config.blocked_domains.push(name.into()),
Err(source) => self
.errors
.push(BuildError::InvalidBlockedDomain { raw, source }),
}
self
}
pub fn block_domain_suffix(mut self, suffix: impl Into<String>) -> Self {
let raw: String = suffix.into();
match raw.parse::<DomainName>() {
Ok(name) => self.config.blocked_suffixes.push(name.into()),
Err(source) => self
.errors
.push(BuildError::InvalidBlockedDomainSuffix { raw, source }),
}
self
}
pub fn rebind_protection(mut self, enabled: bool) -> Self {
self.config.rebind_protection = enabled;
self
}
pub fn nameservers<I>(mut self, nameservers: I) -> Self
where
I: IntoIterator,
I::Item: Into<Nameserver>,
{
self.config.nameservers = nameservers.into_iter().map(Into::into).collect();
self
}
pub fn query_timeout_ms(mut self, ms: u64) -> Self {
self.config.query_timeout_ms = ms;
self
}
pub fn build(mut self) -> Result<DnsConfig, BuildError> {
if let Some(err) = self.errors.drain(..).next() {
return Err(err);
}
Ok(self.config)
}
}
impl Default for DnsBuilder {
fn default() -> Self {
Self::new()
}
}
impl TlsBuilder {
pub fn new() -> Self {
Self {
config: TlsConfig {
enabled: true,
..TlsConfig::default()
},
}
}
pub fn bypass(mut self, pattern: impl Into<String>) -> Self {
self.config.bypass.push(pattern.into());
self
}
pub fn verify_upstream(mut self, verify: bool) -> Self {
self.config.verify_upstream = verify;
self
}
pub fn intercepted_ports(mut self, ports: Vec<u16>) -> Self {
self.config.intercepted_ports = ports;
self
}
pub fn block_quic(mut self, block: bool) -> Self {
self.config.block_quic_on_intercept = block;
self
}
pub fn upstream_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.config.upstream_ca_cert.push(path.into());
self
}
pub fn intercept_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.config.intercept_ca.cert_path = Some(path.into());
self
}
pub fn intercept_ca_key(mut self, path: impl Into<PathBuf>) -> Self {
self.config.intercept_ca.key_path = Some(path.into());
self
}
pub fn build(self) -> TlsConfig {
self.config
}
}
impl SecretBuilder {
pub fn new() -> Self {
Self {
env_var: None,
value: None,
placeholder: None,
allowed_hosts: Vec::new(),
injection: SecretInjection::default(),
require_tls_identity: true,
}
}
pub fn env(mut self, var: impl Into<String>) -> Self {
self.env_var = Some(var.into());
self
}
pub fn value(mut self, value: impl Into<String>) -> Self {
self.value = Some(value.into());
self
}
pub fn placeholder(mut self, placeholder: impl Into<String>) -> Self {
self.placeholder = Some(placeholder.into());
self
}
pub fn allow_host(mut self, host: impl Into<String>) -> Self {
self.allowed_hosts.push(HostPattern::Exact(host.into()));
self
}
pub fn allow_host_pattern(mut self, pattern: impl Into<String>) -> Self {
self.allowed_hosts
.push(HostPattern::Wildcard(pattern.into()));
self
}
pub fn allow_any_host_dangerous(mut self, i_understand_the_risk: bool) -> Self {
if i_understand_the_risk {
self.allowed_hosts.push(HostPattern::Any);
}
self
}
pub fn require_tls_identity(mut self, enabled: bool) -> Self {
self.require_tls_identity = enabled;
self
}
pub fn inject_headers(mut self, enabled: bool) -> Self {
self.injection.headers = enabled;
self
}
pub fn inject_basic_auth(mut self, enabled: bool) -> Self {
self.injection.basic_auth = enabled;
self
}
pub fn inject_query(mut self, enabled: bool) -> Self {
self.injection.query_params = enabled;
self
}
pub fn inject_body(mut self, enabled: bool) -> Self {
self.injection.body = enabled;
self
}
pub fn build(self) -> SecretEntry {
let env_var = self.env_var.expect("SecretBuilder: .env() is required");
let value = self.value.expect("SecretBuilder: .value() is required");
let placeholder = self
.placeholder
.unwrap_or_else(|| format!("$MSB_{env_var}"));
SecretEntry {
env_var,
value,
placeholder,
allowed_hosts: self.allowed_hosts,
injection: self.injection,
require_tls_identity: self.require_tls_identity,
}
}
}
impl Default for NetworkBuilder {
fn default() -> Self {
Self::new()
}
}
impl Default for TlsBuilder {
fn default() -> Self {
Self::new()
}
}
impl Default for SecretBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn block_domain_happy_path() {
let cfg = DnsBuilder::new()
.block_domain("evil.com")
.block_domain_suffix(".tracking.example")
.build()
.unwrap();
assert_eq!(cfg.blocked_domains.len(), 1);
assert_eq!(cfg.blocked_suffixes.len(), 1);
}
#[test]
fn block_domain_invalid_surfaces_at_build() {
let result = DnsBuilder::new().block_domain("not a domain!").build();
match result {
Err(BuildError::InvalidBlockedDomain { raw, .. }) => {
assert_eq!(raw, "not a domain!");
}
other => panic!("expected InvalidBlockedDomain, got {other:?}"),
}
}
#[test]
fn block_domain_suffix_invalid_surfaces_at_build() {
let result = DnsBuilder::new()
.block_domain_suffix("...invalid!!!")
.build();
match result {
Err(BuildError::InvalidBlockedDomainSuffix { raw, .. }) => {
assert_eq!(raw, "...invalid!!!");
}
other => panic!("expected InvalidBlockedDomainSuffix, got {other:?}"),
}
}
#[test]
fn block_domain_first_error_wins() {
let result = DnsBuilder::new()
.block_domain("first bad!")
.block_domain("second bad!")
.build();
match result {
Err(BuildError::InvalidBlockedDomain { raw, .. }) => {
assert_eq!(raw, "first bad!");
}
other => panic!("expected first-error InvalidBlockedDomain, got {other:?}"),
}
}
#[test]
fn dns_error_cascades_through_network_builder() {
let result = NetworkBuilder::new()
.dns(|d| d.block_domain("not a domain!"))
.build();
match result {
Err(BuildError::InvalidBlockedDomain { raw, .. }) => {
assert_eq!(raw, "not a domain!");
}
other => panic!("expected cascaded InvalidBlockedDomain, got {other:?}"),
}
}
#[test]
fn network_builder_happy_path_returns_config() {
let cfg = NetworkBuilder::new()
.dns(|d| d.block_domain("evil.com"))
.build()
.unwrap();
assert_eq!(cfg.dns.blocked_domains.len(), 1);
}
}