use bytes::Bytes;
use super::auth_types::{AuthBlockConfidence, AuthChallenge, AuthErrorBodySignal, AuthScheme};
const BODY_CAP_BYTES: usize = 8 * 1024;
#[must_use]
pub fn parse_www_authenticate(header_value: &str) -> Option<AuthChallenge> {
let trimmed = header_value.trim();
if trimmed.is_empty() {
return None;
}
let (scheme_token, rest) = split_scheme(trimmed)?;
let scheme = scheme_from_token(scheme_token);
let params = parse_challenge_params(rest);
let mut map: std::collections::HashMap<String, String> = params.into_iter().collect();
Some(AuthChallenge {
scheme,
realm: map.remove("realm"),
error: map.remove("error"),
error_description: map.remove("error_description"),
scope: map.remove("scope"),
})
}
fn split_scheme(s: &str) -> Option<(&str, &str)> {
let mut iter = s.splitn(2, char::is_whitespace);
let scheme = iter.next()?.trim();
if scheme.is_empty() {
return None;
}
let rest = iter.next().unwrap_or("").trim();
Some((scheme, rest))
}
fn scheme_from_token(token: &str) -> AuthScheme {
match token.to_ascii_lowercase().as_str() {
"bearer" => AuthScheme::Bearer,
"basic" => AuthScheme::Basic,
"digest" => AuthScheme::Digest,
_ => AuthScheme::Other,
}
}
fn parse_challenge_params(rest: &str) -> Vec<(String, String)> {
let mut out = Vec::new();
let mut cursor = rest;
while !cursor.is_empty() {
cursor = cursor.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
if cursor.is_empty() {
break;
}
let Some((key, after_key)) = read_param_key(cursor) else {
break;
};
let Some(after_eq) = after_key.strip_prefix('=') else {
break;
};
let (value, after_value) = read_param_value(after_eq);
out.push((key.to_ascii_lowercase(), value));
cursor = after_value.trim_start();
}
out
}
fn read_param_key(s: &str) -> Option<(&str, &str)> {
let end = s
.find(|c: char| c == '=' || c.is_whitespace())
.unwrap_or(s.len());
if end == 0 {
return None;
}
Some((&s[..end], &s[end..]))
}
fn read_param_value(s: &str) -> (String, &str) {
let s = s.trim_start();
if let Some(rest) = s.strip_prefix('"') {
return read_quoted_string(rest);
}
let end = s
.find(|c: char| c == ',' || c.is_whitespace())
.unwrap_or(s.len());
(s[..end].to_owned(), &s[end..])
}
fn read_quoted_string(s: &str) -> (String, &str) {
let mut out = String::new();
let mut chars = s.char_indices();
while let Some((i, c)) = chars.next() {
match c {
'\\' => {
if let Some((_, next)) = chars.next() {
out.push(next);
}
}
'"' => {
let after = &s[i + 1..];
return (out, after);
}
_ => out.push(c),
}
}
(out, "")
}
#[must_use]
pub fn parse_auth_error_body(
content_type: Option<&str>,
body: &Bytes,
) -> Option<AuthErrorBodySignal> {
if body.len() > BODY_CAP_BYTES || body.is_empty() {
return None;
}
let ct = content_type?.to_ascii_lowercase();
if ct.starts_with("application/json") || ct.starts_with("application/problem+json") {
return parse_json_error(body);
}
if ct.starts_with("application/x-www-form-urlencoded") {
return parse_form_error(body);
}
None
}
fn parse_json_error(body: &Bytes) -> Option<AuthErrorBodySignal> {
let v: serde_json::Value = serde_json::from_slice(body).ok()?;
let candidates = ["error", "code", "message", "error_description"];
for key in candidates {
if let Some(s) = v.get(key).and_then(serde_json::Value::as_str) {
if let Some(sig) = recognise_auth_error(s) {
return Some(sig);
}
}
}
None
}
fn parse_form_error(body: &Bytes) -> Option<AuthErrorBodySignal> {
let s = std::str::from_utf8(body).ok()?;
for pair in s.split('&') {
let Some((key, value)) = pair.split_once('=') else {
continue;
};
if matches!(key, "error" | "code") {
if let Some(sig) = recognise_auth_error(value) {
return Some(sig);
}
}
}
None
}
#[must_use]
fn recognise_auth_error(code: &str) -> Option<AuthErrorBodySignal> {
let lower = code.trim().to_ascii_lowercase();
let confidence = match lower.as_str() {
"invalid_token"
| "expired_token"
| "insufficient_scope"
| "invalid_credentials"
| "invalid_grant" => AuthBlockConfidence::Strong,
"unauthorized"
| "unauthenticated"
| "authentication_required"
| "login_required"
| "access_denied" => AuthBlockConfidence::Medium,
"forbidden" => AuthBlockConfidence::Weak,
_ => return None,
};
Some(AuthErrorBodySignal {
code: lower,
confidence,
})
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use proptest::prelude::*;
use super::{parse_auth_error_body, AuthBlockConfidence};
fn form_body(s: &str) -> Bytes {
Bytes::copy_from_slice(s.as_bytes())
}
fn form_ct() -> &'static str {
"application/x-www-form-urlencoded"
}
#[test]
fn form_error_malformed_leading_segment_skipped() {
let body = form_body("foo&error=invalid_token");
let result = parse_auth_error_body(Some(form_ct()), &body);
assert!(
result.is_some(),
"expected Some, got None for body 'foo&error=invalid_token'"
);
let sig = result.unwrap();
assert_eq!(sig.code, "invalid_token");
assert_eq!(sig.confidence, AuthBlockConfidence::Strong);
}
#[test]
fn form_error_only_malformed_segments_returns_none() {
let body = form_body("foo&bar");
let result = parse_auth_error_body(Some(form_ct()), &body);
assert!(
result.is_none(),
"expected None for body with no key=value pairs"
);
}
#[test]
fn form_error_no_equals_at_all_returns_none() {
let body = form_body("foobar");
let result = parse_auth_error_body(Some(form_ct()), &body);
assert!(result.is_none(), "expected None for body 'foobar'");
}
#[test]
fn form_error_malformed_trailing_segment_still_finds_key() {
let body = form_body("error=invalid_token&junk");
let result = parse_auth_error_body(Some(form_ct()), &body);
assert!(
result.is_some(),
"expected Some for 'error=invalid_token&junk'"
);
assert_eq!(result.unwrap().code, "invalid_token");
}
proptest! {
#[test]
fn form_error_finds_error_key_despite_junk_pairs(
prefix_junk in prop::collection::vec("[a-z]{1,8}".prop_map(|s| s), 0usize..=3),
suffix_junk in prop::collection::vec("[a-z]{1,8}".prop_map(|s| s), 0usize..=3),
) {
let mut parts: Vec<String> = prefix_junk;
parts.push("error=invalid_token".to_owned());
parts.extend(suffix_junk);
let raw = parts.join("&");
let body = Bytes::copy_from_slice(raw.as_bytes());
let result = parse_auth_error_body(Some(form_ct()), &body);
prop_assert!(
result.is_some(),
"expected Some for body containing error=invalid_token, got None; body={raw:?}"
);
prop_assert_eq!(result.unwrap().code, "invalid_token");
}
}
}