pdk-core 1.7.0-alpha.0

PDK Core
Documentation
// Copyright (c) 2025, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

//! Utils to access and share authentication data between filters.

use std::convert::Infallible;

use crate::policy_context::AUTHENTICATION_PROPERTY;
use classy::extract::context::FilterContext;
use classy::extract::{Extract, FromContext};
use classy::hl::context::{RequestContext, ResponseContext};
use classy::hl::StreamProperties;
use classy::stream::PropertyAccessor;
use log::warn;
use pdk_script::{AuthenticationBinding, IntoValue, Value};
use rmp_serde::Serializer;
use serde::{Deserialize, Serialize};

/// Trait to access and share authentication data between filters.
pub trait AuthenticationHandler {
    /// Get the current data regarding authentication.
    fn authentication(&self) -> Option<AuthenticationData>;
    /// Replace the authentication data.
    fn set_authentication(&self, authentication: Option<&AuthenticationData>);
}

/// Default implementation of the [`AuthenticationHandler`] trait.
pub struct Authentication {
    property_accessor: Box<dyn PropertyAccessor>,
}

impl FromContext<FilterContext> for Authentication {
    type Error = Infallible;

    fn from_context(context: &FilterContext) -> Result<Self, Self::Error> {
        let stream: StreamProperties = context.extract()?;
        Ok(Authentication::new(stream))
    }
}

impl FromContext<RequestContext> for Authentication {
    type Error = Infallible;

    fn from_context(context: &RequestContext) -> Result<Self, Self::Error> {
        let stream: StreamProperties = context.extract()?;
        Ok(Authentication::new(stream))
    }
}

impl<C> FromContext<ResponseContext<C>> for Authentication {
    type Error = Infallible;

    fn from_context(context: &ResponseContext<C>) -> Result<Self, Self::Error> {
        let stream: StreamProperties = context.extract()?;
        Ok(Authentication::new(stream))
    }
}

impl Authentication {
    pub fn new<K: PropertyAccessor + 'static>(property_accessor: K) -> Self {
        Self {
            property_accessor: Box::new(property_accessor),
        }
    }
}

impl AuthenticationHandler for Authentication {
    fn authentication(&self) -> Option<AuthenticationData> {
        let bytes = self
            .property_accessor
            .read_property(AUTHENTICATION_PROPERTY)?;
        AuthenticationStreamSerializer::deserialize(bytes.as_slice())
    }

    fn set_authentication(&self, authentication: Option<&AuthenticationData>) {
        let bytes = authentication.and_then(AuthenticationStreamSerializer::serialize);
        self.property_accessor
            .set_property(AUTHENTICATION_PROPERTY, bytes.as_deref());
    }
}

#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
/// The data regarding the authentication.
pub struct AuthenticationData {
    /// The main value used for authenticated.
    pub principal: Option<String>,
    /// The client id associated to the authentication method used.
    pub client_id: Option<String>,
    /// The name id associated to the authentication method used.
    pub client_name: Option<String>,
    /// Additional properties of the authentication.
    pub properties: Value,
}

impl AuthenticationData {
    pub fn new<K: IntoValue>(
        principal: Option<String>,
        client_id: Option<String>,
        client_name: Option<String>,
        properties: K,
    ) -> Self {
        Self {
            principal,
            client_id,
            client_name,
            properties: properties.into_value(),
        }
    }
}

impl AuthenticationBinding for AuthenticationData {
    fn client_id(&self) -> Option<String> {
        self.client_id.clone()
    }

    fn client_name(&self) -> Option<String> {
        self.client_name.clone()
    }

    fn principal(&self) -> Option<String> {
        self.principal.clone()
    }

    fn properties(&self) -> Option<Value> {
        Some(self.properties.clone())
    }
}

/// Serializes and deserializes Authentication objects so that can be propagated between policies.
/// The chosen serialization format is MessagePack. Using a cross-language format allows to
/// propagate the object between filters that were coded in any language
struct AuthenticationStreamSerializer;

impl AuthenticationStreamSerializer {
    pub fn deserialize(bytes: &[u8]) -> Option<AuthenticationData> {
        match rmp_serde::decode::from_read(bytes) {
            Ok(authentication) => Some(authentication),
            Err(err) => {
                warn!("Unexpected error deserializing Authentication object: {err}");
                None
            }
        }
    }

