use std::fmt;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
pub type GatekeepResult<T> = Result<T, GatekeepError>;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum GatekeepError {
#[error("{field} must not be empty")]
EmptyIdentifier {
field: &'static str,
},
#[error("invalid locale tag: {value}")]
InvalidLocale {
value: String,
},
#[error("policy record is invalid: {reason}")]
InvalidPolicyRecord {
reason: &'static str,
},
}
fn validate_identifier(field: &'static str, value: impl Into<String>) -> GatekeepResult<String> {
let value = value.into();
if value.trim().is_empty() {
Err(GatekeepError::EmptyIdentifier { field })
} else {
Ok(value)
}
}
fn validate_locale(value: impl Into<String>) -> GatekeepResult<String> {
let value = value.into();
let valid = !value.trim().is_empty()
&& value
.bytes()
.all(|byte| byte.is_ascii_alphanumeric() || byte == b'-');
if valid {
Ok(value)
} else {
Err(GatekeepError::InvalidLocale { value })
}
}
macro_rules! owned_id {
($name:ident, $field:literal) => {
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct $name(String);
impl $name {
pub fn new(value: impl Into<String>) -> GatekeepResult<Self> {
validate_identifier($field, value).map(Self)
}
#[allow(dead_code)]
pub(crate) fn from_trusted(value: impl Into<String>) -> Self {
Self(value.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for $name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl Serialize for $name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.0.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for $name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
Self::new(value).map_err(serde::de::Error::custom)
}
}
};
}
macro_rules! static_id {
($name:ident, $owned:ident) => {
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct $name(&'static str);
impl $name {
#[must_use]
pub const fn new(value: &'static str) -> Self {
assert_valid_static_id(value);
Self(value)
}
#[must_use]
pub const fn as_str(self) -> &'static str {
self.0
}
pub fn to_owned_id(self) -> GatekeepResult<$owned> {
$owned::new(self.0)
}
}
};
}
const fn assert_valid_static_id(value: &str) {
let bytes = value.as_bytes();
assert!(!bytes.is_empty(), "static identity must not be empty");
let mut index = 0;
let mut has_non_whitespace = false;
while index < bytes.len() {
let byte = bytes[index];
if !(byte == b' ' || byte == b'\n' || byte == b'\r' || byte == b'\t') {
has_non_whitespace = true;
}
index += 1;
}
assert!(has_non_whitespace, "static identity must not be whitespace");
}
owned_id!(FactId, "fact_id");
owned_id!(ClauseLabel, "clause_label");
owned_id!(ObligationId, "obligation_id");
owned_id!(ParamKey, "param_key");
owned_id!(PolicyHash, "policy_hash");
owned_id!(PolicyId, "policy_id");
owned_id!(ReasonCode, "reason_code");
owned_id!(RequestId, "request_id");
owned_id!(TenantId, "tenant_id");
static_id!(StaticFactId, FactId);
static_id!(StaticClauseLabel, ClauseLabel);
static_id!(StaticObligationId, ObligationId);
static_id!(StaticParamKey, ParamKey);
static_id!(StaticReasonCode, ReasonCode);
static_id!(StaticRequestId, RequestId);
static_id!(StaticTenantId, TenantId);
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Locale(String);
impl Locale {
pub fn new(value: impl Into<String>) -> GatekeepResult<Self> {
validate_locale(value).map(Self)
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Serialize for Locale {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.0.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Locale {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
Self::new(value).map_err(serde::de::Error::custom)
}
}
pub trait Fact {
const ID: StaticFactId;
}
pub trait ObligationSpec {
const ID: StaticObligationId;
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
pub struct SubjectRef {
kind: String,
id: String,
}
impl SubjectRef {
pub fn new(kind: impl Into<String>, id: impl Into<String>) -> GatekeepResult<Self> {
Ok(Self {
kind: validate_identifier("subject_kind", kind)?,
id: validate_identifier("subject_id", id)?,
})
}
#[must_use]
pub fn kind(&self) -> &str {
&self.kind
}
#[must_use]
pub fn id(&self) -> &str {
&self.id
}
}
impl<'de> Deserialize<'de> for SubjectRef {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct SubjectRefRecord {
kind: String,
id: String,
}
let record = SubjectRefRecord::deserialize(deserializer)?;
Self::new(record.kind, record.id).map_err(serde::de::Error::custom)
}
}