pdk-core 1.7.0-alpha.0

PDK Core
Documentation
// Copyright (c) 2025, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

//! Information regarding if a policy from the chain reached a scenario that can be considered
//! an expected error. E.g. a policy that checks credentials and they were invalid.

use classy::{
    extract::{context::ConfigureContext, Extract, FromContext},
    hl::{PropertyAccessor, StreamProperties},
};
use serde_derive::{Deserialize, Serialize};
use std::{convert::Infallible, fmt};

use crate::policy_context::api::Metadata;

const POLICY_VIOLATION_SEPARATOR: u8 = b'/';
const POLICY_REPORT_FIELDS: usize = 4;

#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
/// The type of policy violation.
pub enum PolicyViolationType {
    /// A normal error during policy execution.
    Violation = 0,
    /// Unexpected errors traversal to the policy business logic.
    Error = 1,
}

impl From<&[u8]> for PolicyViolationType {
    fn from(value: &[u8]) -> Self {
        let tag: u8 = value.first().copied().unwrap_or(0);
        match tag {
            1 => PolicyViolationType::Error,
            _ => PolicyViolationType::Violation,
        }
    }
}

impl fmt::Display for PolicyViolationType {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{self:?}")
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
/// The data regarding the policy violation.
pub struct PolicyViolation {
    policy_name: String,
    violation: PolicyViolationType,
    client_name: Option<String>,
    client_id: Option<String>,
}

impl PolicyViolation {
    pub fn new(
        policy_name: String,
        violation: PolicyViolationType,
        client_name: Option<String>,
        client_id: Option<String>,
    ) -> Self {
        PolicyViolation {
            policy_name,
            violation,
            client_name,
            client_id,
        }
    }

    /// Get the name of the policy that emitted the violation.
    pub fn get_policy_name(&self) -> &str {
        &self.policy_name
    }

    /// Get the type of error.
    pub fn get_policy_violation(&self) -> PolicyViolationType {
        self.violation
    }

    /// The client name that triggered the violation.
    pub fn get_client_name(&self) -> Option<&str> {
        self.client_name.as_deref()
    }

    /// The client ID that triggered the violation.
    pub fn get_client_id(&self) -> Option<&str> {
        self.client_id.as_deref()
    }
}

/// Interface to access and modify the violation associated to the current request.
pub struct PolicyViolations {
    stream_properties: Box<dyn PropertyAccessor>,
    policy_name: String,
}

const POLICY_VIOLATION_PROPERTY: &[&str] = &["policy_violation"];

impl PolicyViolations {
    pub fn new<T: PropertyAccessor + 'static>(stream_properties: T, policy_name: String) -> Self {
        Self {
            stream_properties: Box::new(stream_properties),
            policy_name,
        }
    }

    /// Returns the existing policy violation associated to the current request.
    pub fn policy_violation(&self) -> Option<PolicyViolation> {
        self.stream_properties
            .read_property(POLICY_VIOLATION_PROPERTY)
            .as_deref()
            .and_then(deserialize_policy_violation)

        // TODO W-17473828: Check native policy violation existence if no regular policy violation.
    }

    /// Generates a new policy violation for the current request. If one was already generated, it is overridden
    pub fn generate_policy_violation(&self) {
        let policy_violation = PolicyViolation::new(
            self.policy_name.clone(),
            PolicyViolationType::Violation,
            None,
            None,
        );
        self.report(policy_violation);
    }

    /// Generates a new policy violation for the current request additionally informing the associated client app. If one was already generated, it is overridden
    pub fn generate_policy_violation_for_client_app<T: Into<String>, K: Into<String>>(
        &self,
        client_name: T,
        client_id: K,
    ) {
        let policy_violation = PolicyViolation::new(
            self.policy_name.clone(),
            PolicyViolationType::Violation,
            Some(client_name.into()),
            Some(client_id.into()),
        );
        self.report(policy_violation);
    }

    fn report(&self, policy_violation: PolicyViolation) {
        let serialized_report = serialize_policy_violation(policy_violation);
        self.stream_properties
            .set_property(POLICY_VIOLATION_PROPERTY, Some(&serialized_report))
    }
}

