1use crate::config::{CompiledEndpointRules, RouteConfig};
14use crate::error::{ProxyError, Result};
15use nono::undo::{NetworkAuditAuthMechanism, NetworkAuditInjectionMode};
16use rustls::pki_types::pem::PemObject;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tracing::debug;
20use zeroize::Zeroizing;
21
22pub struct LoadedRoute {
28 pub upstream: String,
30
31 pub upstream_host_port: Option<String>,
35
36 pub endpoint_rules: CompiledEndpointRules,
40
41 pub tls_connector: Option<tokio_rustls::TlsConnector>,
45
46 pub requires_intercept: bool,
52
53 pub requires_managed_credential: bool,
58
59 pub managed_auth_mechanism: Option<NetworkAuditAuthMechanism>,
63
64 pub managed_injection_mode: Option<NetworkAuditInjectionMode>,
66}
67
68impl std::fmt::Debug for LoadedRoute {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("LoadedRoute")
71 .field("upstream", &self.upstream)
72 .field("upstream_host_port", &self.upstream_host_port)
73 .field("endpoint_rules", &self.endpoint_rules)
74 .field("has_custom_tls_ca", &self.tls_connector.is_some())
75 .field("requires_intercept", &self.requires_intercept)
76 .field(
77 "requires_managed_credential",
78 &self.requires_managed_credential,
79 )
80 .field("managed_auth_mechanism", &self.managed_auth_mechanism)
81 .field("managed_injection_mode", &self.managed_injection_mode)
82 .finish()
83 }
84}
85
86fn auth_mechanism_for_route(route: &RouteConfig) -> Option<NetworkAuditAuthMechanism> {
87 if route.oauth2.is_some() {
88 return Some(NetworkAuditAuthMechanism::PhantomHeader);
89 }
90
91 if route.credential_key.is_some() {
92 let proxy_mode = route
93 .proxy
94 .as_ref()
95 .and_then(|p| p.inject_mode.clone())
96 .unwrap_or_else(|| route.inject_mode.clone());
97 return Some(match proxy_mode {
98 crate::config::InjectMode::Header | crate::config::InjectMode::BasicAuth => {
99 NetworkAuditAuthMechanism::PhantomHeader
100 }
101 crate::config::InjectMode::UrlPath => NetworkAuditAuthMechanism::PhantomPath,
102 crate::config::InjectMode::QueryParam => NetworkAuditAuthMechanism::PhantomQuery,
103 });
104 }
105
106 None
107}
108
109fn injection_mode_for_route(route: &RouteConfig) -> Option<NetworkAuditInjectionMode> {
110 if route.oauth2.is_some() {
111 return Some(NetworkAuditInjectionMode::OAuth2);
112 }
113
114 if route.credential_key.is_some() {
115 return Some(match route.inject_mode {
116 crate::config::InjectMode::Header => NetworkAuditInjectionMode::Header,
117 crate::config::InjectMode::UrlPath => NetworkAuditInjectionMode::UrlPath,
118 crate::config::InjectMode::QueryParam => NetworkAuditInjectionMode::QueryParam,
119 crate::config::InjectMode::BasicAuth => NetworkAuditInjectionMode::BasicAuth,
120 });
121 }
122
123 None
124}
125
126#[derive(Debug)]
132pub struct RouteStore {
133 routes: HashMap<String, LoadedRoute>,
134}
135
136impl RouteStore {
137 pub fn load(routes: &[RouteConfig]) -> Result<Self> {
143 let mut loaded = HashMap::new();
144
145 let base_root_store = build_base_root_store();
146
147 for route in routes {
148 let normalized_prefix = route.prefix.trim_matches('/').to_string();
149
150 debug!(
151 "Loading route '{}' -> {}",
152 normalized_prefix, route.upstream
153 );
154
155 let endpoint_rules = CompiledEndpointRules::compile(&route.endpoint_rules)
156 .map_err(|e| ProxyError::Config(format!("route '{}': {}", normalized_prefix, e)))?;
157
158 let tls_connector = if route.tls_ca.is_some()
159 || route.tls_client_cert.is_some()
160 || route.tls_client_key.is_some()
161 {
162 debug!(
163 "Building TLS connector for route '{}' (ca={}, client_cert={})",
164 normalized_prefix,
165 route.tls_ca.is_some(),
166 route.tls_client_cert.is_some(),
167 );
168 Some(build_tls_connector(
169 &base_root_store,
170 route.tls_ca.as_deref(),
171 route.tls_client_cert.as_deref(),
172 route.tls_client_key.as_deref(),
173 )?)
174 } else {
175 None
176 };
177
178 let upstream_host_port = extract_host_port(&route.upstream);
179
180 let requires_managed_credential =
187 route.credential_key.is_some() || route.oauth2.is_some();
188 let requires_intercept =
189 requires_managed_credential || !route.endpoint_rules.is_empty();
190 let managed_auth_mechanism = auth_mechanism_for_route(route);
191 let managed_injection_mode = injection_mode_for_route(route);
192
193 loaded.insert(
194 normalized_prefix,
195 LoadedRoute {
196 upstream: route.upstream.clone(),
197 upstream_host_port,
198 endpoint_rules,
199 tls_connector,
200 requires_intercept,
201 requires_managed_credential,
202 managed_auth_mechanism,
203 managed_injection_mode,
204 },
205 );
206 }
207
208 Ok(Self { routes: loaded })
209 }
210
211 #[must_use]
213 pub fn empty() -> Self {
214 Self {
215 routes: HashMap::new(),
216 }
217 }
218
219 #[must_use]
221 pub fn get(&self, prefix: &str) -> Option<&LoadedRoute> {
222 self.routes.get(prefix)
223 }
224
225 #[must_use]
227 pub fn is_empty(&self) -> bool {
228 self.routes.is_empty()
229 }
230
231 #[must_use]
233 pub fn len(&self) -> usize {
234 self.routes.len()
235 }
236
237 #[must_use]
241 pub fn is_route_upstream(&self, host_port: &str) -> bool {
242 let normalised = host_port.to_lowercase();
243 self.routes.values().any(|route| {
244 route
245 .upstream_host_port
246 .as_ref()
247 .is_some_and(|hp| *hp == normalised)
248 })
249 }
250
251 #[must_use]
256 pub fn lookup_by_upstream(&self, host_port: &str) -> Option<(&str, &LoadedRoute)> {
257 let normalised = host_port.to_lowercase();
258 self.routes.iter().find_map(|(prefix, route)| {
259 route
260 .upstream_host_port
261 .as_ref()
262 .filter(|hp| **hp == normalised)
263 .map(|_| (prefix.as_str(), route))
264 })
265 }
266
267 #[must_use]
270 pub fn lookup_all_by_upstream(&self, host_port: &str) -> Vec<(&str, &LoadedRoute)> {
271 let normalised = host_port.to_lowercase();
272 let mut matches: Vec<_> = self
273 .routes
274 .iter()
275 .filter(|(_, route)| {
276 route
277 .upstream_host_port
278 .as_ref()
279 .is_some_and(|hp| *hp == normalised)
280 })
281 .map(|(prefix, route)| (prefix.as_str(), route))
282 .collect();
283 matches.sort_by_key(|(prefix, _)| *prefix);
284 matches
285 }
286
287 #[must_use]
289 pub fn has_intercept_route(&self, host_port: &str) -> bool {
290 let normalised = host_port.to_lowercase();
291 self.routes.values().any(|route| {
292 route
293 .upstream_host_port
294 .as_ref()
295 .is_some_and(|hp| *hp == normalised)
296 && route.requires_intercept
297 })
298 }
299
300 #[must_use]
302 pub fn route_upstream_hosts(&self) -> std::collections::HashSet<String> {
303 self.routes
304 .values()
305 .filter_map(|route| route.upstream_host_port.clone())
306 .collect()
307 }
308}
309
310impl LoadedRoute {
311 #[must_use]
314 pub fn missing_managed_credential(
315 &self,
316 has_static_credential: bool,
317 has_oauth2: bool,
318 ) -> bool {
319 self.requires_managed_credential && !has_static_credential && !has_oauth2
320 }
321}
322
323fn extract_host_port(url: &str) -> Option<String> {
328 let parsed = url::Url::parse(url).ok()?;
329 let host = parsed.host_str()?;
330 let default_port = match parsed.scheme() {
331 "https" => 443,
332 "http" => 80,
333 _ => return None,
334 };
335 let port = parsed.port().unwrap_or(default_port);
336 Some(format!("{}:{}", host.to_lowercase(), port))
337}
338
339fn read_pem_file(path: &std::path::Path, label: &str) -> Result<Zeroizing<Vec<u8>>> {
346 std::fs::read(path)
347 .map(Zeroizing::new)
348 .map_err(|e| match e.kind() {
349 std::io::ErrorKind::NotFound => {
350 ProxyError::Config(format!("{} file not found: '{}'", label, path.display()))
351 }
352 std::io::ErrorKind::PermissionDenied => ProxyError::Config(format!(
353 "{} permission denied: '{}' (check that nono can read this file)",
354 label,
355 path.display()
356 )),
357 _ => ProxyError::Config(format!(
358 "failed to read {} '{}': {}",
359 label,
360 path.display(),
361 e
362 )),
363 })
364}
365
366fn build_base_root_store() -> rustls::RootCertStore {
370 let mut store = rustls::RootCertStore::empty();
371 store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
372 let native = rustls_native_certs::load_native_certs();
373 for cert in native.certs {
374 if let Err(e) = store.add(cert) {
375 debug!("skipping unparseable native cert: {e}");
376 }
377 }
378 store
379}
380
381fn build_tls_connector(
384 base_root_store: &rustls::RootCertStore,
385 ca_path: Option<&str>,
386 client_cert_path: Option<&str>,
387 client_key_path: Option<&str>,
388) -> Result<tokio_rustls::TlsConnector> {
389 let mut root_store = base_root_store.clone();
390
391 if let Some(ca_path) = ca_path {
393 let ca_path = std::path::Path::new(ca_path);
394 let ca_pem = read_pem_file(ca_path, "CA certificate")?;
395
396 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_slice_iter(ca_pem.as_ref())
397 .collect::<std::result::Result<Vec<_>, _>>()
398 .map_err(|e| {
399 ProxyError::Config(format!(
400 "failed to parse CA certificate '{}': {}",
401 ca_path.display(),
402 e
403 ))
404 })?;
405
406 if certs.is_empty() {
407 return Err(ProxyError::Config(format!(
408 "CA certificate file '{}' contains no valid PEM certificates",
409 ca_path.display()
410 )));
411 }
412
413 for cert in certs {
414 root_store.add(cert).map_err(|e| {
415 ProxyError::Config(format!(
416 "invalid CA certificate in '{}': {}",
417 ca_path.display(),
418 e
419 ))
420 })?;
421 }
422 }
423
424 let builder = rustls::ClientConfig::builder_with_provider(Arc::new(
425 rustls::crypto::ring::default_provider(),
426 ))
427 .with_safe_default_protocol_versions()
428 .map_err(|e| ProxyError::Config(format!("TLS config error: {}", e)))?
429 .with_root_certificates(root_store);
430
431 let tls_config = match (client_cert_path, client_key_path) {
433 (Some(cert_path), Some(key_path)) => {
434 let cert_path = std::path::Path::new(cert_path);
435 let key_path = std::path::Path::new(key_path);
436
437 let cert_pem = read_pem_file(cert_path, "client certificate")?;
438 let key_pem = read_pem_file(key_path, "client key")?;
439
440 let cert_chain: Vec<rustls::pki_types::CertificateDer> =
441 rustls::pki_types::CertificateDer::pem_slice_iter(cert_pem.as_ref())
442 .collect::<std::result::Result<Vec<_>, _>>()
443 .map_err(|e| {
444 ProxyError::Config(format!(
445 "failed to parse client certificate '{}': {}",
446 cert_path.display(),
447 e
448 ))
449 })?;
450
451 if cert_chain.is_empty() {
452 return Err(ProxyError::Config(format!(
453 "client certificate file '{}' contains no valid PEM certificates",
454 cert_path.display()
455 )));
456 }
457
458 let private_key = rustls::pki_types::PrivateKeyDer::from_pem_slice(key_pem.as_ref())
459 .map_err(|e| match e {
460 rustls::pki_types::pem::Error::NoItemsFound => ProxyError::Config(format!(
461 "client key file '{}' contains no valid PEM private key",
462 key_path.display()
463 )),
464 _ => ProxyError::Config(format!(
465 "failed to parse client key '{}': {}",
466 key_path.display(),
467 e
468 )),
469 })?;
470
471 builder
472 .with_client_auth_cert(cert_chain, private_key)
473 .map_err(|e| {
474 ProxyError::Config(format!(
475 "invalid client certificate/key pair ('{}', '{}'): {}",
476 cert_path.display(),
477 key_path.display(),
478 e
479 ))
480 })?
481 }
482 (Some(_), None) => {
483 return Err(ProxyError::Config(
484 "tls_client_cert is set but tls_client_key is missing".to_string(),
485 ));
486 }
487 (None, Some(_)) => {
488 return Err(ProxyError::Config(
489 "tls_client_key is set but tls_client_cert is missing".to_string(),
490 ));
491 }
492 (None, None) => builder.with_no_client_auth(),
493 };
494
495 let mut tls_config = tls_config;
504 if client_cert_path.is_some() {
505 tls_config.resumption = rustls::client::Resumption::disabled();
506 }
507
508 Ok(tokio_rustls::TlsConnector::from(Arc::new(tls_config)))
509}
510
511#[cfg(test)]
513fn build_tls_connector_with_ca(ca_path: &str) -> Result<tokio_rustls::TlsConnector> {
514 let base = build_base_root_store();
515 build_tls_connector(&base, Some(ca_path), None, None)
516}
517
518#[cfg(test)]
519#[allow(clippy::unwrap_used)]
520mod tests {
521 use super::*;
522 use crate::config::EndpointRule;
523
524 #[test]
525 fn test_empty_route_store() {
526 let store = RouteStore::empty();
527 assert!(store.is_empty());
528 assert_eq!(store.len(), 0);
529 assert!(store.get("openai").is_none());
530 }
531
532 #[test]
533 fn test_load_routes_without_credentials() {
534 let routes = vec![RouteConfig {
536 prefix: "/openai".to_string(),
537 upstream: "https://api.openai.com".to_string(),
538 credential_key: None,
539 inject_mode: Default::default(),
540 inject_header: "Authorization".to_string(),
541 credential_format: "Bearer {}".to_string(),
542 path_pattern: None,
543 path_replacement: None,
544 query_param_name: None,
545 proxy: None,
546 env_var: None,
547 endpoint_rules: vec![
548 EndpointRule {
549 method: "POST".to_string(),
550 path: "/v1/chat/completions".to_string(),
551 },
552 EndpointRule {
553 method: "GET".to_string(),
554 path: "/v1/models".to_string(),
555 },
556 ],
557 tls_ca: None,
558 tls_client_cert: None,
559 tls_client_key: None,
560 oauth2: None,
561 }];
562
563 let store = RouteStore::load(&routes).unwrap();
564 assert_eq!(store.len(), 1);
565
566 let route = store.get("openai").unwrap();
567 assert_eq!(route.upstream, "https://api.openai.com");
568 assert!(route
569 .endpoint_rules
570 .is_allowed("POST", "/v1/chat/completions"));
571 assert!(route.endpoint_rules.is_allowed("GET", "/v1/models"));
572 assert!(!route
573 .endpoint_rules
574 .is_allowed("DELETE", "/v1/files/file-123"));
575 }
576
577 #[test]
578 fn test_load_routes_normalises_prefix() {
579 let routes = vec![RouteConfig {
580 prefix: "/anthropic/".to_string(),
581 upstream: "https://api.anthropic.com".to_string(),
582 credential_key: None,
583 inject_mode: Default::default(),
584 inject_header: "Authorization".to_string(),
585 credential_format: "Bearer {}".to_string(),
586 path_pattern: None,
587 path_replacement: None,
588 query_param_name: None,
589 proxy: None,
590 env_var: None,
591 endpoint_rules: vec![],
592 tls_ca: None,
593 tls_client_cert: None,
594 tls_client_key: None,
595 oauth2: None,
596 }];
597
598 let store = RouteStore::load(&routes).unwrap();
599 assert!(store.get("anthropic").is_some());
600 assert!(store.get("/anthropic/").is_none());
601 }
602
603 #[test]
604 fn test_is_route_upstream() {
605 let routes = vec![RouteConfig {
606 prefix: "openai".to_string(),
607 upstream: "https://api.openai.com".to_string(),
608 credential_key: None,
609 inject_mode: Default::default(),
610 inject_header: "Authorization".to_string(),
611 credential_format: "Bearer {}".to_string(),
612 path_pattern: None,
613 path_replacement: None,
614 query_param_name: None,
615 proxy: None,
616 env_var: None,
617 endpoint_rules: vec![],
618 tls_ca: None,
619 tls_client_cert: None,
620 tls_client_key: None,
621 oauth2: None,
622 }];
623
624 let store = RouteStore::load(&routes).unwrap();
625 assert!(store.is_route_upstream("api.openai.com:443"));
626 assert!(!store.is_route_upstream("github.com:443"));
627 }
628
629 #[test]
630 fn test_route_upstream_hosts() {
631 let routes = vec![
632 RouteConfig {
633 prefix: "openai".to_string(),
634 upstream: "https://api.openai.com".to_string(),
635 credential_key: None,
636 inject_mode: Default::default(),
637 inject_header: "Authorization".to_string(),
638 credential_format: "Bearer {}".to_string(),
639 path_pattern: None,
640 path_replacement: None,
641 query_param_name: None,
642 proxy: None,
643 env_var: None,
644 endpoint_rules: vec![],
645 tls_ca: None,
646 tls_client_cert: None,
647 tls_client_key: None,
648 oauth2: None,
649 },
650 RouteConfig {
651 prefix: "anthropic".to_string(),
652 upstream: "https://api.anthropic.com".to_string(),
653 credential_key: None,
654 inject_mode: Default::default(),
655 inject_header: "Authorization".to_string(),
656 credential_format: "Bearer {}".to_string(),
657 path_pattern: None,
658 path_replacement: None,
659 query_param_name: None,
660 proxy: None,
661 env_var: None,
662 endpoint_rules: vec![],
663 tls_ca: None,
664 tls_client_cert: None,
665 tls_client_key: None,
666 oauth2: None,
667 },
668 ];
669
670 let store = RouteStore::load(&routes).unwrap();
671 let hosts = store.route_upstream_hosts();
672 assert!(hosts.contains("api.openai.com:443"));
673 assert!(hosts.contains("api.anthropic.com:443"));
674 assert_eq!(hosts.len(), 2);
675 }
676
677 #[test]
678 fn test_extract_host_port_https() {
679 assert_eq!(
680 extract_host_port("https://api.openai.com"),
681 Some("api.openai.com:443".to_string())
682 );
683 }
684
685 #[test]
686 fn test_extract_host_port_with_port() {
687 assert_eq!(
688 extract_host_port("https://api.example.com:8443"),
689 Some("api.example.com:8443".to_string())
690 );
691 }
692
693 #[test]
694 fn test_extract_host_port_http() {
695 assert_eq!(
696 extract_host_port("http://internal-service"),
697 Some("internal-service:80".to_string())
698 );
699 }
700
701 #[test]
702 fn test_extract_host_port_normalises_case() {
703 assert_eq!(
704 extract_host_port("https://API.Example.COM"),
705 Some("api.example.com:443".to_string())
706 );
707 }
708
709 #[test]
710 fn test_loaded_route_debug() {
711 let route = LoadedRoute {
712 upstream: "https://api.openai.com".to_string(),
713 upstream_host_port: Some("api.openai.com:443".to_string()),
714 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
715 tls_connector: None,
716 requires_intercept: false,
717 requires_managed_credential: false,
718 managed_auth_mechanism: None,
719 managed_injection_mode: None,
720 };
721 let debug_output = format!("{:?}", route);
722 assert!(debug_output.contains("api.openai.com"));
723 assert!(debug_output.contains("has_custom_tls_ca"));
724 assert!(debug_output.contains("requires_intercept"));
725 assert!(debug_output.contains("requires_managed_credential"));
726 assert!(debug_output.contains("managed_auth_mechanism"));
727 assert!(debug_output.contains("managed_injection_mode"));
728 }
729
730 #[test]
731 fn test_requires_intercept_credential_only() {
732 let routes = vec![RouteConfig {
733 prefix: "openai".to_string(),
734 upstream: "https://api.openai.com".to_string(),
735 credential_key: Some("openai_api_key".to_string()),
736 inject_mode: Default::default(),
737 inject_header: "Authorization".to_string(),
738 credential_format: "Bearer {}".to_string(),
739 path_pattern: None,
740 path_replacement: None,
741 query_param_name: None,
742 proxy: None,
743 env_var: None,
744 endpoint_rules: vec![],
745 tls_ca: None,
746 tls_client_cert: None,
747 tls_client_key: None,
748 oauth2: None,
749 }];
750 let store = RouteStore::load(&routes).unwrap();
751 let hit = store.lookup_by_upstream("api.openai.com:443").unwrap();
752 assert!(store.has_intercept_route("api.openai.com:443"));
753 assert!(hit.1.requires_managed_credential);
754 assert_eq!(
755 hit.1.managed_auth_mechanism,
756 Some(NetworkAuditAuthMechanism::PhantomHeader)
757 );
758 assert_eq!(
759 hit.1.managed_injection_mode,
760 Some(NetworkAuditInjectionMode::Header)
761 );
762 assert!(!store.has_intercept_route("api.example.com:443"));
763 }
764
765 #[test]
766 fn test_requires_intercept_endpoint_rules_only() {
767 let routes = vec![RouteConfig {
770 prefix: "internal".to_string(),
771 upstream: "https://internal.example.com".to_string(),
772 credential_key: None,
773 inject_mode: Default::default(),
774 inject_header: "Authorization".to_string(),
775 credential_format: "Bearer {}".to_string(),
776 path_pattern: None,
777 path_replacement: None,
778 query_param_name: None,
779 proxy: None,
780 env_var: None,
781 endpoint_rules: vec![EndpointRule {
782 method: "GET".to_string(),
783 path: "/v1/items".to_string(),
784 }],
785 tls_ca: None,
786 tls_client_cert: None,
787 tls_client_key: None,
788 oauth2: None,
789 }];
790 let store = RouteStore::load(&routes).unwrap();
791 let hit = store
792 .lookup_by_upstream("internal.example.com:443")
793 .unwrap();
794 assert!(store.has_intercept_route("internal.example.com:443"));
795 assert!(!hit.1.requires_managed_credential);
796 }
797
798 #[test]
799 fn test_requires_intercept_declarative_only() {
800 let routes = vec![RouteConfig {
803 prefix: "alias".to_string(),
804 upstream: "https://aliased.example.com".to_string(),
805 credential_key: None,
806 inject_mode: Default::default(),
807 inject_header: "Authorization".to_string(),
808 credential_format: "Bearer {}".to_string(),
809 path_pattern: None,
810 path_replacement: None,
811 query_param_name: None,
812 proxy: None,
813 env_var: None,
814 endpoint_rules: vec![],
815 tls_ca: None,
816 tls_client_cert: None,
817 tls_client_key: None,
818 oauth2: None,
819 }];
820 let store = RouteStore::load(&routes).unwrap();
821 assert!(store.is_route_upstream("aliased.example.com:443"));
822 assert!(!store.has_intercept_route("aliased.example.com:443"));
823 }
824
825 #[test]
826 fn test_missing_managed_credential_policy() {
827 let managed = LoadedRoute {
828 upstream: "https://api.openai.com".to_string(),
829 upstream_host_port: Some("api.openai.com:443".to_string()),
830 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
831 tls_connector: None,
832 requires_intercept: true,
833 requires_managed_credential: true,
834 managed_auth_mechanism: Some(NetworkAuditAuthMechanism::PhantomHeader),
835 managed_injection_mode: Some(NetworkAuditInjectionMode::Header),
836 };
837 assert!(managed.missing_managed_credential(false, false));
838 assert!(!managed.missing_managed_credential(true, false));
839 assert!(!managed.missing_managed_credential(false, true));
840
841 let l7_only = LoadedRoute {
842 upstream: "https://internal.example.com".to_string(),
843 upstream_host_port: Some("internal.example.com:443".to_string()),
844 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
845 tls_connector: None,
846 requires_intercept: true,
847 requires_managed_credential: false,
848 managed_auth_mechanism: None,
849 managed_injection_mode: None,
850 };
851 assert!(!l7_only.missing_managed_credential(false, false));
852 }
853
854 #[test]
855 fn test_lookup_by_upstream_returns_prefix() {
856 let routes = vec![RouteConfig {
857 prefix: "openai".to_string(),
858 upstream: "https://api.openai.com".to_string(),
859 credential_key: Some("openai_api_key".to_string()),
860 inject_mode: Default::default(),
861 inject_header: "Authorization".to_string(),
862 credential_format: "Bearer {}".to_string(),
863 path_pattern: None,
864 path_replacement: None,
865 query_param_name: None,
866 proxy: None,
867 env_var: None,
868 endpoint_rules: vec![],
869 tls_ca: None,
870 tls_client_cert: None,
871 tls_client_key: None,
872 oauth2: None,
873 }];
874 let store = RouteStore::load(&routes).unwrap();
875 let hit = store.lookup_by_upstream("api.openai.com:443").unwrap();
876 assert_eq!(hit.0, "openai");
877 assert!(hit.1.requires_intercept);
878 assert!(hit.1.requires_managed_credential);
879 assert!(store.lookup_by_upstream("api.example.com:443").is_none());
880 }
881
882 #[test]
883 fn test_lookup_all_by_upstream_returns_multiple_routes() {
884 let routes = vec![
885 RouteConfig {
886 prefix: "github_org_a".to_string(),
887 upstream: "https://github.com".to_string(),
888 credential_key: Some("env://GH_TOKEN_A".to_string()),
889 inject_mode: Default::default(),
890 inject_header: "Authorization".to_string(),
891 credential_format: "Bearer {}".to_string(),
892 path_pattern: None,
893 path_replacement: None,
894 query_param_name: None,
895 proxy: None,
896 env_var: Some("GH_TOKEN_A".to_string()),
897 endpoint_rules: vec![crate::config::EndpointRule {
898 method: "*".to_string(),
899 path: "/org-a/**".to_string(),
900 }],
901 tls_ca: None,
902 tls_client_cert: None,
903 tls_client_key: None,
904 oauth2: None,
905 },
906 RouteConfig {
907 prefix: "github_org_b".to_string(),
908 upstream: "https://github.com".to_string(),
909 credential_key: Some("env://GH_TOKEN_B".to_string()),
910 inject_mode: Default::default(),
911 inject_header: "Authorization".to_string(),
912 credential_format: "Bearer {}".to_string(),
913 path_pattern: None,
914 path_replacement: None,
915 query_param_name: None,
916 proxy: None,
917 env_var: Some("GH_TOKEN_B".to_string()),
918 endpoint_rules: vec![crate::config::EndpointRule {
919 method: "*".to_string(),
920 path: "/org-b/**".to_string(),
921 }],
922 tls_ca: None,
923 tls_client_cert: None,
924 tls_client_key: None,
925 oauth2: None,
926 },
927 ];
928 let store = RouteStore::load(&routes).unwrap();
929
930 let all = store.lookup_all_by_upstream("github.com:443");
931 assert_eq!(all.len(), 2, "both routes share the same upstream");
932
933 let prefixes: Vec<&str> = all.iter().map(|(p, _)| *p).collect();
934 assert!(prefixes.contains(&"github_org_a"));
935 assert!(prefixes.contains(&"github_org_b"));
936
937 let (_, route_a) = all.iter().find(|(p, _)| *p == "github_org_a").unwrap();
938 assert!(route_a.endpoint_rules.is_allowed("GET", "/org-a/repo"));
939 assert!(!route_a.endpoint_rules.is_allowed("GET", "/org-b/repo"));
940
941 let (_, route_b) = all.iter().find(|(p, _)| *p == "github_org_b").unwrap();
942 assert!(route_b.endpoint_rules.is_allowed("GET", "/org-b/repo"));
943 assert!(!route_b.endpoint_rules.is_allowed("GET", "/org-a/repo"));
944
945 assert!(store.has_intercept_route("github.com:443"));
946 assert!(store.is_route_upstream("github.com:443"));
947 assert!(store.lookup_all_by_upstream("other.com:443").is_empty());
948 }
949
950 #[test]
956 fn test_route_selection_multi_org_profile() {
957 fn gh_route(prefix: &str, env: &str, path: &str) -> RouteConfig {
959 RouteConfig {
960 prefix: prefix.to_string(),
961 upstream: "https://github.com".to_string(),
962 credential_key: Some(format!("env://{env}")),
963 inject_mode: Default::default(),
964 inject_header: "Authorization".to_string(),
965 credential_format: "Bearer {}".to_string(),
966 path_pattern: None,
967 path_replacement: None,
968 query_param_name: None,
969 proxy: None,
970 env_var: Some(env.to_string()),
971 endpoint_rules: vec![crate::config::EndpointRule {
972 method: "*".to_string(),
973 path: path.to_string(),
974 }],
975 tls_ca: None,
976 tls_client_cert: None,
977 tls_client_key: None,
978 oauth2: None,
979 }
980 }
981
982 #[derive(Debug, PartialEq)]
983 enum Selection<'a> {
984 Route(&'a str),
985 Passthrough,
986 Ambiguous(Vec<&'a str>),
987 }
988
989 fn select<'a>(
990 candidates: &'a [(&'a str, &'a LoadedRoute)],
991 method: &str,
992 path: &str,
993 ) -> Selection<'a> {
994 let mut matches: Vec<&str> = Vec::new();
995 let mut catch_all: Option<&str> = None;
996 for (prefix, route) in candidates {
997 if route.endpoint_rules.is_empty() {
998 if catch_all.is_none() {
999 catch_all = Some(*prefix);
1000 }
1001 } else if route.endpoint_rules.is_allowed(method, path) {
1002 matches.push(prefix);
1003 }
1004 }
1005 if matches.len() > 1 {
1006 Selection::Ambiguous(matches)
1007 } else if let Some(svc) = matches.into_iter().next().or(catch_all) {
1008 Selection::Route(svc)
1009 } else {
1010 Selection::Passthrough
1011 }
1012 }
1013
1014 let routes = vec![
1016 gh_route("github_https_org_a", "GH_TOKEN_A", "/org-a/**"),
1017 gh_route("github_https_org_b", "GH_TOKEN_B", "/org-b/**"),
1018 ];
1019 let store = RouteStore::load(&routes).unwrap();
1020 let candidates = store.lookup_all_by_upstream("github.com:443");
1021 assert_eq!(candidates.len(), 2);
1022
1023 assert_eq!(
1025 select(&candidates, "GET", "/org-a/repo.git/info/refs"),
1026 Selection::Route("github_https_org_a")
1027 );
1028 assert_eq!(
1030 select(&candidates, "GET", "/org-b/repo.git/info/refs"),
1031 Selection::Route("github_https_org_b")
1032 );
1033 assert_eq!(
1035 select(&candidates, "GET", "/always-further/nono.git/info/refs"),
1036 Selection::Passthrough
1037 );
1038 assert_eq!(
1040 select(
1041 &candidates,
1042 "POST",
1043 "/always-further/nono.git/git-upload-pack"
1044 ),
1045 Selection::Passthrough
1046 );
1047
1048 let routes_with_catchall = vec![
1050 gh_route("github_https_org_a", "GH_TOKEN_A", "/org-a/**"),
1051 gh_route("github_https_org_b", "GH_TOKEN_B", "/org-b/**"),
1052 gh_route("github_https_all", "GH_TOKEN_A", "/**"),
1053 ];
1054 let store2 = RouteStore::load(&routes_with_catchall).unwrap();
1055 let candidates2 = store2.lookup_all_by_upstream("github.com:443");
1056 assert_eq!(candidates2.len(), 3);
1057
1058 assert_eq!(
1060 select(&candidates2, "GET", "/org-a/repo.git/info/refs"),
1061 Selection::Ambiguous(vec!["github_https_all", "github_https_org_a"])
1062 );
1063 assert_eq!(
1065 select(&candidates2, "GET", "/always-further/nono.git/info/refs"),
1066 Selection::Route("github_https_all")
1067 );
1068 }
1069
1070 const TEST_CA_PEM: &str = "\
1074-----BEGIN CERTIFICATE-----
1075MIIBnjCCAUWgAwIBAgIUT0bpOJJvHdOdZt+gW1stR8VBgXowCgYIKoZIzj0EAwIw
1076FzEVMBMGA1UEAwwMbm9uby10ZXN0LWNhMCAXDTI1MDEwMTAwMDAwMFoYDzIxMjQx
1077MjA3MDAwMDAwWjAXMRUwEwYDVQQDDAxub25vLXRlc3QtY2EwWTATBgcqhkjOPQIB
1078BggqhkjOPQMBBwNCAAR8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
1079AAAAAAAAAAAAAAAAAAAAo1MwUTAdBgNVHQ4EFgQUAAAAAAAAAAAAAAAAAAAAAAAA
1080AAAAMB8GA1UdIwQYMBaAFAAAAAAAAAAAAAAAAAAAAAAAAAAAADAPBgNVHRMBAf8E
1081BTADAQH/MAoGCCqGSM49BAMCA0cAMEQCIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
1082AAAAAAAICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
1083-----END CERTIFICATE-----";
1084
1085 #[test]
1086 fn test_build_tls_connector_with_valid_ca() {
1087 let dir = tempfile::tempdir().unwrap();
1088 let ca_path = dir.path().join("ca.pem");
1089 std::fs::write(&ca_path, TEST_CA_PEM).unwrap();
1090
1091 let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
1092 match result {
1093 Ok(connector) => {
1094 drop(connector);
1095 }
1096 Err(ProxyError::Config(msg)) => {
1097 assert!(
1098 msg.contains("invalid CA certificate") || msg.contains("CA certificate"),
1099 "unexpected error: {}",
1100 msg
1101 );
1102 }
1103 Err(e) => panic!("unexpected error type: {}", e),
1104 }
1105 }
1106
1107 #[test]
1108 fn test_build_tls_connector_missing_file() {
1109 let result = build_tls_connector_with_ca("/nonexistent/path/ca.pem");
1110 let err = result
1111 .err()
1112 .expect("should fail for missing file")
1113 .to_string();
1114 assert!(
1115 err.contains("CA certificate file not found"),
1116 "unexpected error: {}",
1117 err
1118 );
1119 }
1120
1121 #[test]
1122 fn test_build_tls_connector_empty_pem() {
1123 let dir = tempfile::tempdir().unwrap();
1124 let ca_path = dir.path().join("empty.pem");
1125 std::fs::write(&ca_path, "not a certificate\n").unwrap();
1126
1127 let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
1128 let err = result
1129 .err()
1130 .expect("should fail for invalid PEM")
1131 .to_string();
1132 assert!(
1133 err.contains("no valid PEM certificates"),
1134 "unexpected error: {}",
1135 err
1136 );
1137 }
1138
1139 const TEST_CLIENT_CERT_PEM: &str = "\
1145-----BEGIN CERTIFICATE-----
1146MIIBijCCATGgAwIBAgIUEoEb+0z+4CTRCzN98MqeTEXgdO8wCgYIKoZIzj0EAwIw
1147GzEZMBcGA1UEAwwQbm9uby10ZXN0LWNsaWVudDAeFw0yNjA0MTAwMDIwNTdaFw0z
1148NjA0MDcwMDIwNTdaMBsxGTAXBgNVBAMMEG5vbm8tdGVzdC1jbGllbnQwWTATBgcq
1149hkjOPQIBBggqhkjOPQMBBwNCAASt6g2Zt0STlgF+wZ64JzdDRlpPeNr1h56ZLEEq
1150HfVWFhJWIKRSabtxYPV/VJyMv+lo3L0QwSKsouHs3dtF1zVQo1MwUTAdBgNVHQ4E
1151FgQUTiHidg8uqgrJ1qlaVvR+XSebAlEwHwYDVR0jBBgwFoAUTiHidg8uqgrJ1qla
1152VvR+XSebAlEwDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiA9PwBU
1153f832cQkGS9cyYaU7Ij5U8Rcy/g4J7Ckf2nKX3gIgG0aarAFcIzAi5VpxbCwEScnr
1154m0lHTyp6E7ut7llwMBY=
1155-----END CERTIFICATE-----";
1156
1157 const TEST_CLIENT_KEY_PEM: &str = "\
1158-----BEGIN PRIVATE KEY-----
1159MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgskOkyJkTwlMZkm/L
1160eEleLY6bARaHFnqauYJqxNoJWvihRANCAASt6g2Zt0STlgF+wZ64JzdDRlpPeNr1
1161h56ZLEEqHfVWFhJWIKRSabtxYPV/VJyMv+lo3L0QwSKsouHs3dtF1zVQ
1162-----END PRIVATE KEY-----";
1163
1164 #[test]
1165 fn test_build_tls_connector_cert_without_key_errors() {
1166 let dir = tempfile::tempdir().unwrap();
1167 let cert_path = dir.path().join("client.crt");
1168 std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
1169
1170 let base = build_base_root_store();
1171 let result = build_tls_connector(&base, None, Some(cert_path.to_str().unwrap()), None);
1172 let err = result
1173 .err()
1174 .expect("should fail with half-pair")
1175 .to_string();
1176 assert!(
1177 err.contains("tls_client_cert is set but tls_client_key is missing"),
1178 "unexpected error: {}",
1179 err
1180 );
1181 }
1182
1183 #[test]
1184 fn test_build_tls_connector_key_without_cert_errors() {
1185 let dir = tempfile::tempdir().unwrap();
1186 let key_path = dir.path().join("client.key");
1187 std::fs::write(&key_path, TEST_CLIENT_KEY_PEM).unwrap();
1188
1189 let base = build_base_root_store();
1190 let result = build_tls_connector(&base, None, None, Some(key_path.to_str().unwrap()));
1191 let err = result
1192 .err()
1193 .expect("should fail with half-pair")
1194 .to_string();
1195 assert!(
1196 err.contains("tls_client_key is set but tls_client_cert is missing"),
1197 "unexpected error: {}",
1198 err
1199 );
1200 }
1201
1202 #[test]
1203 fn test_build_tls_connector_missing_client_cert_file() {
1204 let dir = tempfile::tempdir().unwrap();
1205 let key_path = dir.path().join("client.key");
1206 std::fs::write(&key_path, TEST_CLIENT_KEY_PEM).unwrap();
1207
1208 let base = build_base_root_store();
1209 let result = build_tls_connector(
1210 &base,
1211 None,
1212 Some("/nonexistent/client.crt"),
1213 Some(key_path.to_str().unwrap()),
1214 );
1215 let err = result.err().expect("should fail").to_string();
1216 assert!(
1217 err.contains("client certificate file not found"),
1218 "unexpected error: {}",
1219 err
1220 );
1221 }
1222
1223 #[test]
1224 fn test_build_tls_connector_missing_client_key_file() {
1225 let dir = tempfile::tempdir().unwrap();
1226 let cert_path = dir.path().join("client.crt");
1227 std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
1228
1229 let base = build_base_root_store();
1230 let result = build_tls_connector(
1231 &base,
1232 None,
1233 Some(cert_path.to_str().unwrap()),
1234 Some("/nonexistent/client.key"),
1235 );
1236 let err = result.err().expect("should fail").to_string();
1237 assert!(
1238 err.contains("client key file not found"),
1239 "unexpected error: {}",
1240 err
1241 );
1242 }
1243
1244 #[test]
1245 #[cfg(unix)]
1246 fn test_build_tls_connector_permission_denied() {
1247 use std::os::unix::fs::PermissionsExt;
1248 let dir = tempfile::tempdir().unwrap();
1249 let cert_path = dir.path().join("client.crt");
1250 std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
1251 std::fs::set_permissions(&cert_path, std::fs::Permissions::from_mode(0o000)).unwrap();
1253
1254 if std::fs::read(&cert_path).is_ok() {
1256 return;
1257 }
1258
1259 let base = build_base_root_store();
1260 let result = build_tls_connector(
1261 &base,
1262 None,
1263 Some(cert_path.to_str().unwrap()),
1264 Some("/nonexistent/key"),
1265 );
1266 let err = result.err().expect("should fail").to_string();
1267 assert!(
1268 err.contains("permission denied"),
1269 "expected permission denied error, got: {}",
1270 err
1271 );
1272 }
1273
1274 #[test]
1275 fn test_build_tls_connector_empty_client_cert_pem() {
1276 let dir = tempfile::tempdir().unwrap();
1277 let cert_path = dir.path().join("client.crt");
1278 let key_path = dir.path().join("client.key");
1279 std::fs::write(&cert_path, "not a certificate\n").unwrap();
1280 std::fs::write(&key_path, TEST_CLIENT_KEY_PEM).unwrap();
1281
1282 let base = build_base_root_store();
1283 let result = build_tls_connector(
1284 &base,
1285 None,
1286 Some(cert_path.to_str().unwrap()),
1287 Some(key_path.to_str().unwrap()),
1288 );
1289 let err = result.err().expect("should fail").to_string();
1290 assert!(
1291 err.contains("no valid PEM certificates"),
1292 "unexpected error: {}",
1293 err
1294 );
1295 }
1296
1297 #[test]
1298 fn test_build_tls_connector_empty_client_key_pem() {
1299 let dir = tempfile::tempdir().unwrap();
1301 let cert_path = dir.path().join("client.crt");
1302 let key_path = dir.path().join("client.key");
1303 std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
1304 std::fs::write(&key_path, "not a key\n").unwrap();
1305
1306 let base = build_base_root_store();
1307 let result = build_tls_connector(
1308 &base,
1309 None,
1310 Some(cert_path.to_str().unwrap()),
1311 Some(key_path.to_str().unwrap()),
1312 );
1313 let err = result
1314 .err()
1315 .expect("should fail with invalid PEM")
1316 .to_string();
1317 assert!(err.contains("client key"), "unexpected error: {}", err);
1318 }
1319
1320 #[test]
1321 fn test_route_store_loads_mtls_route() {
1322 let dir = tempfile::tempdir().unwrap();
1324 let cert_path = dir.path().join("client.crt");
1325 let key_path = dir.path().join("client.key");
1326 std::fs::write(&cert_path, TEST_CLIENT_CERT_PEM).unwrap();
1327 std::fs::write(&key_path, TEST_CLIENT_KEY_PEM).unwrap();
1328
1329 let routes = vec![RouteConfig {
1330 prefix: "k8s".to_string(),
1331 upstream: "https://192.168.64.1:6443".to_string(),
1332 credential_key: None,
1333 inject_mode: Default::default(),
1334 inject_header: "Authorization".to_string(),
1335 credential_format: "Bearer {}".to_string(),
1336 path_pattern: None,
1337 path_replacement: None,
1338 query_param_name: None,
1339 proxy: None,
1340 env_var: None,
1341 endpoint_rules: vec![],
1342 tls_ca: None,
1343 tls_client_cert: Some(cert_path.to_str().unwrap().to_string()),
1344 tls_client_key: Some(key_path.to_str().unwrap().to_string()),
1345 oauth2: None,
1346 }];
1347
1348 let store = RouteStore::load(&routes).expect("should load mTLS route");
1349 let route = store.get("k8s").unwrap();
1350 assert!(
1351 route.tls_connector.is_some(),
1352 "connector must be built when tls_client_cert/key are set"
1353 );
1354 }
1355}