1use std::collections::HashMap;
2
3#[derive(Clone, PartialEq, Eq)]
5pub enum CredentialSource {
6 AuthorizationHeader,
8 QueryParam { param: String },
10 Cookie { name: String },
12}
13
14impl CredentialSource {
15 pub fn variant_name(&self) -> &'static str {
17 match self {
18 CredentialSource::AuthorizationHeader => "AuthorizationHeader",
19 CredentialSource::QueryParam { .. } => "QueryParam",
20 CredentialSource::Cookie { .. } => "Cookie",
21 }
22 }
23}
24
25impl std::fmt::Debug for CredentialSource {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 CredentialSource::AuthorizationHeader => f.write_str("AuthorizationHeader"),
29 CredentialSource::QueryParam { param } => {
30 write!(f, "QueryParam {{ param: {:?} }}", param) }
32 CredentialSource::Cookie { name } => {
33 write!(f, "Cookie {{ name: {:?} }}", name) }
35 }
36 }
37}
38
39#[derive(Clone)]
41pub struct ExtractedToken {
42 pub token: String,
43 pub source: CredentialSource,
44}
45
46impl std::fmt::Debug for ExtractedToken {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("ExtractedToken")
49 .field("token", &"[REDACTED]")
50 .field("source", &self.source)
51 .finish()
52 }
53}
54
55fn percent_decode_str(s: &str) -> String {
57 percent_encoding::percent_decode_str(s)
58 .decode_utf8()
59 .unwrap_or_else(|_| s.into())
60 .into_owned()
61}
62
63pub fn extract_token_from_header(headers: &http::HeaderMap) -> Option<String> {
68 let value = headers.get(http::header::AUTHORIZATION)?;
69 crate::extract_bearer_token(value.to_str().ok()?)
70 .ok()
71 .flatten()
72}
73
74pub fn extract_token_from_query(uri: &http::Uri, param: &str) -> Option<String> {
79 let query = uri.query()?;
80 let pairs = parse_query_string(query);
81 pairs.get(param).map(|v| percent_decode_str(v))
82}
83
84pub fn extract_token_from_cookie(headers: &http::HeaderMap, cookie_name: &str) -> Option<String> {
89 let cookie_header = headers.get(http::header::COOKIE)?;
90 let cookie_str = cookie_header.to_str().ok()?;
91 parse_cookie_header(cookie_str)
92 .get(cookie_name)
93 .map(|v| percent_decode_str(v))
94}
95
96pub fn extract_token_multi(
98 headers: &http::HeaderMap,
99 uri: &http::Uri,
100 sources: &[CredentialSource],
101) -> Option<ExtractedToken> {
102 for source in sources {
103 match source {
104 CredentialSource::AuthorizationHeader => {
105 if let Some(token) = extract_token_from_header(headers) {
106 return Some(ExtractedToken {
107 token,
108 source: source.clone(),
109 });
110 }
111 }
112 CredentialSource::QueryParam { param } => {
113 if let Some(token) = extract_token_from_query(uri, param) {
114 return Some(ExtractedToken {
115 token,
116 source: source.clone(),
117 });
118 }
119 }
120 CredentialSource::Cookie { name } => {
121 if let Some(token) = extract_token_from_cookie(headers, name) {
122 return Some(ExtractedToken {
123 token,
124 source: source.clone(),
125 });
126 }
127 }
128 }
129 }
130 None
131}
132
133pub fn redact_query_params(uri: &http::Uri, sensitive_params: &[&str]) -> String {
137 let query = match uri.query() {
138 Some(q) => q,
139 None => return uri.to_string(),
140 };
141
142 let redacted: Vec<String> = query
143 .split('&')
144 .map(|pair| {
145 if let Some((key, _value)) = pair.split_once('=') {
146 if sensitive_params.iter().any(|s| s == &key) {
147 format!("{}=[REDACTED]", key)
148 } else {
149 pair.to_string()
150 }
151 } else {
152 pair.to_string()
153 }
154 })
155 .collect();
156
157 let base = uri.path();
158 if redacted.is_empty() {
159 base.to_string()
160 } else {
161 format!("{}?{}", base, redacted.join("&"))
162 }
163}
164
165fn parse_query_string(query: &str) -> HashMap<String, String> {
168 let mut map = HashMap::new();
169 for pair in query.split('&') {
170 if let Some((key, value)) = pair.split_once('=') {
171 map.insert(key.to_string(), value.to_string());
172 }
173 }
174 map
175}
176
177fn parse_cookie_header(cookie_str: &str) -> HashMap<String, String> {
178 let mut map = HashMap::new();
179 for pair in cookie_str.split(';') {
180 let pair = pair.trim();
181 if let Some((key, value)) = pair.split_once('=') {
182 map.insert(key.trim().to_string(), value.trim().to_string());
183 }
184 }
185 map
186}
187
188#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
197 fn debug_authorization_header_shows_variant_name() {
198 let source = CredentialSource::AuthorizationHeader;
199 assert_eq!(format!("{:?}", source), "AuthorizationHeader");
200 }
201
202 #[test]
203 fn debug_query_param_shows_param_name() {
204 let source = CredentialSource::QueryParam {
205 param: "token".to_string(),
206 };
207 assert_eq!(format!("{:?}", source), "QueryParam { param: \"token\" }"); }
209
210 #[test]
211 fn debug_cookie_shows_cookie_name() {
212 let source = CredentialSource::Cookie {
213 name: "session".to_string(),
214 };
215 assert_eq!(format!("{:?}", source), "Cookie { name: \"session\" }");
216 }
217
218 #[test]
219 fn credential_source_clone() {
220 let original = CredentialSource::QueryParam {
221 param: "access_token".to_string(),
222 };
223 let cloned = original.clone();
224 assert_eq!(format!("{:?}", original), format!("{:?}", cloned));
225 }
226
227 #[test]
230 fn extract_header_valid_bearer() {
231 let mut headers = http::HeaderMap::new();
232 headers.insert(
233 http::header::AUTHORIZATION,
234 "Bearer mytoken123".parse().unwrap(),
235 );
236 let token = extract_token_from_header(&headers);
237 assert_eq!(token, Some("mytoken123".to_string()));
238 }
239
240 #[test]
241 fn extract_header_missing_returns_none() {
242 let headers = http::HeaderMap::new();
243 let token = extract_token_from_header(&headers);
244 assert!(token.is_none());
245 }
246
247 #[test]
248 fn extract_header_non_bearer_returns_none() {
249 let mut headers = http::HeaderMap::new();
250 headers.insert(http::header::AUTHORIZATION, "Basic abc123".parse().unwrap());
251 let token = extract_token_from_header(&headers);
252 assert!(token.is_none());
253 }
254
255 #[test]
258 fn extract_query_valid_token() {
259 let uri: http::Uri = "/ws?token=abc123".parse().unwrap();
260 let token = extract_token_from_query(&uri, "token");
261 assert_eq!(token, Some("abc123".to_string()));
262 }
263
264 #[test]
265 fn extract_query_missing_param_returns_none() {
266 let uri: http::Uri = "/ws?other=value".parse().unwrap();
267 let token = extract_token_from_query(&uri, "token");
268 assert!(token.is_none());
269 }
270
271 #[test]
272 fn extract_query_no_query_string_returns_none() {
273 let uri: http::Uri = "/ws".parse().unwrap();
274 let token = extract_token_from_query(&uri, "token");
275 assert!(token.is_none());
276 }
277
278 #[test]
279 fn extract_query_percent_encoded() {
280 let uri: http::Uri = "/ws?token=abc%2Bdef".parse().unwrap();
281 let token = extract_token_from_query(&uri, "token");
282 assert_eq!(token, Some("abc+def".to_string()));
283 }
284
285 #[test]
286 fn extract_query_multiple_params() {
287 let uri: http::Uri = "/ws?foo=bar&token=secret&baz=qux".parse().unwrap();
288 let token = extract_token_from_query(&uri, "token");
289 assert_eq!(token, Some("secret".to_string()));
290 }
291
292 #[test]
295 fn extract_cookie_valid() {
296 let mut headers = http::HeaderMap::new();
297 headers.insert(
298 http::header::COOKIE,
299 "session=cookie_token_123".parse().unwrap(),
300 );
301 let token = extract_token_from_cookie(&headers, "session");
302 assert_eq!(token, Some("cookie_token_123".to_string()));
303 }
304
305 #[test]
306 fn extract_cookie_missing_returns_none() {
307 let mut headers = http::HeaderMap::new();
308 headers.insert(http::header::COOKIE, "other=value".parse().unwrap());
309 let token = extract_token_from_cookie(&headers, "session");
310 assert!(token.is_none());
311 }
312
313 #[test]
314 fn extract_cookie_no_cookie_header_returns_none() {
315 let headers = http::HeaderMap::new();
316 let token = extract_token_from_cookie(&headers, "session");
317 assert!(token.is_none());
318 }
319
320 #[test]
321 fn extract_cookie_multiple_cookies() {
322 let mut headers = http::HeaderMap::new();
323 headers.insert(
324 http::header::COOKIE,
325 "foo=bar; auth_token=mycookie; baz=qux".parse().unwrap(),
326 );
327 let token = extract_token_from_cookie(&headers, "auth_token");
328 assert_eq!(token, Some("mycookie".to_string()));
329 }
330
331 #[test]
332 fn extract_cookie_with_spaces() {
333 let mut headers = http::HeaderMap::new();
334 headers.insert(
335 http::header::COOKIE,
336 "foo=bar; auth_token=spaced_token ; baz=qux"
337 .parse()
338 .unwrap(),
339 );
340 let token = extract_token_from_cookie(&headers, "auth_token");
341 assert_eq!(token, Some("spaced_token".to_string()));
342 }
343
344 #[test]
347 fn multi_falls_back_from_header_to_query() {
348 let headers = http::HeaderMap::new();
349 let uri: http::Uri = "/ws?token=query_token".parse().unwrap();
350 let sources = vec![
351 CredentialSource::AuthorizationHeader,
352 CredentialSource::QueryParam {
353 param: "token".to_string(),
354 },
355 ];
356 let result = extract_token_multi(&headers, &uri, &sources);
357 assert!(result.is_some());
358 let extracted = result.unwrap();
359 assert_eq!(extracted.token, "query_token");
360 assert!(matches!(
361 extracted.source,
362 CredentialSource::QueryParam { .. }
363 ));
364 }
365
366 #[test]
367 fn multi_prefers_first_matching_source() {
368 let mut headers = http::HeaderMap::new();
369 headers.insert(
370 http::header::AUTHORIZATION,
371 "Bearer header_token".parse().unwrap(),
372 );
373 let uri: http::Uri = "/ws?token=query_token".parse().unwrap();
374 let sources = vec![
375 CredentialSource::AuthorizationHeader,
376 CredentialSource::QueryParam {
377 param: "token".to_string(),
378 },
379 ];
380 let result = extract_token_multi(&headers, &uri, &sources);
381 assert!(result.is_some());
382 let extracted = result.unwrap();
383 assert_eq!(extracted.token, "header_token");
384 assert!(matches!(
385 extracted.source,
386 CredentialSource::AuthorizationHeader
387 ));
388 }
389
390 #[test]
391 fn multi_falls_back_to_cookie() {
392 let mut headers = http::HeaderMap::new();
393 headers.insert(
394 http::header::COOKIE,
395 "session=cookie_token".parse().unwrap(),
396 );
397 let uri: http::Uri = "/ws".parse().unwrap();
398 let sources = vec![
399 CredentialSource::AuthorizationHeader,
400 CredentialSource::QueryParam {
401 param: "token".to_string(),
402 },
403 CredentialSource::Cookie {
404 name: "session".to_string(),
405 },
406 ];
407 let result = extract_token_multi(&headers, &uri, &sources);
408 assert!(result.is_some());
409 let extracted = result.unwrap();
410 assert_eq!(extracted.token, "cookie_token");
411 assert!(matches!(extracted.source, CredentialSource::Cookie { .. }));
412 }
413
414 #[test]
415 fn multi_returns_none_when_all_fail() {
416 let headers = http::HeaderMap::new();
417 let uri: http::Uri = "/ws".parse().unwrap();
418 let sources = vec![
419 CredentialSource::AuthorizationHeader,
420 CredentialSource::QueryParam {
421 param: "token".to_string(),
422 },
423 CredentialSource::Cookie {
424 name: "session".to_string(),
425 },
426 ];
427 let result = extract_token_multi(&headers, &uri, &sources);
428 assert!(result.is_none());
429 }
430
431 #[test]
434 fn redact_single_sensitive_param() {
435 let uri: http::Uri = "/ws?token=secret&foo=bar".parse().unwrap();
436 let redacted = redact_query_params(&uri, &["token"]);
437 assert_eq!(redacted, "/ws?token=[REDACTED]&foo=bar");
438 }
439
440 #[test]
441 fn redact_multiple_sensitive_params() {
442 let uri: http::Uri = "/ws?token=secret&password=pass123&foo=bar".parse().unwrap();
443 let redacted = redact_query_params(&uri, &["token", "password"]);
444 assert_eq!(redacted, "/ws?token=[REDACTED]&password=[REDACTED]&foo=bar");
445 }
446
447 #[test]
448 fn redact_no_sensitive_params_in_uri() {
449 let uri: http::Uri = "/ws?foo=bar&baz=qux".parse().unwrap();
450 let redacted = redact_query_params(&uri, &["token"]);
451 assert_eq!(redacted, "/ws?foo=bar&baz=qux");
452 }
453
454 #[test]
455 fn redact_no_query_string_returns_uri_as_is() {
456 let uri: http::Uri = "/ws".parse().unwrap();
457 let redacted = redact_query_params(&uri, &["token"]);
458 assert_eq!(redacted, "/ws");
459 }
460
461 #[test]
464 fn percent_decode_plus_sign() {
465 assert_eq!(percent_decode_str("hello%2Bworld"), "hello+world");
466 }
467
468 #[test]
469 fn percent_decode_space() {
470 assert_eq!(percent_decode_str("hello%20world"), "hello world");
471 }
472
473 #[test]
474 fn percent_decode_no_encoding_returns_original() {
475 assert_eq!(percent_decode_str("plaintext"), "plaintext");
476 }
477
478 #[test]
479 fn extracted_token_debug_redacts_token() {
480 let token = ExtractedToken {
481 token: "super-secret-jwt-value".to_string(),
482 source: CredentialSource::AuthorizationHeader,
483 };
484 let debug = format!("{token:?}"); assert!(!debug.contains("super-secret-jwt-value"));
486 assert!(debug.contains("[REDACTED]"));
487 }
488}