impl FromContext<ConfigureContext> for PolicyViolations {
    type Error = Infallible;

    fn from_context(context: &ConfigureContext) -> Result<Self, Self::Error> {
        let metadata: Metadata = context.extract()?;
        let stream_properties: StreamProperties = context.extract()?;
        Ok(PolicyViolations::new(
            stream_properties,
            metadata.policy_metadata.policy_name,
        ))
    }
}

fn serialize_policy_violation(report: PolicyViolation) -> Vec<u8> {
    let name = report.get_client_name().unwrap_or("");
    let id = report.get_client_id().unwrap_or("");
    let mut ser_bytes: Vec<u8> = Vec::with_capacity(
        report.get_policy_name().len() + name.len() + id.len() + POLICY_REPORT_FIELDS,
    );
    let policy_violation_separator_bytes = &[POLICY_VIOLATION_SEPARATOR];
    ser_bytes.extend_from_slice(report.get_policy_name().as_bytes());
    ser_bytes.extend_from_slice(policy_violation_separator_bytes);
    ser_bytes.extend_from_slice(&[report.get_policy_violation() as u8]);
    ser_bytes.extend_from_slice(policy_violation_separator_bytes);
    ser_bytes.extend_from_slice(name.as_bytes());
    ser_bytes.extend_from_slice(policy_violation_separator_bytes);
    ser_bytes.extend_from_slice(id.as_bytes());

    ser_bytes
}

fn deserialize_policy_violation(report: &[u8]) -> Option<PolicyViolation> {
    let mut parts = report.splitn(POLICY_REPORT_FIELDS, |b| *b == POLICY_VIOLATION_SEPARATOR);

    // Deserialization of fields is based on order used in `serialize_policy_violation`
    let policy_name = parts.next()?;
    let policy_violation = parts.next()?;
    let client_name_bytes = parts.next()?;
    let client_name = if client_name_bytes.is_empty() {
        None
    } else {
        Some(String::from_utf8_lossy(client_name_bytes).into_owned())
    };
    let client_id_bytes = parts.next()?;
    let client_id = if client_id_bytes.is_empty() {
        None
    } else {
        Some(String::from_utf8_lossy(client_id_bytes).into_owned())
    };
    Some(PolicyViolation {
        policy_name: String::from_utf8_lossy(policy_name).into_owned(),
        violation: PolicyViolationType::from(policy_violation),
        client_name,
        client_id,
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn successful_serialization() {
        let report = PolicyViolation {
            policy_name: String::from("a_policy_name"),
            violation: PolicyViolationType::Error,
            client_name: Some("app name".to_string()),
            client_id: Some("id".to_string()),
        };
        let serialization = serialize_policy_violation(report);
        let deserialized_report = deserialize_policy_violation(&serialization).unwrap();
        assert_eq!(deserialized_report.policy_name, "a_policy_name");
        assert_eq!(deserialized_report.violation, PolicyViolationType::Error);
        assert_eq!(
            deserialized_report.client_name,
            Some("app name".to_string())
        );
        assert_eq!(deserialized_report.client_id, Some("id".to_string()));
    }

    #[test]
    fn missing_name_can_be_deserialized() {
        let report = PolicyViolation {
            policy_name: String::from("a_policy_name"),
            violation: PolicyViolationType::Error,
            client_name: None,
            client_id: None,
        };
        let serialization = serialize_policy_violation(report);
        let deserialized_report = deserialize_policy_violation(&serialization).unwrap();
        assert_eq!(deserialized_report.policy_name, "a_policy_name");
        assert_eq!(deserialized_report.violation, PolicyViolationType::Error);
        assert_eq!(deserialized_report.client_name, None);
        assert_eq!(deserialized_report.client_id, None);
    }

    #[test]
    fn cant_deserialize_with_missing_violation_type() {
        let wrong_serialization = String::from("the_policy_name");
        let deserialized_report = deserialize_policy_violation(wrong_serialization.as_bytes());
        assert!(deserialized_report.is_none());
    }
}