1use crate::config::{CompiledEndpointRules, InjectMode, RouteConfig};
9use crate::error::{ProxyError, Result};
10use base64::Engine;
11use std::collections::HashMap;
12use tracing::debug;
13use zeroize::Zeroizing;
14
15pub struct LoadedCredential {
17 pub inject_mode: InjectMode,
19 pub upstream: String,
21 pub raw_credential: Zeroizing<String>,
23
24 pub header_name: String,
27 pub header_value: Zeroizing<String>,
29
30 pub path_pattern: Option<String>,
33 pub path_replacement: Option<String>,
35
36 pub query_param_name: Option<String>,
39
40 pub endpoint_rules: CompiledEndpointRules,
44}
45
46impl std::fmt::Debug for LoadedCredential {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.debug_struct("LoadedCredential")
51 .field("inject_mode", &self.inject_mode)
52 .field("upstream", &self.upstream)
53 .field("raw_credential", &"[REDACTED]")
54 .field("header_name", &self.header_name)
55 .field("header_value", &"[REDACTED]")
56 .field("path_pattern", &self.path_pattern)
57 .field("path_replacement", &self.path_replacement)
58 .field("query_param_name", &self.query_param_name)
59 .field("endpoint_rules", &self.endpoint_rules)
60 .finish()
61 }
62}
63
64#[derive(Debug)]
66pub struct CredentialStore {
67 credentials: HashMap<String, LoadedCredential>,
69}
70
71impl CredentialStore {
72 pub fn load(routes: &[RouteConfig]) -> Result<Self> {
82 let mut credentials = HashMap::new();
83
84 for route in routes {
85 if let Some(ref key) = route.credential_key {
86 debug!(
87 "Loading credential for route prefix: {} (mode: {:?})",
88 route.prefix, route.inject_mode
89 );
90
91 let secret = match nono::keystore::load_secret_by_ref(KEYRING_SERVICE, key) {
92 Ok(s) => s,
93 Err(nono::NonoError::SecretNotFound(msg)) => {
94 debug!(
95 "Credential '{}' not available, skipping route: {}",
96 route.prefix, msg
97 );
98 continue;
99 }
100 Err(e) => return Err(ProxyError::Credential(e.to_string())),
101 };
102
103 let effective_format = if route.inject_header != "Authorization"
109 && route.credential_format == "Bearer {}"
110 {
111 "{}".to_string()
112 } else {
113 route.credential_format.clone()
114 };
115
116 let header_value = match route.inject_mode {
117 InjectMode::Header => Zeroizing::new(effective_format.replace("{}", &secret)),
118 InjectMode::BasicAuth => {
119 let encoded =
121 base64::engine::general_purpose::STANDARD.encode(secret.as_bytes());
122 Zeroizing::new(format!("Basic {}", encoded))
123 }
124 InjectMode::UrlPath | InjectMode::QueryParam => Zeroizing::new(String::new()),
126 };
127
128 credentials.insert(
129 route.prefix.clone(),
130 LoadedCredential {
131 inject_mode: route.inject_mode.clone(),
132 upstream: route.upstream.clone(),
133 raw_credential: secret,
134 header_name: route.inject_header.clone(),
135 header_value,
136 path_pattern: route.path_pattern.clone(),
137 path_replacement: route.path_replacement.clone(),
138 query_param_name: route.query_param_name.clone(),
139 endpoint_rules: CompiledEndpointRules::compile(&route.endpoint_rules)
140 .map_err(|e| {
141 ProxyError::Credential(format!("route '{}': {}", route.prefix, e))
142 })?,
143 },
144 );
145 }
146 }
147
148 Ok(Self { credentials })
149 }
150
151 #[must_use]
153 pub fn empty() -> Self {
154 Self {
155 credentials: HashMap::new(),
156 }
157 }
158
159 #[must_use]
161 pub fn get(&self, prefix: &str) -> Option<&LoadedCredential> {
162 self.credentials.get(prefix)
163 }
164
165 #[must_use]
167 pub fn is_empty(&self) -> bool {
168 self.credentials.is_empty()
169 }
170
171 #[must_use]
173 pub fn len(&self) -> usize {
174 self.credentials.len()
175 }
176
177 #[must_use]
179 pub fn loaded_prefixes(&self) -> std::collections::HashSet<String> {
180 self.credentials.keys().cloned().collect()
181 }
182
183 #[must_use]
187 pub fn is_credential_upstream(&self, host_port: &str) -> bool {
188 let normalised = host_port.to_lowercase();
189 self.credentials.values().any(|cred| {
190 extract_host_port(&cred.upstream)
191 .map(|hp| hp == normalised)
192 .unwrap_or(false)
193 })
194 }
195
196 #[must_use]
200 pub fn credential_upstream_hosts(&self) -> std::collections::HashSet<String> {
201 self.credentials
202 .values()
203 .filter_map(|cred| extract_host_port(&cred.upstream))
204 .collect()
205 }
206}
207
208fn extract_host_port(url: &str) -> Option<String> {
213 let parsed = url::Url::parse(url).ok()?;
214 let host = parsed.host_str()?;
215 let default_port = match parsed.scheme() {
216 "https" => 443,
217 "http" => 80,
218 _ => return None,
219 };
220 let port = parsed.port().unwrap_or(default_port);
221 Some(format!("{}:{}", host.to_lowercase(), port))
222}
223
224const KEYRING_SERVICE: &str = nono::keystore::DEFAULT_SERVICE;
227
228#[cfg(test)]
229#[allow(clippy::unwrap_used)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_empty_credential_store() {
235 let store = CredentialStore::empty();
236 assert!(store.is_empty());
237 assert_eq!(store.len(), 0);
238 assert!(store.get("/openai").is_none());
239 }
240
241 #[test]
242 fn test_loaded_credential_debug_redacts_secrets() {
243 let cred = LoadedCredential {
247 inject_mode: InjectMode::Header,
248 upstream: "https://api.openai.com".to_string(),
249 raw_credential: Zeroizing::new("sk-secret-12345".to_string()),
250 header_name: "Authorization".to_string(),
251 header_value: Zeroizing::new("Bearer sk-secret-12345".to_string()),
252 path_pattern: None,
253 path_replacement: None,
254 query_param_name: None,
255 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
256 };
257
258 let debug_output = format!("{:?}", cred);
259
260 assert!(
262 debug_output.contains("[REDACTED]"),
263 "Debug output should contain [REDACTED], got: {}",
264 debug_output
265 );
266 assert!(
268 !debug_output.contains("sk-secret-12345"),
269 "Debug output must not contain the real secret"
270 );
271 assert!(
272 !debug_output.contains("Bearer sk-secret"),
273 "Debug output must not contain the formatted secret"
274 );
275 assert!(debug_output.contains("api.openai.com"));
277 assert!(debug_output.contains("Authorization"));
278 }
279
280 #[test]
281 fn test_extract_host_port_https_no_port() {
282 assert_eq!(
283 extract_host_port("https://api.openai.com"),
284 Some("api.openai.com:443".to_string())
285 );
286 }
287
288 #[test]
289 fn test_extract_host_port_https_with_port() {
290 assert_eq!(
291 extract_host_port("https://api.openai.com:8443"),
292 Some("api.openai.com:8443".to_string())
293 );
294 }
295
296 #[test]
297 fn test_extract_host_port_http_no_port() {
298 assert_eq!(
299 extract_host_port("http://internal:4096"),
300 Some("internal:4096".to_string())
301 );
302 }
303
304 #[test]
305 fn test_extract_host_port_http_default_port() {
306 assert_eq!(
307 extract_host_port("http://internal-service"),
308 Some("internal-service:80".to_string())
309 );
310 }
311
312 #[test]
313 fn test_extract_host_port_normalises_case() {
314 assert_eq!(
315 extract_host_port("https://GitLab-PRD.Home.Example.COM"),
316 Some("gitlab-prd.home.example.com:443".to_string())
317 );
318 }
319
320 #[test]
321 fn test_extract_host_port_with_path() {
322 assert_eq!(
323 extract_host_port("https://api.example.com/v1/endpoint"),
324 Some("api.example.com:443".to_string())
325 );
326 }
327
328 #[test]
329 fn test_extract_host_port_no_scheme() {
330 assert_eq!(extract_host_port("api.openai.com"), None);
331 }
332
333 #[test]
334 fn test_is_credential_upstream() {
335 let mut credentials = HashMap::new();
336 credentials.insert(
337 "gitlab".to_string(),
338 LoadedCredential {
339 inject_mode: InjectMode::Header,
340 upstream: "https://gitlab.example.com".to_string(),
341 raw_credential: Zeroizing::new("token".to_string()),
342 header_name: "PRIVATE-TOKEN".to_string(),
343 header_value: Zeroizing::new("token".to_string()),
344 path_pattern: None,
345 path_replacement: None,
346 query_param_name: None,
347 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
348 },
349 );
350 let store = CredentialStore { credentials };
351
352 assert!(store.is_credential_upstream("gitlab.example.com:443"));
353 assert!(!store.is_credential_upstream("unrelated.example.com:443"));
354 }
355
356 #[test]
357 fn test_is_credential_upstream_empty_store() {
358 let store = CredentialStore::empty();
359 assert!(!store.is_credential_upstream("anything:443"));
360 }
361
362 #[test]
363 fn test_credential_upstream_hosts() {
364 let mut credentials = HashMap::new();
365 credentials.insert(
366 "gitlab".to_string(),
367 LoadedCredential {
368 inject_mode: InjectMode::Header,
369 upstream: "https://gitlab.example.com".to_string(),
370 raw_credential: Zeroizing::new("token".to_string()),
371 header_name: "PRIVATE-TOKEN".to_string(),
372 header_value: Zeroizing::new("token".to_string()),
373 path_pattern: None,
374 path_replacement: None,
375 query_param_name: None,
376 endpoint_rules: CompiledEndpointRules::compile(&[]).unwrap(),
377 },
378 );
379 let store = CredentialStore { credentials };
380
381 let hosts = store.credential_upstream_hosts();
382 assert!(hosts.contains("gitlab.example.com:443"));
383 assert_eq!(hosts.len(), 1);
384 }
385
386 #[test]
387 fn test_load_no_credential_routes() {
388 let routes = vec![RouteConfig {
389 prefix: "/test".to_string(),
390 upstream: "https://example.com".to_string(),
391 credential_key: None,
392 inject_mode: InjectMode::Header,
393 inject_header: "Authorization".to_string(),
394 credential_format: "Bearer {}".to_string(),
395 path_pattern: None,
396 path_replacement: None,
397 query_param_name: None,
398 env_var: None,
399 endpoint_rules: vec![],
400 }];
401 let store = CredentialStore::load(&routes);
402 assert!(store.is_ok());
403 let store = store.unwrap_or_else(|_| CredentialStore::empty());
404 assert!(store.is_empty());
405 }
406}