1use std::net::IpAddr;
6use std::path::PathBuf;
7
8use ipnetwork::{Ipv4Network, Ipv6Network};
9
10use crate::config::{DnsConfig, InterfaceOverrides, NetworkConfig, PortProtocol, PublishedPort};
11use crate::dns::Nameserver;
12use crate::policy::{BuildError, NetworkPolicy};
13use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, ViolationAction};
14use crate::tls::{ScopedUpstreamCaCert, ScopedVerifyUpstream, TlsConfig};
15
16#[derive(Clone)]
22pub struct NetworkBuilder {
23 config: NetworkConfig,
24 errors: Vec<BuildError>,
25}
26
27pub struct DnsBuilder {
29 config: DnsConfig,
30}
31
32pub struct TlsBuilder {
34 config: TlsConfig,
35}
36
37pub struct SecretBuilder {
47 env_var: Option<String>,
48 value: Option<String>,
49 placeholder: Option<String>,
50 allowed_hosts: Vec<HostPattern>,
51 injection: SecretInjection,
52 on_violation: Option<ViolationAction>,
53 require_tls_identity: bool,
54}
55
56#[derive(Default)]
58pub struct ViolationActionBuilder {
59 action: ViolationAction,
60}
61
62impl NetworkBuilder {
67 pub fn new() -> Self {
69 Self {
70 config: NetworkConfig::default(),
71 errors: Vec::new(),
72 }
73 }
74
75 pub fn from_config(config: NetworkConfig) -> Self {
77 Self {
78 config,
79 errors: Vec::new(),
80 }
81 }
82
83 pub fn enabled(mut self, enabled: bool) -> Self {
85 self.config.enabled = enabled;
86 self
87 }
88
89 pub fn port(self, host_port: u16, guest_port: u16) -> Self {
91 self.port_bind(
92 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
93 host_port,
94 guest_port,
95 )
96 }
97
98 pub fn port_udp(self, host_port: u16, guest_port: u16) -> Self {
100 self.port_udp_bind(
101 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
102 host_port,
103 guest_port,
104 )
105 }
106
107 pub fn port_bind(self, host_bind: IpAddr, host_port: u16, guest_port: u16) -> Self {
109 self.add_port(host_bind, host_port, guest_port, PortProtocol::Tcp)
110 }
111
112 pub fn port_udp_bind(self, host_bind: IpAddr, host_port: u16, guest_port: u16) -> Self {
114 self.add_port(host_bind, host_port, guest_port, PortProtocol::Udp)
115 }
116
117 fn add_port(
118 mut self,
119 host_bind: IpAddr,
120 host_port: u16,
121 guest_port: u16,
122 protocol: PortProtocol,
123 ) -> Self {
124 self.config.ports.push(PublishedPort {
125 host_port,
126 guest_port,
127 protocol,
128 host_bind,
129 });
130 self
131 }
132
133 pub fn policy(mut self, policy: NetworkPolicy) -> Self {
135 self.config.policy = policy;
136 self
137 }
138
139 pub fn dns(mut self, f: impl FnOnce(DnsBuilder) -> DnsBuilder) -> Self {
148 self.config.dns = f(DnsBuilder::new()).build();
149 self
150 }
151
152 pub fn tls(mut self, f: impl FnOnce(TlsBuilder) -> TlsBuilder) -> Self {
154 self.config.tls = f(TlsBuilder::new()).build();
155 self
156 }
157
158 pub fn secret(self, f: impl FnOnce(SecretBuilder) -> SecretBuilder) -> Self {
168 self.secret_entry(f(SecretBuilder::new()).build())
169 }
170
171 pub fn secret_entry(mut self, entry: SecretEntry) -> Self {
173 self.config.secrets.secrets.push(entry);
174 self
175 }
176
177 pub fn secret_env(
179 mut self,
180 env_var: impl Into<String>,
181 value: impl Into<String>,
182 placeholder: impl Into<String>,
183 allowed_host: impl Into<String>,
184 ) -> Self {
185 self.config.secrets.secrets.push(SecretEntry {
186 env_var: env_var.into(),
187 value: value.into(),
188 placeholder: placeholder.into(),
189 allowed_hosts: vec![HostPattern::Exact(allowed_host.into())],
190 injection: SecretInjection::default(),
191 on_violation: None,
192 require_tls_identity: true,
193 });
194 self
195 }
196
197 pub fn on_secret_violation(
199 mut self,
200 f: impl FnOnce(ViolationActionBuilder) -> ViolationActionBuilder,
201 ) -> Self {
202 self.config.secrets.on_violation = f(ViolationActionBuilder::default()).build();
203 self
204 }
205
206 pub fn max_connections(mut self, max: usize) -> Self {
208 self.config.max_connections = Some(max);
209 self
210 }
211
212 pub fn interface(mut self, overrides: InterfaceOverrides) -> Self {
214 self.config.interface = overrides;
215 self
216 }
217
218 pub fn ipv4_pool(mut self, pool: Ipv4Network) -> Self {
222 if pool.prefix() > 30 {
223 self.errors.push(BuildError::InvalidIpv4Pool {
224 raw: pool.to_string(),
225 });
226 } else {
227 self.config.interface.ipv4_pool = Some(pool);
228 }
229 self
230 }
231
232 pub fn ipv6_pool(mut self, pool: Ipv6Network) -> Self {
236 if pool.prefix() > 64 {
237 self.errors.push(BuildError::InvalidIpv6Pool {
238 raw: pool.to_string(),
239 });
240 } else {
241 self.config.interface.ipv6_pool = Some(pool);
242 }
243 self
244 }
245
246 pub fn trust_host_cas(mut self, enabled: bool) -> Self {
252 self.config.trust_host_cas = enabled;
253 self
254 }
255
256 pub fn build(mut self) -> Result<NetworkConfig, BuildError> {
262 if let Some(err) = self.errors.drain(..).next() {
263 return Err(err);
264 }
265 self.config.secrets.validate()?;
266 Ok(self.config)
267 }
268}
269
270impl DnsBuilder {
271 pub fn new() -> Self {
273 Self {
274 config: DnsConfig::default(),
275 }
276 }
277
278 pub fn rebind_protection(mut self, enabled: bool) -> Self {
280 self.config.rebind_protection = enabled;
281 self
282 }
283
284 pub fn nameservers<I>(mut self, nameservers: I) -> Self
291 where
292 I: IntoIterator,
293 I::Item: Into<Nameserver>,
294 {
295 self.config.nameservers = nameservers.into_iter().map(Into::into).collect();
296 self
297 }
298
299 pub fn query_timeout_ms(mut self, ms: u64) -> Self {
301 self.config.query_timeout_ms = ms;
302 self
303 }
304
305 pub fn build(self) -> DnsConfig {
307 self.config
308 }
309}
310
311impl Default for DnsBuilder {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317impl TlsBuilder {
318 pub fn new() -> Self {
320 Self {
321 config: TlsConfig {
322 enabled: true,
323 ..TlsConfig::default()
324 },
325 }
326 }
327
328 pub fn bypass(mut self, pattern: impl Into<String>) -> Self {
330 self.config.bypass.push(pattern.into());
331 self
332 }
333
334 pub fn verify_upstream(mut self, verify: bool) -> Self {
336 self.config.verify_upstream = verify;
337 self
338 }
339
340 pub fn verify_upstream_for(mut self, pattern: impl Into<String>, verify: bool) -> Self {
346 self.config
347 .scoped_verify_upstream
348 .push(ScopedVerifyUpstream {
349 pattern: pattern.into(),
350 verify,
351 });
352 self
353 }
354
355 pub fn intercepted_ports(mut self, ports: Vec<u16>) -> Self {
357 self.config.intercepted_ports = ports;
358 self
359 }
360
361 pub fn block_quic(mut self, block: bool) -> Self {
363 self.config.block_quic_on_intercept = block;
364 self
365 }
366
367 pub fn upstream_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
372 self.config.upstream_ca_cert.push(path.into());
373 self
374 }
375
376 pub fn upstream_ca_cert_for(
383 mut self,
384 pattern: impl Into<String>,
385 path: impl Into<PathBuf>,
386 ) -> Self {
387 self.config
388 .scoped_upstream_ca_cert
389 .push(ScopedUpstreamCaCert {
390 pattern: pattern.into(),
391 path: path.into(),
392 });
393 self
394 }
395
396 pub fn intercept_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
398 self.config.intercept_ca.cert_path = Some(path.into());
399 self
400 }
401
402 pub fn intercept_ca_key(mut self, path: impl Into<PathBuf>) -> Self {
404 self.config.intercept_ca.key_path = Some(path.into());
405 self
406 }
407
408 pub fn build(self) -> TlsConfig {
410 self.config
411 }
412}
413
414impl SecretBuilder {
415 pub fn new() -> Self {
417 Self {
418 env_var: None,
419 value: None,
420 placeholder: None,
421 allowed_hosts: Vec::new(),
422 injection: SecretInjection::default(),
423 on_violation: None,
424 require_tls_identity: true,
425 }
426 }
427
428 pub fn env(mut self, var: impl Into<String>) -> Self {
433 self.env_var = Some(var.into());
434 self
435 }
436
437 pub fn value(mut self, value: impl Into<String>) -> Self {
439 self.value = Some(value.into());
440 self
441 }
442
443 pub fn placeholder(mut self, placeholder: impl Into<String>) -> Self {
449 self.placeholder = Some(placeholder.into());
450 self
451 }
452
453 pub fn allow_host(mut self, host: impl Into<String>) -> Self {
455 self.allowed_hosts.push(HostPattern::Exact(host.into()));
456 self
457 }
458
459 pub fn allow_host_pattern(mut self, pattern: impl Into<String>) -> Self {
461 self.allowed_hosts
462 .push(HostPattern::Wildcard(pattern.into()));
463 self
464 }
465
466 pub fn allow_any_host_dangerous(mut self, i_understand_the_risk: bool) -> Self {
469 if i_understand_the_risk {
470 self.allowed_hosts.push(HostPattern::Any);
471 }
472 self
473 }
474
475 pub fn on_violation(
477 mut self,
478 f: impl FnOnce(ViolationActionBuilder) -> ViolationActionBuilder,
479 ) -> Self {
480 self.on_violation = Some(f(ViolationActionBuilder::default()).build());
481 self
482 }
483
484 pub fn require_tls_identity(mut self, enabled: bool) -> Self {
486 self.require_tls_identity = enabled;
487 self
488 }
489
490 pub fn inject_headers(mut self, enabled: bool) -> Self {
492 self.injection.headers = enabled;
493 self
494 }
495
496 pub fn inject_basic_auth(mut self, enabled: bool) -> Self {
498 self.injection.basic_auth = enabled;
499 self
500 }
501
502 pub fn inject_query(mut self, enabled: bool) -> Self {
504 self.injection.query_params = enabled;
505 self
506 }
507
508 pub fn inject_body(mut self, enabled: bool) -> Self {
515 self.injection.body = enabled;
516 self
517 }
518
519 pub fn build(self) -> SecretEntry {
524 let env_var = self.env_var.expect("SecretBuilder: .env() is required");
525 let value = self.value.expect("SecretBuilder: .value() is required");
526 assert!(
527 !self.allowed_hosts.is_empty(),
528 "SecretBuilder: at least one allowed host is required; use .allow_any_host_dangerous(true) for an explicit any-host secret"
529 );
530 let placeholder = self
531 .placeholder
532 .unwrap_or_else(|| format!("$MSB_{env_var}"));
533
534 SecretEntry {
535 env_var,
536 value,
537 placeholder,
538 allowed_hosts: self.allowed_hosts,
539 injection: self.injection,
540 on_violation: self.on_violation,
541 require_tls_identity: self.require_tls_identity,
542 }
543 }
544}
545
546impl ViolationActionBuilder {
547 pub fn new() -> Self {
549 Self::default()
550 }
551
552 pub fn from_action(action: ViolationAction) -> Self {
554 action.into()
555 }
556
557 pub fn block(mut self) -> Self {
559 self.action = ViolationAction::Block;
560 self
561 }
562
563 pub fn block_and_log(mut self) -> Self {
565 self.action = ViolationAction::BlockAndLog;
566 self
567 }
568
569 pub fn block_and_terminate(mut self) -> Self {
571 self.action = ViolationAction::BlockAndTerminate;
572 self
573 }
574
575 pub fn passthrough_host(mut self, host: impl Into<String>) -> Self {
577 self.push_passthrough_host(HostPattern::Exact(host.into()));
578 self
579 }
580
581 pub fn passthrough_host_pattern(mut self, pattern: impl Into<String>) -> Self {
583 self.push_passthrough_host(HostPattern::Wildcard(pattern.into()));
584 self
585 }
586
587 pub fn passthrough_all_hosts(mut self, i_understand_the_risk: bool) -> Self {
589 if i_understand_the_risk {
590 self.push_passthrough_host(HostPattern::Any);
591 }
592 self
593 }
594
595 fn push_passthrough_host(&mut self, host: HostPattern) {
597 match self.action {
598 ViolationAction::Passthrough(ref mut hosts) => hosts.push(host),
599 _ => self.action = ViolationAction::Passthrough(vec![host]),
600 }
601 }
602
603 pub fn build(self) -> ViolationAction {
605 self.action
606 }
607}
608
609impl Default for NetworkBuilder {
614 fn default() -> Self {
615 Self::new()
616 }
617}
618
619impl Default for TlsBuilder {
620 fn default() -> Self {
621 Self::new()
622 }
623}
624
625impl Default for SecretBuilder {
626 fn default() -> Self {
627 Self::new()
628 }
629}
630impl From<ViolationAction> for ViolationActionBuilder {
631 fn from(action: ViolationAction) -> Self {
632 Self { action }
633 }
634}
635
636#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
646 fn network_builder_happy_path_returns_config() {
647 let cfg = NetworkBuilder::new()
648 .dns(|d| d.rebind_protection(false))
649 .build()
650 .unwrap();
651 assert!(!cfg.dns.rebind_protection);
652 }
653
654 #[test]
655 fn port_bind_sets_host_bind() {
656 let bind = "0.0.0.0".parse().unwrap();
657 let cfg = NetworkBuilder::new()
658 .port_bind(bind, 8080, 80)
659 .port_udp_bind(bind, 5353, 53)
660 .build()
661 .unwrap();
662
663 assert_eq!(cfg.ports[0].host_bind, bind);
664 assert_eq!(cfg.ports[0].host_port, 8080);
665 assert_eq!(cfg.ports[0].guest_port, 80);
666 assert_eq!(cfg.ports[0].protocol, PortProtocol::Tcp);
667 assert_eq!(cfg.ports[1].host_bind, bind);
668 assert_eq!(cfg.ports[1].protocol, PortProtocol::Udp);
669 }
670
671 #[test]
672 fn port_helpers_default_to_loopback() {
673 let cfg = NetworkBuilder::new()
674 .port(8080, 80)
675 .port_udp(5353, 53)
676 .build()
677 .unwrap();
678
679 assert_eq!(
680 cfg.ports[0].host_bind,
681 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
682 );
683 assert_eq!(cfg.ports[0].protocol, PortProtocol::Tcp);
684 assert_eq!(
685 cfg.ports[1].host_bind,
686 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
687 );
688 assert_eq!(cfg.ports[1].protocol, PortProtocol::Udp);
689 }
690
691 #[test]
692 fn network_builder_sets_global_passthrough_action() {
693 let cfg = NetworkBuilder::new()
694 .on_secret_violation(|v| {
695 v.passthrough_host("api.anthropic.com")
696 .passthrough_host_pattern("*.anthropic.com")
697 })
698 .build()
699 .unwrap();
700
701 assert_eq!(
702 cfg.secrets.on_violation,
703 ViolationAction::Passthrough(vec![
704 HostPattern::Exact("api.anthropic.com".into()),
705 HostPattern::Wildcard("*.anthropic.com".into()),
706 ])
707 );
708 }
709
710 #[test]
711 fn secret_builder_sets_violation_action() {
712 let secret = SecretBuilder::new()
713 .env("TOKEN")
714 .value("secret-value")
715 .allow_host("api.github.com")
716 .on_violation(|v| {
717 v.passthrough_host("api.anthropic.com")
718 .passthrough_host_pattern("*.anthropic.com")
719 })
720 .build();
721
722 assert_eq!(
723 secret.on_violation,
724 Some(ViolationAction::Passthrough(vec![
725 HostPattern::Exact("api.anthropic.com".into()),
726 HostPattern::Wildcard("*.anthropic.com".into()),
727 ])),
728 );
729 }
730
731 #[test]
732 #[should_panic(expected = "SecretBuilder: at least one allowed host is required")]
733 fn secret_builder_rejects_empty_allowed_hosts() {
734 let _ = SecretBuilder::new()
735 .env("TOKEN")
736 .value("secret-value")
737 .build();
738 }
739
740 #[test]
741 fn network_builder_rejects_invalid_secret_config() {
742 let err = NetworkBuilder::new()
743 .secret_entry(SecretEntry {
744 env_var: "API=KEY".into(),
745 value: "secret-value".into(),
746 placeholder: "$MSB_API_KEY".into(),
747 allowed_hosts: vec![HostPattern::Exact("api.example.com".into())],
748 injection: SecretInjection::default(),
749 on_violation: None,
750 require_tls_identity: true,
751 })
752 .build()
753 .unwrap_err();
754
755 assert!(err.to_string().contains("env_var must not contain `=`"));
756 }
757
758 #[test]
759 fn violation_action_builder_blocking_call_replaces_passthrough_policy() {
760 let action = ViolationActionBuilder::default()
761 .passthrough_host("google.com")
762 .block_and_terminate()
763 .passthrough_host("facebook.com")
764 .build();
765
766 assert_eq!(
767 action,
768 ViolationAction::Passthrough(vec![HostPattern::Exact("facebook.com".into())])
769 );
770 }
771
772 #[test]
773 fn violation_action_builder_accumulates_passthrough_hosts() {
774 let action = ViolationActionBuilder::default()
775 .block()
776 .passthrough_host("google.com")
777 .passthrough_host("facebook.com")
778 .build();
779
780 assert_eq!(
781 action,
782 ViolationAction::Passthrough(vec![
783 HostPattern::Exact("google.com".into()),
784 HostPattern::Exact("facebook.com".into()),
785 ]),
786 );
787 }
788}