use classy::{
extract::{context::ConfigureContext, Extract, FromContext},
hl::{PropertyAccessor, StreamProperties},
};
use serde_derive::{Deserialize, Serialize};
use std::{convert::Infallible, fmt};
use crate::policy_context::api::Metadata;
const POLICY_VIOLATION_SEPARATOR: u8 = b'/';
const POLICY_REPORT_FIELDS: usize = 4;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum PolicyViolationType {
Violation = 0,
Error = 1,
}
impl From<&[u8]> for PolicyViolationType {
fn from(value: &[u8]) -> Self {
let tag: u8 = value.first().copied().unwrap_or(0);
match tag {
1 => PolicyViolationType::Error,
_ => PolicyViolationType::Violation,
}
}
}
impl fmt::Display for PolicyViolationType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self:?}")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyViolation {
policy_name: String,
violation: PolicyViolationType,
client_name: Option<String>,
client_id: Option<String>,
}
impl PolicyViolation {
pub fn new(
policy_name: String,
violation: PolicyViolationType,
client_name: Option<String>,
client_id: Option<String>,
) -> Self {
PolicyViolation {
policy_name,
violation,
client_name,
client_id,
}
}
pub fn get_policy_name(&self) -> &str {
&self.policy_name
}
pub fn get_policy_violation(&self) -> PolicyViolationType {
self.violation
}
pub fn get_client_name(&self) -> Option<&str> {
self.client_name.as_deref()
}
pub fn get_client_id(&self) -> Option<&str> {
self.client_id.as_deref()
}
}
pub struct PolicyViolations {
stream_properties: Box<dyn PropertyAccessor>,
policy_name: String,
}
const POLICY_VIOLATION_PROPERTY: &[&str] = &["policy_violation"];
impl PolicyViolations {
pub fn new<T: PropertyAccessor + 'static>(stream_properties: T, policy_name: String) -> Self {
Self {
stream_properties: Box::new(stream_properties),
policy_name,
}
}
pub fn policy_violation(&self) -> Option<PolicyViolation> {
self.stream_properties
.read_property(POLICY_VIOLATION_PROPERTY)
.as_deref()
.and_then(deserialize_policy_violation)
}
pub fn generate_policy_violation(&self) {
let policy_violation = PolicyViolation::new(
self.policy_name.clone(),
PolicyViolationType::Violation,
None,
None,
);
self.report(policy_violation);
}
pub fn generate_policy_violation_for_client_app<T: Into<String>, K: Into<String>>(
&self,
client_name: T,
client_id: K,
) {
let policy_violation = PolicyViolation::new(
self.policy_name.clone(),
PolicyViolationType::Violation,
Some(client_name.into()),
Some(client_id.into()),
);
self.report(policy_violation);
}
fn report(&self, policy_violation: PolicyViolation) {
let serialized_report = serialize_policy_violation(policy_violation);
self.stream_properties
.set_property(POLICY_VIOLATION_PROPERTY, Some(&serialized_report))
}
}
impl FromContext<ConfigureContext> for PolicyViolations {
type Error = Infallible;
fn from_context(context: &ConfigureContext) -> Result<Self, Self::Error> {
let metadata: Metadata = context.extract()?;
let stream_properties: StreamProperties = context.extract()?;
Ok(PolicyViolations::new(
stream_properties,
metadata.policy_metadata.policy_name,
))
}
}
fn serialize_policy_violation(report: PolicyViolation) -> Vec<u8> {
let name = report.get_client_name().unwrap_or("");
let id = report.get_client_id().unwrap_or("");
let mut ser_bytes: Vec<u8> = Vec::with_capacity(
report.get_policy_name().len() + name.len() + id.len() + POLICY_REPORT_FIELDS,
);
let policy_violation_separator_bytes = &[POLICY_VIOLATION_SEPARATOR];
ser_bytes.extend_from_slice(report.get_policy_name().as_bytes());
ser_bytes.extend_from_slice(policy_violation_separator_bytes);
ser_bytes.extend_from_slice(&[report.get_policy_violation() as u8]);
ser_bytes.extend_from_slice(policy_violation_separator_bytes);
ser_bytes.extend_from_slice(name.as_bytes());
ser_bytes.extend_from_slice(policy_violation_separator_bytes);
ser_bytes.extend_from_slice(id.as_bytes());
ser_bytes
}
fn deserialize_policy_violation(report: &[u8]) -> Option<PolicyViolation> {
let mut parts = report.splitn(POLICY_REPORT_FIELDS, |b| *b == POLICY_VIOLATION_SEPARATOR);
let policy_name = parts.next()?;
let policy_violation = parts.next()?;
let client_name_bytes = parts.next()?;
let client_name = if client_name_bytes.is_empty() {
None
} else {
Some(String::from_utf8_lossy(client_name_bytes).into_owned())
};
let client_id_bytes = parts.next()?;
let client_id = if client_id_bytes.is_empty() {
None
} else {
Some(String::from_utf8_lossy(client_id_bytes).into_owned())
};
Some(PolicyViolation {
policy_name: String::from_utf8_lossy(policy_name).into_owned(),
violation: PolicyViolationType::from(policy_violation),
client_name,
client_id,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn successful_serialization() {
let report = PolicyViolation {
policy_name: String::from("a_policy_name"),
violation: PolicyViolationType::Error,
client_name: Some("app name".to_string()),
client_id: Some("id".to_string()),
};
let serialization = serialize_policy_violation(report);
let deserialized_report = deserialize_policy_violation(&serialization).unwrap();
assert_eq!(deserialized_report.policy_name, "a_policy_name");
assert_eq!(deserialized_report.violation, PolicyViolationType::Error);
assert_eq!(
deserialized_report.client_name,
Some("app name".to_string())
);
assert_eq!(deserialized_report.client_id, Some("id".to_string()));
}
#[test]
fn missing_name_can_be_deserialized() {
let report = PolicyViolation {
policy_name: String::from("a_policy_name"),
violation: PolicyViolationType::Error,
client_name: None,
client_id: None,
};
let serialization = serialize_policy_violation(report);
let deserialized_report = deserialize_policy_violation(&serialization).unwrap();
assert_eq!(deserialized_report.policy_name, "a_policy_name");
assert_eq!(deserialized_report.violation, PolicyViolationType::Error);
assert_eq!(deserialized_report.client_name, None);
assert_eq!(deserialized_report.client_id, None);
}
#[test]
fn cant_deserialize_with_missing_violation_type() {
let wrong_serialization = String::from("the_policy_name");
let deserialized_report = deserialize_policy_violation(wrong_serialization.as_bytes());
assert!(deserialized_report.is_none());
}
}