1use crate::config::{InjectMode, RouteConfig};
13use crate::error::{ProxyError, Result};
14use crate::oauth2::{OAuth2ExchangeConfig, TokenCache};
15use base64::Engine;
16use std::collections::HashMap;
17use tokio_rustls::TlsConnector;
18use tracing::{debug, warn};
19use zeroize::Zeroizing;
20
21pub struct LoadedCredential {
27 pub inject_mode: InjectMode,
29 pub proxy_inject_mode: InjectMode,
31 pub raw_credential: Zeroizing<String>,
33
34 pub header_name: String,
37 pub proxy_header_name: String,
39 pub header_value: Zeroizing<String>,
41
42 pub path_pattern: Option<String>,
45 pub proxy_path_pattern: Option<String>,
47 pub path_replacement: Option<String>,
49
50 pub query_param_name: Option<String>,
53 pub proxy_query_param_name: Option<String>,
55}
56
57impl std::fmt::Debug for LoadedCredential {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("LoadedCredential")
62 .field("inject_mode", &self.inject_mode)
63 .field("proxy_inject_mode", &self.proxy_inject_mode)
64 .field("raw_credential", &"[REDACTED]")
65 .field("header_name", &self.header_name)
66 .field("proxy_header_name", &self.proxy_header_name)
67 .field("header_value", &"[REDACTED]")
68 .field("path_pattern", &self.path_pattern)
69 .field("proxy_path_pattern", &self.proxy_path_pattern)
70 .field("path_replacement", &self.path_replacement)
71 .field("query_param_name", &self.query_param_name)
72 .field("proxy_query_param_name", &self.proxy_query_param_name)
73 .finish()
74 }
75}
76
77#[derive(Debug)]
79pub struct OAuth2Route {
80 pub cache: TokenCache,
82 pub upstream: String,
84}
85
86#[derive(Debug)]
88pub struct CredentialStore {
89 credentials: HashMap<String, LoadedCredential>,
91 oauth2_routes: HashMap<String, OAuth2Route>,
93}
94
95impl CredentialStore {
96 pub fn load(routes: &[RouteConfig], tls_connector: &TlsConnector) -> Result<Self> {
111 let mut credentials = HashMap::new();
112 let mut oauth2_routes = HashMap::new();
113
114 for route in routes {
115 let normalized_prefix = route.prefix.trim_matches('/').to_string();
119 if let Some(ref key) = route.credential_key {
120 debug!(
121 "Loading credential for route prefix: {} (mode: {:?})",
122 normalized_prefix, route.inject_mode
123 );
124
125 let secret = match nono::keystore::load_secret_by_ref(KEYRING_SERVICE, key) {
126 Ok(s) => s,
127 Err(nono::NonoError::SecretNotFound(_)) => {
128 let hint = if !key.contains("://") && cfg!(target_os = "macos") {
129 format!(
130 " To add it to the macOS keychain: security add-generic-password -s \"nono\" -a \"{}\" -w",
131 key
132 )
133 } else {
134 String::new()
135 };
136 warn!(
137 "Credential '{}' not found for route '{}' — requests will proceed without credential injection.{}",
138 key, normalized_prefix, hint
139 );
140 continue;
141 }
142 Err(e) => return Err(ProxyError::Credential(e.to_string())),
143 };
144
145 let effective_format = if route.inject_header != "Authorization"
151 && route.credential_format == "Bearer {}"
152 {
153 "{}".to_string()
154 } else {
155 route.credential_format.clone()
156 };
157
158 let header_value = match route.inject_mode {
159 InjectMode::Header => Zeroizing::new(effective_format.replace("{}", &secret)),
160 InjectMode::BasicAuth => {
161 let encoded =
163 base64::engine::general_purpose::STANDARD.encode(secret.as_bytes());
164 Zeroizing::new(format!("Basic {}", encoded))
165 }
166 InjectMode::UrlPath | InjectMode::QueryParam => Zeroizing::new(String::new()),
168 };
169
170 credentials.insert(
171 normalized_prefix.clone(),
172 LoadedCredential {
173 inject_mode: route.inject_mode.clone(),
174 proxy_inject_mode: route
175 .proxy
176 .as_ref()
177 .and_then(|p| p.inject_mode.clone())
178 .unwrap_or_else(|| route.inject_mode.clone()),
179 raw_credential: secret,
180 header_name: route.inject_header.clone(),
181 proxy_header_name: route
182 .proxy
183 .as_ref()
184 .and_then(|p| p.inject_header.clone())
185 .unwrap_or_else(|| route.inject_header.clone()),
186 header_value,
187 path_pattern: route.path_pattern.clone(),
188 proxy_path_pattern: route
189 .proxy
190 .as_ref()
191 .and_then(|p| p.path_pattern.clone())
192 .or_else(|| route.path_pattern.clone()),
193 path_replacement: route.path_replacement.clone(),
194 query_param_name: route.query_param_name.clone(),
195 proxy_query_param_name: route
196 .proxy
197 .as_ref()
198 .and_then(|p| p.query_param_name.clone())
199 .or_else(|| route.query_param_name.clone()),
200 },
201 );
202 continue;
203 }
204
205 if let Some(ref oauth2) = route.oauth2 {
207 debug!(
208 "Loading OAuth2 credential for route prefix: {}",
209 route.prefix
210 );
211
212 let client_id =
213 match nono::keystore::load_secret_by_ref(KEYRING_SERVICE, &oauth2.client_id) {
214 Ok(s) => s,
215 Err(nono::NonoError::SecretNotFound(msg)) => {
216 debug!(
217 "OAuth2 client_id not available for route '{}': {}",
218 route.prefix, msg
219 );
220 continue;
221 }
222 Err(e) => return Err(ProxyError::Credential(e.to_string())),
223 };
224
225 let client_secret = match nono::keystore::load_secret_by_ref(
226 KEYRING_SERVICE,
227 &oauth2.client_secret,
228 ) {
229 Ok(s) => s,
230 Err(nono::NonoError::SecretNotFound(msg)) => {
231 debug!(
232 "OAuth2 client_secret not available for route '{}': {}",
233 route.prefix, msg
234 );
235 continue;
236 }
237 Err(e) => return Err(ProxyError::Credential(e.to_string())),
238 };
239
240 let config = OAuth2ExchangeConfig {
241 token_url: oauth2.token_url.clone(),
242 client_id,
243 client_secret,
244 scope: oauth2.scope.clone(),
245 };
246
247 match TokenCache::new(config, tls_connector.clone()) {
248 Ok(cache) => {
249 oauth2_routes.insert(
250 route.prefix.clone(),
251 OAuth2Route {
252 cache,
253 upstream: route.upstream.clone(),
254 },
255 );
256 }
257 Err(e) => {
258 debug!(
259 "OAuth2 token exchange failed for route '{}': {}, skipping",
260 route.prefix, e
261 );
262 continue;
263 }
264 }
265 }
266 }
267
268 Ok(Self {
269 credentials,
270 oauth2_routes,
271 })
272 }
273
274 #[must_use]
276 pub fn empty() -> Self {
277 Self {
278 credentials: HashMap::new(),
279 oauth2_routes: HashMap::new(),
280 }
281 }
282
283 #[must_use]
285 pub fn get(&self, prefix: &str) -> Option<&LoadedCredential> {
286 self.credentials.get(prefix)
287 }
288
289 #[must_use]
291 pub fn get_oauth2(&self, prefix: &str) -> Option<&OAuth2Route> {
292 self.oauth2_routes.get(prefix)
293 }
294
295 #[must_use]
297 pub fn is_empty(&self) -> bool {
298 self.credentials.is_empty() && self.oauth2_routes.is_empty()
299 }
300
301 #[must_use]
303 pub fn len(&self) -> usize {
304 self.credentials.len() + self.oauth2_routes.len()
305 }
306
307 #[must_use]
310 pub fn loaded_prefixes(&self) -> std::collections::HashSet<String> {
311 self.credentials
312 .keys()
313 .chain(self.oauth2_routes.keys())
314 .cloned()
315 .collect()
316 }
317}
318
319const KEYRING_SERVICE: &str = nono::keystore::DEFAULT_SERVICE;
322
323#[cfg(test)]
324#[allow(clippy::unwrap_used)]
325mod tests {
326 use super::*;
327 use std::sync::{Arc, Mutex};
328
329 static ENV_LOCK: Mutex<()> = Mutex::new(());
330
331 struct EnvVarGuard {
332 original: Vec<(&'static str, Option<String>)>,
333 }
334
335 #[allow(clippy::disallowed_methods)]
336 impl EnvVarGuard {
337 fn set_all(vars: &[(&'static str, &str)]) -> Self {
338 let original = vars
339 .iter()
340 .map(|(key, _)| (*key, std::env::var(key).ok()))
341 .collect::<Vec<_>>();
342
343 for (key, value) in vars {
344 std::env::set_var(key, value);
345 }
346
347 Self { original }
348 }
349 }
350
351 #[allow(clippy::disallowed_methods)]
352 impl Drop for EnvVarGuard {
353 fn drop(&mut self) {
354 for (key, value) in self.original.iter().rev() {
355 match value {
356 Some(value) => std::env::set_var(key, value),
357 None => std::env::remove_var(key),
358 }
359 }
360 }
361 }
362
363 fn test_tls_connector() -> TlsConnector {
365 let mut root_store = rustls::RootCertStore::empty();
366 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
367 let tls_config = rustls::ClientConfig::builder_with_provider(Arc::new(
368 rustls::crypto::ring::default_provider(),
369 ))
370 .with_safe_default_protocol_versions()
371 .unwrap()
372 .with_root_certificates(root_store)
373 .with_no_client_auth();
374 TlsConnector::from(Arc::new(tls_config))
375 }
376
377 #[test]
378 fn test_empty_credential_store() {
379 let store = CredentialStore::empty();
380 assert!(store.is_empty());
381 assert_eq!(store.len(), 0);
382 assert!(store.get("openai").is_none());
383 assert!(store.get("/openai").is_none());
384 assert!(store.get_oauth2("/openai").is_none());
385 }
386
387 #[test]
388 fn test_loaded_credential_debug_redacts_secrets() {
389 let cred = LoadedCredential {
393 inject_mode: InjectMode::Header,
394 proxy_inject_mode: InjectMode::Header,
395 raw_credential: Zeroizing::new("sk-secret-12345".to_string()),
396 header_name: "Authorization".to_string(),
397 proxy_header_name: "Authorization".to_string(),
398 header_value: Zeroizing::new("Bearer sk-secret-12345".to_string()),
399 path_pattern: None,
400 proxy_path_pattern: None,
401 path_replacement: None,
402 query_param_name: None,
403 proxy_query_param_name: None,
404 };
405
406 let debug_output = format!("{:?}", cred);
407
408 assert!(
410 debug_output.contains("[REDACTED]"),
411 "Debug output should contain [REDACTED], got: {}",
412 debug_output
413 );
414 assert!(
416 !debug_output.contains("sk-secret-12345"),
417 "Debug output must not contain the real secret"
418 );
419 assert!(
420 !debug_output.contains("Bearer sk-secret"),
421 "Debug output must not contain the formatted secret"
422 );
423 assert!(debug_output.contains("Authorization"));
425 }
426
427 #[test]
428 fn test_load_no_credential_routes() {
429 let tls = test_tls_connector();
430 let routes = vec![RouteConfig {
431 prefix: "/test".to_string(),
432 upstream: "https://example.com".to_string(),
433 credential_key: None,
434 inject_mode: InjectMode::Header,
435 inject_header: "Authorization".to_string(),
436 credential_format: "Bearer {}".to_string(),
437 path_pattern: None,
438 path_replacement: None,
439 query_param_name: None,
440 proxy: None,
441 env_var: None,
442 endpoint_rules: vec![],
443 tls_ca: None,
444 tls_client_cert: None,
445 tls_client_key: None,
446 oauth2: None,
447 }];
448 let store = CredentialStore::load(&routes, &tls);
449 assert!(store.is_ok());
450 let store = store.unwrap_or_else(|_| CredentialStore::empty());
451 assert!(store.is_empty());
452 }
453
454 #[test]
455 fn test_get_oauth2_returns_none_for_non_oauth2_routes() {
456 let store = CredentialStore::empty();
457 assert!(store.get_oauth2("openai").is_none());
458 assert!(store.get_oauth2("my-api").is_none());
459 }
460
461 #[test]
462 fn test_is_empty_false_with_only_oauth2_routes() {
463 use std::time::Duration;
467
468 let cache = make_test_token_cache("test-token", Duration::from_secs(3600));
469 let mut oauth2_routes = HashMap::new();
470 oauth2_routes.insert(
471 "my-api".to_string(),
472 OAuth2Route {
473 cache,
474 upstream: "https://api.example.com".to_string(),
475 },
476 );
477
478 let store = CredentialStore {
479 credentials: HashMap::new(),
480 oauth2_routes,
481 };
482
483 assert!(
484 !store.is_empty(),
485 "store with OAuth2 routes should not be empty"
486 );
487 assert_eq!(store.len(), 1);
488 assert!(store.get_oauth2("my-api").is_some());
489 assert!(store.get("my-api").is_none());
490 }
491
492 #[test]
493 fn test_loaded_prefixes_includes_oauth2() {
494 use std::time::Duration;
495
496 let cache = make_test_token_cache("test-token", Duration::from_secs(3600));
497 let mut oauth2_routes = HashMap::new();
498 oauth2_routes.insert(
499 "my-api".to_string(),
500 OAuth2Route {
501 cache,
502 upstream: "https://api.example.com".to_string(),
503 },
504 );
505
506 let store = CredentialStore {
507 credentials: HashMap::new(),
508 oauth2_routes,
509 };
510
511 let prefixes = store.loaded_prefixes();
512 assert!(prefixes.contains("my-api"));
513 }
514
515 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
516 async fn test_load_oauth2_unreachable_endpoint_skips_route() {
517 use crate::config::OAuth2Config;
518
519 let _lock = ENV_LOCK.lock().unwrap();
520 let _env = EnvVarGuard::set_all(&[
521 ("TEST_OAUTH2_CLIENT_ID", "test-client"),
522 ("TEST_OAUTH2_CLIENT_SECRET", "test-secret"),
523 ]);
524 let tls = test_tls_connector();
525 let routes = vec![RouteConfig {
526 prefix: "my-api".to_string(),
527 upstream: "https://api.example.com".to_string(),
528 credential_key: None,
529 inject_mode: InjectMode::Header,
530 inject_header: "Authorization".to_string(),
531 credential_format: "Bearer {}".to_string(),
532 path_pattern: None,
533 path_replacement: None,
534 query_param_name: None,
535 proxy: None,
536 env_var: Some("MY_API_KEY".to_string()),
537 endpoint_rules: vec![],
538 tls_ca: None,
539 tls_client_cert: None,
540 tls_client_key: None,
541 oauth2: Some(OAuth2Config {
542 token_url: "https://127.0.0.1:1/oauth/token".to_string(),
544 client_id: "env://TEST_OAUTH2_CLIENT_ID".to_string(),
546 client_secret: "env://TEST_OAUTH2_CLIENT_SECRET".to_string(),
547 scope: String::new(),
548 }),
549 }];
550
551 let store = CredentialStore::load(&routes, &tls);
552
553 assert!(
555 store.is_ok(),
556 "load should not fail on unreachable OAuth2 endpoint"
557 );
558 let store = store.unwrap();
559
560 assert!(
562 store.is_empty(),
563 "unreachable OAuth2 endpoint should result in skipped route"
564 );
565 assert!(store.get_oauth2("my-api").is_none());
566 }
567
568 fn make_test_token_cache(token: &str, ttl: std::time::Duration) -> TokenCache {
570 use crate::oauth2::OAuth2ExchangeConfig;
571
572 let config = OAuth2ExchangeConfig {
573 token_url: "https://127.0.0.1:1/oauth/token".to_string(),
574 client_id: Zeroizing::new("test-client".to_string()),
575 client_secret: Zeroizing::new("test-secret".to_string()),
576 scope: String::new(),
577 };
578
579 TokenCache::new_from_parts(config, test_tls_connector(), token, ttl)
580 }
581}