pdk-unit 1.8.0

PDK Unit Test Framework
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use crate::{Backend, UnitHttpMessage, UnitHttpRequest, UnitHttpResponse};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use std::cell::RefCell;
use std::collections::HashMap;

#[derive(Default)]
pub struct LdapBackend {
    #[allow(clippy::type_complexity)]
    configs: RefCell<HashMap<Option<UnitLdapConfig>, Vec<(String, String)>>>,
}

/// LDAP server configuration used to scope credential pairs registered with [`UnitTest::add_ldap_data`].
///
/// A registered credential pair is considered a match when its associated `UnitLdapConfig`
/// equals the LDAP connection parameters used by the policy.
///
/// Construct a value using the builder methods, starting from [`Default::default`]:
///
/// ```ignore
/// let config = UnitLdapConfig::default()
///     .server_url("ldap://ldap.example.com:389")
///     .server_user_dn("cn=admin,dc=example,dc=com")
///     .server_user_password("secret")
///     .search_base("ou=users,dc=example,dc=com")
///     .search_filter("(uid={0})")
///     .search_in_subtree();
/// ```
#[derive(PartialEq, Eq, Hash, Default)]
pub struct UnitLdapConfig {
    server_url: String,
    server_user_dn: String,
    server_user_password: String,
    search_base: String,
    search_filter: String,
    search_in_subtree: bool,
}

impl UnitLdapConfig {
    /// Sets the LDAP server URL (e.g. `ldap://ldap.example.com:389`).
    pub fn server_url(mut self, url: impl Into<String>) -> Self {
        self.server_url = url.into();
        self
    }

    /// Sets the distinguished name used to bind to the LDAP server (e.g. `cn=admin,dc=example,dc=com`).
    pub fn server_user_dn(mut self, dn: impl Into<String>) -> Self {
        self.server_user_dn = dn.into();
        self
    }

    /// Sets the password used together with [`server_user_dn`](Self::server_user_dn) to bind to the LDAP server.
    pub fn server_user_password(mut self, pass: impl Into<String>) -> Self {
        self.server_user_password = pass.into();
        self
    }

    /// Sets the base DN from which the user search is performed (e.g. `ou=users,dc=example,dc=com`).
    pub fn search_base(mut self, base: impl Into<String>) -> Self {
        self.search_base = base.into();
        self
    }

    /// Sets the LDAP search filter used to locate the authenticating user (e.g. `(uid={0})`).
    pub fn search_filter(mut self, filter: impl Into<String>) -> Self {
        self.search_filter = filter.into();
        self
    }

    /// Enables recursive subtree searching instead of a single-level search.
    pub fn search_in_subtree(mut self) -> Self {
        self.search_in_subtree = true;
        self
    }
}

impl LdapBackend {
    pub fn add_data<U: Into<String>, P: Into<String>>(
        &self,
        config: Option<UnitLdapConfig>,
        user: U,
        pass: P,
    ) {
        self.configs
            .borrow_mut()
            .entry(config)
            .or_default()
            .push((user.into(), pass.into()));
    }
}

impl Backend for LdapBackend {
    fn call(&self, req: UnitHttpRequest) -> UnitHttpResponse {
        let headers = req.headers();

        let headers: HashMap<String, String> = headers
            .iter()
            .map(|(k, v)| (k.to_ascii_lowercase(), v.clone()))
            .collect();

        let config = headers
            .get("x-flex-authentication-ldap-url")
            .map(|url| UnitLdapConfig {
                server_url: url.clone(),
                server_user_dn: headers
                    .get("x-flex-authentication-ldap-bind-dn")
                    .cloned()
                    .unwrap_or_default(),
                server_user_password: headers
                    .get("x-flex-authentication-ldap-bind-pass")
                    .cloned()
                    .unwrap_or_default(),
                search_base: headers
                    .get("x-flex-authentication-ldap-search-base")
                    .cloned()
                    .unwrap_or_default(),
                search_filter: headers
                    .get("x-flex-authentication-ldap-search-filter")
                    .cloned()
                    .unwrap_or_default(),
                search_in_subtree: headers
                    .get("x-flex-authentication-ldap-search-in-subtree")
                    .is_some_and(|v| v == "true"),
            });

        let credential = match headers
            .get("authorization")
            .and_then(|v| v.strip_prefix("Basic "))
            .map(|v| v.to_string())
        {
            Some(c) => c,
            None => return UnitHttpResponse::new(401),
        };

        let (user, pass) = match base64_decode_credentials(&credential) {
            Some(pair) => pair,
            None => return UnitHttpResponse::new(400),
        };

        let configs = self.configs.borrow();
        let found = configs
            .get(&config)
            .into_iter()
            .chain(configs.get(&None))
            .flat_map(|pairs| pairs.iter())
            .any(|(u, p)| u == &user && p == &pass);

        if found {
            UnitHttpResponse::new(200)
        } else {
            UnitHttpResponse::new(401)
        }
    }
}

fn base64_decode_credentials(encoded: &str) -> Option<(String, String)> {
    let decoded = BASE64.decode(encoded.as_bytes()).ok()?;
    let s = String::from_utf8(decoded).ok()?;
    let mut parts = s.splitn(2, ':');
    let user = parts.next()?.to_string();
    let pass = parts.next()?.to_string();
    Some((user, pass))
}