Skip to main content

nano_get/
auth.rs

1use std::sync::Arc;
2
3use crate::errors::NanoGetError;
4use crate::request::Header;
5use crate::response::Response;
6use crate::url::Url;
7
8/// Indicates which authentication space a challenge belongs to.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum AuthTarget {
11    /// Origin server authentication (`WWW-Authenticate` / `Authorization`).
12    Origin,
13    /// Proxy authentication (`Proxy-Authenticate` / `Proxy-Authorization`).
14    Proxy,
15}
16
17/// A single auth-param pair parsed from a challenge.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct AuthParam {
20    /// Parameter name.
21    pub name: String,
22    /// Parameter value with surrounding quotes removed when applicable.
23    pub value: String,
24}
25
26/// Parsed authentication challenge.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct Challenge {
29    /// Authentication scheme token, for example `Basic`.
30    pub scheme: String,
31    /// Optional token68 payload.
32    pub token68: Option<String>,
33    /// Optional list of auth-params.
34    pub params: Vec<AuthParam>,
35}
36
37/// Authentication handler result for a challenge set.
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum AuthDecision {
40    /// Retry the request with these headers.
41    UseHeaders(Vec<Header>),
42    /// Do not handle this challenge.
43    NoMatch,
44    /// Stop authentication processing and return an error.
45    Abort,
46}
47
48/// Callback interface for custom authentication schemes.
49pub trait AuthHandler {
50    /// Chooses how to respond to a challenge set.
51    ///
52    /// Return [`AuthDecision::UseHeaders`] to retry with credentials, [`AuthDecision::NoMatch`]
53    /// to leave the response unchanged, or [`AuthDecision::Abort`] to stop with an error.
54    fn respond(
55        &self,
56        target: AuthTarget,
57        url: &Url,
58        challenges: &[Challenge],
59        request: &crate::request::Request,
60        response: &Response,
61    ) -> Result<AuthDecision, NanoGetError>;
62}
63
64pub(crate) type DynAuthHandler = Arc<dyn AuthHandler + Send + Sync>;
65
66#[derive(Clone)]
67pub(crate) struct BasicAuthHandler {
68    header_value: String,
69    target: AuthTarget,
70}
71
72impl BasicAuthHandler {
73    pub(crate) fn new(
74        username: impl Into<String>,
75        password: impl Into<String>,
76        target: AuthTarget,
77    ) -> Self {
78        Self {
79            header_value: basic_authorization_value(username.into(), password.into()),
80            target,
81        }
82    }
83
84    pub(crate) fn header_value(&self) -> &str {
85        &self.header_value
86    }
87}
88
89impl AuthHandler for BasicAuthHandler {
90    fn respond(
91        &self,
92        target: AuthTarget,
93        _url: &Url,
94        challenges: &[Challenge],
95        _request: &crate::request::Request,
96        _response: &Response,
97    ) -> Result<AuthDecision, NanoGetError> {
98        if target != self.target {
99            return Ok(AuthDecision::NoMatch);
100        }
101
102        if challenges
103            .iter()
104            .any(|challenge| challenge.scheme.eq_ignore_ascii_case("basic"))
105        {
106            let header_name = match target {
107                AuthTarget::Origin => "Authorization",
108                AuthTarget::Proxy => "Proxy-Authorization",
109            };
110            return Ok(AuthDecision::UseHeaders(vec![Header::new(
111                header_name,
112                self.header_value.clone(),
113            )?]));
114        }
115
116        Ok(AuthDecision::NoMatch)
117    }
118}
119
120pub(crate) fn basic_authorization_value(
121    username: impl Into<String>,
122    password: impl Into<String>,
123) -> String {
124    let credentials = format!("{}:{}", username.into(), password.into());
125    format!("Basic {}", base64_encode(credentials.as_bytes()))
126}
127
128pub(crate) fn parse_authenticate_headers(
129    headers: &[Header],
130    header_name: &str,
131) -> Result<Vec<Challenge>, NanoGetError> {
132    let values: Vec<&str> = headers
133        .iter()
134        .filter(|header| header.matches_name(header_name))
135        .map(Header::value)
136        .collect();
137
138    if values.is_empty() {
139        return Ok(Vec::new());
140    }
141
142    let mut challenges = Vec::new();
143    for value in values {
144        challenges.extend(parse_challenge_list(value)?);
145    }
146    Ok(challenges)
147}
148
149fn parse_challenge_list(value: &str) -> Result<Vec<Challenge>, NanoGetError> {
150    let bytes = value.as_bytes();
151    let mut index = 0usize;
152    let mut challenges = Vec::new();
153
154    while index < bytes.len() {
155        skip_ows_and_commas(bytes, &mut index);
156        if index >= bytes.len() {
157            break;
158        }
159
160        let scheme = parse_token(bytes, &mut index)
161            .ok_or_else(|| NanoGetError::MalformedChallenge(value.to_string()))?;
162        skip_spaces(bytes, &mut index);
163
164        let mut challenge = Challenge {
165            scheme,
166            token68: None,
167            params: Vec::new(),
168        };
169
170        if index < bytes.len() && bytes[index] != b',' {
171            if looks_like_auth_param(bytes, index) {
172                challenge.params = parse_auth_params(bytes, &mut index)?;
173            } else {
174                challenge.token68 = Some(parse_token68(bytes, &mut index)?);
175            }
176        }
177
178        challenges.push(challenge);
179
180        skip_spaces(bytes, &mut index);
181        if index < bytes.len() && bytes[index] == b',' {
182            index += 1;
183        }
184    }
185
186    Ok(challenges)
187}
188
189fn parse_auth_params(bytes: &[u8], index: &mut usize) -> Result<Vec<AuthParam>, NanoGetError> {
190    let mut params = Vec::new();
191
192    loop {
193        skip_spaces(bytes, index);
194        let name = parse_token(bytes, index).ok_or_else(|| {
195            NanoGetError::MalformedChallenge(String::from_utf8_lossy(bytes).into_owned())
196        })?;
197        skip_spaces(bytes, index);
198
199        if *index >= bytes.len() || bytes[*index] != b'=' {
200            return Err(NanoGetError::MalformedChallenge(
201                String::from_utf8_lossy(bytes).into_owned(),
202            ));
203        }
204        *index += 1;
205        skip_spaces(bytes, index);
206
207        let value = if *index < bytes.len() && bytes[*index] == b'"' {
208            parse_quoted_string(bytes, index)?
209        } else {
210            parse_token(bytes, index).ok_or_else(|| {
211                NanoGetError::MalformedChallenge(String::from_utf8_lossy(bytes).into_owned())
212            })?
213        };
214        params.push(AuthParam { name, value });
215
216        skip_spaces(bytes, index);
217        if *index >= bytes.len() || bytes[*index] != b',' {
218            break;
219        }
220
221        let lookahead = *index + 1;
222        let mut next_index = lookahead;
223        skip_spaces(bytes, &mut next_index);
224        if !looks_like_auth_param(bytes, next_index) {
225            break;
226        }
227        *index += 1;
228    }
229
230    Ok(params)
231}
232
233fn looks_like_auth_param(bytes: &[u8], mut index: usize) -> bool {
234    let token_start = index;
235    while index < bytes.len() && is_tchar(bytes[index]) {
236        index += 1;
237    }
238
239    if index == token_start {
240        return false;
241    }
242
243    while index < bytes.len() && bytes[index] == b' ' {
244        index += 1;
245    }
246
247    if index >= bytes.len() || bytes[index] != b'=' {
248        return false;
249    }
250
251    let mut after_equals = index + 1;
252    while after_equals < bytes.len() && bytes[after_equals] == b' ' {
253        after_equals += 1;
254    }
255
256    if after_equals >= bytes.len() {
257        return false;
258    }
259
260    if bytes[after_equals] == b'"' {
261        return true;
262    }
263
264    is_tchar(bytes[after_equals])
265}
266
267fn parse_token68(bytes: &[u8], index: &mut usize) -> Result<String, NanoGetError> {
268    let start = *index;
269    while *index < bytes.len() && is_token68(bytes[*index]) {
270        *index += 1;
271    }
272
273    if *index == start {
274        return Err(NanoGetError::MalformedChallenge(
275            String::from_utf8_lossy(bytes).into_owned(),
276        ));
277    }
278
279    Ok(String::from_utf8_lossy(&bytes[start..*index]).into_owned())
280}
281
282fn parse_token(bytes: &[u8], index: &mut usize) -> Option<String> {
283    let start = *index;
284    while *index < bytes.len() && is_tchar(bytes[*index]) {
285        *index += 1;
286    }
287
288    if *index == start {
289        None
290    } else {
291        Some(String::from_utf8_lossy(&bytes[start..*index]).into_owned())
292    }
293}
294
295fn parse_quoted_string(bytes: &[u8], index: &mut usize) -> Result<String, NanoGetError> {
296    if *index >= bytes.len() || bytes[*index] != b'"' {
297        return Err(NanoGetError::MalformedChallenge(
298            String::from_utf8_lossy(bytes).into_owned(),
299        ));
300    }
301    *index += 1;
302
303    let mut value = String::new();
304    while *index < bytes.len() {
305        match bytes[*index] {
306            b'\\' => {
307                *index += 1;
308                if *index >= bytes.len() {
309                    return Err(NanoGetError::MalformedChallenge(
310                        String::from_utf8_lossy(bytes).into_owned(),
311                    ));
312                }
313                value.push(bytes[*index] as char);
314                *index += 1;
315            }
316            b'"' => {
317                *index += 1;
318                return Ok(value);
319            }
320            byte => {
321                value.push(byte as char);
322                *index += 1;
323            }
324        }
325    }
326
327    Err(NanoGetError::MalformedChallenge(
328        String::from_utf8_lossy(bytes).into_owned(),
329    ))
330}
331
332fn skip_spaces(bytes: &[u8], index: &mut usize) {
333    while *index < bytes.len() && bytes[*index] == b' ' {
334        *index += 1;
335    }
336}
337
338fn skip_ows_and_commas(bytes: &[u8], index: &mut usize) {
339    while *index < bytes.len() && (bytes[*index] == b' ' || bytes[*index] == b',') {
340        *index += 1;
341    }
342}
343
344fn is_tchar(byte: u8) -> bool {
345    matches!(
346        byte,
347        b'!' | b'#'
348            | b'$'
349            | b'%'
350            | b'&'
351            | b'\''
352            | b'*'
353            | b'+'
354            | b'-'
355            | b'.'
356            | b'^'
357            | b'_'
358            | b'`'
359            | b'|'
360            | b'~'
361    ) || byte.is_ascii_alphanumeric()
362}
363
364fn is_token68(byte: u8) -> bool {
365    byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'.' | b'_' | b'~' | b'+' | b'/' | b'=')
366}
367
368fn base64_encode(input: &[u8]) -> String {
369    const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
370    let mut output = String::new();
371
372    for chunk in input.chunks(3) {
373        let b0 = chunk[0];
374        let b1 = *chunk.get(1).unwrap_or(&0);
375        let b2 = *chunk.get(2).unwrap_or(&0);
376        let triple = ((b0 as u32) << 16) | ((b1 as u32) << 8) | b2 as u32;
377
378        output.push(TABLE[((triple >> 18) & 0x3f) as usize] as char);
379        output.push(TABLE[((triple >> 12) & 0x3f) as usize] as char);
380
381        if chunk.len() > 1 {
382            output.push(TABLE[((triple >> 6) & 0x3f) as usize] as char);
383        } else {
384            output.push('=');
385        }
386
387        if chunk.len() > 2 {
388            output.push(TABLE[(triple & 0x3f) as usize] as char);
389        } else {
390            output.push('=');
391        }
392    }
393
394    output
395}
396
397#[cfg(test)]
398mod tests {
399    use std::sync::Arc;
400
401    use super::{
402        basic_authorization_value, looks_like_auth_param, parse_auth_params,
403        parse_authenticate_headers, parse_quoted_string, parse_token, AuthDecision, AuthHandler,
404        AuthTarget, BasicAuthHandler, Challenge,
405    };
406    use crate::errors::NanoGetError;
407    use crate::request::{Header, Request};
408    use crate::response::{HttpVersion, Response};
409    use crate::url::Url;
410
411    #[test]
412    fn parses_single_challenge() {
413        let headers = vec![Header::unchecked("WWW-Authenticate", "Basic realm=\"api\"")];
414        let challenges = parse_authenticate_headers(&headers, "www-authenticate").unwrap();
415        assert_eq!(challenges.len(), 1);
416        assert_eq!(challenges[0].scheme, "Basic");
417        assert_eq!(challenges[0].params[0].name, "realm");
418        assert_eq!(challenges[0].params[0].value, "api");
419    }
420
421    #[test]
422    fn parses_multiple_challenges_in_one_field() {
423        let headers = vec![Header::unchecked(
424            "WWW-Authenticate",
425            "Basic realm=\"api\", Bearer token68token",
426        )];
427        let challenges = parse_authenticate_headers(&headers, "www-authenticate").unwrap();
428        assert_eq!(challenges.len(), 2);
429        assert_eq!(challenges[1].scheme, "Bearer");
430        assert_eq!(challenges[1].token68.as_deref(), Some("token68token"));
431    }
432
433    #[test]
434    fn parses_multiple_header_fields() {
435        let headers = vec![
436            Header::unchecked("WWW-Authenticate", "Basic realm=\"one\""),
437            Header::unchecked("WWW-Authenticate", "Digest realm=\"two\""),
438        ];
439        let challenges = parse_authenticate_headers(&headers, "www-authenticate").unwrap();
440        assert_eq!(challenges.len(), 2);
441    }
442
443    #[test]
444    fn parses_quoted_commas_and_escapes() {
445        let headers = vec![Header::unchecked(
446            "WWW-Authenticate",
447            "Digest realm=\"a,b\", title=\"say \\\"hi\\\"\"",
448        )];
449        let challenges = parse_authenticate_headers(&headers, "www-authenticate").unwrap();
450        assert_eq!(challenges[0].params[0].value, "a,b");
451        assert_eq!(challenges[0].params[1].value, "say \"hi\"");
452    }
453
454    #[test]
455    fn rejects_malformed_challenges() {
456        let headers = vec![Header::unchecked(
457            "WWW-Authenticate",
458            "Basic realm=\"unterminated",
459        )];
460        assert!(parse_authenticate_headers(&headers, "www-authenticate").is_err());
461    }
462
463    #[test]
464    fn encodes_basic_auth_values() {
465        assert_eq!(
466            basic_authorization_value("user", "pass"),
467            "Basic dXNlcjpwYXNz"
468        );
469        assert_eq!(basic_authorization_value("user", ""), "Basic dXNlcjo=");
470        assert_eq!(basic_authorization_value("", ""), "Basic Og==");
471    }
472
473    #[test]
474    fn basic_handler_matches_basic_challenges() {
475        let handler = BasicAuthHandler::new("user", "pass", AuthTarget::Origin);
476        let response = Response {
477            version: HttpVersion::Http11,
478            status_code: 401,
479            reason_phrase: "Unauthorized".to_string(),
480            headers: Vec::new(),
481            trailers: Vec::new(),
482            body: Vec::new(),
483        };
484        let decision = handler
485            .respond(
486                AuthTarget::Origin,
487                &Url::parse("http://example.com").unwrap(),
488                &[Challenge {
489                    scheme: "Basic".to_string(),
490                    token68: None,
491                    params: Vec::new(),
492                }],
493                &Request::get("http://example.com").unwrap(),
494                &response,
495            )
496            .unwrap();
497        assert!(matches!(decision, AuthDecision::UseHeaders(_)));
498    }
499
500    #[test]
501    fn basic_handler_propagates_header_validation_errors() {
502        let handler = BasicAuthHandler {
503            header_value: "line\nbreak".to_string(),
504            target: AuthTarget::Origin,
505        };
506        let response = Response {
507            version: HttpVersion::Http11,
508            status_code: 401,
509            reason_phrase: "Unauthorized".to_string(),
510            headers: Vec::new(),
511            trailers: Vec::new(),
512            body: Vec::new(),
513        };
514        let error = handler
515            .respond(
516                AuthTarget::Origin,
517                &Url::parse("http://example.com").unwrap(),
518                &[Challenge {
519                    scheme: "Basic".to_string(),
520                    token68: None,
521                    params: Vec::new(),
522                }],
523                &Request::get("http://example.com").unwrap(),
524                &response,
525            )
526            .unwrap_err();
527        assert!(matches!(error, NanoGetError::InvalidHeaderValue(_)));
528    }
529
530    #[test]
531    fn basic_handler_returns_no_match_for_other_target_or_scheme() {
532        let handler = BasicAuthHandler::new("user", "pass", AuthTarget::Origin);
533        let response = Response {
534            version: HttpVersion::Http11,
535            status_code: 401,
536            reason_phrase: "Unauthorized".to_string(),
537            headers: Vec::new(),
538            trailers: Vec::new(),
539            body: Vec::new(),
540        };
541        let request = Request::get("http://example.com").unwrap();
542        let url = Url::parse("http://example.com").unwrap();
543
544        let wrong_target = handler
545            .respond(
546                AuthTarget::Proxy,
547                &url,
548                &[Challenge {
549                    scheme: "Basic".to_string(),
550                    token68: None,
551                    params: Vec::new(),
552                }],
553                &request,
554                &response,
555            )
556            .unwrap();
557        assert!(matches!(wrong_target, AuthDecision::NoMatch));
558
559        let wrong_scheme = handler
560            .respond(
561                AuthTarget::Origin,
562                &url,
563                &[Challenge {
564                    scheme: "Digest".to_string(),
565                    token68: None,
566                    params: Vec::new(),
567                }],
568                &request,
569                &response,
570            )
571            .unwrap();
572        assert!(matches!(wrong_scheme, AuthDecision::NoMatch));
573    }
574
575    #[test]
576    fn parse_headers_handles_empty_and_malformed_token68_cases() {
577        let empty = parse_authenticate_headers(&[], "www-authenticate").unwrap();
578        assert!(empty.is_empty());
579
580        let trailing = vec![Header::unchecked(
581            "WWW-Authenticate",
582            "Basic realm=\"a\", ,",
583        )];
584        let challenges = parse_authenticate_headers(&trailing, "www-authenticate").unwrap();
585        assert_eq!(challenges.len(), 1);
586
587        let malformed = vec![Header::unchecked("WWW-Authenticate", "Bearer ?")];
588        assert!(matches!(
589            parse_authenticate_headers(&malformed, "www-authenticate"),
590            Err(NanoGetError::MalformedChallenge(_))
591        ));
592
593        let bare_scheme = vec![Header::unchecked(
594            "WWW-Authenticate",
595            "Negotiate, Basic realm=\"api\"",
596        )];
597        let challenges = parse_authenticate_headers(&bare_scheme, "www-authenticate").unwrap();
598        assert_eq!(challenges[0].scheme, "Negotiate");
599        assert!(challenges[0].token68.is_none());
600    }
601
602    #[test]
603    fn private_parser_helpers_cover_error_paths() {
604        let mut index = 0usize;
605        let error = parse_auth_params(b"=oops", &mut index).unwrap_err();
606        assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
607
608        let mut index = 0usize;
609        let error = parse_auth_params(b"realm x", &mut index).unwrap_err();
610        assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
611
612        let mut index = 5usize;
613        let error = parse_auth_params(b"realm", &mut index).unwrap_err();
614        assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
615
616        let mut index = 0usize;
617        let error = parse_auth_params(b"realm= ", &mut index).unwrap_err();
618        assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
619
620        let bytes = b"token=   ";
621        assert!(!looks_like_auth_param(bytes, 0));
622        let bytes = b"token =\"x\"";
623        assert!(looks_like_auth_param(bytes, 0));
624        let bytes = b"token =!";
625        assert!(looks_like_auth_param(bytes, 0));
626
627        let mut token_index = 0usize;
628        assert!(parse_token(b"=", &mut token_index).is_none());
629
630        let mut quoted_index = 0usize;
631        let error = parse_quoted_string(b"token", &mut quoted_index).unwrap_err();
632        assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
633
634        let mut escaped_index = 0usize;
635        let error = parse_quoted_string(br#""unterminated\"#, &mut escaped_index).unwrap_err();
636        assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
637    }
638
639    struct NoopHandler;
640
641    impl AuthHandler for NoopHandler {
642        fn respond(
643            &self,
644            _target: AuthTarget,
645            _url: &Url,
646            _challenges: &[Challenge],
647            _request: &Request,
648            _response: &Response,
649        ) -> Result<AuthDecision, NanoGetError> {
650            Ok(AuthDecision::NoMatch)
651        }
652    }
653
654    #[test]
655    fn auth_handlers_are_object_safe() {
656        let _handler: Arc<dyn AuthHandler + Send + Sync> = Arc::new(NoopHandler);
657    }
658
659    #[test]
660    fn noop_handler_returns_nomatch() {
661        let handler = NoopHandler;
662        let decision = handler
663            .respond(
664                AuthTarget::Origin,
665                &Url::parse("http://example.com").unwrap(),
666                &[],
667                &Request::get("http://example.com").unwrap(),
668                &Response {
669                    version: HttpVersion::Http11,
670                    status_code: 401,
671                    reason_phrase: "Unauthorized".to_string(),
672                    headers: Vec::new(),
673                    trailers: Vec::new(),
674                    body: Vec::new(),
675                },
676            )
677            .unwrap();
678        assert!(matches!(decision, AuthDecision::NoMatch));
679    }
680}