1use std::sync::Arc;
2
3use http::header::{PROXY_AUTHENTICATE, WWW_AUTHENTICATE};
4use http::{Extensions, HeaderMap, Method, Request, StatusCode, Uri, Version};
5
6use crate::{BoxFuture, RequestBody, WireError};
7
8pub trait Authenticator: Send + Sync + 'static {
10 fn authenticate(
11 &self,
12 ctx: AuthContext,
13 ) -> BoxFuture<Result<Option<Request<RequestBody>>, WireError>>;
14}
15
16impl<T> Authenticator for Arc<T>
17where
18 T: Authenticator + ?Sized,
19{
20 fn authenticate(
21 &self,
22 ctx: AuthContext,
23 ) -> BoxFuture<Result<Option<Request<RequestBody>>, WireError>> {
24 (**self).authenticate(ctx)
25 }
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum AuthKind {
30 Origin,
31 Proxy,
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct AuthChallenge {
37 scheme: String,
38 token68: Option<String>,
39 parameters: Vec<AuthChallengeParam>,
40}
41
42impl AuthChallenge {
43 pub fn scheme(&self) -> &str {
45 &self.scheme
46 }
47
48 pub fn token68(&self) -> Option<&str> {
50 self.token68.as_deref()
51 }
52
53 pub fn parameters(&self) -> &[AuthChallengeParam] {
55 &self.parameters
56 }
57
58 pub fn parameter(&self, name: &str) -> Option<&str> {
60 self.parameters
61 .iter()
62 .find(|parameter| parameter.name.eq_ignore_ascii_case(name))
63 .map(|parameter| parameter.value.as_str())
64 }
65
66 pub fn realm(&self) -> Option<&str> {
68 self.parameter("realm")
69 }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq)]
74pub struct AuthChallengeParam {
75 name: String,
76 value: String,
77}
78
79impl AuthChallengeParam {
80 pub fn name(&self) -> &str {
82 &self.name
83 }
84
85 pub fn value(&self) -> &str {
87 &self.value
88 }
89}
90
91pub struct AuthContext {
93 kind: AuthKind,
94 request_method: Method,
95 request_uri: Uri,
96 request_version: Version,
97 request_headers: HeaderMap,
98 request_extensions: Extensions,
99 request_body: Option<RequestBody>,
100 response_status: StatusCode,
101 response_headers: HeaderMap,
102 total_attempt: u32,
103 retry_count: u32,
104 redirect_count: u32,
105 auth_count: u32,
106}
107
108impl AuthContext {
109 #[allow(clippy::too_many_arguments)]
110 pub fn new(
111 kind: AuthKind,
112 request_method: Method,
113 request_uri: Uri,
114 request_version: Version,
115 request_headers: HeaderMap,
116 request_extensions: Extensions,
117 request_body: Option<RequestBody>,
118 response_status: StatusCode,
119 response_headers: HeaderMap,
120 total_attempt: u32,
121 retry_count: u32,
122 redirect_count: u32,
123 auth_count: u32,
124 ) -> Self {
125 Self {
126 kind,
127 request_method,
128 request_uri,
129 request_version,
130 request_headers,
131 request_extensions,
132 request_body,
133 response_status,
134 response_headers,
135 total_attempt,
136 retry_count,
137 redirect_count,
138 auth_count,
139 }
140 }
141
142 pub fn kind(&self) -> AuthKind {
144 self.kind
145 }
146
147 pub fn request_method(&self) -> &Method {
149 &self.request_method
150 }
151
152 pub fn request_uri(&self) -> &Uri {
154 &self.request_uri
155 }
156
157 pub fn request_headers(&self) -> &HeaderMap {
159 &self.request_headers
160 }
161
162 pub fn response_status(&self) -> StatusCode {
164 self.response_status
165 }
166
167 pub fn response_headers(&self) -> &HeaderMap {
169 &self.response_headers
170 }
171
172 pub fn challenges(&self) -> Vec<AuthChallenge> {
178 let header = match self.kind {
179 AuthKind::Origin => WWW_AUTHENTICATE,
180 AuthKind::Proxy => PROXY_AUTHENTICATE,
181 };
182 parse_auth_challenges(
183 self.response_headers
184 .get_all(header)
185 .iter()
186 .filter_map(|value| value.to_str().ok()),
187 )
188 }
189
190 pub fn total_attempt(&self) -> u32 {
192 self.total_attempt
193 }
194
195 pub fn retry_count(&self) -> u32 {
197 self.retry_count
198 }
199
200 pub fn redirect_count(&self) -> u32 {
202 self.redirect_count
203 }
204
205 pub fn auth_count(&self) -> u32 {
207 self.auth_count
208 }
209
210 pub fn is_replayable(&self) -> bool {
212 self.request_body.is_some()
213 }
214
215 pub fn try_clone_request(&self) -> Option<Request<RequestBody>> {
217 let body = self
218 .request_body
219 .as_ref()
220 .and_then(RequestBody::try_clone)?;
221 let mut request = Request::builder()
222 .method(self.request_method.clone())
223 .uri(self.request_uri.clone())
224 .version(self.request_version)
225 .body(body)
226 .ok()?;
227 *request.headers_mut() = self.request_headers.clone();
228 *request.extensions_mut() = self.request_extensions.clone();
229 Some(request)
230 }
231}
232
233fn parse_auth_challenges<'a>(values: impl IntoIterator<Item = &'a str>) -> Vec<AuthChallenge> {
234 let mut challenges = Vec::new();
235 for value in values {
236 challenges.extend(parse_auth_challenge_header(value));
237 }
238 challenges
239}
240
241fn parse_auth_challenge_header(value: &str) -> Vec<AuthChallenge> {
242 let mut challenges = Vec::new();
243 let mut current: Option<AuthChallenge> = None;
244
245 for part in split_top_level_commas(value) {
246 let part = part.trim();
247 if part.is_empty() {
248 continue;
249 }
250
251 if let Some(parameter) = parse_auth_param(part) {
252 if let Some(challenge) = current.as_mut() {
253 challenge.parameters.push(parameter);
254 continue;
255 }
256 }
257
258 if let Some(challenge) = current.take() {
259 challenges.push(challenge);
260 }
261 current = parse_challenge_start(part);
262 }
263
264 if let Some(challenge) = current {
265 challenges.push(challenge);
266 }
267
268 challenges
269}
270
271fn parse_challenge_start(value: &str) -> Option<AuthChallenge> {
272 let (scheme, rest) = parse_token(value)?;
273 if !rest.is_empty() && !rest.starts_with(char::is_whitespace) {
274 return None;
275 }
276
277 let rest = rest.trim();
278 let mut challenge = AuthChallenge {
279 scheme: scheme.to_owned(),
280 token68: None,
281 parameters: Vec::new(),
282 };
283 if rest.is_empty() {
284 return Some(challenge);
285 }
286
287 if is_token68(rest) {
288 challenge.token68 = Some(rest.to_owned());
289 } else if rest.contains('=') {
290 challenge.parameters.extend(parse_auth_params(rest));
291 }
292
293 Some(challenge)
294}
295
296fn parse_auth_params(value: &str) -> Vec<AuthChallengeParam> {
297 split_top_level_commas(value)
298 .into_iter()
299 .filter_map(parse_auth_param)
300 .collect()
301}
302
303fn parse_auth_param(value: &str) -> Option<AuthChallengeParam> {
304 let (name, rest) = parse_token(value.trim())?;
305 let rest = rest.trim_start();
306 let rest = rest.strip_prefix('=')?.trim_start();
307 let (value, remaining) = parse_auth_param_value(rest)?;
308 remaining.trim().is_empty().then_some(AuthChallengeParam {
309 name: name.to_owned(),
310 value,
311 })
312}
313
314fn parse_auth_param_value(value: &str) -> Option<(String, &str)> {
315 if value.starts_with('"') {
316 return parse_quoted_string(value);
317 }
318
319 let (value, rest) = parse_token(value)?;
320 Some((value.to_owned(), rest))
321}
322
323fn parse_quoted_string(value: &str) -> Option<(String, &str)> {
324 let mut out = String::new();
325 let mut escaped = false;
326 for (index, ch) in value.char_indices().skip(1) {
327 if escaped {
328 out.push(ch);
329 escaped = false;
330 continue;
331 }
332
333 match ch {
334 '\\' => escaped = true,
335 '"' => return Some((out, &value[index + ch.len_utf8()..])),
336 _ => out.push(ch),
337 }
338 }
339 None
340}
341
342fn split_top_level_commas(value: &str) -> Vec<&str> {
343 let mut out = Vec::new();
344 let mut start = 0;
345 let mut in_quote = false;
346 let mut escaped = false;
347
348 for (index, ch) in value.char_indices() {
349 if escaped {
350 escaped = false;
351 continue;
352 }
353
354 match ch {
355 '\\' if in_quote => escaped = true,
356 '"' => in_quote = !in_quote,
357 ',' if !in_quote => {
358 out.push(&value[start..index]);
359 start = index + ch.len_utf8();
360 }
361 _ => {}
362 }
363 }
364
365 out.push(&value[start..]);
366 out
367}
368
369fn parse_token(value: &str) -> Option<(&str, &str)> {
370 let end = value
371 .char_indices()
372 .take_while(|(_, ch)| is_token_char(*ch))
373 .map(|(index, ch)| index + ch.len_utf8())
374 .last()?;
375 Some((&value[..end], &value[end..]))
376}
377
378fn is_token_char(ch: char) -> bool {
379 ch.is_ascii_alphanumeric()
380 || matches!(
381 ch,
382 '!' | '#'
383 | '$'
384 | '%'
385 | '&'
386 | '\''
387 | '*'
388 | '+'
389 | '-'
390 | '.'
391 | '^'
392 | '_'
393 | '`'
394 | '|'
395 | '~'
396 )
397}
398
399fn is_token68(value: &str) -> bool {
400 let mut seen_padding = false;
401 let mut has_value = false;
402 for ch in value.chars() {
403 match ch {
404 '=' => seen_padding = true,
405 'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '.' | '_' | '~' | '+' | '/' => {
406 if seen_padding {
407 return false;
408 }
409 has_value = true;
410 }
411 _ => return false,
412 }
413 }
414 has_value
415}
416
417#[cfg(test)]
418mod tests {
419 use http::header::{PROXY_AUTHENTICATE, WWW_AUTHENTICATE};
420 use http::{HeaderMap, Method, Request, StatusCode, Version};
421
422 use super::{is_token_char, parse_auth_challenge_header, AuthContext, AuthKind};
423 use crate::RequestBody;
424
425 #[test]
426 fn parses_rfc7235_multi_challenge_example() {
427 let challenges = parse_auth_challenge_header(
428 r#"Newauth realm="apps", type=1, title="Login to \"apps\"", Basic realm="simple""#,
429 );
430
431 assert_eq!(challenges.len(), 2);
432 assert_eq!(challenges[0].scheme(), "Newauth");
433 assert_eq!(challenges[0].realm(), Some("apps"));
434 assert_eq!(challenges[0].parameter("type"), Some("1"));
435 assert_eq!(challenges[0].parameter("title"), Some("Login to \"apps\""));
436 assert_eq!(challenges[1].scheme(), "Basic");
437 assert_eq!(challenges[1].realm(), Some("simple"));
438 }
439
440 #[test]
441 fn parses_token68_and_multiple_header_fields() {
442 let mut headers = HeaderMap::new();
443 headers.insert(WWW_AUTHENTICATE, "Bearer abcDEF123+/==".parse().unwrap());
444 headers.append(WWW_AUTHENTICATE, r#"Basic realm="simple""#.parse().unwrap());
445 let ctx = test_context(AuthKind::Origin, headers);
446
447 let challenges = ctx.challenges();
448 assert_eq!(challenges.len(), 2);
449 assert_eq!(challenges[0].scheme(), "Bearer");
450 assert_eq!(challenges[0].token68(), Some("abcDEF123+/=="));
451 assert_eq!(challenges[1].scheme(), "Basic");
452 assert_eq!(challenges[1].realm(), Some("simple"));
453 }
454
455 #[test]
456 fn proxy_context_reads_proxy_authenticate_only() {
457 let mut headers = HeaderMap::new();
458 headers.insert(WWW_AUTHENTICATE, r#"Basic realm="origin""#.parse().unwrap());
459 headers.insert(
460 PROXY_AUTHENTICATE,
461 r#"Digest realm="proxy", nonce="n""#.parse().unwrap(),
462 );
463 let ctx = test_context(AuthKind::Proxy, headers);
464
465 let challenges = ctx.challenges();
466 assert_eq!(challenges.len(), 1);
467 assert_eq!(challenges[0].scheme(), "Digest");
468 assert_eq!(challenges[0].realm(), Some("proxy"));
469 assert_eq!(challenges[0].parameter("nonce"), Some("n"));
470 }
471
472 #[test]
473 fn keeps_commas_inside_quoted_parameter_values() {
474 let challenges =
475 parse_auth_challenge_header(r#"Bearer realm="api, v1", scope="read,write""#);
476
477 assert_eq!(challenges.len(), 1);
478 assert_eq!(challenges[0].scheme(), "Bearer");
479 assert_eq!(challenges[0].realm(), Some("api, v1"));
480 assert_eq!(challenges[0].parameter("scope"), Some("read,write"));
481 }
482
483 #[test]
484 fn skips_malformed_and_non_utf8_challenge_fields() {
485 let challenges = parse_auth_challenge_header(r#"=bad, Basic realm="simple""#);
486 assert_eq!(challenges.len(), 1);
487 assert_eq!(challenges[0].scheme(), "Basic");
488
489 let mut headers = HeaderMap::new();
490 headers.insert(
491 WWW_AUTHENTICATE,
492 http::HeaderValue::from_bytes(b"\xff").unwrap(),
493 );
494 let ctx = test_context(AuthKind::Origin, headers);
495 assert!(ctx.challenges().is_empty());
496 }
497
498 #[test]
499 fn token_chars_match_http_tchar() {
500 for ch in
501 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&'*+-.^_`|~".chars()
502 {
503 assert!(is_token_char(ch), "{ch:?} should be accepted");
504 }
505
506 for ch in "()<>@,;:\\\"/[]?={} \t\r\n".chars() {
507 assert!(!is_token_char(ch), "{ch:?} should be rejected");
508 }
509 }
510
511 #[test]
512 fn rejects_invalid_token_chars_in_challenge_scheme_and_param_names() {
513 let challenges = parse_auth_challenge_header(
514 r#"Bad/Scheme realm="ignored", Basic realm="simple", bad/name="ignored""#,
515 );
516
517 assert_eq!(challenges.len(), 1);
518 assert_eq!(challenges[0].scheme(), "Basic");
519 assert_eq!(challenges[0].realm(), Some("simple"));
520 assert_eq!(challenges[0].parameter("bad/name"), None);
521 }
522
523 fn test_context(kind: AuthKind, response_headers: HeaderMap) -> AuthContext {
524 let request = Request::builder()
525 .method(Method::GET)
526 .uri("http://example.com/")
527 .body(RequestBody::empty())
528 .expect("request");
529 AuthContext::new(
530 kind,
531 request.method().clone(),
532 request.uri().clone(),
533 Version::HTTP_11,
534 request.headers().clone(),
535 request.extensions().clone(),
536 request.body().try_clone(),
537 StatusCode::UNAUTHORIZED,
538 response_headers,
539 1,
540 0,
541 0,
542 0,
543 )
544 }
545}