Skip to main content

openwire_core/
auth.rs

1use std::sync::Arc;
2
3use http::header::{PROXY_AUTHENTICATE, WWW_AUTHENTICATE};
4use http::{Extensions, HeaderMap, Method, Request, StatusCode, Uri, Version};
5
6use crate::{BoxFuture, RequestBody, WireError};
7
8/// Produces authenticated follow-up requests for authentication challenges.
9pub trait Authenticator: Send + Sync + 'static {
10    fn authenticate(
11        &self,
12        ctx: AuthContext,
13    ) -> BoxFuture<Result<Option<Request<RequestBody>>, WireError>>;
14}
15
16impl<T> Authenticator for Arc<T>
17where
18    T: Authenticator + ?Sized,
19{
20    fn authenticate(
21        &self,
22        ctx: AuthContext,
23    ) -> BoxFuture<Result<Option<Request<RequestBody>>, WireError>> {
24        (**self).authenticate(ctx)
25    }
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum AuthKind {
30    Origin,
31    Proxy,
32}
33
34/// A parsed RFC 7235 authentication challenge.
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct AuthChallenge {
37    scheme: String,
38    token68: Option<String>,
39    parameters: Vec<AuthChallengeParam>,
40}
41
42impl AuthChallenge {
43    /// Returns the case-preserving authentication scheme token.
44    pub fn scheme(&self) -> &str {
45        &self.scheme
46    }
47
48    /// Returns the token68 payload for schemes that use token68 syntax.
49    pub fn token68(&self) -> Option<&str> {
50        self.token68.as_deref()
51    }
52
53    /// Returns parsed auth parameters in header order.
54    pub fn parameters(&self) -> &[AuthChallengeParam] {
55        &self.parameters
56    }
57
58    /// Returns the first parameter value matching `name`, ignoring ASCII case.
59    pub fn parameter(&self, name: &str) -> Option<&str> {
60        self.parameters
61            .iter()
62            .find(|parameter| parameter.name.eq_ignore_ascii_case(name))
63            .map(|parameter| parameter.value.as_str())
64    }
65
66    /// Returns the challenge realm parameter when present.
67    pub fn realm(&self) -> Option<&str> {
68        self.parameter("realm")
69    }
70}
71
72/// A parsed auth-param from an RFC 7235 challenge.
73#[derive(Clone, Debug, PartialEq, Eq)]
74pub struct AuthChallengeParam {
75    name: String,
76    value: String,
77}
78
79impl AuthChallengeParam {
80    /// Returns the case-preserving parameter name.
81    pub fn name(&self) -> &str {
82        &self.name
83    }
84
85    /// Returns the unquoted parameter value.
86    pub fn value(&self) -> &str {
87        &self.value
88    }
89}
90
91/// Owned authentication challenge context passed to an [`Authenticator`].
92pub struct AuthContext {
93    kind: AuthKind,
94    request_method: Method,
95    request_uri: Uri,
96    request_version: Version,
97    request_headers: HeaderMap,
98    request_extensions: Extensions,
99    request_body: Option<RequestBody>,
100    response_status: StatusCode,
101    response_headers: HeaderMap,
102    total_attempt: u32,
103    retry_count: u32,
104    redirect_count: u32,
105    auth_count: u32,
106}
107
108impl AuthContext {
109    #[allow(clippy::too_many_arguments)]
110    pub fn new(
111        kind: AuthKind,
112        request_method: Method,
113        request_uri: Uri,
114        request_version: Version,
115        request_headers: HeaderMap,
116        request_extensions: Extensions,
117        request_body: Option<RequestBody>,
118        response_status: StatusCode,
119        response_headers: HeaderMap,
120        total_attempt: u32,
121        retry_count: u32,
122        redirect_count: u32,
123        auth_count: u32,
124    ) -> Self {
125        Self {
126            kind,
127            request_method,
128            request_uri,
129            request_version,
130            request_headers,
131            request_extensions,
132            request_body,
133            response_status,
134            response_headers,
135            total_attempt,
136            retry_count,
137            redirect_count,
138            auth_count,
139        }
140    }
141
142    /// Returns whether this is an origin or proxy authentication challenge.
143    pub fn kind(&self) -> AuthKind {
144        self.kind
145    }
146
147    /// Returns the method of the challenged request.
148    pub fn request_method(&self) -> &Method {
149        &self.request_method
150    }
151
152    /// Returns the URI of the challenged request.
153    pub fn request_uri(&self) -> &Uri {
154        &self.request_uri
155    }
156
157    /// Returns the request headers of the challenged request.
158    pub fn request_headers(&self) -> &HeaderMap {
159        &self.request_headers
160    }
161
162    /// Returns the response status that triggered authentication.
163    pub fn response_status(&self) -> StatusCode {
164        self.response_status
165    }
166
167    /// Returns the response headers that triggered authentication.
168    pub fn response_headers(&self) -> &HeaderMap {
169        &self.response_headers
170    }
171
172    /// Parses the applicable authentication challenges from the response.
173    ///
174    /// Origin contexts read `WWW-Authenticate`; proxy contexts read
175    /// `Proxy-Authenticate`. Invalid, non-UTF-8, or non-token challenge fields
176    /// are skipped.
177    pub fn challenges(&self) -> Vec<AuthChallenge> {
178        let header = match self.kind {
179            AuthKind::Origin => WWW_AUTHENTICATE,
180            AuthKind::Proxy => PROXY_AUTHENTICATE,
181        };
182        parse_auth_challenges(
183            self.response_headers
184                .get_all(header)
185                .iter()
186                .filter_map(|value| value.to_str().ok()),
187        )
188    }
189
190    /// Returns the current total attempt number for the logical call.
191    pub fn total_attempt(&self) -> u32 {
192        self.total_attempt
193    }
194
195    /// Returns the retry count accumulated before this auth decision.
196    pub fn retry_count(&self) -> u32 {
197        self.retry_count
198    }
199
200    /// Returns the redirect count accumulated before this auth decision.
201    pub fn redirect_count(&self) -> u32 {
202        self.redirect_count
203    }
204
205    /// Returns the completed auth follow-up count before this auth decision.
206    pub fn auth_count(&self) -> u32 {
207        self.auth_count
208    }
209
210    /// Returns whether the challenged request body can be replayed.
211    pub fn is_replayable(&self) -> bool {
212        self.request_body.is_some()
213    }
214
215    /// Clones the challenged request when its body is replayable.
216    pub fn try_clone_request(&self) -> Option<Request<RequestBody>> {
217        let body = self
218            .request_body
219            .as_ref()
220            .and_then(RequestBody::try_clone)?;
221        let mut request = Request::builder()
222            .method(self.request_method.clone())
223            .uri(self.request_uri.clone())
224            .version(self.request_version)
225            .body(body)
226            .ok()?;
227        *request.headers_mut() = self.request_headers.clone();
228        *request.extensions_mut() = self.request_extensions.clone();
229        Some(request)
230    }
231}
232
233fn parse_auth_challenges<'a>(values: impl IntoIterator<Item = &'a str>) -> Vec<AuthChallenge> {
234    let mut challenges = Vec::new();
235    for value in values {
236        challenges.extend(parse_auth_challenge_header(value));
237    }
238    challenges
239}
240
241fn parse_auth_challenge_header(value: &str) -> Vec<AuthChallenge> {
242    let mut challenges = Vec::new();
243    let mut current: Option<AuthChallenge> = None;
244
245    for part in split_top_level_commas(value) {
246        let part = part.trim();
247        if part.is_empty() {
248            continue;
249        }
250
251        if let Some(parameter) = parse_auth_param(part) {
252            if let Some(challenge) = current.as_mut() {
253                challenge.parameters.push(parameter);
254                continue;
255            }
256        }
257
258        if let Some(challenge) = current.take() {
259            challenges.push(challenge);
260        }
261        current = parse_challenge_start(part);
262    }
263
264    if let Some(challenge) = current {
265        challenges.push(challenge);
266    }
267
268    challenges
269}
270
271fn parse_challenge_start(value: &str) -> Option<AuthChallenge> {
272    let (scheme, rest) = parse_token(value)?;
273    if !rest.is_empty() && !rest.starts_with(char::is_whitespace) {
274        return None;
275    }
276
277    let rest = rest.trim();
278    let mut challenge = AuthChallenge {
279        scheme: scheme.to_owned(),
280        token68: None,
281        parameters: Vec::new(),
282    };
283    if rest.is_empty() {
284        return Some(challenge);
285    }
286
287    if is_token68(rest) {
288        challenge.token68 = Some(rest.to_owned());
289    } else if rest.contains('=') {
290        challenge.parameters.extend(parse_auth_params(rest));
291    }
292
293    Some(challenge)
294}
295
296fn parse_auth_params(value: &str) -> Vec<AuthChallengeParam> {
297    split_top_level_commas(value)
298        .into_iter()
299        .filter_map(parse_auth_param)
300        .collect()
301}
302
303fn parse_auth_param(value: &str) -> Option<AuthChallengeParam> {
304    let (name, rest) = parse_token(value.trim())?;
305    let rest = rest.trim_start();
306    let rest = rest.strip_prefix('=')?.trim_start();
307    let (value, remaining) = parse_auth_param_value(rest)?;
308    remaining.trim().is_empty().then_some(AuthChallengeParam {
309        name: name.to_owned(),
310        value,
311    })
312}
313
314fn parse_auth_param_value(value: &str) -> Option<(String, &str)> {
315    if value.starts_with('"') {
316        return parse_quoted_string(value);
317    }
318
319    let (value, rest) = parse_token(value)?;
320    Some((value.to_owned(), rest))
321}
322
323fn parse_quoted_string(value: &str) -> Option<(String, &str)> {
324    let mut out = String::new();
325    let mut escaped = false;
326    for (index, ch) in value.char_indices().skip(1) {
327        if escaped {
328            out.push(ch);
329            escaped = false;
330            continue;
331        }
332
333        match ch {
334            '\\' => escaped = true,
335            '"' => return Some((out, &value[index + ch.len_utf8()..])),
336            _ => out.push(ch),
337        }
338    }
339    None
340}
341
342fn split_top_level_commas(value: &str) -> Vec<&str> {
343    let mut out = Vec::new();
344    let mut start = 0;
345    let mut in_quote = false;
346    let mut escaped = false;
347
348    for (index, ch) in value.char_indices() {
349        if escaped {
350            escaped = false;
351            continue;
352        }
353
354        match ch {
355            '\\' if in_quote => escaped = true,
356            '"' => in_quote = !in_quote,
357            ',' if !in_quote => {
358                out.push(&value[start..index]);
359                start = index + ch.len_utf8();
360            }
361            _ => {}
362        }
363    }
364
365    out.push(&value[start..]);
366    out
367}
368
369fn parse_token(value: &str) -> Option<(&str, &str)> {
370    let end = value
371        .char_indices()
372        .take_while(|(_, ch)| is_token_char(*ch))
373        .map(|(index, ch)| index + ch.len_utf8())
374        .last()?;
375    Some((&value[..end], &value[end..]))
376}
377
378fn is_token_char(ch: char) -> bool {
379    ch.is_ascii_alphanumeric()
380        || matches!(
381            ch,
382            '!' | '#'
383                | '$'
384                | '%'
385                | '&'
386                | '\''
387                | '*'
388                | '+'
389                | '-'
390                | '.'
391                | '^'
392                | '_'
393                | '`'
394                | '|'
395                | '~'
396        )
397}
398
399fn is_token68(value: &str) -> bool {
400    let mut seen_padding = false;
401    let mut has_value = false;
402    for ch in value.chars() {
403        match ch {
404            '=' => seen_padding = true,
405            'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '.' | '_' | '~' | '+' | '/' => {
406                if seen_padding {
407                    return false;
408                }
409                has_value = true;
410            }
411            _ => return false,
412        }
413    }
414    has_value
415}
416
417#[cfg(test)]
418mod tests {
419    use http::header::{PROXY_AUTHENTICATE, WWW_AUTHENTICATE};
420    use http::{HeaderMap, Method, Request, StatusCode, Version};
421
422    use super::{is_token_char, parse_auth_challenge_header, AuthContext, AuthKind};
423    use crate::RequestBody;
424
425    #[test]
426    fn parses_rfc7235_multi_challenge_example() {
427        let challenges = parse_auth_challenge_header(
428            r#"Newauth realm="apps", type=1, title="Login to \"apps\"", Basic realm="simple""#,
429        );
430
431        assert_eq!(challenges.len(), 2);
432        assert_eq!(challenges[0].scheme(), "Newauth");
433        assert_eq!(challenges[0].realm(), Some("apps"));
434        assert_eq!(challenges[0].parameter("type"), Some("1"));
435        assert_eq!(challenges[0].parameter("title"), Some("Login to \"apps\""));
436        assert_eq!(challenges[1].scheme(), "Basic");
437        assert_eq!(challenges[1].realm(), Some("simple"));
438    }
439
440    #[test]
441    fn parses_token68_and_multiple_header_fields() {
442        let mut headers = HeaderMap::new();
443        headers.insert(WWW_AUTHENTICATE, "Bearer abcDEF123+/==".parse().unwrap());
444        headers.append(WWW_AUTHENTICATE, r#"Basic realm="simple""#.parse().unwrap());
445        let ctx = test_context(AuthKind::Origin, headers);
446
447        let challenges = ctx.challenges();
448        assert_eq!(challenges.len(), 2);
449        assert_eq!(challenges[0].scheme(), "Bearer");
450        assert_eq!(challenges[0].token68(), Some("abcDEF123+/=="));
451        assert_eq!(challenges[1].scheme(), "Basic");
452        assert_eq!(challenges[1].realm(), Some("simple"));
453    }
454
455    #[test]
456    fn proxy_context_reads_proxy_authenticate_only() {
457        let mut headers = HeaderMap::new();
458        headers.insert(WWW_AUTHENTICATE, r#"Basic realm="origin""#.parse().unwrap());
459        headers.insert(
460            PROXY_AUTHENTICATE,
461            r#"Digest realm="proxy", nonce="n""#.parse().unwrap(),
462        );
463        let ctx = test_context(AuthKind::Proxy, headers);
464
465        let challenges = ctx.challenges();
466        assert_eq!(challenges.len(), 1);
467        assert_eq!(challenges[0].scheme(), "Digest");
468        assert_eq!(challenges[0].realm(), Some("proxy"));
469        assert_eq!(challenges[0].parameter("nonce"), Some("n"));
470    }
471
472    #[test]
473    fn keeps_commas_inside_quoted_parameter_values() {
474        let challenges =
475            parse_auth_challenge_header(r#"Bearer realm="api, v1", scope="read,write""#);
476
477        assert_eq!(challenges.len(), 1);
478        assert_eq!(challenges[0].scheme(), "Bearer");
479        assert_eq!(challenges[0].realm(), Some("api, v1"));
480        assert_eq!(challenges[0].parameter("scope"), Some("read,write"));
481    }
482
483    #[test]
484    fn skips_malformed_and_non_utf8_challenge_fields() {
485        let challenges = parse_auth_challenge_header(r#"=bad, Basic realm="simple""#);
486        assert_eq!(challenges.len(), 1);
487        assert_eq!(challenges[0].scheme(), "Basic");
488
489        let mut headers = HeaderMap::new();
490        headers.insert(
491            WWW_AUTHENTICATE,
492            http::HeaderValue::from_bytes(b"\xff").unwrap(),
493        );
494        let ctx = test_context(AuthKind::Origin, headers);
495        assert!(ctx.challenges().is_empty());
496    }
497
498    #[test]
499    fn token_chars_match_http_tchar() {
500        for ch in
501            "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&'*+-.^_`|~".chars()
502        {
503            assert!(is_token_char(ch), "{ch:?} should be accepted");
504        }
505
506        for ch in "()<>@,;:\\\"/[]?={} \t\r\n".chars() {
507            assert!(!is_token_char(ch), "{ch:?} should be rejected");
508        }
509    }
510
511    #[test]
512    fn rejects_invalid_token_chars_in_challenge_scheme_and_param_names() {
513        let challenges = parse_auth_challenge_header(
514            r#"Bad/Scheme realm="ignored", Basic realm="simple", bad/name="ignored""#,
515        );
516
517        assert_eq!(challenges.len(), 1);
518        assert_eq!(challenges[0].scheme(), "Basic");
519        assert_eq!(challenges[0].realm(), Some("simple"));
520        assert_eq!(challenges[0].parameter("bad/name"), None);
521    }
522
523    fn test_context(kind: AuthKind, response_headers: HeaderMap) -> AuthContext {
524        let request = Request::builder()
525            .method(Method::GET)
526            .uri("http://example.com/")
527            .body(RequestBody::empty())
528            .expect("request");
529        AuthContext::new(
530            kind,
531            request.method().clone(),
532            request.uri().clone(),
533            Version::HTTP_11,
534            request.headers().clone(),
535            request.extensions().clone(),
536            request.body().try_clone(),
537            StatusCode::UNAUTHORIZED,
538            response_headers,
539            1,
540            0,
541            0,
542            0,
543        )
544    }
545}