Skip to main content

nono_proxy/
credential.rs

1//! Credential loading and management for reverse proxy mode.
2//!
3//! Loads API credentials from the system keystore or 1Password at proxy startup.
4//! Credentials are stored in `Zeroizing<String>` and injected into
5//! requests via headers, URL paths, query parameters, or Basic Auth.
6//! The sandboxed agent never sees the real credentials.
7
8use crate::config::{CompiledEndpointRules, InjectMode, RouteConfig};
9use crate::error::{ProxyError, Result};
10use base64::Engine;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::debug;
14use zeroize::Zeroizing;
15
16/// A loaded credential ready for injection.
17pub struct LoadedCredential {
18    /// Injection mode
19    pub inject_mode: InjectMode,
20    /// Upstream URL (e.g., "https://api.openai.com")
21    pub upstream: String,
22    /// Raw credential value from keystore (for modes that need it directly)
23    pub raw_credential: Zeroizing<String>,
24
25    // --- Header mode ---
26    /// Header name to inject (e.g., "Authorization")
27    pub header_name: String,
28    /// Formatted header value (e.g., "Bearer sk-...")
29    pub header_value: Zeroizing<String>,
30
31    // --- URL path mode ---
32    /// Pattern to match in incoming path (with {} placeholder)
33    pub path_pattern: Option<String>,
34    /// Pattern for outgoing path (with {} placeholder)
35    pub path_replacement: Option<String>,
36
37    // --- Query param mode ---
38    /// Query parameter name
39    pub query_param_name: Option<String>,
40
41    // --- L7 endpoint filtering ---
42    /// Pre-compiled endpoint rules for method+path filtering.
43    /// Compiled once at load time to avoid per-request glob compilation.
44    pub endpoint_rules: CompiledEndpointRules,
45
46    // --- Custom CA TLS ---
47    /// Per-route TLS connector with custom CA trust, if configured.
48    /// Built once at startup from the route's `tls_ca` certificate file.
49    /// When `None`, the shared default connector (webpki roots only) is used.
50    pub tls_connector: Option<tokio_rustls::TlsConnector>,
51}
52
53/// Custom Debug impl that redacts secret values to prevent accidental leakage
54/// in logs, panic messages, or debug output.
55impl std::fmt::Debug for LoadedCredential {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("LoadedCredential")
58            .field("inject_mode", &self.inject_mode)
59            .field("upstream", &self.upstream)
60            .field("raw_credential", &"[REDACTED]")
61            .field("header_name", &self.header_name)
62            .field("header_value", &"[REDACTED]")
63            .field("path_pattern", &self.path_pattern)
64            .field("path_replacement", &self.path_replacement)
65            .field("query_param_name", &self.query_param_name)
66            .field("endpoint_rules", &self.endpoint_rules)
67            .field("has_custom_tls_ca", &self.tls_connector.is_some())
68            .finish()
69    }
70}
71
72/// Credential store for all configured routes.
73#[derive(Debug)]
74pub struct CredentialStore {
75    /// Map from route prefix to loaded credential
76    credentials: HashMap<String, LoadedCredential>,
77}
78
79impl CredentialStore {
80    /// Load credentials for all configured routes from the system keystore.
81    ///
82    /// Routes without a `credential_key` are skipped (no credential injection).
83    /// Routes whose credential is not found (e.g. unset env var) are skipped
84    /// with a warning — this allows profiles to declare optional credentials
85    /// without failing when they are unavailable.
86    ///
87    /// Returns an error only for hard failures (keystore access errors,
88    /// config parse errors, non-UTF-8 values).
89    pub fn load(routes: &[RouteConfig]) -> Result<Self> {
90        let mut credentials = HashMap::new();
91
92        for route in routes {
93            if let Some(ref key) = route.credential_key {
94                debug!(
95                    "Loading credential for route prefix: {} (mode: {:?})",
96                    route.prefix, route.inject_mode
97                );
98
99                let secret = match nono::keystore::load_secret_by_ref(KEYRING_SERVICE, key) {
100                    Ok(s) => s,
101                    Err(nono::NonoError::SecretNotFound(msg)) => {
102                        debug!(
103                            "Credential '{}' not available, skipping route: {}",
104                            route.prefix, msg
105                        );
106                        continue;
107                    }
108                    Err(e) => return Err(ProxyError::Credential(e.to_string())),
109                };
110
111                // Format header value based on mode.
112                // When inject_header is not "Authorization" (e.g., "PRIVATE-TOKEN",
113                // "X-API-Key"), the credential is injected as-is unless the user
114                // explicitly set a custom format. The default "Bearer {}" only
115                // makes sense for the Authorization header.
116                let effective_format = if route.inject_header != "Authorization"
117                    && route.credential_format == "Bearer {}"
118                {
119                    "{}".to_string()
120                } else {
121                    route.credential_format.clone()
122                };
123
124                let header_value = match route.inject_mode {
125                    InjectMode::Header => Zeroizing::new(effective_format.replace("{}", &secret)),
126                    InjectMode::BasicAuth => {
127                        // Base64 encode the credential for Basic auth
128                        let encoded =
129                            base64::engine::general_purpose::STANDARD.encode(secret.as_bytes());
130                        Zeroizing::new(format!("Basic {}", encoded))
131                    }
132                    // For url_path and query_param, header_value is not used
133                    InjectMode::UrlPath | InjectMode::QueryParam => Zeroizing::new(String::new()),
134                };
135
136                // Build per-route TLS connector if a custom CA is configured
137                let tls_connector = match route.tls_ca {
138                    Some(ref ca_path) => {
139                        debug!(
140                            "Building TLS connector with custom CA for route '{}': {}",
141                            route.prefix, ca_path
142                        );
143                        Some(build_tls_connector_with_ca(ca_path)?)
144                    }
145                    None => None,
146                };
147
148                credentials.insert(
149                    route.prefix.clone(),
150                    LoadedCredential {
151                        inject_mode: route.inject_mode.clone(),
152                        upstream: route.upstream.clone(),
153                        raw_credential: secret,
154                        header_name: route.inject_header.clone(),
155                        header_value,
156                        path_pattern: route.path_pattern.clone(),
157                        path_replacement: route.path_replacement.clone(),
158                        query_param_name: route.query_param_name.clone(),
159                        endpoint_rules: CompiledEndpointRules::compile(&route.endpoint_rules)
160                            .map_err(|e| {
161                                ProxyError::Credential(format!("route '{}': {}", route.prefix, e))
162                            })?,
163                        tls_connector,
164                    },
165                );
166            }
167        }
168
169        Ok(Self { credentials })
170    }
171
172    /// Create an empty credential store (no credential injection).
173    #[must_use]
174    pub fn empty() -> Self {
175        Self {
176            credentials: HashMap::new(),
177        }
178    }
179
180    /// Get a credential for a route prefix, if configured.
181    #[must_use]
182    pub fn get(&self, prefix: &str) -> Option<&LoadedCredential> {
183        self.credentials.get(prefix)
184    }
185
186    /// Check if any credentials are loaded.
187    #[must_use]
188    pub fn is_empty(&self) -> bool {
189        self.credentials.is_empty()
190    }
191
192    /// Number of loaded credentials.
193    #[must_use]
194    pub fn len(&self) -> usize {
195        self.credentials.len()
196    }
197
198    /// Returns the set of route prefixes that have loaded credentials.
199    #[must_use]
200    pub fn loaded_prefixes(&self) -> std::collections::HashSet<String> {
201        self.credentials.keys().cloned().collect()
202    }
203
204    /// Check whether `host_port` (e.g. `"gitlab.example.com:443"`) matches
205    /// any credential upstream. Used to block CONNECT tunnels that would
206    /// bypass L7 path filtering.
207    #[must_use]
208    pub fn is_credential_upstream(&self, host_port: &str) -> bool {
209        let normalised = host_port.to_lowercase();
210        self.credentials.values().any(|cred| {
211            extract_host_port(&cred.upstream)
212                .map(|hp| hp == normalised)
213                .unwrap_or(false)
214        })
215    }
216
217    /// Return the set of normalised `host:port` strings for all credential
218    /// upstreams. Used to compute smart `NO_PROXY` — hosts in this set must
219    /// NOT be bypassed because they need reverse proxy credential injection.
220    #[must_use]
221    pub fn credential_upstream_hosts(&self) -> std::collections::HashSet<String> {
222        self.credentials
223            .values()
224            .filter_map(|cred| extract_host_port(&cred.upstream))
225            .collect()
226    }
227}
228
229/// Extract and normalise `host:port` from a URL string.
230///
231/// Defaults to port 443 for `https://` and 80 for `http://` when no
232/// explicit port is present. Returns `None` if the URL cannot be parsed.
233fn extract_host_port(url: &str) -> Option<String> {
234    let parsed = url::Url::parse(url).ok()?;
235    let host = parsed.host_str()?;
236    let default_port = match parsed.scheme() {
237        "https" => 443,
238        "http" => 80,
239        _ => return None,
240    };
241    let port = parsed.port().unwrap_or(default_port);
242    Some(format!("{}:{}", host.to_lowercase(), port))
243}
244
245/// Build a `TlsConnector` that trusts the system roots plus a custom CA certificate.
246///
247/// The CA file must be PEM-encoded and contain at least one certificate.
248/// Returns an error if the file cannot be read, contains no valid certificates,
249/// or the TLS configuration fails.
250fn build_tls_connector_with_ca(ca_path: &str) -> Result<tokio_rustls::TlsConnector> {
251    let ca_path = std::path::Path::new(ca_path);
252
253    let ca_pem = Zeroizing::new(std::fs::read(ca_path).map_err(|e| {
254        if e.kind() == std::io::ErrorKind::NotFound {
255            ProxyError::Config(format!(
256                "CA certificate file not found: '{}'",
257                ca_path.display()
258            ))
259        } else {
260            ProxyError::Config(format!(
261                "failed to read CA certificate '{}': {}",
262                ca_path.display(),
263                e
264            ))
265        }
266    })?);
267
268    let mut root_store = rustls::RootCertStore::empty();
269
270    // Add system roots first
271    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
272
273    // Parse and add custom CA certificates from PEM file
274    let certs: Vec<_> = rustls_pemfile::certs(&mut ca_pem.as_slice())
275        .collect::<std::result::Result<Vec<_>, _>>()
276        .map_err(|e| {
277            ProxyError::Config(format!(
278                "failed to parse CA certificate '{}': {}",
279                ca_path.display(),
280                e
281            ))
282        })?;
283
284    if certs.is_empty() {
285        return Err(ProxyError::Config(format!(
286            "CA certificate file '{}' contains no valid PEM certificates",
287            ca_path.display()
288        )));
289    }
290
291    for cert in certs {
292        root_store.add(cert).map_err(|e| {
293            ProxyError::Config(format!(
294                "invalid CA certificate in '{}': {}",
295                ca_path.display(),
296                e
297            ))
298        })?;
299    }
300
301    let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
302        rustls::crypto::ring::default_provider(),
303    ))
304    .with_safe_default_protocol_versions()
305    .map_err(|e| ProxyError::Config(format!("TLS config error: {}", e)))?
306    .with_root_certificates(root_store)
307    .with_no_client_auth();
308
309    Ok(tokio_rustls::TlsConnector::from(Arc::new(tls_config)))
310}
311
312/// The keyring service name used by nono for all credentials.
313/// Uses the same constant as `nono::keystore::DEFAULT_SERVICE` to ensure consistency.
314const KEYRING_SERVICE: &str = nono::keystore::DEFAULT_SERVICE;
315
316#[cfg(test)]
317#[allow(clippy::unwrap_used)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_empty_credential_store() {
323        let store = CredentialStore::empty();
324        assert!(store.is_empty());
325        assert_eq!(store.len(), 0);
326        assert!(store.get("/openai").is_none());
327    }
328
329    #[test]
330    fn test_loaded_credential_debug_redacts_secrets() {
331        // Security: Debug output must NEVER contain real secret values.
332        // This prevents accidental leakage in logs, panic messages, or
333        // tracing output at debug level.
334        let cred = LoadedCredential {
335            inject_mode: InjectMode::Header,
336            upstream: "https://api.openai.com".to_string(),
337            raw_credential: Zeroizing::new("sk-secret-12345".to_string()),
338            header_name: "Authorization".to_string(),
339            header_value: Zeroizing::new("Bearer sk-secret-12345".to_string()),
340            path_pattern: None,
341            path_replacement: None,
342            query_param_name: None,
343            endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
344            tls_connector: None,
345        };
346
347        let debug_output = format!("{:?}", cred);
348
349        // Must contain REDACTED markers
350        assert!(
351            debug_output.contains("[REDACTED]"),
352            "Debug output should contain [REDACTED], got: {}",
353            debug_output
354        );
355        // Must NOT contain the actual secret
356        assert!(
357            !debug_output.contains("sk-secret-12345"),
358            "Debug output must not contain the real secret"
359        );
360        assert!(
361            !debug_output.contains("Bearer sk-secret"),
362            "Debug output must not contain the formatted secret"
363        );
364        // Non-secret fields should still be visible
365        assert!(debug_output.contains("api.openai.com"));
366        assert!(debug_output.contains("Authorization"));
367    }
368
369    #[test]
370    fn test_extract_host_port_https_no_port() {
371        assert_eq!(
372            extract_host_port("https://api.openai.com"),
373            Some("api.openai.com:443".to_string())
374        );
375    }
376
377    #[test]
378    fn test_extract_host_port_https_with_port() {
379        assert_eq!(
380            extract_host_port("https://api.openai.com:8443"),
381            Some("api.openai.com:8443".to_string())
382        );
383    }
384
385    #[test]
386    fn test_extract_host_port_http_no_port() {
387        assert_eq!(
388            extract_host_port("http://internal:4096"),
389            Some("internal:4096".to_string())
390        );
391    }
392
393    #[test]
394    fn test_extract_host_port_http_default_port() {
395        assert_eq!(
396            extract_host_port("http://internal-service"),
397            Some("internal-service:80".to_string())
398        );
399    }
400
401    #[test]
402    fn test_extract_host_port_normalises_case() {
403        assert_eq!(
404            extract_host_port("https://GitLab-PRD.Home.Example.COM"),
405            Some("gitlab-prd.home.example.com:443".to_string())
406        );
407    }
408
409    #[test]
410    fn test_extract_host_port_with_path() {
411        assert_eq!(
412            extract_host_port("https://api.example.com/v1/endpoint"),
413            Some("api.example.com:443".to_string())
414        );
415    }
416
417    #[test]
418    fn test_extract_host_port_no_scheme() {
419        assert_eq!(extract_host_port("api.openai.com"), None);
420    }
421
422    #[test]
423    fn test_is_credential_upstream() {
424        let mut credentials = HashMap::new();
425        credentials.insert(
426            "gitlab".to_string(),
427            LoadedCredential {
428                inject_mode: InjectMode::Header,
429                upstream: "https://gitlab.example.com".to_string(),
430                raw_credential: Zeroizing::new("token".to_string()),
431                header_name: "PRIVATE-TOKEN".to_string(),
432                header_value: Zeroizing::new("token".to_string()),
433                path_pattern: None,
434                path_replacement: None,
435                query_param_name: None,
436                endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
437                tls_connector: None,
438            },
439        );
440        let store = CredentialStore { credentials };
441
442        assert!(store.is_credential_upstream("gitlab.example.com:443"));
443        assert!(!store.is_credential_upstream("unrelated.example.com:443"));
444    }
445
446    #[test]
447    fn test_is_credential_upstream_empty_store() {
448        let store = CredentialStore::empty();
449        assert!(!store.is_credential_upstream("anything:443"));
450    }
451
452    #[test]
453    fn test_credential_upstream_hosts() {
454        let mut credentials = HashMap::new();
455        credentials.insert(
456            "gitlab".to_string(),
457            LoadedCredential {
458                inject_mode: InjectMode::Header,
459                upstream: "https://gitlab.example.com".to_string(),
460                raw_credential: Zeroizing::new("token".to_string()),
461                header_name: "PRIVATE-TOKEN".to_string(),
462                header_value: Zeroizing::new("token".to_string()),
463                path_pattern: None,
464                path_replacement: None,
465                query_param_name: None,
466                endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
467                tls_connector: None,
468            },
469        );
470        let store = CredentialStore { credentials };
471
472        let hosts = store.credential_upstream_hosts();
473        assert!(hosts.contains("gitlab.example.com:443"));
474        assert_eq!(hosts.len(), 1);
475    }
476
477    #[test]
478    fn test_load_no_credential_routes() {
479        let routes = vec![RouteConfig {
480            prefix: "/test".to_string(),
481            upstream: "https://example.com".to_string(),
482            credential_key: None,
483            inject_mode: InjectMode::Header,
484            inject_header: "Authorization".to_string(),
485            credential_format: "Bearer {}".to_string(),
486            path_pattern: None,
487            path_replacement: None,
488            query_param_name: None,
489            env_var: None,
490            endpoint_rules: vec![],
491            tls_ca: None,
492        }];
493        let store = CredentialStore::load(&routes);
494        assert!(store.is_ok());
495        let store = store.unwrap_or_else(|_| CredentialStore::empty());
496        assert!(store.is_empty());
497    }
498
499    /// Self-signed CA for testing. Generated with:
500    /// openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \
501    ///   -keyout /dev/null -nodes -days 36500 -subj '/CN=nono-test-ca' -out -
502    const TEST_CA_PEM: &str = "\
503-----BEGIN CERTIFICATE-----
504MIIBnjCCAUWgAwIBAgIUT0bpOJJvHdOdZt+gW1stR8VBgXowCgYIKoZIzj0EAwIw
505FzEVMBMGA1UEAwwMbm9uby10ZXN0LWNhMCAXDTI1MDEwMTAwMDAwMFoYDzIxMjQx
506MjA3MDAwMDAwWjAXMRUwEwYDVQQDDAxub25vLXRlc3QtY2EwWTATBgcqhkjOPQIB
507BggqhkjOPQMBBwNCAAR8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
508AAAAAAAAAAAAAAAAAAAAo1MwUTAdBgNVHQ4EFgQUAAAAAAAAAAAAAAAAAAAAAAAA
509AAAAMB8GA1UdIwQYMBaAFAAAAAAAAAAAAAAAAAAAAAAAAAAAADAPBgNVHRMBAf8E
510BTADAQH/MAoGCCqGSM49BAMCA0cAMEQCIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
511AAAAAAAICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
512-----END CERTIFICATE-----";
513
514    #[test]
515    fn test_build_tls_connector_with_valid_ca() {
516        let dir = tempfile::tempdir().unwrap();
517        let ca_path = dir.path().join("ca.pem");
518        std::fs::write(&ca_path, TEST_CA_PEM).unwrap();
519
520        // The test CA has dummy key material so rustls will reject it,
521        // but we test the file-reading and PEM-parsing path separately.
522        // A valid CA cert would succeed; here we verify the error is from
523        // certificate validation, not file I/O or PEM parsing.
524        let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
525        // Either succeeds (if rustls accepts the cert) or fails with a
526        // certificate validation error — both are acceptable since we're
527        // testing the plumbing, not the cert content.
528        match result {
529            Ok(connector) => {
530                // Connector was built — custom CA was accepted
531                drop(connector);
532            }
533            Err(ProxyError::Config(msg)) => {
534                // Expected: invalid certificate content in test fixture
535                assert!(
536                    msg.contains("invalid CA certificate") || msg.contains("CA certificate"),
537                    "unexpected error: {}",
538                    msg
539                );
540            }
541            Err(e) => panic!("unexpected error type: {}", e),
542        }
543    }
544
545    #[test]
546    fn test_build_tls_connector_missing_file() {
547        let result = build_tls_connector_with_ca("/nonexistent/path/ca.pem");
548        let err = result
549            .err()
550            .expect("should fail for missing file")
551            .to_string();
552        assert!(
553            err.contains("CA certificate file not found"),
554            "unexpected error: {}",
555            err
556        );
557    }
558
559    #[test]
560    fn test_build_tls_connector_empty_pem() {
561        let dir = tempfile::tempdir().unwrap();
562        let ca_path = dir.path().join("empty.pem");
563        std::fs::write(&ca_path, "not a certificate\n").unwrap();
564
565        let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
566        let err = result
567            .err()
568            .expect("should fail for invalid PEM")
569            .to_string();
570        assert!(
571            err.contains("no valid PEM certificates"),
572            "unexpected error: {}",
573            err
574        );
575    }
576
577    #[test]
578    fn test_build_tls_connector_empty_file() {
579        let dir = tempfile::tempdir().unwrap();
580        let ca_path = dir.path().join("empty.pem");
581        std::fs::write(&ca_path, "").unwrap();
582
583        let result = build_tls_connector_with_ca(ca_path.to_str().unwrap());
584        let err = result
585            .err()
586            .expect("should fail for empty file")
587            .to_string();
588        assert!(
589            err.contains("no valid PEM certificates"),
590            "unexpected error: {}",
591            err
592        );
593    }
594}