pdk-contracts-lib 1.9.1-alpha.2

PDK Contracts Library
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use base64::{prelude::BASE64_STANDARD, Engine as _};
use pdk_core::classy::hl::{HeadersHandler, HeadersState, RequestHeadersState};
use pdk_core::logger;

use zeroize::ZeroizeOnDrop;

use super::credentials::{ClientId, ClientSecret};

const AUTHORIZATION_HEADER: &str = "authorization";
const BASIC_AUTHORIZATION_SCHEMA: &str = "Basic";
const PAIR_SEPARATOR: &str = ":";

/// Error returned when [basic_auth_credentials()] fails.
#[non_exhaustive]
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum BasicAuthError {
    /// Basic authentication header not found.
    #[error("Basic authentication header not found.")]
    HeaderNotFound,

    /// Invalid basic auth header value format.
    #[error("Invalid basic auth header value format.")]
    InvalidHeadeValueFormat,

    /// Unknown auth schema.
    #[error("Unknown auth schema {0}.")]
    UnknownSchema(String),

    /// Invalid Base-64 encoding.
    #[error(transparent)]
    InvalidBase64(InvalidBase64),

    /// Invalid UTF-8 encoding.
    #[error(transparent)]
    InvalidUtf8(InvalidUtf8),

    /// Invalid credentials format.
    #[error("Invalid credentials format.")]
    InvalidCredentialsFormat,
}

/// Represents an invalid Base-64 encoding.
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
#[error("Invalid Base64 Encoding")]
pub struct InvalidBase64(base64::DecodeError);

/// Represents an invalid UTF-8 encoding.
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
#[error("Invalid UTF-8 Encoding")]
pub struct InvalidUtf8(std::str::Utf8Error);

#[derive(ZeroizeOnDrop)]
struct AuthorizationHeader(String);

impl AuthorizationHeader {
    fn from_handler(handler: &dyn HeadersHandler) -> Result<Self, BasicAuthError> {
        handler
            .header(AUTHORIZATION_HEADER)
            .map(Self)
            .ok_or_else(|| {
                logger::debug!("Authorization header not present");
                BasicAuthError::HeaderNotFound
            })
    }

    fn auth_value(&self) -> Result<&str, BasicAuthError> {
        let mut list = self.0.split_whitespace();

        if let (Some(auth_type), Some(auth_value)) = (list.next(), list.next()) {
            if auth_type != BASIC_AUTHORIZATION_SCHEMA {
                return Err(BasicAuthError::UnknownSchema(auth_type.to_string()));
            }

            Ok(auth_value)
        } else {
            Err(BasicAuthError::InvalidHeadeValueFormat)
        }
    }

    fn decode(&self) -> Result<DecodedHeader, BasicAuthError> {
        let auth_value = self.auth_value()?;

        let decoded_header = BASE64_STANDARD.decode(auth_value).map_err(|e| {
            logger::debug!("There was a problem when trying to decoding auth header: {e}");

            BasicAuthError::InvalidBase64(InvalidBase64(e))
        })?;

        Ok(DecodedHeader(decoded_header))
    }
}

#[derive(ZeroizeOnDrop)]
struct DecodedHeader(Vec<u8>);

impl DecodedHeader {
    fn as_utf8(&self) -> Result<&str, BasicAuthError> {
        std::str::from_utf8(self.0.as_slice()).map_err(|e| {
            logger::debug!("There was a problem when trying to translate auth header: {e}");

            BasicAuthError::InvalidUtf8(InvalidUtf8(e))
        })
    }

    fn as_credentials(&self) -> Result<(ClientId, ClientSecret), BasicAuthError> {
        let Some((client_id, client_secret)) = self.as_utf8()?.split_once(PAIR_SEPARATOR) else {
            return Err(BasicAuthError::InvalidCredentialsFormat);
        };

        let result = (
            ClientId::new(client_id.to_string()),
            ClientSecret::new(client_secret.to_string()),
        );

        Ok(result)
    }
}

fn credentials_from_handler(
    handler: &dyn HeadersHandler,
) -> Result<(ClientId, ClientSecret), BasicAuthError> {
    AuthorizationHeader::from_handler(handler)?
        .decode()?
        .as_credentials()
}

/// Extracts a pair of credentials from a Basic-Auth header.
pub fn basic_auth_credentials(
    request_headers_state: &RequestHeadersState,
) -> Result<(ClientId, ClientSecret), BasicAuthError> {
    credentials_from_handler(request_headers_state.handler())
}

#[cfg(test)]
mod tests {
    use pdk_core::classy::hl::HeadersHandler;

    use crate::api::basic_auth::BasicAuthError;

    use super::{credentials_from_handler, AUTHORIZATION_HEADER};

    struct HandlerMock(Option<String>);

    impl HandlerMock {
        fn absent() -> Self {
            Self(None)
        }

        fn new(value: impl Into<String>) -> Self {
            Self(Some(value.into()))
        }
    }

    impl HeadersHandler for HandlerMock {
        fn headers(&self) -> Vec<(String, String)> {
            unreachable!()
        }

        fn header(&self, name: &str) -> Option<String> {
            (name == AUTHORIZATION_HEADER)
                .then(|| self.0.clone())
                .flatten()
        }

        fn add_header(&self, _: &str, _: &str) {
            unreachable!()
        }

        fn set_header(&self, _: &str, _: &str) {
            unreachable!()
        }

        fn set_headers(&self, _: Vec<(&str, &str)>) {
            unreachable!()
        }

        fn remove_header(&self, _: &str) {
            unreachable!()
        }
    }

    #[test]
    fn header_not_found() {
        let result = credentials_from_handler(&HandlerMock::absent());

        assert_eq!(result.unwrap_err(), BasicAuthError::HeaderNotFound);
    }

    #[test]
    fn unknown_schema() {
        let result = credentials_from_handler(&HandlerMock::new("Unknown aaaaa"));

        assert_eq!(
            result.unwrap_err(),
            BasicAuthError::UnknownSchema("Unknown".to_string())
        );
    }

    #[test]
    fn invalid_header_format() {
        let result = credentials_from_handler(&HandlerMock::new("Invalid"));

        assert_eq!(result.unwrap_err(), BasicAuthError::InvalidHeadeValueFormat);
    }

    #[test]
    fn invalid_base64() {
        let result = credentials_from_handler(&HandlerMock::new("Basic ####"));

        assert!(matches!(
            result.unwrap_err(),
            BasicAuthError::InvalidBase64(_)
        ));
    }

    #[test]
    fn invalid_utf8() {
        let result = credentials_from_handler(&HandlerMock::new("Basic aaaa"));

        assert!(matches!(
            result.unwrap_err(),
            BasicAuthError::InvalidUtf8(_)
        ));
    }

    #[test]
    fn invalid_credentials_format() {
        let result = credentials_from_handler(&HandlerMock::new("Basic c29tZSB1c2Vy"));

        assert_eq!(
            result.unwrap_err(),
            BasicAuthError::InvalidCredentialsFormat
        );
    }

    #[test]
    fn valid_credentials() {
        let (id, secret) =
            credentials_from_handler(&HandlerMock::new("Basic dXNlcjE6cGFzc3dvcmQx")).unwrap();

        assert_eq!(id.as_str(), "user1");
        assert_eq!(secret.as_str(), "password1");
    }
}