use alloc::string::{String, ToString};
use alloc::vec::Vec;
use core::fmt;
use crate::base64;
use crate::validate::{
escape_quotes, is_qdtext_char, is_quoted_pair_char, is_token_char, is_valid_token, trim_ows,
trim_ows_start,
};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum AuthError {
Empty,
InvalidFormat,
NotBasicScheme,
NotDigestScheme,
NotBearerScheme,
Base64DecodeError,
Utf8Error,
MissingColon,
InvalidParameter,
MissingParameter,
InvalidToken,
DuplicateParameter,
ColonInUserId,
ControlCharacter,
InvalidCharset,
ConflictingUsernameField,
InvalidUsernameExtValue,
TooManyParameters,
}
impl fmt::Display for AuthError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AuthError::Empty => write!(f, "empty authorization header"),
AuthError::InvalidFormat => write!(f, "invalid authorization format"),
AuthError::NotBasicScheme => write!(f, "not basic authentication scheme"),
AuthError::NotDigestScheme => write!(f, "not digest authentication scheme"),
AuthError::NotBearerScheme => write!(f, "not bearer authentication scheme"),
AuthError::Base64DecodeError => write!(f, "base64 decode error"),
AuthError::Utf8Error => write!(f, "utf-8 decode error"),
AuthError::MissingColon => write!(f, "missing colon in credentials"),
AuthError::InvalidParameter => write!(f, "invalid auth parameter"),
AuthError::MissingParameter => write!(f, "missing required auth parameter"),
AuthError::InvalidToken => write!(f, "invalid auth token"),
AuthError::DuplicateParameter => write!(f, "duplicate auth parameter"),
AuthError::ColonInUserId => write!(f, "colon in user-id"),
AuthError::ControlCharacter => write!(f, "control character in credentials"),
AuthError::InvalidCharset => write!(f, "charset must be UTF-8"),
AuthError::ConflictingUsernameField => {
write!(
f,
"both username and username* present (RFC 7616 Section 3.4)"
)
}
AuthError::InvalidUsernameExtValue => write!(f, "invalid username* ext-value"),
AuthError::TooManyParameters => write!(f, "too many auth parameters"),
}
}
}
const MAX_AUTH_PARAMS: usize = 32;
impl core::error::Error for AuthError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BasicAuth {
username: String,
password: String,
}
impl BasicAuth {
pub fn new(username: &str, password: &str) -> Result<Self, AuthError> {
if username.contains(':') {
return Err(AuthError::ColonInUserId);
}
if has_control_chars(username) || has_control_chars(password) {
return Err(AuthError::ControlCharacter);
}
Ok(BasicAuth {
username: username.to_string(),
password: password.to_string(),
})
}
pub fn parse(input: &str) -> Result<Self, AuthError> {
let input = trim_ows(input);
if input.is_empty() {
return Err(AuthError::Empty);
}
let credentials = strip_scheme(input, "Basic").ok_or(AuthError::NotBasicScheme)?;
if credentials.is_empty() {
return Err(AuthError::InvalidFormat);
}
if !is_token68(credentials) {
return Err(AuthError::InvalidToken);
}
let decoded = base64::decode(credentials).map_err(|_| AuthError::Base64DecodeError)?;
let decoded_str = String::from_utf8(decoded).map_err(|_| AuthError::Utf8Error)?;
let colon_pos = decoded_str.find(':').ok_or(AuthError::MissingColon)?;
let username = &decoded_str[..colon_pos];
let password = &decoded_str[colon_pos + 1..];
if has_control_chars(username) || has_control_chars(password) {
return Err(AuthError::ControlCharacter);
}
Ok(BasicAuth {
username: username.to_string(),
password: password.to_string(),
})
}
pub fn username(&self) -> &str {
&self.username
}
pub fn password(&self) -> &str {
&self.password
}
pub fn to_header_value(&self) -> String {
let credentials = alloc::format!("{}:{}", self.username, self.password);
alloc::format!("Basic {}", base64::encode(credentials.as_bytes()))
}
}
impl fmt::Display for BasicAuth {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WwwAuthenticate {
realm: String,
charset: Option<String>,
}
impl WwwAuthenticate {
pub fn basic(realm: &str) -> Self {
WwwAuthenticate {
realm: realm.to_string(),
charset: None,
}
}
pub fn with_charset_utf8(mut self) -> Self {
self.charset = Some("UTF-8".to_string());
self
}
pub fn parse(input: &str) -> Result<Self, AuthError> {
let input = trim_ows(input);
if input.is_empty() {
return Err(AuthError::Empty);
}
let params = strip_scheme(input, "Basic").ok_or(AuthError::NotBasicScheme)?;
if params.is_empty() {
return Err(AuthError::InvalidFormat);
}
let parsed_params = parse_auth_params(params)?;
let mut realm = None;
let mut charset = None;
for (key, value) in &parsed_params {
match key.as_str() {
"realm" => realm = Some(value.clone()),
"charset" => {
if !value.eq_ignore_ascii_case("UTF-8") {
return Err(AuthError::InvalidCharset);
}
charset = Some(value.clone());
}
_ => {} }
}
let realm = realm.ok_or(AuthError::InvalidFormat)?;
Ok(WwwAuthenticate { realm, charset })
}
pub fn realm(&self) -> &str {
&self.realm
}
pub fn charset(&self) -> Option<&str> {
self.charset.as_deref()
}
pub fn to_header_value(&self) -> String {
self.to_string()
}
}
impl fmt::Display for WwwAuthenticate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Basic realm=\"{}\"", escape_quotes(&self.realm))?;
if let Some(charset) = &self.charset {
write!(f, ", charset=\"{}\"", escape_quotes(charset))?;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DigestAuth {
params: Vec<(String, String)>,
}
impl DigestAuth {
pub fn parse(input: &str) -> Result<Self, AuthError> {
let input = trim_ows(input);
if input.is_empty() {
return Err(AuthError::Empty);
}
let params = strip_scheme(input, "Digest").ok_or(AuthError::NotDigestScheme)?;
if params.is_empty() {
return Err(AuthError::InvalidFormat);
}
let params = parse_auth_params(params)?;
let has_username = params.iter().any(|(n, _)| n == "username");
let has_username_ext = params.iter().any(|(n, _)| n == "username*");
if has_username && has_username_ext {
return Err(AuthError::ConflictingUsernameField);
}
if !has_username && !has_username_ext {
return Err(AuthError::MissingParameter);
}
if has_username_ext {
let raw = params
.iter()
.find(|(n, _)| n == "username*")
.map(|(_, v)| v.as_str())
.unwrap_or("");
decode_username_ext_value(raw)?;
}
if !has_required_params(¶ms, &["realm", "nonce", "uri", "response"]) {
return Err(AuthError::MissingParameter);
}
Ok(DigestAuth { params })
}
pub fn param(&self, name: &str) -> Option<&str> {
let name = name.to_ascii_lowercase();
self.params
.iter()
.find(|(n, _)| n == &name)
.map(|(_, v)| v.as_str())
}
pub fn username(&self) -> Option<&str> {
self.param("username")
}
pub fn username_decoded(&self) -> Option<String> {
if let Some(v) = self.param("username") {
return Some(v.to_string());
}
let raw = self.param("username*")?;
decode_username_ext_value(raw).ok()
}
pub fn realm(&self) -> Option<&str> {
self.param("realm")
}
pub fn nonce(&self) -> Option<&str> {
self.param("nonce")
}
pub fn uri(&self) -> Option<&str> {
self.param("uri")
}
pub fn response(&self) -> Option<&str> {
self.param("response")
}
pub fn to_header_value(&self) -> String {
alloc::format!("Digest {}", format_auth_params(&self.params))
}
}
impl fmt::Display for DigestAuth {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DigestChallenge {
params: Vec<(String, String)>,
}
impl DigestChallenge {
pub fn parse(input: &str) -> Result<Self, AuthError> {
let input = trim_ows(input);
if input.is_empty() {
return Err(AuthError::Empty);
}
let params = strip_scheme(input, "Digest").ok_or(AuthError::NotDigestScheme)?;
if params.is_empty() {
return Err(AuthError::InvalidFormat);
}
let params = parse_auth_params(params)?;
if !has_required_params(¶ms, &["realm", "nonce"]) {
return Err(AuthError::MissingParameter);
}
Ok(DigestChallenge { params })
}
pub fn param(&self, name: &str) -> Option<&str> {
let name = name.to_ascii_lowercase();
self.params
.iter()
.find(|(n, _)| n == &name)
.map(|(_, v)| v.as_str())
}
pub fn realm(&self) -> Option<&str> {
self.param("realm")
}
pub fn nonce(&self) -> Option<&str> {
self.param("nonce")
}
pub fn to_header_value(&self) -> String {
alloc::format!("Digest {}", format_auth_params(&self.params))
}
}
impl fmt::Display for DigestChallenge {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BearerToken {
token: String,
}
impl BearerToken {
pub fn parse(input: &str) -> Result<Self, AuthError> {
let input = trim_ows(input);
if input.is_empty() {
return Err(AuthError::Empty);
}
let is_bearer_scheme = input
.get(..6)
.is_some_and(|s| s.eq_ignore_ascii_case("Bearer"));
if !is_bearer_scheme {
return Err(AuthError::NotBearerScheme);
}
let token = strip_scheme(input, "Bearer").unwrap_or("");
if token.is_empty() {
return Err(AuthError::InvalidFormat);
}
if !is_token68(token) {
return Err(AuthError::InvalidToken);
}
Ok(BearerToken {
token: token.to_string(),
})
}
pub fn token(&self) -> &str {
&self.token
}
pub fn to_header_value(&self) -> String {
alloc::format!("Bearer {}", self.token)
}
}
impl fmt::Display for BearerToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BearerChallenge {
params: Vec<(String, String)>,
}
impl BearerChallenge {
pub fn parse(input: &str) -> Result<Self, AuthError> {
let input = trim_ows(input);
if input.is_empty() {
return Err(AuthError::Empty);
}
let params = strip_scheme(input, "Bearer").ok_or(AuthError::NotBearerScheme)?;
if params.is_empty() {
return Err(AuthError::InvalidFormat);
}
let params = parse_auth_params(params)?;
Ok(BearerChallenge { params })
}
pub fn param(&self, name: &str) -> Option<&str> {
let name = name.to_ascii_lowercase();
self.params
.iter()
.find(|(n, _)| n == &name)
.map(|(_, v)| v.as_str())
}
pub fn to_header_value(&self) -> String {
alloc::format!("Bearer {}", format_auth_params(&self.params))
}
}
impl fmt::Display for BearerChallenge {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum Authorization {
Basic(BasicAuth),
Digest(DigestAuth),
Bearer(BearerToken),
}
impl Authorization {
pub fn parse(input: &str) -> Result<Self, AuthError> {
let input = trim_ows(input);
if input.is_empty() {
return Err(AuthError::Empty);
}
if strip_scheme(input, "Basic").is_some() {
return Ok(Authorization::Basic(BasicAuth::parse(input)?));
}
if strip_scheme(input, "Digest").is_some() {
return Ok(Authorization::Digest(DigestAuth::parse(input)?));
}
if strip_scheme(input, "Bearer").is_some() {
return Ok(Authorization::Bearer(BearerToken::parse(input)?));
}
Err(AuthError::InvalidFormat)
}
pub fn to_header_value(&self) -> String {
match self {
Authorization::Basic(auth) => auth.to_header_value(),
Authorization::Digest(auth) => auth.to_header_value(),
Authorization::Bearer(auth) => auth.to_header_value(),
}
}
}
impl fmt::Display for Authorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum AuthChallenge {
Basic(WwwAuthenticate),
Digest(DigestChallenge),
Bearer(BearerChallenge),
}
impl AuthChallenge {
pub fn parse(input: &str) -> Result<Self, AuthError> {
let input = trim_ows(input);
if input.is_empty() {
return Err(AuthError::Empty);
}
if strip_scheme(input, "Basic").is_some() {
return Ok(AuthChallenge::Basic(WwwAuthenticate::parse(input)?));
}
if strip_scheme(input, "Digest").is_some() {
return Ok(AuthChallenge::Digest(DigestChallenge::parse(input)?));
}
if strip_scheme(input, "Bearer").is_some() {
return Ok(AuthChallenge::Bearer(BearerChallenge::parse(input)?));
}
Err(AuthError::InvalidFormat)
}
pub fn to_header_value(&self) -> String {
match self {
AuthChallenge::Basic(auth) => auth.to_header_value(),
AuthChallenge::Digest(auth) => auth.to_header_value(),
AuthChallenge::Bearer(auth) => auth.to_header_value(),
}
}
}
impl fmt::Display for AuthChallenge {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProxyAuthorization(Authorization);
impl ProxyAuthorization {
pub fn parse(input: &str) -> Result<Self, AuthError> {
Authorization::parse(input).map(ProxyAuthorization)
}
pub fn authorization(&self) -> &Authorization {
&self.0
}
pub fn to_header_value(&self) -> String {
self.0.to_header_value()
}
}
impl fmt::Display for ProxyAuthorization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProxyAuthenticate(AuthChallenge);
impl ProxyAuthenticate {
pub fn parse(input: &str) -> Result<Self, AuthError> {
AuthChallenge::parse(input).map(ProxyAuthenticate)
}
pub fn challenge(&self) -> &AuthChallenge {
&self.0
}
pub fn to_header_value(&self) -> String {
self.0.to_header_value()
}
}
impl fmt::Display for ProxyAuthenticate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_header_value())
}
}
fn strip_scheme<'a>(input: &'a str, scheme: &str) -> Option<&'a str> {
let input = trim_ows_start(input);
let scheme_len = scheme.len();
if input.len() <= scheme_len {
return None;
}
let prefix = input.get(..scheme_len)?;
if !prefix.eq_ignore_ascii_case(scheme) {
return None;
}
let rest = input.get(scheme_len..)?;
if rest.is_empty() {
return None;
}
if !rest.starts_with(' ') && !rest.starts_with('\t') {
return None;
}
Some(trim_ows_start(rest))
}
fn parse_auth_params(input: &str) -> Result<Vec<(String, String)>, AuthError> {
let mut params = Vec::new();
let bytes = input.as_bytes();
let mut i = 0;
while i < bytes.len() {
while i < bytes.len() && is_ows(bytes[i]) {
i += 1;
}
if i < bytes.len() && bytes[i] == b',' {
i += 1;
continue;
}
while i < bytes.len() && is_ows(bytes[i]) {
i += 1;
}
if i >= bytes.len() {
break;
}
let name_start = i;
while i < bytes.len() && is_token_char(bytes[i]) {
i += 1;
}
if i == name_start {
return Err(AuthError::InvalidParameter);
}
let name = &input[name_start..i];
while i < bytes.len() && is_ows(bytes[i]) {
i += 1;
}
if i >= bytes.len() || bytes[i] != b'=' {
return Err(AuthError::InvalidParameter);
}
i += 1;
while i < bytes.len() && is_ows(bytes[i]) {
i += 1;
}
if i >= bytes.len() {
return Err(AuthError::InvalidParameter);
}
let value = if bytes[i] == b'"' {
i += 1;
let inner = &input[i..];
let mut iter = inner.chars();
let mut value = String::new();
let mut consumed: usize = 0;
let mut closed = false;
while let Some(c) = iter.next() {
if c == '"' {
consumed += 1; closed = true;
break;
} else if c == '\\' {
consumed += 1; let next_c = iter.next().ok_or(AuthError::InvalidParameter)?;
if !is_quoted_pair_char(next_c) {
return Err(AuthError::InvalidParameter);
}
consumed += next_c.len_utf8();
value.push(next_c);
} else {
if !is_qdtext_char(c) {
return Err(AuthError::InvalidParameter);
}
consumed += c.len_utf8();
value.push(c);
}
}
if !closed {
return Err(AuthError::InvalidParameter);
}
i += consumed;
value
} else {
let value_start = i;
while i < bytes.len() && !is_ows(bytes[i]) && bytes[i] != b',' {
i += 1;
}
let token = &input[value_start..i];
if token.is_empty() || !is_valid_token(token) {
return Err(AuthError::InvalidParameter);
}
token.to_string()
};
let key = name.to_ascii_lowercase();
if params.iter().any(|(n, _)| n == &key) {
return Err(AuthError::DuplicateParameter);
}
if params.len() >= MAX_AUTH_PARAMS {
return Err(AuthError::TooManyParameters);
}
params.push((key, value));
while i < bytes.len() && is_ows(bytes[i]) {
i += 1;
}
if i < bytes.len() {
if bytes[i] == b',' {
i += 1;
} else {
return Err(AuthError::InvalidParameter);
}
}
}
if params.is_empty() {
return Err(AuthError::InvalidFormat);
}
Ok(params)
}
fn has_required_params(params: &[(String, String)], required: &[&str]) -> bool {
required.iter().all(|name| {
let name = name.to_ascii_lowercase();
params.iter().any(|(n, _)| n == &name)
})
}
fn format_auth_params(params: &[(String, String)]) -> String {
let mut parts = Vec::new();
for (name, value) in params {
if needs_quoting(value) {
parts.push(alloc::format!("{}=\"{}\"", name, escape_quotes(value)));
} else {
parts.push(alloc::format!("{}={}", name, value));
}
}
parts.join(", ")
}
fn needs_quoting(value: &str) -> bool {
value.is_empty() || value.bytes().any(|b| !is_token_char(b))
}
fn is_token68(value: &str) -> bool {
if value.is_empty() {
return false;
}
let trimmed = value.trim_end_matches('=');
!trimmed.is_empty() && trimmed.bytes().all(is_token68_char)
}
fn is_token68_char(b: u8) -> bool {
matches!(
b,
b'A'..=b'Z'
| b'a'..=b'z'
| b'0'..=b'9'
| b'-'
| b'.'
| b'_'
| b'~'
| b'+'
| b'/'
)
}
fn is_ows(b: u8) -> bool {
b == b' ' || b == b'\t'
}
fn has_control_chars(s: &str) -> bool {
s.bytes().any(|b| b <= 0x1F || b == 0x7F)
}
fn decode_username_ext_value(input: &str) -> Result<String, AuthError> {
let first_quote = input.find('\'').ok_or(AuthError::InvalidUsernameExtValue)?;
let charset = &input[..first_quote];
if !charset.eq_ignore_ascii_case("UTF-8") {
return Err(AuthError::InvalidUsernameExtValue);
}
let rest = &input[first_quote + 1..];
let second_quote = rest.find('\'').ok_or(AuthError::InvalidUsernameExtValue)?;
let value_chars = &rest[second_quote + 1..];
let bytes = value_chars.as_bytes();
let mut result = alloc::vec::Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b == b'%' {
if i + 2 >= bytes.len() {
return Err(AuthError::InvalidUsernameExtValue);
}
let hi = (bytes[i + 1] as char)
.to_digit(16)
.ok_or(AuthError::InvalidUsernameExtValue)? as u8;
let lo = (bytes[i + 2] as char)
.to_digit(16)
.ok_or(AuthError::InvalidUsernameExtValue)? as u8;
result.push((hi << 4) | lo);
i += 3;
} else if is_attr_char(b) {
result.push(b);
i += 1;
} else {
return Err(AuthError::InvalidUsernameExtValue);
}
}
String::from_utf8(result).map_err(|_| AuthError::InvalidUsernameExtValue)
}
fn is_attr_char(b: u8) -> bool {
matches!(
b,
b'A'..=b'Z'
| b'a'..=b'z'
| b'0'..=b'9'
| b'!'
| b'#'
| b'$'
| b'&'
| b'+'
| b'-'
| b'.'
| b'^'
| b'_'
| b'`'
| b'|'
| b'~'
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token68_equals_only_at_end() {
assert!(is_token68("Zm8="));
assert!(is_token68("Zg=="));
assert!(is_token68("abc"));
assert!(!is_token68("a=b"));
assert!(!is_token68("a=b="));
assert!(!is_token68("="));
assert!(!is_token68("=="));
assert!(!is_token68(""));
}
}