1use std::sync::Arc;
2
3use crate::errors::NanoGetError;
4use crate::request::Header;
5use crate::response::Response;
6use crate::url::Url;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum AuthTarget {
11 Origin,
13 Proxy,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct AuthParam {
20 pub name: String,
22 pub value: String,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct Challenge {
29 pub scheme: String,
31 pub token68: Option<String>,
33 pub params: Vec<AuthParam>,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum AuthDecision {
40 UseHeaders(Vec<Header>),
42 NoMatch,
44 Abort,
46}
47
48pub trait AuthHandler {
50 fn respond(
55 &self,
56 target: AuthTarget,
57 url: &Url,
58 challenges: &[Challenge],
59 request: &crate::request::Request,
60 response: &Response,
61 ) -> Result<AuthDecision, NanoGetError>;
62}
63
64pub(crate) type DynAuthHandler = Arc<dyn AuthHandler + Send + Sync>;
65
66#[derive(Clone)]
67pub(crate) struct BasicAuthHandler {
68 header_value: String,
69 target: AuthTarget,
70}
71
72impl BasicAuthHandler {
73 pub(crate) fn new(
74 username: impl Into<String>,
75 password: impl Into<String>,
76 target: AuthTarget,
77 ) -> Self {
78 Self {
79 header_value: basic_authorization_value(username.into(), password.into()),
80 target,
81 }
82 }
83
84 pub(crate) fn header_value(&self) -> &str {
85 &self.header_value
86 }
87}
88
89impl AuthHandler for BasicAuthHandler {
90 fn respond(
91 &self,
92 target: AuthTarget,
93 _url: &Url,
94 challenges: &[Challenge],
95 _request: &crate::request::Request,
96 _response: &Response,
97 ) -> Result<AuthDecision, NanoGetError> {
98 if target != self.target {
99 return Ok(AuthDecision::NoMatch);
100 }
101
102 if challenges
103 .iter()
104 .any(|challenge| challenge.scheme.eq_ignore_ascii_case("basic"))
105 {
106 let header_name = match target {
107 AuthTarget::Origin => "Authorization",
108 AuthTarget::Proxy => "Proxy-Authorization",
109 };
110 return Ok(AuthDecision::UseHeaders(vec![Header::new(
111 header_name,
112 self.header_value.clone(),
113 )?]));
114 }
115
116 Ok(AuthDecision::NoMatch)
117 }
118}
119
120pub(crate) fn basic_authorization_value(
121 username: impl Into<String>,
122 password: impl Into<String>,
123) -> String {
124 let credentials = format!("{}:{}", username.into(), password.into());
125 format!("Basic {}", base64_encode(credentials.as_bytes()))
126}
127
128pub(crate) fn parse_authenticate_headers(
129 headers: &[Header],
130 header_name: &str,
131) -> Result<Vec<Challenge>, NanoGetError> {
132 let values: Vec<&str> = headers
133 .iter()
134 .filter(|header| header.matches_name(header_name))
135 .map(Header::value)
136 .collect();
137
138 if values.is_empty() {
139 return Ok(Vec::new());
140 }
141
142 let mut challenges = Vec::new();
143 for value in values {
144 challenges.extend(parse_challenge_list(value)?);
145 }
146 Ok(challenges)
147}
148
149fn parse_challenge_list(value: &str) -> Result<Vec<Challenge>, NanoGetError> {
150 let bytes = value.as_bytes();
151 let mut index = 0usize;
152 let mut challenges = Vec::new();
153
154 while index < bytes.len() {
155 skip_ows_and_commas(bytes, &mut index);
156 if index >= bytes.len() {
157 break;
158 }
159
160 let scheme = parse_token(bytes, &mut index)
161 .ok_or_else(|| NanoGetError::MalformedChallenge(value.to_string()))?;
162 skip_spaces(bytes, &mut index);
163
164 let mut challenge = Challenge {
165 scheme,
166 token68: None,
167 params: Vec::new(),
168 };
169
170 if index < bytes.len() && bytes[index] != b',' {
171 if looks_like_auth_param(bytes, index) {
172 challenge.params = parse_auth_params(bytes, &mut index)?;
173 } else {
174 challenge.token68 = Some(parse_token68(bytes, &mut index)?);
175 }
176 }
177
178 challenges.push(challenge);
179
180 skip_spaces(bytes, &mut index);
181 if index < bytes.len() && bytes[index] == b',' {
182 index += 1;
183 }
184 }
185
186 Ok(challenges)
187}
188
189fn parse_auth_params(bytes: &[u8], index: &mut usize) -> Result<Vec<AuthParam>, NanoGetError> {
190 let mut params = Vec::new();
191
192 loop {
193 skip_spaces(bytes, index);
194 let name = parse_token(bytes, index).ok_or_else(|| {
195 NanoGetError::MalformedChallenge(String::from_utf8_lossy(bytes).into_owned())
196 })?;
197 skip_spaces(bytes, index);
198
199 if *index >= bytes.len() || bytes[*index] != b'=' {
200 return Err(NanoGetError::MalformedChallenge(
201 String::from_utf8_lossy(bytes).into_owned(),
202 ));
203 }
204 *index += 1;
205 skip_spaces(bytes, index);
206
207 let value = if *index < bytes.len() && bytes[*index] == b'"' {
208 parse_quoted_string(bytes, index)?
209 } else {
210 parse_token(bytes, index).ok_or_else(|| {
211 NanoGetError::MalformedChallenge(String::from_utf8_lossy(bytes).into_owned())
212 })?
213 };
214 params.push(AuthParam { name, value });
215
216 skip_spaces(bytes, index);
217 if *index >= bytes.len() || bytes[*index] != b',' {
218 break;
219 }
220
221 let lookahead = *index + 1;
222 let mut next_index = lookahead;
223 skip_spaces(bytes, &mut next_index);
224 if !looks_like_auth_param(bytes, next_index) {
225 break;
226 }
227 *index += 1;
228 }
229
230 Ok(params)
231}
232
233fn looks_like_auth_param(bytes: &[u8], mut index: usize) -> bool {
234 let token_start = index;
235 while index < bytes.len() && is_tchar(bytes[index]) {
236 index += 1;
237 }
238
239 if index == token_start {
240 return false;
241 }
242
243 while index < bytes.len() && bytes[index] == b' ' {
244 index += 1;
245 }
246
247 if index >= bytes.len() || bytes[index] != b'=' {
248 return false;
249 }
250
251 let mut after_equals = index + 1;
252 while after_equals < bytes.len() && bytes[after_equals] == b' ' {
253 after_equals += 1;
254 }
255
256 if after_equals >= bytes.len() {
257 return false;
258 }
259
260 if bytes[after_equals] == b'"' {
261 return true;
262 }
263
264 is_tchar(bytes[after_equals])
265}
266
267fn parse_token68(bytes: &[u8], index: &mut usize) -> Result<String, NanoGetError> {
268 let start = *index;
269 while *index < bytes.len() && is_token68(bytes[*index]) {
270 *index += 1;
271 }
272
273 if *index == start {
274 return Err(NanoGetError::MalformedChallenge(
275 String::from_utf8_lossy(bytes).into_owned(),
276 ));
277 }
278
279 Ok(String::from_utf8_lossy(&bytes[start..*index]).into_owned())
280}
281
282fn parse_token(bytes: &[u8], index: &mut usize) -> Option<String> {
283 let start = *index;
284 while *index < bytes.len() && is_tchar(bytes[*index]) {
285 *index += 1;
286 }
287
288 if *index == start {
289 None
290 } else {
291 Some(String::from_utf8_lossy(&bytes[start..*index]).into_owned())
292 }
293}
294
295fn parse_quoted_string(bytes: &[u8], index: &mut usize) -> Result<String, NanoGetError> {
296 if *index >= bytes.len() || bytes[*index] != b'"' {
297 return Err(NanoGetError::MalformedChallenge(
298 String::from_utf8_lossy(bytes).into_owned(),
299 ));
300 }
301 *index += 1;
302
303 let mut value = String::new();
304 while *index < bytes.len() {
305 match bytes[*index] {
306 b'\\' => {
307 *index += 1;
308 if *index >= bytes.len() {
309 return Err(NanoGetError::MalformedChallenge(
310 String::from_utf8_lossy(bytes).into_owned(),
311 ));
312 }
313 value.push(bytes[*index] as char);
314 *index += 1;
315 }
316 b'"' => {
317 *index += 1;
318 return Ok(value);
319 }
320 byte => {
321 value.push(byte as char);
322 *index += 1;
323 }
324 }
325 }
326
327 Err(NanoGetError::MalformedChallenge(
328 String::from_utf8_lossy(bytes).into_owned(),
329 ))
330}
331
332fn skip_spaces(bytes: &[u8], index: &mut usize) {
333 while *index < bytes.len() && bytes[*index] == b' ' {
334 *index += 1;
335 }
336}
337
338fn skip_ows_and_commas(bytes: &[u8], index: &mut usize) {
339 while *index < bytes.len() && (bytes[*index] == b' ' || bytes[*index] == b',') {
340 *index += 1;
341 }
342}
343
344fn is_tchar(byte: u8) -> bool {
345 matches!(
346 byte,
347 b'!' | b'#'
348 | b'$'
349 | b'%'
350 | b'&'
351 | b'\''
352 | b'*'
353 | b'+'
354 | b'-'
355 | b'.'
356 | b'^'
357 | b'_'
358 | b'`'
359 | b'|'
360 | b'~'
361 ) || byte.is_ascii_alphanumeric()
362}
363
364fn is_token68(byte: u8) -> bool {
365 byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'.' | b'_' | b'~' | b'+' | b'/' | b'=')
366}
367
368fn base64_encode(input: &[u8]) -> String {
369 const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
370 let mut output = String::new();
371
372 for chunk in input.chunks(3) {
373 let b0 = chunk[0];
374 let b1 = *chunk.get(1).unwrap_or(&0);
375 let b2 = *chunk.get(2).unwrap_or(&0);
376 let triple = ((b0 as u32) << 16) | ((b1 as u32) << 8) | b2 as u32;
377
378 output.push(TABLE[((triple >> 18) & 0x3f) as usize] as char);
379 output.push(TABLE[((triple >> 12) & 0x3f) as usize] as char);
380
381 if chunk.len() > 1 {
382 output.push(TABLE[((triple >> 6) & 0x3f) as usize] as char);
383 } else {
384 output.push('=');
385 }
386
387 if chunk.len() > 2 {
388 output.push(TABLE[(triple & 0x3f) as usize] as char);
389 } else {
390 output.push('=');
391 }
392 }
393
394 output
395}
396
397#[cfg(test)]
398mod tests {
399 use std::sync::Arc;
400
401 use super::{
402 basic_authorization_value, looks_like_auth_param, parse_auth_params,
403 parse_authenticate_headers, parse_quoted_string, parse_token, AuthDecision, AuthHandler,
404 AuthTarget, BasicAuthHandler, Challenge,
405 };
406 use crate::errors::NanoGetError;
407 use crate::request::{Header, Request};
408 use crate::response::{HttpVersion, Response};
409 use crate::url::Url;
410
411 #[test]
412 fn parses_single_challenge() {
413 let headers = vec![Header::unchecked("WWW-Authenticate", "Basic realm=\"api\"")];
414 let challenges = parse_authenticate_headers(&headers, "www-authenticate").unwrap();
415 assert_eq!(challenges.len(), 1);
416 assert_eq!(challenges[0].scheme, "Basic");
417 assert_eq!(challenges[0].params[0].name, "realm");
418 assert_eq!(challenges[0].params[0].value, "api");
419 }
420
421 #[test]
422 fn parses_multiple_challenges_in_one_field() {
423 let headers = vec![Header::unchecked(
424 "WWW-Authenticate",
425 "Basic realm=\"api\", Bearer token68token",
426 )];
427 let challenges = parse_authenticate_headers(&headers, "www-authenticate").unwrap();
428 assert_eq!(challenges.len(), 2);
429 assert_eq!(challenges[1].scheme, "Bearer");
430 assert_eq!(challenges[1].token68.as_deref(), Some("token68token"));
431 }
432
433 #[test]
434 fn parses_multiple_header_fields() {
435 let headers = vec![
436 Header::unchecked("WWW-Authenticate", "Basic realm=\"one\""),
437 Header::unchecked("WWW-Authenticate", "Digest realm=\"two\""),
438 ];
439 let challenges = parse_authenticate_headers(&headers, "www-authenticate").unwrap();
440 assert_eq!(challenges.len(), 2);
441 }
442
443 #[test]
444 fn parses_quoted_commas_and_escapes() {
445 let headers = vec![Header::unchecked(
446 "WWW-Authenticate",
447 "Digest realm=\"a,b\", title=\"say \\\"hi\\\"\"",
448 )];
449 let challenges = parse_authenticate_headers(&headers, "www-authenticate").unwrap();
450 assert_eq!(challenges[0].params[0].value, "a,b");
451 assert_eq!(challenges[0].params[1].value, "say \"hi\"");
452 }
453
454 #[test]
455 fn rejects_malformed_challenges() {
456 let headers = vec![Header::unchecked(
457 "WWW-Authenticate",
458 "Basic realm=\"unterminated",
459 )];
460 assert!(parse_authenticate_headers(&headers, "www-authenticate").is_err());
461 }
462
463 #[test]
464 fn encodes_basic_auth_values() {
465 assert_eq!(
466 basic_authorization_value("user", "pass"),
467 "Basic dXNlcjpwYXNz"
468 );
469 assert_eq!(basic_authorization_value("user", ""), "Basic dXNlcjo=");
470 assert_eq!(basic_authorization_value("", ""), "Basic Og==");
471 }
472
473 #[test]
474 fn basic_handler_matches_basic_challenges() {
475 let handler = BasicAuthHandler::new("user", "pass", AuthTarget::Origin);
476 let response = Response {
477 version: HttpVersion::Http11,
478 status_code: 401,
479 reason_phrase: "Unauthorized".to_string(),
480 headers: Vec::new(),
481 trailers: Vec::new(),
482 body: Vec::new(),
483 };
484 let decision = handler
485 .respond(
486 AuthTarget::Origin,
487 &Url::parse("http://example.com").unwrap(),
488 &[Challenge {
489 scheme: "Basic".to_string(),
490 token68: None,
491 params: Vec::new(),
492 }],
493 &Request::get("http://example.com").unwrap(),
494 &response,
495 )
496 .unwrap();
497 assert!(matches!(decision, AuthDecision::UseHeaders(_)));
498 }
499
500 #[test]
501 fn basic_handler_propagates_header_validation_errors() {
502 let handler = BasicAuthHandler {
503 header_value: "line\nbreak".to_string(),
504 target: AuthTarget::Origin,
505 };
506 let response = Response {
507 version: HttpVersion::Http11,
508 status_code: 401,
509 reason_phrase: "Unauthorized".to_string(),
510 headers: Vec::new(),
511 trailers: Vec::new(),
512 body: Vec::new(),
513 };
514 let error = handler
515 .respond(
516 AuthTarget::Origin,
517 &Url::parse("http://example.com").unwrap(),
518 &[Challenge {
519 scheme: "Basic".to_string(),
520 token68: None,
521 params: Vec::new(),
522 }],
523 &Request::get("http://example.com").unwrap(),
524 &response,
525 )
526 .unwrap_err();
527 assert!(matches!(error, NanoGetError::InvalidHeaderValue(_)));
528 }
529
530 #[test]
531 fn basic_handler_returns_no_match_for_other_target_or_scheme() {
532 let handler = BasicAuthHandler::new("user", "pass", AuthTarget::Origin);
533 let response = Response {
534 version: HttpVersion::Http11,
535 status_code: 401,
536 reason_phrase: "Unauthorized".to_string(),
537 headers: Vec::new(),
538 trailers: Vec::new(),
539 body: Vec::new(),
540 };
541 let request = Request::get("http://example.com").unwrap();
542 let url = Url::parse("http://example.com").unwrap();
543
544 let wrong_target = handler
545 .respond(
546 AuthTarget::Proxy,
547 &url,
548 &[Challenge {
549 scheme: "Basic".to_string(),
550 token68: None,
551 params: Vec::new(),
552 }],
553 &request,
554 &response,
555 )
556 .unwrap();
557 assert!(matches!(wrong_target, AuthDecision::NoMatch));
558
559 let wrong_scheme = handler
560 .respond(
561 AuthTarget::Origin,
562 &url,
563 &[Challenge {
564 scheme: "Digest".to_string(),
565 token68: None,
566 params: Vec::new(),
567 }],
568 &request,
569 &response,
570 )
571 .unwrap();
572 assert!(matches!(wrong_scheme, AuthDecision::NoMatch));
573 }
574
575 #[test]
576 fn parse_headers_handles_empty_and_malformed_token68_cases() {
577 let empty = parse_authenticate_headers(&[], "www-authenticate").unwrap();
578 assert!(empty.is_empty());
579
580 let trailing = vec![Header::unchecked(
581 "WWW-Authenticate",
582 "Basic realm=\"a\", ,",
583 )];
584 let challenges = parse_authenticate_headers(&trailing, "www-authenticate").unwrap();
585 assert_eq!(challenges.len(), 1);
586
587 let malformed = vec![Header::unchecked("WWW-Authenticate", "Bearer ?")];
588 assert!(matches!(
589 parse_authenticate_headers(&malformed, "www-authenticate"),
590 Err(NanoGetError::MalformedChallenge(_))
591 ));
592
593 let bare_scheme = vec![Header::unchecked(
594 "WWW-Authenticate",
595 "Negotiate, Basic realm=\"api\"",
596 )];
597 let challenges = parse_authenticate_headers(&bare_scheme, "www-authenticate").unwrap();
598 assert_eq!(challenges[0].scheme, "Negotiate");
599 assert!(challenges[0].token68.is_none());
600 }
601
602 #[test]
603 fn private_parser_helpers_cover_error_paths() {
604 let mut index = 0usize;
605 let error = parse_auth_params(b"=oops", &mut index).unwrap_err();
606 assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
607
608 let mut index = 0usize;
609 let error = parse_auth_params(b"realm x", &mut index).unwrap_err();
610 assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
611
612 let mut index = 5usize;
613 let error = parse_auth_params(b"realm", &mut index).unwrap_err();
614 assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
615
616 let mut index = 0usize;
617 let error = parse_auth_params(b"realm= ", &mut index).unwrap_err();
618 assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
619
620 let bytes = b"token= ";
621 assert!(!looks_like_auth_param(bytes, 0));
622 let bytes = b"token =\"x\"";
623 assert!(looks_like_auth_param(bytes, 0));
624 let bytes = b"token =!";
625 assert!(looks_like_auth_param(bytes, 0));
626
627 let mut token_index = 0usize;
628 assert!(parse_token(b"=", &mut token_index).is_none());
629
630 let mut quoted_index = 0usize;
631 let error = parse_quoted_string(b"token", &mut quoted_index).unwrap_err();
632 assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
633
634 let mut escaped_index = 0usize;
635 let error = parse_quoted_string(br#""unterminated\"#, &mut escaped_index).unwrap_err();
636 assert!(matches!(error, NanoGetError::MalformedChallenge(_)));
637 }
638
639 struct NoopHandler;
640
641 impl AuthHandler for NoopHandler {
642 fn respond(
643 &self,
644 _target: AuthTarget,
645 _url: &Url,
646 _challenges: &[Challenge],
647 _request: &Request,
648 _response: &Response,
649 ) -> Result<AuthDecision, NanoGetError> {
650 Ok(AuthDecision::NoMatch)
651 }
652 }
653
654 #[test]
655 fn auth_handlers_are_object_safe() {
656 let _handler: Arc<dyn AuthHandler + Send + Sync> = Arc::new(NoopHandler);
657 }
658
659 #[test]
660 fn noop_handler_returns_nomatch() {
661 let handler = NoopHandler;
662 let decision = handler
663 .respond(
664 AuthTarget::Origin,
665 &Url::parse("http://example.com").unwrap(),
666 &[],
667 &Request::get("http://example.com").unwrap(),
668 &Response {
669 version: HttpVersion::Http11,
670 status_code: 401,
671 reason_phrase: "Unauthorized".to_string(),
672 headers: Vec::new(),
673 trailers: Vec::new(),
674 body: Vec::new(),
675 },
676 )
677 .unwrap();
678 assert!(matches!(decision, AuthDecision::NoMatch));
679 }
680}