1use std::net::IpAddr;
6use std::path::PathBuf;
7
8use ipnetwork::{Ipv4Network, Ipv6Network};
9
10use crate::config::{DnsConfig, InterfaceOverrides, NetworkConfig, PortProtocol, PublishedPort};
11use crate::dns::Nameserver;
12use crate::policy::{BuildError, NetworkPolicy};
13use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, ViolationAction};
14use crate::tls::TlsConfig;
15
16#[derive(Clone)]
22pub struct NetworkBuilder {
23 config: NetworkConfig,
24 errors: Vec<BuildError>,
25}
26
27pub struct DnsBuilder {
29 config: DnsConfig,
30}
31
32pub struct TlsBuilder {
34 config: TlsConfig,
35}
36
37pub struct SecretBuilder {
47 env_var: Option<String>,
48 value: Option<String>,
49 placeholder: Option<String>,
50 allowed_hosts: Vec<HostPattern>,
51 injection: SecretInjection,
52 on_violation: Option<ViolationAction>,
53 require_tls_identity: bool,
54}
55
56#[derive(Default)]
58pub struct ViolationActionBuilder {
59 action: ViolationAction,
60}
61
62impl NetworkBuilder {
67 pub fn new() -> Self {
69 Self {
70 config: NetworkConfig::default(),
71 errors: Vec::new(),
72 }
73 }
74
75 pub fn from_config(config: NetworkConfig) -> Self {
77 Self {
78 config,
79 errors: Vec::new(),
80 }
81 }
82
83 pub fn enabled(mut self, enabled: bool) -> Self {
85 self.config.enabled = enabled;
86 self
87 }
88
89 pub fn port(self, host_port: u16, guest_port: u16) -> Self {
91 self.port_bind(
92 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
93 host_port,
94 guest_port,
95 )
96 }
97
98 pub fn port_udp(self, host_port: u16, guest_port: u16) -> Self {
100 self.port_udp_bind(
101 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
102 host_port,
103 guest_port,
104 )
105 }
106
107 pub fn port_bind(self, host_bind: IpAddr, host_port: u16, guest_port: u16) -> Self {
109 self.add_port(host_bind, host_port, guest_port, PortProtocol::Tcp)
110 }
111
112 pub fn port_udp_bind(self, host_bind: IpAddr, host_port: u16, guest_port: u16) -> Self {
114 self.add_port(host_bind, host_port, guest_port, PortProtocol::Udp)
115 }
116
117 fn add_port(
118 mut self,
119 host_bind: IpAddr,
120 host_port: u16,
121 guest_port: u16,
122 protocol: PortProtocol,
123 ) -> Self {
124 self.config.ports.push(PublishedPort {
125 host_port,
126 guest_port,
127 protocol,
128 host_bind,
129 });
130 self
131 }
132
133 pub fn policy(mut self, policy: NetworkPolicy) -> Self {
135 self.config.policy = policy;
136 self
137 }
138
139 pub fn dns(mut self, f: impl FnOnce(DnsBuilder) -> DnsBuilder) -> Self {
148 self.config.dns = f(DnsBuilder::new()).build();
149 self
150 }
151
152 pub fn tls(mut self, f: impl FnOnce(TlsBuilder) -> TlsBuilder) -> Self {
154 self.config.tls = f(TlsBuilder::new()).build();
155 self
156 }
157
158 pub fn secret(mut self, f: impl FnOnce(SecretBuilder) -> SecretBuilder) -> Self {
168 self.config
169 .secrets
170 .secrets
171 .push(f(SecretBuilder::new()).build());
172 self
173 }
174
175 pub fn secret_env(
177 mut self,
178 env_var: impl Into<String>,
179 value: impl Into<String>,
180 placeholder: impl Into<String>,
181 allowed_host: impl Into<String>,
182 ) -> Self {
183 self.config.secrets.secrets.push(SecretEntry {
184 env_var: env_var.into(),
185 value: value.into(),
186 placeholder: placeholder.into(),
187 allowed_hosts: vec![HostPattern::Exact(allowed_host.into())],
188 injection: SecretInjection::default(),
189 on_violation: None,
190 require_tls_identity: true,
191 });
192 self
193 }
194
195 pub fn on_secret_violation(
197 mut self,
198 f: impl FnOnce(ViolationActionBuilder) -> ViolationActionBuilder,
199 ) -> Self {
200 self.config.secrets.on_violation = f(ViolationActionBuilder::default()).build();
201 self
202 }
203
204 pub fn max_connections(mut self, max: usize) -> Self {
206 self.config.max_connections = Some(max);
207 self
208 }
209
210 pub fn interface(mut self, overrides: InterfaceOverrides) -> Self {
212 self.config.interface = overrides;
213 self
214 }
215
216 pub fn ipv4_pool(mut self, pool: Ipv4Network) -> Self {
220 if pool.prefix() > 30 {
221 self.errors.push(BuildError::InvalidIpv4Pool {
222 raw: pool.to_string(),
223 });
224 } else {
225 self.config.interface.ipv4_pool = Some(pool);
226 }
227 self
228 }
229
230 pub fn ipv6_pool(mut self, pool: Ipv6Network) -> Self {
234 if pool.prefix() > 64 {
235 self.errors.push(BuildError::InvalidIpv6Pool {
236 raw: pool.to_string(),
237 });
238 } else {
239 self.config.interface.ipv6_pool = Some(pool);
240 }
241 self
242 }
243
244 pub fn trust_host_cas(mut self, enabled: bool) -> Self {
250 self.config.trust_host_cas = enabled;
251 self
252 }
253
254 pub fn build(mut self) -> Result<NetworkConfig, BuildError> {
260 if let Some(err) = self.errors.drain(..).next() {
261 return Err(err);
262 }
263 Ok(self.config)
264 }
265}
266
267impl DnsBuilder {
268 pub fn new() -> Self {
270 Self {
271 config: DnsConfig::default(),
272 }
273 }
274
275 pub fn rebind_protection(mut self, enabled: bool) -> Self {
277 self.config.rebind_protection = enabled;
278 self
279 }
280
281 pub fn nameservers<I>(mut self, nameservers: I) -> Self
288 where
289 I: IntoIterator,
290 I::Item: Into<Nameserver>,
291 {
292 self.config.nameservers = nameservers.into_iter().map(Into::into).collect();
293 self
294 }
295
296 pub fn query_timeout_ms(mut self, ms: u64) -> Self {
298 self.config.query_timeout_ms = ms;
299 self
300 }
301
302 pub fn build(self) -> DnsConfig {
304 self.config
305 }
306}
307
308impl Default for DnsBuilder {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314impl TlsBuilder {
315 pub fn new() -> Self {
317 Self {
318 config: TlsConfig {
319 enabled: true,
320 ..TlsConfig::default()
321 },
322 }
323 }
324
325 pub fn bypass(mut self, pattern: impl Into<String>) -> Self {
327 self.config.bypass.push(pattern.into());
328 self
329 }
330
331 pub fn verify_upstream(mut self, verify: bool) -> Self {
333 self.config.verify_upstream = verify;
334 self
335 }
336
337 pub fn intercepted_ports(mut self, ports: Vec<u16>) -> Self {
339 self.config.intercepted_ports = ports;
340 self
341 }
342
343 pub fn block_quic(mut self, block: bool) -> Self {
345 self.config.block_quic_on_intercept = block;
346 self
347 }
348
349 pub fn upstream_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
354 self.config.upstream_ca_cert.push(path.into());
355 self
356 }
357
358 pub fn intercept_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
360 self.config.intercept_ca.cert_path = Some(path.into());
361 self
362 }
363
364 pub fn intercept_ca_key(mut self, path: impl Into<PathBuf>) -> Self {
366 self.config.intercept_ca.key_path = Some(path.into());
367 self
368 }
369
370 pub fn build(self) -> TlsConfig {
372 self.config
373 }
374}
375
376impl SecretBuilder {
377 pub fn new() -> Self {
379 Self {
380 env_var: None,
381 value: None,
382 placeholder: None,
383 allowed_hosts: Vec::new(),
384 injection: SecretInjection::default(),
385 on_violation: None,
386 require_tls_identity: true,
387 }
388 }
389
390 pub fn env(mut self, var: impl Into<String>) -> Self {
392 self.env_var = Some(var.into());
393 self
394 }
395
396 pub fn value(mut self, value: impl Into<String>) -> Self {
398 self.value = Some(value.into());
399 self
400 }
401
402 pub fn placeholder(mut self, placeholder: impl Into<String>) -> Self {
405 self.placeholder = Some(placeholder.into());
406 self
407 }
408
409 pub fn allow_host(mut self, host: impl Into<String>) -> Self {
411 self.allowed_hosts.push(HostPattern::Exact(host.into()));
412 self
413 }
414
415 pub fn allow_host_pattern(mut self, pattern: impl Into<String>) -> Self {
417 self.allowed_hosts
418 .push(HostPattern::Wildcard(pattern.into()));
419 self
420 }
421
422 pub fn allow_any_host_dangerous(mut self, i_understand_the_risk: bool) -> Self {
425 if i_understand_the_risk {
426 self.allowed_hosts.push(HostPattern::Any);
427 }
428 self
429 }
430
431 pub fn on_violation(
433 mut self,
434 f: impl FnOnce(ViolationActionBuilder) -> ViolationActionBuilder,
435 ) -> Self {
436 self.on_violation = Some(f(ViolationActionBuilder::default()).build());
437 self
438 }
439
440 pub fn require_tls_identity(mut self, enabled: bool) -> Self {
442 self.require_tls_identity = enabled;
443 self
444 }
445
446 pub fn inject_headers(mut self, enabled: bool) -> Self {
448 self.injection.headers = enabled;
449 self
450 }
451
452 pub fn inject_basic_auth(mut self, enabled: bool) -> Self {
454 self.injection.basic_auth = enabled;
455 self
456 }
457
458 pub fn inject_query(mut self, enabled: bool) -> Self {
460 self.injection.query_params = enabled;
461 self
462 }
463
464 pub fn inject_body(mut self, enabled: bool) -> Self {
466 self.injection.body = enabled;
467 self
468 }
469
470 pub fn build(self) -> SecretEntry {
475 let env_var = self.env_var.expect("SecretBuilder: .env() is required");
476 let value = self.value.expect("SecretBuilder: .value() is required");
477 let placeholder = self
478 .placeholder
479 .unwrap_or_else(|| format!("$MSB_{env_var}"));
480
481 SecretEntry {
482 env_var,
483 value,
484 placeholder,
485 allowed_hosts: self.allowed_hosts,
486 injection: self.injection,
487 on_violation: self.on_violation,
488 require_tls_identity: self.require_tls_identity,
489 }
490 }
491}
492
493impl ViolationActionBuilder {
494 pub fn new() -> Self {
496 Self::default()
497 }
498
499 pub fn from_action(action: ViolationAction) -> Self {
501 action.into()
502 }
503
504 pub fn block(mut self) -> Self {
506 self.action = ViolationAction::Block;
507 self
508 }
509
510 pub fn block_and_log(mut self) -> Self {
512 self.action = ViolationAction::BlockAndLog;
513 self
514 }
515
516 pub fn block_and_terminate(mut self) -> Self {
518 self.action = ViolationAction::BlockAndTerminate;
519 self
520 }
521
522 pub fn passthrough_host(mut self, host: impl Into<String>) -> Self {
524 self.push_passthrough_host(HostPattern::Exact(host.into()));
525 self
526 }
527
528 pub fn passthrough_host_pattern(mut self, pattern: impl Into<String>) -> Self {
530 self.push_passthrough_host(HostPattern::Wildcard(pattern.into()));
531 self
532 }
533
534 pub fn passthrough_all_hosts(mut self, i_understand_the_risk: bool) -> Self {
536 if i_understand_the_risk {
537 self.push_passthrough_host(HostPattern::Any);
538 }
539 self
540 }
541
542 fn push_passthrough_host(&mut self, host: HostPattern) {
544 match self.action {
545 ViolationAction::Passthrough(ref mut hosts) => hosts.push(host),
546 _ => self.action = ViolationAction::Passthrough(vec![host]),
547 }
548 }
549
550 pub fn build(self) -> ViolationAction {
552 self.action
553 }
554}
555
556impl Default for NetworkBuilder {
561 fn default() -> Self {
562 Self::new()
563 }
564}
565
566impl Default for TlsBuilder {
567 fn default() -> Self {
568 Self::new()
569 }
570}
571
572impl Default for SecretBuilder {
573 fn default() -> Self {
574 Self::new()
575 }
576}
577impl From<ViolationAction> for ViolationActionBuilder {
578 fn from(action: ViolationAction) -> Self {
579 Self { action }
580 }
581}
582
583#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
593 fn network_builder_happy_path_returns_config() {
594 let cfg = NetworkBuilder::new()
595 .dns(|d| d.rebind_protection(false))
596 .build()
597 .unwrap();
598 assert!(!cfg.dns.rebind_protection);
599 }
600
601 #[test]
602 fn port_bind_sets_host_bind() {
603 let bind = "0.0.0.0".parse().unwrap();
604 let cfg = NetworkBuilder::new()
605 .port_bind(bind, 8080, 80)
606 .port_udp_bind(bind, 5353, 53)
607 .build()
608 .unwrap();
609
610 assert_eq!(cfg.ports[0].host_bind, bind);
611 assert_eq!(cfg.ports[0].host_port, 8080);
612 assert_eq!(cfg.ports[0].guest_port, 80);
613 assert_eq!(cfg.ports[0].protocol, PortProtocol::Tcp);
614 assert_eq!(cfg.ports[1].host_bind, bind);
615 assert_eq!(cfg.ports[1].protocol, PortProtocol::Udp);
616 }
617
618 #[test]
619 fn network_builder_sets_global_passthrough_action() {
620 let cfg = NetworkBuilder::new()
621 .on_secret_violation(|v| {
622 v.passthrough_host("api.anthropic.com")
623 .passthrough_host_pattern("*.anthropic.com")
624 })
625 .build()
626 .unwrap();
627
628 assert_eq!(
629 cfg.secrets.on_violation,
630 ViolationAction::Passthrough(vec![
631 HostPattern::Exact("api.anthropic.com".into()),
632 HostPattern::Wildcard("*.anthropic.com".into()),
633 ])
634 );
635 }
636
637 #[test]
638 fn secret_builder_sets_violation_action() {
639 let secret = SecretBuilder::new()
640 .env("TOKEN")
641 .value("secret-value")
642 .allow_host("api.github.com")
643 .on_violation(|v| {
644 v.passthrough_host("api.anthropic.com")
645 .passthrough_host_pattern("*.anthropic.com")
646 })
647 .build();
648
649 assert_eq!(
650 secret.on_violation,
651 Some(ViolationAction::Passthrough(vec![
652 HostPattern::Exact("api.anthropic.com".into()),
653 HostPattern::Wildcard("*.anthropic.com".into()),
654 ])),
655 );
656 }
657
658 #[test]
659 fn violation_action_builder_blocking_call_replaces_passthrough_policy() {
660 let action = ViolationActionBuilder::default()
661 .passthrough_host("google.com")
662 .block_and_terminate()
663 .passthrough_host("facebook.com")
664 .build();
665
666 assert_eq!(
667 action,
668 ViolationAction::Passthrough(vec![HostPattern::Exact("facebook.com".into())])
669 );
670 }
671
672 #[test]
673 fn violation_action_builder_accumulates_passthrough_hosts() {
674 let action = ViolationActionBuilder::default()
675 .block()
676 .passthrough_host("google.com")
677 .passthrough_host("facebook.com")
678 .build();
679
680 assert_eq!(
681 action,
682 ViolationAction::Passthrough(vec![
683 HostPattern::Exact("google.com".into()),
684 HostPattern::Exact("facebook.com".into()),
685 ]),
686 );
687 }
688}