1use std::net::IpAddr;
6use std::path::PathBuf;
7
8use ipnetwork::{Ipv4Network, Ipv6Network};
9
10use crate::config::{DnsConfig, InterfaceOverrides, NetworkConfig, PortProtocol, PublishedPort};
11use crate::dns::Nameserver;
12use crate::policy::{BuildError, NetworkPolicy};
13use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, ViolationAction};
14use crate::tls::TlsConfig;
15
16#[derive(Clone)]
22pub struct NetworkBuilder {
23 config: NetworkConfig,
24 errors: Vec<BuildError>,
25}
26
27pub struct DnsBuilder {
29 config: DnsConfig,
30}
31
32pub struct TlsBuilder {
34 config: TlsConfig,
35}
36
37pub struct SecretBuilder {
47 env_var: Option<String>,
48 value: Option<String>,
49 placeholder: Option<String>,
50 allowed_hosts: Vec<HostPattern>,
51 injection: SecretInjection,
52 on_violation: Option<ViolationAction>,
53 require_tls_identity: bool,
54}
55
56#[derive(Default)]
58pub struct ViolationActionBuilder {
59 action: ViolationAction,
60}
61
62impl NetworkBuilder {
67 pub fn new() -> Self {
69 Self {
70 config: NetworkConfig::default(),
71 errors: Vec::new(),
72 }
73 }
74
75 pub fn from_config(config: NetworkConfig) -> Self {
77 Self {
78 config,
79 errors: Vec::new(),
80 }
81 }
82
83 pub fn enabled(mut self, enabled: bool) -> Self {
85 self.config.enabled = enabled;
86 self
87 }
88
89 pub fn port(self, host_port: u16, guest_port: u16) -> Self {
91 self.port_bind(
92 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
93 host_port,
94 guest_port,
95 )
96 }
97
98 pub fn port_udp(self, host_port: u16, guest_port: u16) -> Self {
100 self.port_udp_bind(
101 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
102 host_port,
103 guest_port,
104 )
105 }
106
107 pub fn port_bind(self, host_bind: IpAddr, host_port: u16, guest_port: u16) -> Self {
109 self.add_port(host_bind, host_port, guest_port, PortProtocol::Tcp)
110 }
111
112 pub fn port_udp_bind(self, host_bind: IpAddr, host_port: u16, guest_port: u16) -> Self {
114 self.add_port(host_bind, host_port, guest_port, PortProtocol::Udp)
115 }
116
117 fn add_port(
118 mut self,
119 host_bind: IpAddr,
120 host_port: u16,
121 guest_port: u16,
122 protocol: PortProtocol,
123 ) -> Self {
124 self.config.ports.push(PublishedPort {
125 host_port,
126 guest_port,
127 protocol,
128 host_bind,
129 });
130 self
131 }
132
133 pub fn policy(mut self, policy: NetworkPolicy) -> Self {
135 self.config.policy = policy;
136 self
137 }
138
139 pub fn dns(mut self, f: impl FnOnce(DnsBuilder) -> DnsBuilder) -> Self {
148 self.config.dns = f(DnsBuilder::new()).build();
149 self
150 }
151
152 pub fn tls(mut self, f: impl FnOnce(TlsBuilder) -> TlsBuilder) -> Self {
154 self.config.tls = f(TlsBuilder::new()).build();
155 self
156 }
157
158 pub fn secret(self, f: impl FnOnce(SecretBuilder) -> SecretBuilder) -> Self {
168 self.secret_entry(f(SecretBuilder::new()).build())
169 }
170
171 pub fn secret_entry(mut self, entry: SecretEntry) -> Self {
173 self.config.secrets.secrets.push(entry);
174 self
175 }
176
177 pub fn secret_env(
179 mut self,
180 env_var: impl Into<String>,
181 value: impl Into<String>,
182 placeholder: impl Into<String>,
183 allowed_host: impl Into<String>,
184 ) -> Self {
185 self.config.secrets.secrets.push(SecretEntry {
186 env_var: env_var.into(),
187 value: value.into(),
188 placeholder: placeholder.into(),
189 allowed_hosts: vec![HostPattern::Exact(allowed_host.into())],
190 injection: SecretInjection::default(),
191 on_violation: None,
192 require_tls_identity: true,
193 });
194 self
195 }
196
197 pub fn on_secret_violation(
199 mut self,
200 f: impl FnOnce(ViolationActionBuilder) -> ViolationActionBuilder,
201 ) -> Self {
202 self.config.secrets.on_violation = f(ViolationActionBuilder::default()).build();
203 self
204 }
205
206 pub fn max_connections(mut self, max: usize) -> Self {
208 self.config.max_connections = Some(max);
209 self
210 }
211
212 pub fn interface(mut self, overrides: InterfaceOverrides) -> Self {
214 self.config.interface = overrides;
215 self
216 }
217
218 pub fn ipv4_pool(mut self, pool: Ipv4Network) -> Self {
222 if pool.prefix() > 30 {
223 self.errors.push(BuildError::InvalidIpv4Pool {
224 raw: pool.to_string(),
225 });
226 } else {
227 self.config.interface.ipv4_pool = Some(pool);
228 }
229 self
230 }
231
232 pub fn ipv6_pool(mut self, pool: Ipv6Network) -> Self {
236 if pool.prefix() > 64 {
237 self.errors.push(BuildError::InvalidIpv6Pool {
238 raw: pool.to_string(),
239 });
240 } else {
241 self.config.interface.ipv6_pool = Some(pool);
242 }
243 self
244 }
245
246 pub fn trust_host_cas(mut self, enabled: bool) -> Self {
252 self.config.trust_host_cas = enabled;
253 self
254 }
255
256 pub fn build(mut self) -> Result<NetworkConfig, BuildError> {
262 if let Some(err) = self.errors.drain(..).next() {
263 return Err(err);
264 }
265 self.config.secrets.validate()?;
266 Ok(self.config)
267 }
268}
269
270impl DnsBuilder {
271 pub fn new() -> Self {
273 Self {
274 config: DnsConfig::default(),
275 }
276 }
277
278 pub fn rebind_protection(mut self, enabled: bool) -> Self {
280 self.config.rebind_protection = enabled;
281 self
282 }
283
284 pub fn nameservers<I>(mut self, nameservers: I) -> Self
291 where
292 I: IntoIterator,
293 I::Item: Into<Nameserver>,
294 {
295 self.config.nameservers = nameservers.into_iter().map(Into::into).collect();
296 self
297 }
298
299 pub fn query_timeout_ms(mut self, ms: u64) -> Self {
301 self.config.query_timeout_ms = ms;
302 self
303 }
304
305 pub fn build(self) -> DnsConfig {
307 self.config
308 }
309}
310
311impl Default for DnsBuilder {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317impl TlsBuilder {
318 pub fn new() -> Self {
320 Self {
321 config: TlsConfig {
322 enabled: true,
323 ..TlsConfig::default()
324 },
325 }
326 }
327
328 pub fn bypass(mut self, pattern: impl Into<String>) -> Self {
330 self.config.bypass.push(pattern.into());
331 self
332 }
333
334 pub fn verify_upstream(mut self, verify: bool) -> Self {
336 self.config.verify_upstream = verify;
337 self
338 }
339
340 pub fn intercepted_ports(mut self, ports: Vec<u16>) -> Self {
342 self.config.intercepted_ports = ports;
343 self
344 }
345
346 pub fn block_quic(mut self, block: bool) -> Self {
348 self.config.block_quic_on_intercept = block;
349 self
350 }
351
352 pub fn upstream_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
357 self.config.upstream_ca_cert.push(path.into());
358 self
359 }
360
361 pub fn intercept_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
363 self.config.intercept_ca.cert_path = Some(path.into());
364 self
365 }
366
367 pub fn intercept_ca_key(mut self, path: impl Into<PathBuf>) -> Self {
369 self.config.intercept_ca.key_path = Some(path.into());
370 self
371 }
372
373 pub fn build(self) -> TlsConfig {
375 self.config
376 }
377}
378
379impl SecretBuilder {
380 pub fn new() -> Self {
382 Self {
383 env_var: None,
384 value: None,
385 placeholder: None,
386 allowed_hosts: Vec::new(),
387 injection: SecretInjection::default(),
388 on_violation: None,
389 require_tls_identity: true,
390 }
391 }
392
393 pub fn env(mut self, var: impl Into<String>) -> Self {
398 self.env_var = Some(var.into());
399 self
400 }
401
402 pub fn value(mut self, value: impl Into<String>) -> Self {
404 self.value = Some(value.into());
405 self
406 }
407
408 pub fn placeholder(mut self, placeholder: impl Into<String>) -> Self {
414 self.placeholder = Some(placeholder.into());
415 self
416 }
417
418 pub fn allow_host(mut self, host: impl Into<String>) -> Self {
420 self.allowed_hosts.push(HostPattern::Exact(host.into()));
421 self
422 }
423
424 pub fn allow_host_pattern(mut self, pattern: impl Into<String>) -> Self {
426 self.allowed_hosts
427 .push(HostPattern::Wildcard(pattern.into()));
428 self
429 }
430
431 pub fn allow_any_host_dangerous(mut self, i_understand_the_risk: bool) -> Self {
434 if i_understand_the_risk {
435 self.allowed_hosts.push(HostPattern::Any);
436 }
437 self
438 }
439
440 pub fn on_violation(
442 mut self,
443 f: impl FnOnce(ViolationActionBuilder) -> ViolationActionBuilder,
444 ) -> Self {
445 self.on_violation = Some(f(ViolationActionBuilder::default()).build());
446 self
447 }
448
449 pub fn require_tls_identity(mut self, enabled: bool) -> Self {
451 self.require_tls_identity = enabled;
452 self
453 }
454
455 pub fn inject_headers(mut self, enabled: bool) -> Self {
457 self.injection.headers = enabled;
458 self
459 }
460
461 pub fn inject_basic_auth(mut self, enabled: bool) -> Self {
463 self.injection.basic_auth = enabled;
464 self
465 }
466
467 pub fn inject_query(mut self, enabled: bool) -> Self {
469 self.injection.query_params = enabled;
470 self
471 }
472
473 pub fn inject_body(mut self, enabled: bool) -> Self {
480 self.injection.body = enabled;
481 self
482 }
483
484 pub fn build(self) -> SecretEntry {
489 let env_var = self.env_var.expect("SecretBuilder: .env() is required");
490 let value = self.value.expect("SecretBuilder: .value() is required");
491 assert!(
492 !self.allowed_hosts.is_empty(),
493 "SecretBuilder: at least one allowed host is required; use .allow_any_host_dangerous(true) for an explicit any-host secret"
494 );
495 let placeholder = self
496 .placeholder
497 .unwrap_or_else(|| format!("$MSB_{env_var}"));
498
499 SecretEntry {
500 env_var,
501 value,
502 placeholder,
503 allowed_hosts: self.allowed_hosts,
504 injection: self.injection,
505 on_violation: self.on_violation,
506 require_tls_identity: self.require_tls_identity,
507 }
508 }
509}
510
511impl ViolationActionBuilder {
512 pub fn new() -> Self {
514 Self::default()
515 }
516
517 pub fn from_action(action: ViolationAction) -> Self {
519 action.into()
520 }
521
522 pub fn block(mut self) -> Self {
524 self.action = ViolationAction::Block;
525 self
526 }
527
528 pub fn block_and_log(mut self) -> Self {
530 self.action = ViolationAction::BlockAndLog;
531 self
532 }
533
534 pub fn block_and_terminate(mut self) -> Self {
536 self.action = ViolationAction::BlockAndTerminate;
537 self
538 }
539
540 pub fn passthrough_host(mut self, host: impl Into<String>) -> Self {
542 self.push_passthrough_host(HostPattern::Exact(host.into()));
543 self
544 }
545
546 pub fn passthrough_host_pattern(mut self, pattern: impl Into<String>) -> Self {
548 self.push_passthrough_host(HostPattern::Wildcard(pattern.into()));
549 self
550 }
551
552 pub fn passthrough_all_hosts(mut self, i_understand_the_risk: bool) -> Self {
554 if i_understand_the_risk {
555 self.push_passthrough_host(HostPattern::Any);
556 }
557 self
558 }
559
560 fn push_passthrough_host(&mut self, host: HostPattern) {
562 match self.action {
563 ViolationAction::Passthrough(ref mut hosts) => hosts.push(host),
564 _ => self.action = ViolationAction::Passthrough(vec![host]),
565 }
566 }
567
568 pub fn build(self) -> ViolationAction {
570 self.action
571 }
572}
573
574impl Default for NetworkBuilder {
579 fn default() -> Self {
580 Self::new()
581 }
582}
583
584impl Default for TlsBuilder {
585 fn default() -> Self {
586 Self::new()
587 }
588}
589
590impl Default for SecretBuilder {
591 fn default() -> Self {
592 Self::new()
593 }
594}
595impl From<ViolationAction> for ViolationActionBuilder {
596 fn from(action: ViolationAction) -> Self {
597 Self { action }
598 }
599}
600
601#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
611 fn network_builder_happy_path_returns_config() {
612 let cfg = NetworkBuilder::new()
613 .dns(|d| d.rebind_protection(false))
614 .build()
615 .unwrap();
616 assert!(!cfg.dns.rebind_protection);
617 }
618
619 #[test]
620 fn port_bind_sets_host_bind() {
621 let bind = "0.0.0.0".parse().unwrap();
622 let cfg = NetworkBuilder::new()
623 .port_bind(bind, 8080, 80)
624 .port_udp_bind(bind, 5353, 53)
625 .build()
626 .unwrap();
627
628 assert_eq!(cfg.ports[0].host_bind, bind);
629 assert_eq!(cfg.ports[0].host_port, 8080);
630 assert_eq!(cfg.ports[0].guest_port, 80);
631 assert_eq!(cfg.ports[0].protocol, PortProtocol::Tcp);
632 assert_eq!(cfg.ports[1].host_bind, bind);
633 assert_eq!(cfg.ports[1].protocol, PortProtocol::Udp);
634 }
635
636 #[test]
637 fn network_builder_sets_global_passthrough_action() {
638 let cfg = NetworkBuilder::new()
639 .on_secret_violation(|v| {
640 v.passthrough_host("api.anthropic.com")
641 .passthrough_host_pattern("*.anthropic.com")
642 })
643 .build()
644 .unwrap();
645
646 assert_eq!(
647 cfg.secrets.on_violation,
648 ViolationAction::Passthrough(vec![
649 HostPattern::Exact("api.anthropic.com".into()),
650 HostPattern::Wildcard("*.anthropic.com".into()),
651 ])
652 );
653 }
654
655 #[test]
656 fn secret_builder_sets_violation_action() {
657 let secret = SecretBuilder::new()
658 .env("TOKEN")
659 .value("secret-value")
660 .allow_host("api.github.com")
661 .on_violation(|v| {
662 v.passthrough_host("api.anthropic.com")
663 .passthrough_host_pattern("*.anthropic.com")
664 })
665 .build();
666
667 assert_eq!(
668 secret.on_violation,
669 Some(ViolationAction::Passthrough(vec![
670 HostPattern::Exact("api.anthropic.com".into()),
671 HostPattern::Wildcard("*.anthropic.com".into()),
672 ])),
673 );
674 }
675
676 #[test]
677 #[should_panic(expected = "SecretBuilder: at least one allowed host is required")]
678 fn secret_builder_rejects_empty_allowed_hosts() {
679 let _ = SecretBuilder::new()
680 .env("TOKEN")
681 .value("secret-value")
682 .build();
683 }
684
685 #[test]
686 fn network_builder_rejects_invalid_secret_config() {
687 let err = NetworkBuilder::new()
688 .secret_entry(SecretEntry {
689 env_var: "API=KEY".into(),
690 value: "secret-value".into(),
691 placeholder: "$MSB_API_KEY".into(),
692 allowed_hosts: vec![HostPattern::Exact("api.example.com".into())],
693 injection: SecretInjection::default(),
694 on_violation: None,
695 require_tls_identity: true,
696 })
697 .build()
698 .unwrap_err();
699
700 assert!(err.to_string().contains("env_var must not contain `=`"));
701 }
702
703 #[test]
704 fn violation_action_builder_blocking_call_replaces_passthrough_policy() {
705 let action = ViolationActionBuilder::default()
706 .passthrough_host("google.com")
707 .block_and_terminate()
708 .passthrough_host("facebook.com")
709 .build();
710
711 assert_eq!(
712 action,
713 ViolationAction::Passthrough(vec![HostPattern::Exact("facebook.com".into())])
714 );
715 }
716
717 #[test]
718 fn violation_action_builder_accumulates_passthrough_hosts() {
719 let action = ViolationActionBuilder::default()
720 .block()
721 .passthrough_host("google.com")
722 .passthrough_host("facebook.com")
723 .build();
724
725 assert_eq!(
726 action,
727 ViolationAction::Passthrough(vec![
728 HostPattern::Exact("google.com".into()),
729 HostPattern::Exact("facebook.com".into()),
730 ]),
731 );
732 }
733}