pdk_core/policy_context/
policy_violation.rs

1// Copyright (c) 2025, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5//! Information regarding if a policy from the chain reached a scenario that can be considered
6//! an expected error. E.g. a policy that checks credentials and they were invalid.
7
8use std::{convert::Infallible, fmt};
9
10use classy::{
11    extract::{context::ConfigureContext, Extract, FromContext},
12    hl::{PropertyAccessor, StreamProperties},
13};
14
15use crate::policy_context::api::Metadata;
16
17const POLICY_VIOLATION_SEPARATOR: u8 = b'/';
18const POLICY_REPORT_FIELDS: usize = 4;
19
20#[derive(Clone, Copy, Debug, PartialEq, Eq)]
21/// The type of policy violation.
22pub enum PolicyViolationType {
23    /// A normal error during policy execution.
24    Violation = 0,
25    /// Unexpected errors traversal to the policy business logic.
26    Error = 1,
27}
28
29impl From<&[u8]> for PolicyViolationType {
30    fn from(value: &[u8]) -> Self {
31        let tag: u8 = value.first().copied().unwrap_or(0);
32        match tag {
33            1 => PolicyViolationType::Error,
34            _ => PolicyViolationType::Violation,
35        }
36    }
37}
38
39impl fmt::Display for PolicyViolationType {
40    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
41        write!(f, "{self:?}")
42    }
43}
44
45#[derive(Debug, Clone)]
46/// The data regarding the policy violation.
47pub struct PolicyViolation {
48    policy_name: String,
49    violation: PolicyViolationType,
50    client_name: Option<String>,
51    client_id: Option<String>,
52}
53
54impl PolicyViolation {
55    pub fn new(
56        policy_name: String,
57        violation: PolicyViolationType,
58        client_name: Option<String>,
59        client_id: Option<String>,
60    ) -> Self {
61        PolicyViolation {
62            policy_name,
63            violation,
64            client_name,
65            client_id,
66        }
67    }
68
69    /// Get the name of the policy that emitted the violation.
70    pub fn get_policy_name(&self) -> &str {
71        &self.policy_name
72    }
73
74    /// Get the type of error.
75    pub fn get_policy_violation(&self) -> PolicyViolationType {
76        self.violation
77    }
78
79    /// The client name that triggered the violation.
80    pub fn get_client_name(&self) -> Option<&str> {
81        self.client_name.as_deref()
82    }
83
84    /// The client ID that triggered the violation.
85    pub fn get_client_id(&self) -> Option<&str> {
86        self.client_id.as_deref()
87    }
88}
89
90/// Interface to access and modify the violation associated to the current request.
91pub struct PolicyViolations {
92    stream_properties: Box<dyn PropertyAccessor>,
93    policy_name: String,
94}
95
96const POLICY_VIOLATION_PROPERTY: &[&str] = &["policy_violation"];
97
98impl PolicyViolations {
99    pub fn new<T: PropertyAccessor + 'static>(stream_properties: T, policy_name: String) -> Self {
100        Self {
101            stream_properties: Box::new(stream_properties),
102            policy_name,
103        }
104    }
105
106    /// Returns the existing policy violation associated to the current request.
107    pub fn policy_violation(&self) -> Option<PolicyViolation> {
108        self.stream_properties
109            .read_property(POLICY_VIOLATION_PROPERTY)
110            .as_deref()
111            .and_then(deserialize_policy_violation)
112
113        // TODO W-17473828: Check native policy violation existence if no regular policy violation.
114    }
115
116    /// Generates a new policy violation for the current request. If one was already generated, it is overridden
117    pub fn generate_policy_violation(&self) {
118        let policy_violation = PolicyViolation::new(
119            self.policy_name.clone(),
120            PolicyViolationType::Violation,
121            None,
122            None,
123        );
124        self.report(policy_violation);
125    }
126
127    /// Generates a new policy violation for the current request additionally informing the associated client app. If one was already generated, it is overridden
128    pub fn generate_policy_violation_for_client_app<T: Into<String>, K: Into<String>>(
129        &self,
130        client_name: T,
131        client_id: K,
132    ) {
133        let policy_violation = PolicyViolation::new(
134            self.policy_name.clone(),
135            PolicyViolationType::Violation,
136            Some(client_name.into()),
137            Some(client_id.into()),
138        );
139        self.report(policy_violation);
140    }
141
142    fn report(&self, policy_violation: PolicyViolation) {
143        let serialized_report = serialize_policy_violation(policy_violation);
144        self.stream_properties
145            .set_property(POLICY_VIOLATION_PROPERTY, Some(&serialized_report))
146    }
147}
148
149impl FromContext<ConfigureContext> for PolicyViolations {
150    type Error = Infallible;
151
152    fn from_context(context: &ConfigureContext) -> Result<Self, Self::Error> {
153        let metadata: Metadata = context.extract()?;
154        let stream_properties: StreamProperties = context.extract()?;
155        Ok(PolicyViolations::new(
156            stream_properties,
157            metadata.policy_metadata.policy_name,
158        ))
159    }
160}
161
162fn serialize_policy_violation(report: PolicyViolation) -> Vec<u8> {
163    let name = report.get_client_name().unwrap_or("");
164    let id = report.get_client_id().unwrap_or("");
165    let mut ser_bytes: Vec<u8> = Vec::with_capacity(
166        report.get_policy_name().len() + name.len() + id.len() + POLICY_REPORT_FIELDS,
167    );
168    let policy_violation_separator_bytes = &[POLICY_VIOLATION_SEPARATOR];
169    ser_bytes.extend_from_slice(report.get_policy_name().as_bytes());
170    ser_bytes.extend_from_slice(policy_violation_separator_bytes);
171    ser_bytes.extend_from_slice(&[report.get_policy_violation() as u8]);
172    ser_bytes.extend_from_slice(policy_violation_separator_bytes);
173    ser_bytes.extend_from_slice(name.as_bytes());
174    ser_bytes.extend_from_slice(policy_violation_separator_bytes);
175    ser_bytes.extend_from_slice(id.as_bytes());
176
177    ser_bytes
178}
179
180fn deserialize_policy_violation(report: &[u8]) -> Option<PolicyViolation> {
181    let mut parts = report.splitn(POLICY_REPORT_FIELDS, |b| *b == POLICY_VIOLATION_SEPARATOR);
182
183    // Deserialization of fields is based on order used in `serialize_policy_violation`
184    let policy_name = parts.next()?;
185    let policy_violation = parts.next()?;
186    let client_name_bytes = parts.next()?;
187    let client_name = if client_name_bytes.is_empty() {
188        None
189    } else {
190        Some(String::from_utf8_lossy(client_name_bytes).into_owned())
191    };
192    let client_id_bytes = parts.next()?;
193    let client_id = if client_id_bytes.is_empty() {
194        None
195    } else {
196        Some(String::from_utf8_lossy(client_id_bytes).into_owned())
197    };
198    Some(PolicyViolation {
199        policy_name: String::from_utf8_lossy(policy_name).into_owned(),
200        violation: PolicyViolationType::from(policy_violation),
201        client_name,
202        client_id,
203    })
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn successful_serialization() {
212        let report = PolicyViolation {
213            policy_name: String::from("a_policy_name"),
214            violation: PolicyViolationType::Error,
215            client_name: Some("app name".to_string()),
216            client_id: Some("id".to_string()),
217        };
218        let serialization = serialize_policy_violation(report);
219        let deserialized_report = deserialize_policy_violation(&serialization).unwrap();
220        assert_eq!(deserialized_report.policy_name, "a_policy_name");
221        assert_eq!(deserialized_report.violation, PolicyViolationType::Error);
222        assert_eq!(
223            deserialized_report.client_name,
224            Some("app name".to_string())
225        );
226        assert_eq!(deserialized_report.client_id, Some("id".to_string()));
227    }
228
229    #[test]
230    fn missing_name_can_be_deserialized() {
231        let report = PolicyViolation {
232            policy_name: String::from("a_policy_name"),
233            violation: PolicyViolationType::Error,
234            client_name: None,
235            client_id: None,
236        };
237        let serialization = serialize_policy_violation(report);
238        let deserialized_report = deserialize_policy_violation(&serialization).unwrap();
239        assert_eq!(deserialized_report.policy_name, "a_policy_name");
240        assert_eq!(deserialized_report.violation, PolicyViolationType::Error);
241        assert_eq!(deserialized_report.client_name, None);
242        assert_eq!(deserialized_report.client_id, None);
243    }
244
245    #[test]
246    fn cant_deserialize_with_missing_violation_type() {
247        let wrong_serialization = String::from("the_policy_name");
248        let deserialized_report = deserialize_policy_violation(wrong_serialization.as_bytes());
249        assert!(deserialized_report.is_none());
250    }
251}