    pub fn serialize(authentication: &AuthenticationData) -> Option<Vec<u8>> {
        let mut buf = Vec::new();
        let result = authentication.serialize(&mut Serializer::new(&mut buf));

        match result {
            Ok(_) => Some(buf),
            Err(err) => {
                warn!("Unexpected error serializing Authentication object: {err}");
                None
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use std::cell::RefCell;
    use std::collections::HashMap;

    use super::*;
    use classy::proxy_wasm::types::Bytes;

    const KEY_1: &str = "key1";
    const KEY_2: &str = "key2";

    const VALUE: &str = "value2";

    const PRINCIPAL: &str = "principal";
    const CLIENT_ID: &str = "client_id";
    const CLIENT_NAME: &str = "client_name";

    #[derive(Default)]
    struct MockPropertyAccessor {
        properties: RefCell<HashMap<Vec<String>, Option<Bytes>>>,
    }

    impl PropertyAccessor for MockPropertyAccessor {
        fn read_property(&self, path: &[&str]) -> Option<Bytes> {
            let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
            self.properties
                .take()
                .get(&path)
                .cloned()
                .unwrap_or_default()
        }

        fn set_property(&self, path: &[&str], value: Option<&[u8]>) {
            let path: Vec<String> = path.to_vec().iter().map(|x| x.to_string()).collect();
            let bytes = value.map(Bytes::from);
            self.properties.borrow_mut().insert(path.to_vec(), bytes);
        }
    }

    #[test]
    fn serialize_and_deserialize_authentication_to_bytes() {
        let auth = create_authentication();
        let property_accessor = MockPropertyAccessor::default();
        let auth_handler = Authentication::new(property_accessor);

        auth_handler.set_authentication(Some(&auth));
        let auth = auth_handler.authentication();

        assert_authentication(auth.clone());
        match auth.unwrap().properties {
            Value::Object(obj) => assert_eq!(obj.len(), 2),
            _ => panic!(),
        }
    }

    #[test]
    fn handler_get_empty() {
        let property_accessor = MockPropertyAccessor::default();
        let auth_handler = Authentication::new(property_accessor);

        let auth = auth_handler.authentication();

        assert!(auth.is_none())
    }

    #[test]
    fn handler_new_authentication_creates_auth_when_no_previous_data() {
        let property_accessor = MockPropertyAccessor::default();
        let auth_handler = Authentication::new(property_accessor);

        let new_auth = AuthenticationData::new(
            Some(PRINCIPAL.to_string()),
            Some(CLIENT_ID.to_string()),
            Some(CLIENT_NAME.to_string()),
            HashMap::from([
                (KEY_1.to_string(), Value::Bool(true)),
                (KEY_2.to_string(), Value::String(VALUE.to_string())),
            ]),
        );

        auth_handler.set_authentication(Some(&new_auth));

        let auth = auth_handler.authentication();

        assert_authentication(auth.clone());
        assert_eq!(new_auth, auth.unwrap());
    }

    fn assert_authentication(auth: Option<AuthenticationData>) {
        assert!(auth.is_some());
        let unwrapped = auth.unwrap();
        assert_eq!(unwrapped.principal, Some(PRINCIPAL.to_string()));
        assert_eq!(unwrapped.client_id, Some(CLIENT_ID.to_string()));
        assert_eq!(unwrapped.client_name, Some(CLIENT_NAME.to_string()));
        assert_eq!(
            unwrapped.properties.as_object().unwrap().get(KEY_1),
            Some(&Value::Bool(true))
        );
        assert_eq!(
            unwrapped.properties.as_object().unwrap().get(KEY_2),
            Some(&Value::String(VALUE.to_string()))
        );
    }

    fn create_authentication() -> AuthenticationData {
        AuthenticationData {
            principal: Some(PRINCIPAL.to_string()),
            client_id: Some(CLIENT_ID.to_string()),
            client_name: Some(CLIENT_NAME.to_string()),
            properties: HashMap::from([
                (KEY_1.to_string(), Value::Bool(true)),
                (KEY_2.to_string(), Value::String(VALUE.to_string())),
            ])
            .into_value(),
        }
    }
}