use base64::{prelude::BASE64_STANDARD, Engine as _};
use pdk_core::classy::hl::{HeadersHandler, HeadersState, RequestHeadersState};
use pdk_core::logger;
use zeroize::ZeroizeOnDrop;
use super::credentials::{ClientId, ClientSecret};
const AUTHORIZATION_HEADER: &str = "authorization";
const BASIC_AUTHORIZATION_SCHEMA: &str = "Basic";
const PAIR_SEPARATOR: &str = ":";
#[non_exhaustive]
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum BasicAuthError {
#[error("Basic authentication header not found.")]
HeaderNotFound,
#[error("Invalid basic auth header value format.")]
InvalidHeadeValueFormat,
#[error("Unknown auth schema {0}.")]
UnknownSchema(String),
#[error(transparent)]
InvalidBase64(InvalidBase64),
#[error(transparent)]
InvalidUtf8(InvalidUtf8),
#[error("Invalid credentials format.")]
InvalidCredentialsFormat,
}
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
#[error("Invalid Base64 Encoding")]
pub struct InvalidBase64(base64::DecodeError);
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
#[error("Invalid UTF-8 Encoding")]
pub struct InvalidUtf8(std::str::Utf8Error);
#[derive(ZeroizeOnDrop)]
struct AuthorizationHeader(String);
impl AuthorizationHeader {
fn from_handler(handler: &dyn HeadersHandler) -> Result<Self, BasicAuthError> {
handler
.header(AUTHORIZATION_HEADER)
.map(Self)
.ok_or_else(|| {
logger::debug!("Authorization header not present");
BasicAuthError::HeaderNotFound
})
}
fn auth_value(&self) -> Result<&str, BasicAuthError> {
let mut list = self.0.split_whitespace();
if let (Some(auth_type), Some(auth_value)) = (list.next(), list.next()) {
if auth_type != BASIC_AUTHORIZATION_SCHEMA {
return Err(BasicAuthError::UnknownSchema(auth_type.to_string()));
}
Ok(auth_value)
} else {
Err(BasicAuthError::InvalidHeadeValueFormat)
}
}
fn decode(&self) -> Result<DecodedHeader, BasicAuthError> {
let auth_value = self.auth_value()?;
let decoded_header = BASE64_STANDARD.decode(auth_value).map_err(|e| {
logger::debug!("There was a problem when trying to decoding auth header: {e}");
BasicAuthError::InvalidBase64(InvalidBase64(e))
})?;
Ok(DecodedHeader(decoded_header))
}
}
#[derive(ZeroizeOnDrop)]
struct DecodedHeader(Vec<u8>);
impl DecodedHeader {
fn as_utf8(&self) -> Result<&str, BasicAuthError> {
std::str::from_utf8(self.0.as_slice()).map_err(|e| {
logger::debug!("There was a problem when trying to translate auth header: {e}");
BasicAuthError::InvalidUtf8(InvalidUtf8(e))
})
}
fn as_credentials(&self) -> Result<(ClientId, ClientSecret), BasicAuthError> {
let Some((client_id, client_secret)) = self.as_utf8()?.split_once(PAIR_SEPARATOR) else {
return Err(BasicAuthError::InvalidCredentialsFormat);
};
let result = (
ClientId::new(client_id.to_string()),
ClientSecret::new(client_secret.to_string()),
);
Ok(result)
}
}
fn credentials_from_handler(
handler: &dyn HeadersHandler,
) -> Result<(ClientId, ClientSecret), BasicAuthError> {
AuthorizationHeader::from_handler(handler)?
.decode()?
.as_credentials()
}
pub fn basic_auth_credentials(
request_headers_state: &RequestHeadersState,
) -> Result<(ClientId, ClientSecret), BasicAuthError> {
credentials_from_handler(request_headers_state.handler())
}
#[cfg(test)]
mod tests {
use pdk_core::classy::hl::HeadersHandler;
use crate::api::basic_auth::BasicAuthError;
use super::{credentials_from_handler, AUTHORIZATION_HEADER};
struct HandlerMock(Option<String>);
impl HandlerMock {
fn absent() -> Self {
Self(None)
}
fn new(value: impl Into<String>) -> Self {
Self(Some(value.into()))
}
}
impl HeadersHandler for HandlerMock {
fn headers(&self) -> Vec<(String, String)> {
unreachable!()
}
fn header(&self, name: &str) -> Option<String> {
(name == AUTHORIZATION_HEADER)
.then(|| self.0.clone())
.flatten()
}
fn add_header(&self, _: &str, _: &str) {
unreachable!()
}
fn set_header(&self, _: &str, _: &str) {
unreachable!()
}
fn set_headers(&self, _: Vec<(&str, &str)>) {
unreachable!()
}
fn remove_header(&self, _: &str) {
unreachable!()
}
}
#[test]
fn header_not_found() {
let result = credentials_from_handler(&HandlerMock::absent());
assert_eq!(result.unwrap_err(), BasicAuthError::HeaderNotFound);
}
#[test]
fn unknown_schema() {
let result = credentials_from_handler(&HandlerMock::new("Unknown aaaaa"));
assert_eq!(
result.unwrap_err(),
BasicAuthError::UnknownSchema("Unknown".to_string())
);
}
#[test]
fn invalid_header_format() {
let result = credentials_from_handler(&HandlerMock::new("Invalid"));
assert_eq!(result.unwrap_err(), BasicAuthError::InvalidHeadeValueFormat);
}
#[test]
fn invalid_base64() {
let result = credentials_from_handler(&HandlerMock::new("Basic ####"));
assert!(matches!(
result.unwrap_err(),
BasicAuthError::InvalidBase64(_)
));
}
#[test]
fn invalid_utf8() {
let result = credentials_from_handler(&HandlerMock::new("Basic aaaa"));
assert!(matches!(
result.unwrap_err(),
BasicAuthError::InvalidUtf8(_)
));
}
#[test]
fn invalid_credentials_format() {
let result = credentials_from_handler(&HandlerMock::new("Basic c29tZSB1c2Vy"));
assert_eq!(
result.unwrap_err(),
BasicAuthError::InvalidCredentialsFormat
);
}
#[test]
fn valid_credentials() {
let (id, secret) =
credentials_from_handler(&HandlerMock::new("Basic dXNlcjE6cGFzc3dvcmQx")).unwrap();
assert_eq!(id.as_str(), "user1");
assert_eq!(secret.as_str(), "password1");
}
}