1use crate::config::{CompiledEndpointRules, InjectMode, RouteConfig};
9use crate::error::{ProxyError, Result};
10use base64::Engine;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::debug;
14use zeroize::Zeroizing;
15
16pub struct LoadedCredential {
18 pub inject_mode: InjectMode,
20 pub upstream: String,
22 pub raw_credential: Zeroizing<String>,
24
25 pub header_name: String,
28 pub header_value: Zeroizing<String>,
30
31 pub path_pattern: Option<String>,
34 pub path_replacement: Option<String>,
36
37 pub query_param_name: Option<String>,
40
41 pub endpoint_rules: CompiledEndpointRules,
45
46 pub tls_connector: Option<tokio_rustls::TlsConnector>,
51}
52
53impl std::fmt::Debug for LoadedCredential {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("LoadedCredential")
58 .field("inject_mode", &self.inject_mode)
59 .field("upstream", &self.upstream)
60 .field("raw_credential", &"[REDACTED]")
61 .field("header_name", &self.header_name)
62 .field("header_value", &"[REDACTED]")
63 .field("path_pattern", &self.path_pattern)
64 .field("path_replacement", &self.path_replacement)
65 .field("query_param_name", &self.query_param_name)
66 .field("endpoint_rules", &self.endpoint_rules)
67 .field("has_custom_tls_ca", &self.tls_connector.is_some())
68 .finish()
69 }
70}
71
72#[derive(Debug)]
74pub struct CredentialStore {
75 credentials: HashMap<String, LoadedCredential>,
77}
78
79impl CredentialStore {
80 pub fn load(routes: &[RouteConfig]) -> Result<Self> {
90 let mut credentials = HashMap::new();
91
92 for route in routes {
93 if let Some(ref key) = route.credential_key {
94 debug!(
95 "Loading credential for route prefix: {} (mode: {:?})",
96 route.prefix, route.inject_mode
97 );
98
99 let secret = match nono::keystore::load_secret_by_ref(KEYRING_SERVICE, key) {
100 Ok(s) => s,
101 Err(nono::NonoError::SecretNotFound(msg)) => {
102 debug!(
103 "Credential '{}' not available, skipping route: {}",
104 route.prefix, msg
105 );
106 continue;
107 }
108 Err(e) => return Err(ProxyError::Credential(e.to_string())),
109 };
110
111 let effective_format = if route.inject_header != "Authorization"
117 && route.credential_format == "Bearer {}"
118 {
119 "{}".to_string()
120 } else {
121 route.credential_format.clone()
122 };
123
124 let header_value = match route.inject_mode {
125 InjectMode::Header => Zeroizing::new(effective_format.replace("{}", &secret)),
126 InjectMode::BasicAuth => {
127 let encoded =
129 base64::engine::general_purpose::STANDARD.encode(secret.as_bytes());
130 Zeroizing::new(format!("Basic {}", encoded))
131 }
132 InjectMode::UrlPath | InjectMode::QueryParam => Zeroizing::new(String::new()),
134 };
135
136 let tls_connector = match route.tls_ca {
138 Some(ref ca_path) => {
139 debug!(
140 "Building TLS connector with custom CA for route '{}': {}",
141 route.prefix, ca_path
142 );
143 Some(build_tls_connector_with_ca(ca_path)?)
144 }
145 None => None,
146 };
147
148 credentials.insert(
149 route.prefix.clone(),
150 LoadedCredential {
151 inject_mode: route.inject_mode.clone(),
152 upstream: route.upstream.clone(),
153 raw_credential: secret,
154 header_name: route.inject_header.clone(),
155 header_value,
156 path_pattern: route.path_pattern.clone(),
157 path_replacement: route.path_replacement.clone(),
158 query_param_name: route.query_param_name.clone(),
159 endpoint_rules: CompiledEndpointRules::compile(&route.endpoint_rules)
160 .map_err(|e| {
161 ProxyError::Credential(format!("route '{}': {}", route.prefix, e))
162 })?,
163 tls_connector,
164 },
165 );
166 }
167 }
168
169 Ok(Self { credentials })
170 }
171
172 #[must_use]
174 pub fn empty() -> Self {
175 Self {
176 credentials: HashMap::new(),
177 }
178 }
179
180 #[must_use]
182 pub fn get(&self, prefix: &str) -> Option<&LoadedCredential> {
183 self.credentials.get(prefix)
184 }
185
186 #[must_use]
188 pub fn is_empty(&self) -> bool {
189 self.credentials.is_empty()
190 }
191
192 #[must_use]
194 pub fn len(&self) -> usize {
195 self.credentials.len()
196 }
197
198 #[must_use]
200 pub fn loaded_prefixes(&self) -> std::collections::HashSet<String> {
201 self.credentials.keys().cloned().collect()
202 }
203
204 #[must_use]
208 pub fn is_credential_upstream(&self, host_port: &str) -> bool {
209 let normalised = host_port.to_lowercase();
210 self.credentials.values().any(|cred| {
211 extract_host_port(&cred.upstream)
212 .map(|hp| hp == normalised)
213 .unwrap_or(false)
214 })
215 }
216
217 #[must_use]
221 pub fn credential_upstream_hosts(&self) -> std::collections::HashSet<String> {
222 self.credentials
223 .values()
224 .filter_map(|cred| extract_host_port(&cred.upstream))
225 .collect()
226 }
227}
228
229fn extract_host_port(url: &str) -> Option<String> {
234 let parsed = url::Url::parse(url).ok()?;
235 let host = parsed.host_str()?;
236 let default_port = match parsed.scheme() {
237 "https" => 443,
238 "http" => 80,
239 _ => return None,
240 };
241 let port = parsed.port().unwrap_or(default_port);
242 Some(format!("{}:{}", host.to_lowercase(), port))
243}
244
245fn build_tls_connector_with_ca(ca_path: &str) -> Result<tokio_rustls::TlsConnector> {
251 let ca_path = std::path::Path::new(ca_path);
252
253 let ca_pem = Zeroizing::new(std::fs::read(ca_path).map_err(|e| {
254 if e.kind() == std::io::ErrorKind::NotFound {
255 ProxyError::Config(format!(
256 "CA certificate file not found: '{}'",
257 ca_path.display()
258 ))
259 } else {
260 ProxyError::Config(format!(
261 "failed to read CA certificate '{}': {}",
262 ca_path.display(),
263 e
264 ))
265 }
266 })?);
267
268 let mut root_store = rustls::RootCertStore::empty();
269
270 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
272
273 let certs: Vec<_> = rustls_pemfile::certs(&mut ca_pem.as_slice())
275 .collect::<std::result::Result<Vec<_>, _>>()
276 .map_err(|e| {
277 ProxyError::Config(format!(
278 "failed to parse CA certificate '{}': {}",
279 ca_path.display(),
280 e
281 ))
282 })?;
283
284 if certs.is_empty() {
285 return Err(ProxyError::Config(format!(
286 "CA certificate file '{}' contains no valid PEM certificates",
287 ca_path.display()
288 )));
289 }
290
291 for cert in certs {
292 root_store.add(cert).map_err(|e| {
293 ProxyError::Config(format!(
294 "invalid CA certificate in '{}': {}",
295 ca_path.display(),
296 e
297 ))
298 })?;
299 }
300
301 let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
302 rustls::crypto::ring::default_provider(),
303 ))
304 .with_safe_default_protocol_versions()
305 .map_err(|e| ProxyError::Config(format!("TLS config error: {}", e)))?
306 .with_root_certificates(root_store)
307 .with_no_client_auth();
308
309 Ok(tokio_rustls::TlsConnector::from(Arc::new(tls_config)))
310}
311
312const KEYRING_SERVICE: &str = nono::keystore::DEFAULT_SERVICE;
315
316#[cfg(test)]
317#[allow(clippy::unwrap_used)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_empty_credential_store() {
323 let store = CredentialStore::empty();
324 assert!(store.is_empty());
325 assert_eq!(store.len(), 0);
326 assert!(store.get("/openai").is_none());
327 }
328
329 #[test]
330 fn test_loaded_credential_debug_redacts_secrets() {
331 let cred = LoadedCredential {
335 inject_mode: InjectMode::Header,
336 upstream: "https://api.openai.com".to_string(),
337 raw_credential: Zeroizing::new("sk-secret-12345".to_string()),
338 header_name: "Authorization".to_string(),
339 header_value: Zeroizing::new("Bearer sk-secret-12345".to_string()),
340 path_pattern: None,
341 path_replacement: None,
342 query_param_name: None,
343 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
344 tls_connector: None,
345 };
346
347 let debug_output = format!("{:?}", cred);
348
349 assert!(
351 debug_output.contains("[REDACTED]"),
352 "Debug output should contain [REDACTED], got: {}",
353 debug_output
354 );
355 assert!(
357 !debug_output.contains("sk-secret-12345"),
358 "Debug output must not contain the real secret"
359 );
360 assert!(
361 !debug_output.contains("Bearer sk-secret"),
362 "Debug output must not contain the formatted secret"
363 );
364 assert!(debug_output.contains("api.openai.com"));
366 assert!(debug_output.contains("Authorization"));
367 }
368
369 #[test]
370 fn test_extract_host_port_https_no_port() {
371 assert_eq!(
372 extract_host_port("https://api.openai.com"),
373 Some("api.openai.com:443".to_string())
374 );
375 }
376
377 #[test]
378 fn test_extract_host_port_https_with_port() {
379 assert_eq!(
380 extract_host_port("https://api.openai.com:8443"),
381 Some("api.openai.com:8443".to_string())
382 );
383 }
384
385 #[test]
386 fn test_extract_host_port_http_no_port() {
387 assert_eq!(
388 extract_host_port("http://internal:4096"),
389 Some("internal:4096".to_string())
390 );
391 }
392
393 #[test]
394 fn test_extract_host_port_http_default_port() {
395 assert_eq!(
396 extract_host_port("http://internal-service"),
397 Some("internal-service:80".to_string())
398 );
399 }
400
401 #[test]
402 fn test_extract_host_port_normalises_case() {
403 assert_eq!(
404 extract_host_port("https://GitLab-PRD.Home.Example.COM"),
405 Some("gitlab-prd.home.example.com:443".to_string())
406 );
407 }
408
409 #[test]
410 fn test_extract_host_port_with_path() {
411 assert_eq!(
412 extract_host_port("https://api.example.com/v1/endpoint"),
413 Some("api.example.com:443".to_string())
414 );
415 }
416
417 #[test]
418 fn test_extract_host_port_no_scheme() {
419 assert_eq!(extract_host_port("api.openai.com"), None);
420 }
421
422 #[test]
423 fn test_is_credential_upstream() {
424 let mut credentials = HashMap::new();
425 credentials.insert(
426 "gitlab".to_string(),
427 LoadedCredential {
428 inject_mode: InjectMode::Header,
429 upstream: "https://gitlab.example.com".to_string(),
430 raw_credential: Zeroizing::new("token".to_string()),
431 header_name: "PRIVATE-TOKEN".to_string(),
432 header_value: Zeroizing::new("token".to_string()),
433 path_pattern: None,
434 path_replacement: None,
435 query_param_name: None,
436 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
437 tls_connector: None,
438 },
439 );
440 let store = CredentialStore { credentials };
441
442 assert!(store.is_credential_upstream("gitlab.example.com:443"));
443 assert!(!store.is_credential_upstream("unrelated.example.com:443"));
444 }
445
446 #[test]
447 fn test_is_credential_upstream_empty_store() {
448 let store = CredentialStore::empty();
449 assert!(!store.is_credential_upstream("anything:443"));
450 }
451
452 #[test]
453 fn test_credential_upstream_hosts() {
454 let mut credentials = HashMap::new();
455 credentials.insert(
456 "gitlab".to_string(),
457 LoadedCredential {
458 inject_mode: InjectMode::Header,
459 upstream: "https://gitlab.example.com".to_string(),
460 raw_credential: Zeroizing::new("token".to_string()),
461 header_name: "PRIVATE-TOKEN".to_string(),
462 header_value: Zeroizing::new("token".to_string()),
463 path_pattern: None,
464 path_replacement: None,
465 query_param_name: None,
466 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
467 tls_connector: None,
468 },
469 );
470 let store = CredentialStore { credentials };
471
472 let hosts = store.credential_upstream_hosts();
473 assert!(hosts.contains("gitlab.example.com:443"));
474 assert_eq!(hosts.len(), 1);
475 }
476
477 #[test]
478 fn test_load_no_credential_routes() {
479 let routes = vec![RouteConfig {
480 prefix: "/test".to_string(),
481 upstream: "https://example.com".to_string(),
482 credential_key: None,
483 inject_mode: InjectMode::Header,
484 inject_header: "Authorization".to_string(),
485 credential_format: "Bearer {}".to_string(),
486 path_pattern: None,
487 path_replacement: None,
488 query_param_name: None,
489 env_var: None,
490 endpoint_rules: vec![],
491 tls_ca: None,
492 }];
493 let store = CredentialStore::load(&routes);
494 assert!(store.is_ok());
495 let store = store.unwrap_or_else(|_| CredentialStore::empty());
496 assert!(store.is_empty());
497 }
498
499 const TEST_CA_PEM: &str = "\
503-----BEGIN CERTIFICATE-----
504MIIBnjCCAUWgAwIBAgIUT0bpOJJvHdOdZt+gW1stR8VBgXowCgYIKoZIzj0EAwIw
505FzEVMBMGA1UEAwwMbm9uby10ZXN0LWNhMCAXDTI1MDEwMTAwMDAwMFoYDzIxMjQx
506MjA3MDAwMDAwWjAXMRUwEwYDVQQDDAxub25vLXRlc3QtY2EwWTATBgcqhkjOPQIB
507BggqhkjOPQMBBwNCAAR8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
508AAAAAAAAAAAAAAAAAAAAo1MwUTAdBgNVHQ4EFgQUAAAAAAAAAAAAAAAAAAAAAAAA
509AAAAMB8GA1UdIwQYMBaAFAAAAAAAAAAAAAAAAAAAAAAAAAAAADAPBgNVHRMBAf8E
510BTADAQH/MAoGCCqGSM49BAMCA0cAMEQCIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
511AAAAAAAICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
512-----END CERTIFICATE-----";
513
514 #[test]
515 fn test_build_tls_connector_with_valid_ca() {
516 let dir = tempfile::tempdir().unwrap();
517 let ca_path = dir.path().join("ca.pem");
518 std::fs::write(&ca_path, TEST_CA_PEM).unwrap();
519
520 let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
525 match result {
529 Ok(connector) => {
530 drop(connector);
532 }
533 Err(ProxyError::Config(msg)) => {
534 assert!(
536 msg.contains("invalid CA certificate") || msg.contains("CA certificate"),
537 "unexpected error: {}",
538 msg
539 );
540 }
541 Err(e) => panic!("unexpected error type: {}", e),
542 }
543 }
544
545 #[test]
546 fn test_build_tls_connector_missing_file() {
547 let result = build_tls_connector_with_ca("/nonexistent/path/ca.pem");
548 let err = result
549 .err()
550 .expect("should fail for missing file")
551 .to_string();
552 assert!(
553 err.contains("CA certificate file not found"),
554 "unexpected error: {}",
555 err
556 );
557 }
558
559 #[test]
560 fn test_build_tls_connector_empty_pem() {
561 let dir = tempfile::tempdir().unwrap();
562 let ca_path = dir.path().join("empty.pem");
563 std::fs::write(&ca_path, "not a certificate\n").unwrap();
564
565 let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
566 let err = result
567 .err()
568 .expect("should fail for invalid PEM")
569 .to_string();
570 assert!(
571 err.contains("no valid PEM certificates"),
572 "unexpected error: {}",
573 err
574 );
575 }
576
577 #[test]
578 fn test_build_tls_connector_empty_file() {
579 let dir = tempfile::tempdir().unwrap();
580 let ca_path = dir.path().join("empty.pem");
581 std::fs::write(&ca_path, "").unwrap();
582
583 let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
584 let err = result
585 .err()
586 .expect("should fail for empty file")
587 .to_string();
588 assert!(
589 err.contains("no valid PEM certificates"),
590 "unexpected error: {}",
591 err
592 );
593 }
594}