pdk_core/policy_context/
policy_violation.rs1use std::{convert::Infallible, fmt};
9
10use classy::{
11 extract::{context::ConfigureContext, Extract, FromContext},
12 hl::{PropertyAccessor, StreamProperties},
13};
14
15use crate::policy_context::api::Metadata;
16
17const POLICY_VIOLATION_SEPARATOR: u8 = b'/';
18const POLICY_REPORT_FIELDS: usize = 4;
19
20#[derive(Clone, Copy, Debug, PartialEq, Eq)]
21pub enum PolicyViolationType {
23 Violation = 0,
25 Error = 1,
27}
28
29impl From<&[u8]> for PolicyViolationType {
30 fn from(value: &[u8]) -> Self {
31 let tag: u8 = value.first().copied().unwrap_or(0);
32 match tag {
33 1 => PolicyViolationType::Error,
34 _ => PolicyViolationType::Violation,
35 }
36 }
37}
38
39impl fmt::Display for PolicyViolationType {
40 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
41 write!(f, "{self:?}")
42 }
43}
44
45#[derive(Debug, Clone)]
46pub struct PolicyViolation {
48 policy_name: String,
49 violation: PolicyViolationType,
50 client_name: Option<String>,
51 client_id: Option<String>,
52}
53
54impl PolicyViolation {
55 pub fn new(
56 policy_name: String,
57 violation: PolicyViolationType,
58 client_name: Option<String>,
59 client_id: Option<String>,
60 ) -> Self {
61 PolicyViolation {
62 policy_name,
63 violation,
64 client_name,
65 client_id,
66 }
67 }
68
69 pub fn get_policy_name(&self) -> &str {
71 &self.policy_name
72 }
73
74 pub fn get_policy_violation(&self) -> PolicyViolationType {
76 self.violation
77 }
78
79 pub fn get_client_name(&self) -> Option<&str> {
81 self.client_name.as_deref()
82 }
83
84 pub fn get_client_id(&self) -> Option<&str> {
86 self.client_id.as_deref()
87 }
88}
89
90pub struct PolicyViolations {
92 stream_properties: Box<dyn PropertyAccessor>,
93 policy_name: String,
94}
95
96const POLICY_VIOLATION_PROPERTY: &[&str] = &["policy_violation"];
97
98impl PolicyViolations {
99 pub fn new<T: PropertyAccessor + 'static>(stream_properties: T, policy_name: String) -> Self {
100 Self {
101 stream_properties: Box::new(stream_properties),
102 policy_name,
103 }
104 }
105
106 pub fn policy_violation(&self) -> Option<PolicyViolation> {
108 self.stream_properties
109 .read_property(POLICY_VIOLATION_PROPERTY)
110 .as_deref()
111 .and_then(deserialize_policy_violation)
112
113 }
115
116 pub fn generate_policy_violation(&self) {
118 let policy_violation = PolicyViolation::new(
119 self.policy_name.clone(),
120 PolicyViolationType::Violation,
121 None,
122 None,
123 );
124 self.report(policy_violation);
125 }
126
127 pub fn generate_policy_violation_for_client_app<T: Into<String>, K: Into<String>>(
129 &self,
130 client_name: T,
131 client_id: K,
132 ) {
133 let policy_violation = PolicyViolation::new(
134 self.policy_name.clone(),
135 PolicyViolationType::Violation,
136 Some(client_name.into()),
137 Some(client_id.into()),
138 );
139 self.report(policy_violation);
140 }
141
142 fn report(&self, policy_violation: PolicyViolation) {
143 let serialized_report = serialize_policy_violation(policy_violation);
144 self.stream_properties
145 .set_property(POLICY_VIOLATION_PROPERTY, Some(&serialized_report))
146 }
147}
148
149impl FromContext<ConfigureContext> for PolicyViolations {
150 type Error = Infallible;
151
152 fn from_context(context: &ConfigureContext) -> Result<Self, Self::Error> {
153 let metadata: Metadata = context.extract()?;
154 let stream_properties: StreamProperties = context.extract()?;
155 Ok(PolicyViolations::new(
156 stream_properties,
157 metadata.policy_metadata.policy_name,
158 ))
159 }
160}
161
162fn serialize_policy_violation(report: PolicyViolation) -> Vec<u8> {
163 let name = report.get_client_name().unwrap_or("");
164 let id = report.get_client_id().unwrap_or("");
165 let mut ser_bytes: Vec<u8> = Vec::with_capacity(
166 report.get_policy_name().len() + name.len() + id.len() + POLICY_REPORT_FIELDS,
167 );
168 let policy_violation_separator_bytes = &[POLICY_VIOLATION_SEPARATOR];
169 ser_bytes.extend_from_slice(report.get_policy_name().as_bytes());
170 ser_bytes.extend_from_slice(policy_violation_separator_bytes);
171 ser_bytes.extend_from_slice(&[report.get_policy_violation() as u8]);
172 ser_bytes.extend_from_slice(policy_violation_separator_bytes);
173 ser_bytes.extend_from_slice(name.as_bytes());
174 ser_bytes.extend_from_slice(policy_violation_separator_bytes);
175 ser_bytes.extend_from_slice(id.as_bytes());
176
177 ser_bytes
178}
179
180fn deserialize_policy_violation(report: &[u8]) -> Option<PolicyViolation> {
181 let mut parts = report.splitn(POLICY_REPORT_FIELDS, |b| *b == POLICY_VIOLATION_SEPARATOR);
182
183 let policy_name = parts.next()?;
185 let policy_violation = parts.next()?;
186 let client_name_bytes = parts.next()?;
187 let client_name = if client_name_bytes.is_empty() {
188 None
189 } else {
190 Some(String::from_utf8_lossy(client_name_bytes).into_owned())
191 };
192 let client_id_bytes = parts.next()?;
193 let client_id = if client_id_bytes.is_empty() {
194 None
195 } else {
196 Some(String::from_utf8_lossy(client_id_bytes).into_owned())
197 };
198 Some(PolicyViolation {
199 policy_name: String::from_utf8_lossy(policy_name).into_owned(),
200 violation: PolicyViolationType::from(policy_violation),
201 client_name,
202 client_id,
203 })
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn successful_serialization() {
212 let report = PolicyViolation {
213 policy_name: String::from("a_policy_name"),
214 violation: PolicyViolationType::Error,
215 client_name: Some("app name".to_string()),
216 client_id: Some("id".to_string()),
217 };
218 let serialization = serialize_policy_violation(report);
219 let deserialized_report = deserialize_policy_violation(&serialization).unwrap();
220 assert_eq!(deserialized_report.policy_name, "a_policy_name");
221 assert_eq!(deserialized_report.violation, PolicyViolationType::Error);
222 assert_eq!(
223 deserialized_report.client_name,
224 Some("app name".to_string())
225 );
226 assert_eq!(deserialized_report.client_id, Some("id".to_string()));
227 }
228
229 #[test]
230 fn missing_name_can_be_deserialized() {
231 let report = PolicyViolation {
232 policy_name: String::from("a_policy_name"),
233 violation: PolicyViolationType::Error,
234 client_name: None,
235 client_id: None,
236 };
237 let serialization = serialize_policy_violation(report);
238 let deserialized_report = deserialize_policy_violation(&serialization).unwrap();
239 assert_eq!(deserialized_report.policy_name, "a_policy_name");
240 assert_eq!(deserialized_report.violation, PolicyViolationType::Error);
241 assert_eq!(deserialized_report.client_name, None);
242 assert_eq!(deserialized_report.client_id, None);
243 }
244
245 #[test]
246 fn cant_deserialize_with_missing_violation_type() {
247 let wrong_serialization = String::from("the_policy_name");
248 let deserialized_report = deserialize_policy_violation(wrong_serialization.as_bytes());
249 assert!(deserialized_report.is_none());
250 }
251}