use crate::error::{Error, Result};
use ipnetwork::IpNetwork;
use std::net::IpAddr;
pub const MAX_CONSTRAINT_DEPTH: u32 = 32;
use glob::Pattern as GlobPattern;
use regex::Regex as RegexPattern;
use serde::{Deserialize, Serialize};
use std::cell::Cell;
use std::collections::{BTreeMap, HashMap};
thread_local! {
static DESERIALIZATION_DEPTH: Cell<usize> = const { Cell::new(0) };
}
struct DepthGuard;
impl DepthGuard {
fn new<E: serde::de::Error>() -> std::result::Result<Self, E> {
DESERIALIZATION_DEPTH.with(|depth| {
let d = depth.get();
if d > MAX_CONSTRAINT_DEPTH as usize {
return Err(E::custom(format!(
"constraint recursion depth exceeded maximum of {}",
MAX_CONSTRAINT_DEPTH
)));
}
depth.set(d + 1);
Ok(DepthGuard)
})
}
}
impl Drop for DepthGuard {
fn drop(&mut self) {
DESERIALIZATION_DEPTH.with(|depth| {
depth.set(depth.get() - 1);
});
}
}
pub mod constraint_type_id {
pub const EXACT: u8 = 1;
pub const PATTERN: u8 = 2;
pub const RANGE: u8 = 3;
pub const ONE_OF: u8 = 4;
pub const REGEX: u8 = 5;
pub const RESERVED_INT_RANGE: u8 = 6;
pub const NOT_ONE_OF: u8 = 7;
pub const CIDR: u8 = 8;
pub const URL_PATTERN: u8 = 9;
pub const CONTAINS: u8 = 10;
pub const SUBSET: u8 = 11;
pub const ALL: u8 = 12;
pub const ANY: u8 = 13;
pub const NOT: u8 = 14;
pub const CEL: u8 = 15;
pub const WILDCARD: u8 = 16;
pub const SUBPATH: u8 = 17;
pub const URL_SAFE: u8 = 18;
pub const SHLEX: u8 = 128;
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum Constraint {
Wildcard(Wildcard),
Pattern(Pattern),
Regex(RegexConstraint),
Exact(Exact),
OneOf(OneOf),
NotOneOf(NotOneOf),
Range(Range),
Cidr(Cidr),
UrlPattern(UrlPattern),
Contains(Contains),
Subset(Subset),
All(All),
Any(Any),
Not(Not),
Cel(CelConstraint),
Subpath(Subpath),
UrlSafe(UrlSafe),
Shlex(Shlex),
Unknown { type_id: u8, payload: Vec<u8> },
}
impl Serialize for Constraint {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use constraint_type_id::*;
use serde::ser::SerializeTuple;
let mut tup = serializer.serialize_tuple(2)?;
match self {
Constraint::Exact(v) => {
tup.serialize_element(&EXACT)?;
tup.serialize_element(v)?;
}
Constraint::Pattern(v) => {
tup.serialize_element(&PATTERN)?;
tup.serialize_element(v)?;
}
Constraint::Range(v) => {
tup.serialize_element(&RANGE)?;
tup.serialize_element(v)?;
}
Constraint::OneOf(v) => {
tup.serialize_element(&ONE_OF)?;
tup.serialize_element(v)?;
}
Constraint::Regex(v) => {
tup.serialize_element(®EX)?;
tup.serialize_element(v)?;
}
Constraint::NotOneOf(v) => {
tup.serialize_element(&NOT_ONE_OF)?;
tup.serialize_element(v)?;
}
Constraint::Cidr(v) => {
tup.serialize_element(&CIDR)?;
tup.serialize_element(v)?;
}
Constraint::UrlPattern(v) => {
tup.serialize_element(&URL_PATTERN)?;
tup.serialize_element(v)?;
}
Constraint::Contains(v) => {
tup.serialize_element(&CONTAINS)?;
tup.serialize_element(v)?;
}
Constraint::Subset(v) => {
tup.serialize_element(&SUBSET)?;
tup.serialize_element(v)?;
}
Constraint::All(v) => {
tup.serialize_element(&ALL)?;
tup.serialize_element(v)?;
}
Constraint::Any(v) => {
tup.serialize_element(&ANY)?;
tup.serialize_element(v)?;
}
Constraint::Not(v) => {
tup.serialize_element(&NOT)?;
tup.serialize_element(v)?;
}
Constraint::Cel(v) => {
tup.serialize_element(&CEL)?;
tup.serialize_element(v)?;
}
Constraint::Wildcard(v) => {
tup.serialize_element(&WILDCARD)?;
tup.serialize_element(v)?;
}
Constraint::Subpath(v) => {
tup.serialize_element(&SUBPATH)?;
tup.serialize_element(v)?;
}
Constraint::UrlSafe(v) => {
tup.serialize_element(&URL_SAFE)?;
tup.serialize_element(v)?;
}
Constraint::Shlex(v) => {
tup.serialize_element(&SHLEX)?;
tup.serialize_element(v)?;
}
Constraint::Unknown { type_id, payload } => {
tup.serialize_element(type_id)?;
tup.serialize_element(&serde_bytes::Bytes::new(payload))?;
}
}
tup.end()
}
}
impl<'de> serde::Deserialize<'de> for Constraint {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use constraint_type_id::*;
use serde::de::{Error as DeError, SeqAccess, Visitor};
let _guard = DepthGuard::new::<D::Error>()?;
struct ConstraintVisitor;
impl<'de> Visitor<'de> for ConstraintVisitor {
type Value = Constraint;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a constraint array [type_id, value]")
}
fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let type_id: u8 = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(0, &self))?;
let constraint = match type_id {
EXACT => {
let v: Exact = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Exact(v)
}
PATTERN => {
let v: Pattern = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Pattern(v)
}
RANGE => {
let v: Range = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Range(v)
}
ONE_OF => {
let v: OneOf = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::OneOf(v)
}
REGEX => {
let v: RegexConstraint = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Regex(v)
}
NOT_ONE_OF => {
let v: NotOneOf = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::NotOneOf(v)
}
CIDR => {
let v: Cidr = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Cidr(v)
}
URL_PATTERN => {
let v: UrlPattern = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::UrlPattern(v)
}
CONTAINS => {
let v: Contains = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Contains(v)
}
SUBSET => {
let v: Subset = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Subset(v)
}
ALL => {
let v: All = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::All(v)
}
ANY => {
let v: Any = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Any(v)
}
NOT => {
let v: Not = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Not(v)
}
CEL => {
let v: CelConstraint = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Cel(v)
}
WILDCARD => {
let v: Wildcard = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Wildcard(v)
}
SUBPATH => {
let v: Subpath = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Subpath(v)
}
URL_SAFE => {
let v: UrlSafe = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::UrlSafe(v)
}
SHLEX => {
let v: Shlex = seq
.next_element()?
.ok_or_else(|| A::Error::invalid_length(1, &self))?;
Constraint::Shlex(v)
}
_ => {
let payload: Vec<u8> = seq
.next_element::<serde_bytes::ByteBuf>()?
.map(|b| b.into_vec())
.unwrap_or_default();
Constraint::Unknown { type_id, payload }
}
};
Ok(constraint)
}
}
let constraint = deserializer.deserialize_seq(ConstraintVisitor)?;
constraint
.validate_depth()
.map_err(serde::de::Error::custom)?;
Ok(constraint)
}
}
impl Constraint {
pub fn depth(&self) -> u32 {
match self {
Constraint::Wildcard(_)
| Constraint::Pattern(_)
| Constraint::Regex(_)
| Constraint::Exact(_)
| Constraint::OneOf(_)
| Constraint::NotOneOf(_)
| Constraint::Range(_)
| Constraint::Cidr(_)
| Constraint::UrlPattern(_)
| Constraint::Contains(_)
| Constraint::Subset(_)
| Constraint::Cel(_)
| Constraint::Subpath(_)
| Constraint::UrlSafe(_)
| Constraint::Shlex(_)
| Constraint::Unknown { .. } => 0,
Constraint::All(all) => {
1 + all.constraints.iter().map(|c| c.depth()).max().unwrap_or(0)
}
Constraint::Any(any) => {
1 + any.constraints.iter().map(|c| c.depth()).max().unwrap_or(0)
}
Constraint::Not(not) => 1 + not.constraint.depth(),
}
}
pub fn validate_depth(&self) -> Result<()> {
let depth = self.depth();
if depth > MAX_CONSTRAINT_DEPTH {
Err(Error::ConstraintDepthExceeded {
depth,
max: MAX_CONSTRAINT_DEPTH,
})
} else {
Ok(())
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
match self {
Constraint::Wildcard(_) => Ok(true), Constraint::Pattern(p) => p.matches(value),
Constraint::Regex(r) => r.matches(value),
Constraint::Exact(e) => e.matches(value),
Constraint::OneOf(o) => o.matches(value),
Constraint::NotOneOf(n) => n.matches(value),
Constraint::Range(r) => r.matches(value),
Constraint::Cidr(c) => c.matches(value),
Constraint::UrlPattern(u) => u.matches(value),
Constraint::Contains(c) => c.matches(value),
Constraint::Subset(s) => s.matches(value),
Constraint::All(a) => a.matches(value),
Constraint::Any(a) => a.matches(value),
Constraint::Not(n) => n.matches(value),
Constraint::Cel(c) => c.matches(value),
Constraint::Subpath(s) => s.matches(value),
Constraint::UrlSafe(u) => u.matches(value),
Constraint::Shlex(sh) => sh.matches(value),
Constraint::Unknown { type_id, .. } => Err(Error::ConstraintNotSatisfied {
field: "constraint".to_string(),
reason: format!("unknown constraint type ID {}", type_id),
}),
}
}
pub fn validate_attenuation(&self, child: &Constraint) -> Result<()> {
match (self, child) {
(Constraint::Wildcard(_), _) => Ok(()),
(_, Constraint::Wildcard(_)) => Err(Error::WildcardExpansion {
parent_type: self.type_name().to_string(),
}),
(Constraint::Pattern(parent), Constraint::Pattern(child_pat)) => {
parent.validate_attenuation(child_pat)
}
(Constraint::Pattern(parent), Constraint::Exact(child_exact)) => {
if parent.matches(&child_exact.value)? {
Ok(())
} else {
Err(Error::ValueNotInParentSet {
value: format!("{:?}", child_exact.value),
})
}
}
(Constraint::Regex(parent), Constraint::Regex(child_regex)) => {
parent.validate_attenuation(child_regex)
}
(Constraint::Regex(parent), Constraint::Exact(child_exact)) => {
if parent.matches(&child_exact.value)? {
Ok(())
} else {
Err(Error::ValueNotInParentSet {
value: format!("{:?}", child_exact.value),
})
}
}
(Constraint::Exact(parent), Constraint::Exact(child)) => {
if parent.value == child.value {
Ok(())
} else {
Err(Error::ExactValueMismatch {
parent: format!("{:?}", parent.value),
child: format!("{:?}", child.value),
})
}
}
(Constraint::OneOf(parent), Constraint::OneOf(child)) => {
parent.validate_attenuation(child)
}
(Constraint::OneOf(parent), Constraint::Exact(child)) => {
if parent.contains(&child.value) {
Ok(())
} else {
Err(Error::ValueNotInParentSet {
value: format!("{:?}", child.value),
})
}
}
(Constraint::OneOf(_), Constraint::NotOneOf(_)) => {
Err(Error::IncompatibleConstraintTypes {
parent_type: "OneOf".to_string(),
child_type: "NotOneOf".to_string(),
})
}
(Constraint::NotOneOf(parent), Constraint::NotOneOf(child)) => {
parent.validate_attenuation(child)
}
(Constraint::Range(parent), Constraint::Range(child)) => {
parent.validate_attenuation(child)
}
(Constraint::Range(parent), Constraint::Exact(child_exact)) => {
match child_exact.value.as_number() {
Some(n) if parent.contains_value(n) => Ok(()),
Some(n) => Err(Error::ValueNotInRange {
value: n,
min: parent.min,
max: parent.max,
}),
None => Err(Error::IncompatibleConstraintTypes {
parent_type: "Range".to_string(),
child_type: "Exact (non-numeric)".to_string(),
}),
}
}
(Constraint::Cidr(parent), Constraint::Cidr(child)) => {
parent.validate_attenuation(child)
}
(Constraint::Cidr(parent), Constraint::Exact(child_exact)) => {
match child_exact.value.as_str() {
Some(ip_str) => {
if parent.contains_ip(ip_str)? {
Ok(())
} else {
Err(Error::IpNotInCidr {
ip: ip_str.to_string(),
cidr: parent.network.to_string(),
})
}
}
None => Err(Error::IncompatibleConstraintTypes {
parent_type: "Cidr".to_string(),
child_type: "Exact (non-string)".to_string(),
}),
}
}
(Constraint::UrlPattern(parent), Constraint::UrlPattern(child)) => {
parent.validate_attenuation(child)
}
(Constraint::UrlPattern(parent), Constraint::Exact(child_exact)) => {
match child_exact.value.as_str() {
Some(url_str) => {
if parent.matches_url(url_str)? {
Ok(())
} else {
Err(Error::UrlMismatch {
reason: format!(
"URL '{}' does not match pattern '{}'",
url_str, parent
),
})
}
}
None => Err(Error::IncompatibleConstraintTypes {
parent_type: "UrlPattern".to_string(),
child_type: "Exact (non-string)".to_string(),
}),
}
}
(Constraint::Contains(parent), Constraint::Contains(child)) => {
parent.validate_attenuation(child)
}
(Constraint::Subset(parent), Constraint::Subset(child)) => {
parent.validate_attenuation(child)
}
(Constraint::All(parent), Constraint::All(child)) => parent.validate_attenuation(child),
(Constraint::Cel(parent), Constraint::Cel(child)) => parent.validate_attenuation(child),
(Constraint::Subpath(parent), Constraint::Subpath(child)) => {
parent.validate_attenuation(child)
}
(Constraint::Subpath(parent), Constraint::Exact(child_exact)) => {
match child_exact.value.as_str() {
Some(path_str) => {
if parent.contains_path(path_str)? {
Ok(())
} else {
Err(Error::PathNotContained {
path: path_str.to_string(),
root: parent.root.clone(),
})
}
}
None => Err(Error::IncompatibleConstraintTypes {
parent_type: "Subpath".to_string(),
child_type: "Exact (non-string)".to_string(),
}),
}
}
(Constraint::UrlSafe(parent), Constraint::UrlSafe(child)) => {
parent.validate_attenuation(child)
}
(Constraint::UrlSafe(parent), Constraint::Exact(child_exact)) => {
match child_exact.value.as_str() {
Some(url_str) => {
if parent.is_safe(url_str)? {
Ok(())
} else {
Err(Error::UrlNotSafe {
url: url_str.to_string(),
reason: "URL blocked by UrlSafe constraint".to_string(),
})
}
}
None => Err(Error::IncompatibleConstraintTypes {
parent_type: "UrlSafe".to_string(),
child_type: "Exact (non-string)".to_string(),
}),
}
}
(Constraint::Shlex(parent), Constraint::Shlex(child)) => {
parent.validate_attenuation(child)
}
(Constraint::Shlex(parent), Constraint::Exact(child_exact)) => {
match child_exact.value.as_str() {
Some(cmd_str) => {
if parent.matches(&ConstraintValue::String(cmd_str.to_string()))? {
Ok(())
} else {
Err(Error::ConstraintNotSatisfied {
field: "command".to_string(),
reason: format!(
"command '{}' rejected by Shlex constraint",
cmd_str
),
})
}
}
None => Err(Error::IncompatibleConstraintTypes {
parent_type: "Shlex".to_string(),
child_type: "Exact (non-string)".to_string(),
}),
}
}
_ => Err(Error::IncompatibleConstraintTypes {
parent_type: self.type_name().to_string(),
child_type: child.type_name().to_string(),
}),
}
}
pub fn type_name(&self) -> &'static str {
match self {
Constraint::Wildcard(_) => "Wildcard",
Constraint::Pattern(_) => "Pattern",
Constraint::Regex(_) => "Regex",
Constraint::Exact(_) => "Exact",
Constraint::OneOf(_) => "OneOf",
Constraint::Cidr(_) => "Cidr",
Constraint::UrlPattern(_) => "UrlPattern",
Constraint::NotOneOf(_) => "NotOneOf",
Constraint::Range(_) => "Range",
Constraint::Contains(_) => "Contains",
Constraint::Subset(_) => "Subset",
Constraint::All(_) => "All",
Constraint::Any(_) => "Any",
Constraint::Not(_) => "Not",
Constraint::Cel(_) => "Cel",
Constraint::Subpath(_) => "Subpath",
Constraint::UrlSafe(_) => "UrlSafe",
Constraint::Shlex(_) => "Shlex",
Constraint::Unknown { .. } => "Unknown",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ConstraintValue {
String(String),
Integer(i64),
Float(f64),
Boolean(bool),
List(Vec<ConstraintValue>),
Object(BTreeMap<String, ConstraintValue>),
Null,
}
impl ConstraintValue {
pub fn as_str(&self) -> Option<&str> {
match self {
ConstraintValue::String(s) => Some(s),
_ => None,
}
}
pub fn as_number(&self) -> Option<f64> {
match self {
ConstraintValue::Integer(i) => Some(*i as f64),
ConstraintValue::Float(f) => Some(*f),
_ => None,
}
}
pub fn as_list(&self) -> Option<&Vec<ConstraintValue>> {
match self {
ConstraintValue::List(l) => Some(l),
_ => None,
}
}
}
impl From<&str> for ConstraintValue {
fn from(s: &str) -> Self {
ConstraintValue::String(s.to_string())
}
}
impl From<String> for ConstraintValue {
fn from(s: String) -> Self {
ConstraintValue::String(s)
}
}
impl From<i64> for ConstraintValue {
fn from(n: i64) -> Self {
ConstraintValue::Integer(n)
}
}
impl From<i32> for ConstraintValue {
fn from(n: i32) -> Self {
ConstraintValue::Integer(n as i64)
}
}
impl From<f64> for ConstraintValue {
fn from(n: f64) -> Self {
ConstraintValue::Float(n)
}
}
impl From<bool> for ConstraintValue {
fn from(b: bool) -> Self {
ConstraintValue::Boolean(b)
}
}
impl<T: Into<ConstraintValue>> From<Vec<T>> for ConstraintValue {
fn from(v: Vec<T>) -> Self {
ConstraintValue::List(v.into_iter().map(Into::into).collect())
}
}
impl std::fmt::Display for ConstraintValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstraintValue::String(s) => write!(f, "{}", s),
ConstraintValue::Integer(i) => write!(f, "{}", i),
ConstraintValue::Float(n) => write!(f, "{}", n),
ConstraintValue::Boolean(b) => write!(f, "{}", b),
ConstraintValue::Null => write!(f, "null"),
ConstraintValue::List(l) => {
write!(f, "[")?;
for (i, v) in l.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", v)?;
}
write!(f, "]")
}
ConstraintValue::Object(m) => {
write!(f, "{{")?;
for (i, (k, v)) in m.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}: {}", k, v)?;
}
write!(f, "}}")
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct Wildcard;
impl Wildcard {
pub fn new() -> Self {
Self
}
pub fn matches(&self, _value: &ConstraintValue) -> Result<bool> {
Ok(true)
}
}
impl From<Wildcard> for Constraint {
fn from(w: Wildcard) -> Self {
Constraint::Wildcard(w)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Pattern {
pub pattern: String,
#[serde(skip)]
compiled: Option<GlobPattern>,
}
impl Pattern {
pub fn new(pattern: &str) -> Result<Self> {
let compiled =
GlobPattern::new(pattern).map_err(|e| Error::InvalidPattern(e.to_string()))?;
Ok(Self {
pattern: pattern.to_string(),
compiled: Some(compiled),
})
}
pub fn as_str(&self) -> &str {
&self.pattern
}
fn get_compiled(&self) -> Result<GlobPattern> {
if let Some(ref compiled) = self.compiled {
Ok(compiled.clone())
} else {
GlobPattern::new(&self.pattern).map_err(|e| Error::InvalidPattern(e.to_string()))
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
let compiled = self.get_compiled()?;
match value {
ConstraintValue::String(s) => Ok(compiled.matches(s)),
_ => Ok(false),
}
}
pub fn validate_attenuation(&self, child: &Pattern) -> Result<()> {
if self.pattern == child.pattern {
return Ok(());
}
let parent_type = self.pattern_type();
let child_type = child.pattern_type();
match (parent_type, child_type) {
(PatternType::Exact, _) => Err(Error::PatternExpanded {
parent: self.pattern.clone(),
child: child.pattern.clone(),
}),
(PatternType::Prefix(parent_prefix), PatternType::Prefix(child_prefix)) => {
if child_prefix.starts_with(parent_prefix) {
Ok(())
} else {
Err(Error::PatternExpanded {
parent: self.pattern.clone(),
child: child.pattern.clone(),
})
}
}
(PatternType::Prefix(parent_prefix), PatternType::Exact) => {
if child.pattern.starts_with(parent_prefix) {
Ok(())
} else {
Err(Error::PatternExpanded {
parent: self.pattern.clone(),
child: child.pattern.clone(),
})
}
}
(PatternType::Suffix(parent_suffix), PatternType::Suffix(child_suffix)) => {
if child_suffix.ends_with(parent_suffix) {
Ok(())
} else {
Err(Error::PatternExpanded {
parent: self.pattern.clone(),
child: child.pattern.clone(),
})
}
}
(PatternType::Suffix(parent_suffix), PatternType::Exact) => {
if child.pattern.ends_with(parent_suffix) {
Ok(())
} else {
Err(Error::PatternExpanded {
parent: self.pattern.clone(),
child: child.pattern.clone(),
})
}
}
(PatternType::Complex, _) | (_, PatternType::Complex) => Err(Error::PatternExpanded {
parent: self.pattern.clone(),
child: child.pattern.clone(),
}),
_ => Err(Error::PatternExpanded {
parent: self.pattern.clone(),
child: child.pattern.clone(),
}),
}
}
fn pattern_type(&self) -> PatternType<'_> {
let star_count = self.pattern.matches('*').count();
match star_count {
0 => PatternType::Exact,
1 => {
if self.pattern.ends_with('*') {
PatternType::Prefix(&self.pattern[..self.pattern.len() - 1])
} else if self.pattern.starts_with('*') {
PatternType::Suffix(&self.pattern[1..])
} else {
PatternType::Complex
}
}
_ => PatternType::Complex,
}
}
}
#[derive(Debug)]
enum PatternType<'a> {
Exact,
Prefix(&'a str),
Suffix(&'a str),
Complex,
}
impl From<Pattern> for Constraint {
fn from(p: Pattern) -> Self {
Constraint::Pattern(p)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegexConstraint {
pub pattern: String,
#[serde(skip)]
compiled: Option<RegexPattern>,
}
impl PartialEq for RegexConstraint {
fn eq(&self, other: &Self) -> bool {
self.pattern == other.pattern
}
}
impl RegexConstraint {
pub fn new(pattern: &str) -> Result<Self> {
let compiled = RegexPattern::new(pattern)
.map_err(|e| Error::InvalidPattern(format!("invalid regex: {}", e)))?;
Ok(Self {
pattern: pattern.to_string(),
compiled: Some(compiled),
})
}
fn get_compiled(&self) -> Result<RegexPattern> {
if let Some(ref compiled) = self.compiled {
Ok(compiled.clone())
} else {
RegexPattern::new(&self.pattern)
.map_err(|e| Error::InvalidPattern(format!("invalid regex: {}", e)))
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
let compiled = self.get_compiled()?;
match value {
ConstraintValue::String(s) => Ok(compiled.is_match(s)),
_ => Ok(false),
}
}
pub fn validate_attenuation(&self, child: &RegexConstraint) -> Result<()> {
if self.pattern == child.pattern {
return Ok(());
}
Err(Error::MonotonicityViolation(
"regex attenuation requires pattern match; use Exact for specific values".to_string(),
))
}
}
impl From<RegexConstraint> for Constraint {
fn from(r: RegexConstraint) -> Self {
Constraint::Regex(r)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Exact {
pub value: ConstraintValue,
}
impl Exact {
pub fn new(value: impl Into<ConstraintValue>) -> Self {
Self {
value: value.into(),
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
Ok(&self.value == value)
}
}
impl From<Exact> for Constraint {
fn from(e: Exact) -> Self {
Constraint::Exact(e)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OneOf {
pub values: Vec<ConstraintValue>,
}
impl OneOf {
pub fn new<S: Into<String>>(values: impl IntoIterator<Item = S>) -> Self {
Self {
values: values
.into_iter()
.map(|s| ConstraintValue::String(s.into()))
.collect(),
}
}
pub fn from_values(values: impl IntoIterator<Item = impl Into<ConstraintValue>>) -> Self {
Self {
values: values.into_iter().map(Into::into).collect(),
}
}
pub fn contains(&self, value: &ConstraintValue) -> bool {
self.values.contains(value)
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
Ok(self.contains(value))
}
pub fn validate_attenuation(&self, child: &OneOf) -> Result<()> {
for v in &child.values {
if !self.contains(v) {
return Err(Error::ValueNotInParentSet {
value: format!("{:?}", v),
});
}
}
Ok(())
}
}
impl From<OneOf> for Constraint {
fn from(o: OneOf) -> Self {
Constraint::OneOf(o)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct NotOneOf {
pub excluded: Vec<ConstraintValue>,
}
impl NotOneOf {
pub fn new(excluded: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
excluded: excluded
.into_iter()
.map(|s| ConstraintValue::String(s.into()))
.collect(),
}
}
pub fn from_values(excluded: impl IntoIterator<Item = impl Into<ConstraintValue>>) -> Self {
Self {
excluded: excluded.into_iter().map(Into::into).collect(),
}
}
pub fn is_excluded(&self, value: &ConstraintValue) -> bool {
self.excluded.contains(value)
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
Ok(!self.is_excluded(value))
}
pub fn validate_attenuation(&self, child: &NotOneOf) -> Result<()> {
for v in &self.excluded {
if !child.excluded.contains(v) {
return Err(Error::ExclusionRemoved {
value: format!("{:?}", v),
});
}
}
Ok(())
}
}
impl From<NotOneOf> for Constraint {
fn from(n: NotOneOf) -> Self {
Constraint::NotOneOf(n)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Range {
pub min: Option<f64>,
pub max: Option<f64>,
pub min_inclusive: bool,
pub max_inclusive: bool,
}
impl Range {
pub fn new(min: Option<f64>, max: Option<f64>) -> Result<Self> {
if let Some(m) = min {
if m.is_nan() {
return Err(Error::InvalidRange("min cannot be NaN".to_string()));
}
}
if let Some(m) = max {
if m.is_nan() {
return Err(Error::InvalidRange("max cannot be NaN".to_string()));
}
}
Ok(Self {
min,
max,
min_inclusive: true,
max_inclusive: true,
})
}
pub fn max(max: f64) -> Result<Self> {
Self::new(None, Some(max))
}
pub fn min(min: f64) -> Result<Self> {
Self::new(Some(min), None)
}
pub fn between(min: f64, max: f64) -> Result<Self> {
Self::new(Some(min), Some(max))
}
pub fn min_exclusive(mut self) -> Self {
self.min_inclusive = false;
self
}
pub fn max_exclusive(mut self) -> Self {
self.max_inclusive = false;
self
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
let n = match value.as_number() {
Some(n) => n,
None => return Ok(false),
};
if n.is_nan() {
return Ok(false);
}
let min_ok = match self.min {
None => true,
Some(min) if self.min_inclusive => n >= min,
Some(min) => n > min,
};
let max_ok = match self.max {
None => true,
Some(max) if self.max_inclusive => n <= max,
Some(max) => n < max,
};
Ok(min_ok && max_ok)
}
pub fn validate_attenuation(&self, child: &Range) -> Result<()> {
match (self.min, child.min) {
(Some(parent_min), Some(child_min)) => {
if child_min < parent_min {
return Err(Error::RangeExpanded {
bound: "min".to_string(),
parent_value: parent_min,
child_value: child_min,
});
}
if child_min == parent_min && !self.min_inclusive && child.min_inclusive {
return Err(Error::RangeInclusivityExpanded {
bound: "min".to_string(),
value: parent_min,
parent_inclusive: false,
child_inclusive: true,
});
}
}
(Some(parent_min), None) => {
return Err(Error::RangeExpanded {
bound: "min".to_string(),
parent_value: parent_min,
child_value: f64::NEG_INFINITY,
});
}
_ => {}
}
match (self.max, child.max) {
(Some(parent_max), Some(child_max)) => {
if child_max > parent_max {
return Err(Error::RangeExpanded {
bound: "max".to_string(),
parent_value: parent_max,
child_value: child_max,
});
}
if child_max == parent_max && !self.max_inclusive && child.max_inclusive {
return Err(Error::RangeInclusivityExpanded {
bound: "max".to_string(),
value: parent_max,
parent_inclusive: false,
child_inclusive: true,
});
}
}
(Some(parent_max), None) => {
return Err(Error::RangeExpanded {
bound: "max".to_string(),
parent_value: parent_max,
child_value: f64::INFINITY,
});
}
_ => {}
}
Ok(())
}
pub fn contains_value(&self, value: f64) -> bool {
if value.is_nan() {
return false;
}
let min_ok = match self.min {
None => true,
Some(min) if self.min_inclusive => value >= min,
Some(min) => value > min,
};
let max_ok = match self.max {
None => true,
Some(max) if self.max_inclusive => value <= max,
Some(max) => value < max,
};
min_ok && max_ok
}
}
impl From<Range> for Constraint {
fn from(r: Range) -> Self {
Constraint::Range(r)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Cidr {
pub network: IpNetwork,
pub cidr_string: String,
}
impl Serialize for Cidr {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.cidr_string.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Cidr {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Cidr::new(&s).map_err(serde::de::Error::custom)
}
}
impl Cidr {
pub fn new(cidr: &str) -> Result<Self> {
let network = cidr.parse::<IpNetwork>().map_err(|e| Error::InvalidCidr {
cidr: cidr.to_string(),
reason: e.to_string(),
})?;
Ok(Self {
network,
cidr_string: cidr.to_string(),
})
}
pub fn contains_ip(&self, ip_str: &str) -> Result<bool> {
let ip = ip_str
.parse::<IpAddr>()
.map_err(|e| Error::InvalidIpAddress {
ip: ip_str.to_string(),
reason: e.to_string(),
})?;
Ok(self.network.contains(ip))
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
match value.as_str() {
Some(ip_str) => self.contains_ip(ip_str),
None => Ok(false),
}
}
pub fn validate_attenuation(&self, child: &Cidr) -> Result<()> {
let parent_net = self.network;
let child_net = child.network;
if !parent_net.contains(child_net.network()) {
return Err(Error::CidrNotSubnet {
parent: self.cidr_string.clone(),
child: child.cidr_string.clone(),
});
}
if !parent_net.contains(child_net.broadcast()) {
return Err(Error::CidrNotSubnet {
parent: self.cidr_string.clone(),
child: child.cidr_string.clone(),
});
}
Ok(())
}
}
impl std::fmt::Display for Cidr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cidr({})", self.cidr_string)
}
}
impl From<Cidr> for Constraint {
fn from(c: Cidr) -> Self {
Constraint::Cidr(c)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct UrlPattern {
pub pattern: String,
pub schemes: Vec<String>,
pub host_pattern: Option<String>,
pub port: Option<u16>,
pub path_pattern: Option<String>,
}
impl Serialize for UrlPattern {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.pattern.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for UrlPattern {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
UrlPattern::new(&s).map_err(serde::de::Error::custom)
}
}
impl UrlPattern {
pub fn new(pattern: &str) -> Result<Self> {
const HOST_PLACEHOLDER: &str = "__tenuo_host_wildcard__";
const PATH_PLACEHOLDER: &str = "__tenuo_path_wildcard__";
if pattern.contains(HOST_PLACEHOLDER) || pattern.contains(PATH_PLACEHOLDER) {
return Err(Error::InvalidUrl {
url: pattern.to_string(),
reason: "pattern contains reserved internal sequence".to_string(),
});
}
let (schemes, url_str) = if pattern.starts_with("*://") {
(vec![], pattern.replacen("*://", "https://", 1))
} else {
let scheme_end = pattern.find("://").ok_or_else(|| Error::InvalidUrl {
url: pattern.to_string(),
reason: "missing scheme (expected 'scheme://')".to_string(),
})?;
let scheme = &pattern[..scheme_end];
(vec![scheme.to_lowercase()], pattern.to_string())
};
let parse_str = url_str
.replace("*.", &format!("{}.", HOST_PLACEHOLDER))
.replace("/*", &format!("/{}", PATH_PLACEHOLDER));
let parsed = url::Url::parse(&parse_str).map_err(|e| Error::InvalidUrl {
url: pattern.to_string(),
reason: e.to_string(),
})?;
let host_pattern = parsed
.host_str()
.map(|h| h.replace(&format!("{}.", HOST_PLACEHOLDER), "*."));
let port = parsed.port();
let path = parsed.path();
let path_pattern = if path.is_empty() || path == "/" {
None
} else {
Some(path.replace(PATH_PLACEHOLDER, "*"))
};
Ok(Self {
pattern: pattern.to_string(),
schemes,
host_pattern,
port,
path_pattern,
})
}
pub fn matches_url(&self, url_str: &str) -> Result<bool> {
let parsed = url::Url::parse(url_str).map_err(|e| Error::InvalidUrl {
url: url_str.to_string(),
reason: e.to_string(),
})?;
if !self.schemes.is_empty() && !self.schemes.contains(&parsed.scheme().to_lowercase()) {
return Ok(false);
}
if let Some(host_pattern) = &self.host_pattern {
let host = parsed.host_str().unwrap_or("");
if !Self::matches_host_pattern(host_pattern, host) {
return Ok(false);
}
}
if let Some(required_port) = self.port {
let actual_port = parsed.port().unwrap_or_else(|| match parsed.scheme() {
"https" => 443,
"http" => 80,
_ => 0,
});
if actual_port != required_port {
return Ok(false);
}
}
if let Some(path_pattern) = &self.path_pattern {
let path = parsed.path();
if !Self::matches_path_pattern(path_pattern, path) {
return Ok(false);
}
}
Ok(true)
}
fn matches_host_pattern(pattern: &str, host: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(suffix) = pattern.strip_prefix("*.") {
host == suffix || host.ends_with(&format!(".{}", suffix))
} else {
pattern.eq_ignore_ascii_case(host)
}
}
fn matches_path_pattern(pattern: &str, path: &str) -> bool {
if pattern == "*" || pattern == "/*" {
return true;
}
if let Ok(glob) = GlobPattern::new(pattern) {
glob.matches(path)
} else {
pattern == path
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
match value.as_str() {
Some(url_str) => self.matches_url(url_str),
None => Ok(false),
}
}
pub fn validate_attenuation(&self, child: &UrlPattern) -> Result<()> {
if !self.schemes.is_empty() {
if child.schemes.is_empty() {
return Err(Error::UrlSchemeExpanded {
parent: self.schemes.join(","),
child: "*".to_string(),
});
}
for child_scheme in &child.schemes {
if !self.schemes.contains(child_scheme) {
return Err(Error::UrlSchemeExpanded {
parent: self.schemes.join(","),
child: child_scheme.clone(),
});
}
}
}
if let Some(parent_host) = &self.host_pattern {
if let Some(child_host) = &child.host_pattern {
if !Self::is_host_subset(parent_host, child_host) {
return Err(Error::UrlHostExpanded {
parent: parent_host.clone(),
child: child_host.clone(),
});
}
}
else {
return Err(Error::UrlHostExpanded {
parent: parent_host.clone(),
child: "*".to_string(),
});
}
}
if let Some(parent_port) = self.port {
match child.port {
Some(child_port) if child_port != parent_port => {
return Err(Error::UrlPortExpanded {
parent: Some(parent_port),
child: Some(child_port),
});
}
None => {
return Err(Error::UrlPortExpanded {
parent: Some(parent_port),
child: None,
});
}
_ => {}
}
}
if let Some(parent_path) = &self.path_pattern {
if let Some(child_path) = &child.path_pattern {
if !Self::is_path_subset(parent_path, child_path) {
return Err(Error::UrlPathExpanded {
parent: parent_path.clone(),
child: child_path.clone(),
});
}
}
else {
return Err(Error::UrlPathExpanded {
parent: parent_path.clone(),
child: "*".to_string(),
});
}
}
Ok(())
}
fn is_host_subset(parent: &str, child: &str) -> bool {
if parent == "*" {
return true; }
if let Some(parent_suffix) = parent.strip_prefix("*.") {
if child == parent_suffix {
return true;
}
if child.ends_with(&format!(".{}", parent_suffix)) {
return true;
}
if let Some(child_suffix) = child.strip_prefix("*.") {
return child_suffix.ends_with(&format!(".{}", parent_suffix))
|| child_suffix == parent_suffix;
}
false
} else {
parent.eq_ignore_ascii_case(child)
}
}
fn is_path_subset(parent: &str, child: &str) -> bool {
if parent == "*" || parent == "/*" {
return true; }
if parent.ends_with("/*") {
let parent_prefix = &parent[..parent.len() - 1]; if child.starts_with(parent_prefix) {
return true;
}
if child.ends_with("/*") {
let child_prefix = &child[..child.len() - 1];
return child_prefix.starts_with(parent_prefix);
}
return false;
}
if parent == child {
return true;
}
if child.starts_with(parent) && child[parent.len()..].starts_with('/') {
return true;
}
false
}
}
impl std::fmt::Display for UrlPattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "UrlPattern({})", self.pattern)
}
}
impl From<UrlPattern> for Constraint {
fn from(u: UrlPattern) -> Self {
Constraint::UrlPattern(u)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Subpath {
pub root: String,
#[serde(default = "default_true")]
pub case_sensitive: bool,
#[serde(default = "default_true")]
pub allow_equal: bool,
}
fn default_true() -> bool {
true
}
impl Subpath {
pub fn new(root: impl Into<String>) -> Result<Self> {
let root = root.into();
if !Self::is_absolute(&root) {
return Err(Error::InvalidPath {
path: root,
reason: "root must be an absolute path".to_string(),
});
}
Ok(Self {
root: Self::normalize_path(&root),
case_sensitive: true,
allow_equal: true,
})
}
pub fn with_options(
root: impl Into<String>,
case_sensitive: bool,
allow_equal: bool,
) -> Result<Self> {
let root = root.into();
if !Self::is_absolute(&root) {
return Err(Error::InvalidPath {
path: root,
reason: "root must be an absolute path".to_string(),
});
}
let mut normalized = Self::normalize_path(&root);
if !case_sensitive {
normalized = normalized.to_lowercase();
}
Ok(Self {
root: normalized,
case_sensitive,
allow_equal,
})
}
fn is_absolute(path: &str) -> bool {
if path.starts_with('/') {
return true;
}
if path.len() >= 3 {
let bytes = path.as_bytes();
if bytes[0].is_ascii_alphabetic()
&& bytes[1] == b':'
&& (bytes[2] == b'\\' || bytes[2] == b'/')
{
return true;
}
}
false
}
fn normalize_path(path: &str) -> String {
let mut components: Vec<&str> = Vec::new();
let (prefix, rest) = if let Some(stripped) = path.strip_prefix('/') {
("/", stripped)
} else if path.len() >= 2 && path.as_bytes()[1] == b':' {
let sep_pos =
if path.len() > 2 && (path.as_bytes()[2] == b'\\' || path.as_bytes()[2] == b'/') {
3
} else {
2
};
(&path[..sep_pos], &path[sep_pos..])
} else {
("", path)
};
for component in rest.split(['/', '\\']) {
match component {
"" | "." => continue, ".." => {
components.pop();
}
_ => components.push(component),
}
}
let mut result = prefix.to_string();
for (i, component) in components.iter().enumerate() {
if (i > 0 || !prefix.is_empty()) && !result.ends_with('/') && !result.ends_with('\\') {
result.push('/');
}
result.push_str(component);
}
if result.is_empty() {
result = prefix.to_string();
}
result
}
pub fn contains_path(&self, path: &str) -> Result<bool> {
if path.contains('\0') {
return Ok(false);
}
if !Self::is_absolute(path) {
return Ok(false);
}
let mut normalized = Self::normalize_path(path);
if !self.case_sensitive {
normalized = normalized.to_lowercase();
}
if self.allow_equal && normalized == self.root {
return Ok(true);
}
let root_with_sep = format!("{}/", self.root.trim_end_matches('/'));
Ok(normalized.starts_with(&root_with_sep))
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
match value.as_str() {
Some(path_str) => self.contains_path(path_str),
None => Ok(false),
}
}
pub fn validate_attenuation(&self, child: &Subpath) -> Result<()> {
if !self.contains_path(&child.root)? {
return Err(Error::PathNotContained {
path: child.root.clone(),
root: self.root.clone(),
});
}
if !self.case_sensitive && child.case_sensitive {
return Err(Error::MonotonicityViolation(
"cannot attenuate case-insensitive to case-sensitive".to_string(),
));
}
if !self.allow_equal && child.allow_equal {
return Err(Error::MonotonicityViolation(
"cannot attenuate allow_equal=false to allow_equal=true".to_string(),
));
}
Ok(())
}
}
impl std::fmt::Display for Subpath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.case_sensitive && self.allow_equal {
write!(f, "Subpath({})", self.root)
} else {
write!(
f,
"Subpath({}, case_sensitive={}, allow_equal={})",
self.root, self.case_sensitive, self.allow_equal
)
}
}
}
impl From<Subpath> for Constraint {
fn from(s: Subpath) -> Self {
Constraint::Subpath(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UrlSafe {
#[serde(default = "default_schemes")]
pub schemes: Vec<String>,
#[serde(default)]
pub allow_domains: Option<Vec<String>>,
#[serde(default)]
pub deny_domains: Option<Vec<String>>,
#[serde(default)]
pub allow_ports: Option<Vec<u16>>,
#[serde(default = "default_true")]
pub block_private: bool,
#[serde(default = "default_true")]
pub block_loopback: bool,
#[serde(default = "default_true")]
pub block_metadata: bool,
#[serde(default = "default_true")]
pub block_reserved: bool,
#[serde(default)]
pub block_internal_tlds: bool,
}
fn default_schemes() -> Vec<String> {
vec!["http".to_string(), "https".to_string()]
}
const METADATA_HOSTS: &[&str] = &[
"169.254.169.254", "metadata.google.internal", "metadata.goog", "100.100.100.200", ];
const INTERNAL_TLDS: &[&str] = &[
".internal",
".local",
".localhost",
".lan",
".corp",
".home",
".svc", ".default", ];
impl UrlSafe {
pub fn new() -> Self {
Self {
schemes: default_schemes(),
allow_domains: None,
deny_domains: None,
allow_ports: None,
block_private: true,
block_loopback: true,
block_metadata: true,
block_reserved: true,
block_internal_tlds: false,
}
}
pub fn with_domains(domains: Vec<impl Into<String>>) -> Self {
Self {
schemes: default_schemes(),
allow_domains: Some(domains.into_iter().map(Into::into).collect()),
deny_domains: None,
allow_ports: None,
block_private: true,
block_loopback: true,
block_metadata: true,
block_reserved: true,
block_internal_tlds: false,
}
}
pub fn is_safe(&self, url: &str) -> Result<bool> {
use url::Url;
if url.contains('\0') {
return Ok(false);
}
let parsed = match Url::parse(url) {
Ok(u) => u,
Err(_) => return Ok(false),
};
let scheme = parsed.scheme().to_lowercase();
if !self.schemes.iter().any(|s| s.to_lowercase() == scheme) {
return Ok(false);
}
let host = match parsed.host_str() {
Some(h) if !h.is_empty() => h,
_ => return Ok(false), };
let host = urlencoding_decode(host).to_lowercase();
if let Some(ref allowed_ports) = self.allow_ports {
if let Some(port) = parsed.port_or_known_default() {
if !allowed_ports.contains(&port) {
return Ok(false);
}
}
}
if self.block_loopback && (host == "localhost" || host == "localhost.localdomain") {
return Ok(false);
}
if self.block_internal_tlds {
for tld in INTERNAL_TLDS {
if host.ends_with(tld) || host == tld[1..] {
return Ok(false);
}
}
}
if self.block_metadata && METADATA_HOSTS.contains(&host.as_str()) {
return Ok(false);
}
if let Some(ip) = self.parse_ip(&host) {
if !self.check_ip_safe(&ip) {
return Ok(false);
}
} else {
if self.looks_like_ambiguous_ip(&host) {
return Ok(false);
}
if let Some(ref domains) = self.allow_domains {
if !self.check_domain_allowed(&host, domains) {
return Ok(false);
}
}
}
if let Some(ref denied) = self.deny_domains {
if self.check_domain_allowed(&host, denied) {
return Ok(false);
}
}
Ok(true)
}
fn parse_ip(&self, host: &str) -> Option<IpAddr> {
let host = host.trim_start_matches('[').trim_end_matches(']');
if let Ok(ip) = host.parse::<IpAddr>() {
return Some(ip);
}
if host.chars().all(|c| c.is_ascii_digit()) {
if let Ok(int_val) = host.parse::<u32>() {
return Some(IpAddr::V4(std::net::Ipv4Addr::from(int_val)));
}
}
if host.to_lowercase().starts_with("0x") {
if let Ok(int_val) = u32::from_str_radix(&host[2..], 16) {
return Some(IpAddr::V4(std::net::Ipv4Addr::from(int_val)));
}
}
if host.starts_with('0') && host.contains('.') {
let parts: Vec<&str> = host.split('.').collect();
if parts.len() == 4 {
let all_numeric = parts
.iter()
.all(|p| !p.is_empty() && p.chars().all(|c| c.is_ascii_digit()));
if all_numeric {
let has_leading_zeros = parts.iter().any(|p| p.len() > 1 && p.starts_with('0'));
if has_leading_zeros {
let is_clear_octal = parts
.iter()
.all(|p| p.chars().all(|c| ('0'..='7').contains(&c)));
if is_clear_octal {
let mut octets = [0u8; 4];
let mut octal_valid = true;
let mut decimal_same = true;
for (i, part) in parts.iter().enumerate() {
let octal_val = if part.starts_with('0') && part.len() > 1 {
u8::from_str_radix(part, 8).ok()
} else {
part.parse::<u8>().ok()
};
let decimal_val = part.parse::<u8>().ok();
if let Some(ov) = octal_val {
octets[i] = ov;
if decimal_val != Some(ov) {
decimal_same = false;
}
} else {
octal_valid = false;
break;
}
}
if octal_valid {
if !decimal_same {
return None;
}
return Some(IpAddr::V4(std::net::Ipv4Addr::from(octets)));
}
}
return None;
}
}
}
}
None
}
fn check_ip_safe(&self, ip: &IpAddr) -> bool {
let ip = match ip {
IpAddr::V6(v6) => {
if let Some(mapped) = v6.to_ipv4_mapped() {
IpAddr::V4(mapped)
}
else {
let segments = v6.segments();
if segments[0..6].iter().all(|&s| s == 0) {
let octets = v6.octets();
let ipv4 =
std::net::Ipv4Addr::new(octets[12], octets[13], octets[14], octets[15]);
if ipv4.octets() != [0, 0, 0, 0] && ipv4.octets() != [0, 0, 0, 1] {
IpAddr::V4(ipv4)
} else {
*ip
}
} else {
*ip
}
}
}
_ => *ip,
};
if self.block_loopback && ip.is_loopback() {
return false;
}
if self.block_private {
if let IpAddr::V4(v4) = ip {
let octets = v4.octets();
if octets[0] == 10 {
return false;
}
if octets[0] == 172 && (16..=31).contains(&octets[1]) {
return false;
}
if octets[0] == 192 && octets[1] == 168 {
return false;
}
}
if let IpAddr::V6(v6) = ip {
let segments = v6.segments();
if (segments[0] & 0xfe00) == 0xfc00 {
return false;
}
if (segments[0] & 0xffc0) == 0xfe80 {
return false;
}
}
}
if self.block_reserved {
if let IpAddr::V4(v4) = ip {
let octets = v4.octets();
if octets[0] == 0 {
return false;
}
if (224..=239).contains(&octets[0]) {
return false;
}
if octets == [255, 255, 255, 255] {
return false;
}
}
}
if self.block_metadata {
if let IpAddr::V4(v4) = ip {
let octets = v4.octets();
if octets[0] == 169 && octets[1] == 254 {
return false;
}
}
}
true
}
fn check_domain_allowed(&self, host: &str, domains: &[String]) -> bool {
for pattern in domains {
let pattern = pattern.to_lowercase();
if pattern.starts_with("*.") {
let suffix = &pattern[1..]; if host.ends_with(suffix) || host == &pattern[2..] {
return true;
}
} else if host == pattern {
return true;
}
}
false
}
fn looks_like_ambiguous_ip(&self, host: &str) -> bool {
let parts: Vec<&str> = host.split('.').collect();
if parts.len() != 4 {
return false;
}
let all_numeric = parts
.iter()
.all(|p| !p.is_empty() && p.chars().all(|c| c.is_ascii_digit()));
if !all_numeric {
return false;
}
for part in &parts {
if part.len() > 1 && part.starts_with('0') {
return true;
}
}
false
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
match value.as_str() {
Some(url_str) => self.is_safe(url_str),
None => Ok(false),
}
}
pub fn validate_attenuation(&self, child: &UrlSafe) -> Result<()> {
for scheme in &child.schemes {
if !self
.schemes
.iter()
.any(|s| s.to_lowercase() == scheme.to_lowercase())
{
return Err(Error::MonotonicityViolation(format!(
"child scheme '{}' not in parent schemes",
scheme
)));
}
}
if self.block_private && !child.block_private {
return Err(Error::MonotonicityViolation(
"cannot disable block_private".to_string(),
));
}
if self.block_loopback && !child.block_loopback {
return Err(Error::MonotonicityViolation(
"cannot disable block_loopback".to_string(),
));
}
if self.block_metadata && !child.block_metadata {
return Err(Error::MonotonicityViolation(
"cannot disable block_metadata".to_string(),
));
}
if self.block_reserved && !child.block_reserved {
return Err(Error::MonotonicityViolation(
"cannot disable block_reserved".to_string(),
));
}
if self.block_internal_tlds && !child.block_internal_tlds {
return Err(Error::MonotonicityViolation(
"cannot disable block_internal_tlds".to_string(),
));
}
if let Some(ref parent_domains) = self.allow_domains {
match &child.allow_domains {
None => {
return Err(Error::MonotonicityViolation(
"child must have domain allowlist if parent does".to_string(),
));
}
Some(child_domains) => {
for cd in child_domains {
if !parent_domains.iter().any(|pd| {
let pd = pd.to_lowercase();
let cd = cd.to_lowercase();
if pd == cd {
return true;
}
if pd.starts_with("*.") && cd.ends_with(&pd[1..]) {
return true;
}
false
}) {
return Err(Error::MonotonicityViolation(format!(
"child domain '{}' not covered by parent allowlist",
cd
)));
}
}
}
}
}
if let Some(ref parent_denied) = self.deny_domains {
match &child.deny_domains {
None => {
return Err(Error::MonotonicityViolation(
"child must have deny_domains if parent does".to_string(),
));
}
Some(child_denied) => {
for pd in parent_denied {
let pd_lower = pd.to_lowercase();
if !child_denied.iter().any(|cd| cd.to_lowercase() == pd_lower) {
return Err(Error::MonotonicityViolation(format!(
"child removes denied domain '{}' from parent",
pd
)));
}
}
}
}
}
Ok(())
}
}
impl Default for UrlSafe {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for UrlSafe {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut opts = Vec::new();
if self.schemes != default_schemes() {
opts.push(format!("schemes={:?}", self.schemes));
}
if let Some(ref domains) = self.allow_domains {
opts.push(format!("allow_domains={:?}", domains));
}
if let Some(ref domains) = self.deny_domains {
opts.push(format!("deny_domains={:?}", domains));
}
if !self.block_private {
opts.push("block_private=false".to_string());
}
if !self.block_loopback {
opts.push("block_loopback=false".to_string());
}
if !self.block_metadata {
opts.push("block_metadata=false".to_string());
}
if self.block_internal_tlds {
opts.push("block_internal_tlds=true".to_string());
}
if opts.is_empty() {
write!(f, "UrlSafe()")
} else {
write!(f, "UrlSafe({})", opts.join(", "))
}
}
}
impl From<UrlSafe> for Constraint {
fn from(u: UrlSafe) -> Self {
Constraint::UrlSafe(u)
}
}
const SHELL_DANGEROUS_CHARS: &[char] = &[
'\0', '\n', '\r', '\x0b', '\x0c', '\x07', '\x08', '\x7f', '$', '`', '|', '&', ';', '<', '>', '(', ')', ];
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Shlex {
pub allow: Vec<String>,
}
impl Shlex {
pub fn new(allow: Vec<impl Into<String>>) -> Self {
Self {
allow: allow.into_iter().map(Into::into).collect(),
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
let cmd = match value.as_str() {
Some(s) => s,
None => return Ok(false),
};
if cmd.is_empty() {
return Ok(false);
}
for ch in cmd.chars() {
if SHELL_DANGEROUS_CHARS.contains(&ch) {
return Ok(false);
}
}
let tokens: Vec<&str> = cmd.split_whitespace().collect();
if tokens.is_empty() {
return Ok(false);
}
let binary = tokens[0];
if binary.contains("..") {
return Ok(false);
}
let bin_name = binary.rsplit('/').next().unwrap_or(binary);
if !self.allow.iter().any(|a| a == binary || a == bin_name) {
return Ok(false);
}
Ok(true)
}
pub fn validate_attenuation(&self, child: &Shlex) -> Result<()> {
for bin in &child.allow {
if !self.allow.contains(bin) {
return Err(Error::MonotonicityViolation(format!(
"child allows binary '{}' not in parent allowlist",
bin
)));
}
}
Ok(())
}
}
impl std::fmt::Display for Shlex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Shlex(allow={:?})", self.allow)
}
}
impl From<Shlex> for Constraint {
fn from(s: Shlex) -> Self {
Constraint::Shlex(s)
}
}
fn urlencoding_decode(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '%' {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() == 2 {
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
continue;
}
}
result.push('%');
result.push_str(&hex);
} else {
result.push(c);
}
}
result
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Contains {
pub required: Vec<ConstraintValue>,
}
impl Contains {
pub fn new(required: impl IntoIterator<Item = impl Into<ConstraintValue>>) -> Self {
Self {
required: required.into_iter().map(Into::into).collect(),
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
let list = match value.as_list() {
Some(l) => l,
None => return Ok(false),
};
Ok(self.required.iter().all(|r| list.contains(r)))
}
pub fn validate_attenuation(&self, child: &Contains) -> Result<()> {
for v in &self.required {
if !child.required.contains(v) {
return Err(Error::RequiredValueRemoved {
value: format!("{:?}", v),
});
}
}
Ok(())
}
}
impl From<Contains> for Constraint {
fn from(c: Contains) -> Self {
Constraint::Contains(c)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Subset {
pub allowed: Vec<ConstraintValue>,
}
impl Subset {
pub fn new(allowed: impl IntoIterator<Item = impl Into<ConstraintValue>>) -> Self {
Self {
allowed: allowed.into_iter().map(Into::into).collect(),
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
let list = match value.as_list() {
Some(l) => l,
None => return Ok(false),
};
Ok(list.iter().all(|v| self.allowed.contains(v)))
}
pub fn validate_attenuation(&self, child: &Subset) -> Result<()> {
for v in &child.allowed {
if !self.allowed.contains(v) {
return Err(Error::MonotonicityViolation(format!(
"child allows {:?} which parent does not allow",
v
)));
}
}
Ok(())
}
}
impl From<Subset> for Constraint {
fn from(s: Subset) -> Self {
Constraint::Subset(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct All {
pub constraints: Vec<Constraint>,
}
impl All {
pub fn new(constraints: impl IntoIterator<Item = Constraint>) -> Self {
Self {
constraints: constraints.into_iter().collect(),
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
for c in &self.constraints {
if !c.matches(value)? {
return Ok(false);
}
}
Ok(true)
}
pub fn validate_attenuation(&self, child: &All) -> Result<()> {
for parent_c in &self.constraints {
let found = child
.constraints
.iter()
.any(|child_c| parent_c.validate_attenuation(child_c).is_ok());
if !found {
return Err(Error::MonotonicityViolation(
"child All must include all parent constraints".to_string(),
));
}
}
Ok(())
}
}
impl From<All> for Constraint {
fn from(a: All) -> Self {
Constraint::All(a)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Any {
pub constraints: Vec<Constraint>,
}
impl Any {
pub fn new(constraints: impl IntoIterator<Item = Constraint>) -> Self {
Self {
constraints: constraints.into_iter().collect(),
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
for c in &self.constraints {
if c.matches(value)? {
return Ok(true);
}
}
Ok(false)
}
}
impl From<Any> for Constraint {
fn from(a: Any) -> Self {
Constraint::Any(a)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Not {
pub constraint: Box<Constraint>,
}
impl Not {
pub fn new(constraint: Constraint) -> Self {
Self {
constraint: Box::new(constraint),
}
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
Ok(!self.constraint.matches(value)?)
}
}
impl From<Not> for Constraint {
fn from(n: Not) -> Self {
Constraint::Not(n)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CelConstraint {
pub expression: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_expression: Option<String>,
}
impl CelConstraint {
pub fn new(expression: impl Into<String>) -> Self {
Self {
expression: expression.into(),
parent_expression: None,
}
}
pub fn attenuate(parent: &CelConstraint, additional_predicate: &str) -> Self {
Self {
expression: format!("({}) && ({})", parent.expression, additional_predicate),
parent_expression: Some(parent.expression.clone()),
}
}
pub fn validate(&self) -> Result<()> {
crate::cel::compile(&self.expression)?;
Ok(())
}
pub fn matches(&self, value: &ConstraintValue) -> Result<bool> {
crate::cel::evaluate_with_value_context(&self.expression, value)
}
pub fn matches_with_context(
&self,
value: &ConstraintValue,
context: &HashMap<String, ConstraintValue>,
) -> Result<bool> {
crate::cel::evaluate(&self.expression, value, context)
}
pub fn validate_attenuation(&self, child: &CelConstraint) -> Result<()> {
if normalize_cel_whitespace(&child.expression) == normalize_cel_whitespace(&self.expression)
{
return Ok(());
}
let child_normalized = normalize_cel_whitespace(&child.expression);
let parent_normalized = normalize_cel_whitespace(&self.expression);
let expected_prefix = format!("({})&&", parent_normalized);
if !child_normalized.starts_with(&expected_prefix) {
return Err(Error::MonotonicityViolation(format!(
"child CEL must be '({}) && (<predicate>)', got '{}'",
self.expression, child.expression
)));
}
let remainder = &child_normalized[expected_prefix.len()..];
if !has_balanced_outer_parens(remainder) {
return Err(Error::MonotonicityViolation(format!(
"child CEL predicate must be parenthesized: '({}) && (<predicate>)', got '({}) && {}'",
self.expression, self.expression, remainder
)));
}
child.validate()?;
Ok(())
}
}
fn normalize_cel_whitespace(expr: &str) -> String {
expr.split_whitespace().collect::<Vec<_>>().join("")
}
fn has_balanced_outer_parens(s: &str) -> bool {
if !s.starts_with('(') || !s.ends_with(')') {
return false;
}
let mut depth = 0i32;
for (i, ch) in s.char_indices() {
match ch {
'(' => depth += 1,
')' => {
depth -= 1;
if depth == 0 && i != s.len() - 1 {
return false;
}
}
_ => {}
}
}
depth == 0
}
impl From<CelConstraint> for Constraint {
fn from(c: CelConstraint) -> Self {
Constraint::Cel(c)
}
}
fn is_false(b: &bool) -> bool {
!*b
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ConstraintSet {
constraints: BTreeMap<String, Constraint>,
#[serde(default, skip_serializing_if = "is_false")]
allow_unknown: bool,
}
impl ConstraintSet {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, field: impl Into<String>, constraint: impl Into<Constraint>) {
self.constraints.insert(field.into(), constraint.into());
}
pub fn get(&self, field: &str) -> Option<&Constraint> {
self.constraints.get(field)
}
pub fn allow_unknown(&self) -> bool {
self.allow_unknown
}
pub fn set_allow_unknown(&mut self, allow: bool) {
self.allow_unknown = allow;
}
pub fn validate_depth(&self) -> Result<()> {
for constraint in self.constraints.values() {
constraint.validate_depth()?;
}
Ok(())
}
pub fn matches(&self, args: &HashMap<String, ConstraintValue>) -> Result<()> {
if !self.constraints.is_empty() && !self.allow_unknown {
for key in args.keys() {
if !self.constraints.contains_key(key) {
return Err(Error::ConstraintNotSatisfied {
field: key.clone(),
reason: "unknown field not allowed (zero-trust mode)".to_string(),
});
}
}
}
for (field, constraint) in &self.constraints {
let value = args
.get(field)
.ok_or_else(|| Error::ConstraintNotSatisfied {
field: field.clone(),
reason: "missing required argument".to_string(),
})?;
if !constraint.matches(value)? {
return Err(Error::ConstraintNotSatisfied {
field: field.clone(),
reason: "value does not match constraint".to_string(),
});
}
}
Ok(())
}
pub fn validate_attenuation(&self, child: &ConstraintSet) -> Result<()> {
if !self.allow_unknown && child.allow_unknown {
return Err(Error::MonotonicityViolation(
"child cannot enable allow_unknown when parent has it disabled".to_string(),
));
}
for (field, parent_constraint) in &self.constraints {
let child_constraint = child.constraints.get(field).ok_or_else(|| {
Error::MonotonicityViolation(format!(
"child is missing constraint for field '{}' that parent has",
field
))
})?;
parent_constraint.validate_attenuation(child_constraint)?;
}
if !self.constraints.is_empty() {
for key in child.constraints.keys() {
if !self.constraints.contains_key(key) {
return Err(Error::MonotonicityViolation(format!(
"child adds argument key '{}' not present in parent's \
non-empty constraint map (keyset identity, I4)",
key
)));
}
}
}
Ok(())
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &Constraint)> {
self.constraints.iter()
}
pub fn is_empty(&self) -> bool {
self.constraints.is_empty()
}
pub fn len(&self) -> usize {
self.constraints.len()
}
}
impl FromIterator<(String, Constraint)> for ConstraintSet {
fn from_iter<T: IntoIterator<Item = (String, Constraint)>>(iter: T) -> Self {
Self {
constraints: iter.into_iter().collect(),
allow_unknown: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_suffix_wildcard() {
let pattern = Pattern::new("staging-*").unwrap();
assert!(pattern.matches(&"staging-web".into()).unwrap());
assert!(pattern.matches(&"staging-api".into()).unwrap());
assert!(pattern.matches(&"staging-".into()).unwrap()); assert!(!pattern.matches(&"prod-web".into()).unwrap());
assert!(!pattern.matches(&"Staging-web".into()).unwrap()); }
#[test]
fn test_pattern_prefix_wildcard() {
let pattern = Pattern::new("*@company.com").unwrap();
assert!(pattern.matches(&"cfo@company.com".into()).unwrap());
assert!(pattern.matches(&"alice@company.com".into()).unwrap());
assert!(pattern.matches(&"@company.com".into()).unwrap()); assert!(!pattern.matches(&"hacker@evil.com".into()).unwrap());
assert!(!pattern.matches(&"cfo@company.com.evil.com".into()).unwrap());
}
#[test]
fn test_pattern_middle_wildcard() {
let pattern = Pattern::new("/data/*/file.txt").unwrap();
assert!(pattern.matches(&"/data/reports/file.txt".into()).unwrap());
assert!(pattern.matches(&"/data/x/file.txt".into()).unwrap());
assert!(!pattern.matches(&"/data/reports/other.txt".into()).unwrap());
assert!(!pattern.matches(&"/data/file.txt".into()).unwrap()); }
#[test]
fn test_pattern_multiple_wildcards() {
let pattern = Pattern::new("/*/reports/*.pdf").unwrap();
assert!(pattern.matches(&"/data/reports/q3.pdf".into()).unwrap());
assert!(pattern.matches(&"/home/reports/annual.pdf".into()).unwrap());
assert!(!pattern.matches(&"/data/reports/q3.txt".into()).unwrap());
assert!(!pattern.matches(&"/data/other/q3.pdf".into()).unwrap());
}
#[test]
fn test_pattern_bidirectional_wildcard() {
let pattern = Pattern::new("*-prod-*").unwrap();
assert!(pattern.matches(&"db-prod-primary".into()).unwrap());
assert!(pattern.matches(&"cache-prod-replica".into()).unwrap());
assert!(pattern.matches(&"-prod-".into()).unwrap()); assert!(!pattern.matches(&"db-staging-primary".into()).unwrap());
assert!(!pattern.matches(&"prod-only".into()).unwrap());
let pattern = Pattern::new("*safe*").unwrap();
assert!(pattern.matches(&"unsafe".into()).unwrap());
assert!(pattern.matches(&"safeguard".into()).unwrap());
assert!(pattern.matches(&"is-safe-mode".into()).unwrap());
assert!(!pattern.matches(&"danger".into()).unwrap());
}
#[test]
fn test_pattern_bidirectional_attenuation() {
let parent = Pattern::new("*-prod-*").unwrap();
let child_same = Pattern::new("*-prod-*").unwrap();
assert!(parent.validate_attenuation(&child_same).is_ok());
let child_prefix = Pattern::new("db-prod-*").unwrap();
assert!(parent.validate_attenuation(&child_prefix).is_err());
let child_suffix = Pattern::new("*-prod-primary").unwrap();
assert!(parent.validate_attenuation(&child_suffix).is_err());
let child_exact = Pattern::new("db-prod-primary").unwrap();
assert!(parent.validate_attenuation(&child_exact).is_err());
}
#[test]
fn test_pattern_complex_attenuation() {
let parent = Pattern::new("/data/*/file.txt").unwrap();
let child_same = Pattern::new("/data/*/file.txt").unwrap();
assert!(parent.validate_attenuation(&child_same).is_ok());
let child_different = Pattern::new("/data/reports/file.txt").unwrap();
assert!(parent.validate_attenuation(&child_different).is_err());
}
#[test]
fn test_pattern_single_wildcard() {
let pattern = Pattern::new("*").unwrap();
assert!(pattern.matches(&"anything".into()).unwrap());
assert!(pattern.matches(&"".into()).unwrap());
assert!(pattern.matches(&"foo/bar/baz".into()).unwrap());
}
#[test]
fn test_pattern_question_mark() {
let pattern = Pattern::new("file?.txt").unwrap();
assert!(pattern.matches(&"file1.txt".into()).unwrap());
assert!(pattern.matches(&"fileA.txt".into()).unwrap());
assert!(!pattern.matches(&"file12.txt".into()).unwrap());
assert!(!pattern.matches(&"file.txt".into()).unwrap());
}
#[test]
fn test_pattern_character_class() {
let pattern = Pattern::new("env-[psd]*").unwrap(); assert!(pattern.matches(&"env-prod".into()).unwrap());
assert!(pattern.matches(&"env-staging".into()).unwrap());
assert!(pattern.matches(&"env-dev".into()).unwrap());
assert!(!pattern.matches(&"env-test".into()).unwrap()); }
#[test]
fn test_pattern_no_wildcard() {
let pattern = Pattern::new("/data/file.txt").unwrap();
assert!(pattern.matches(&"/data/file.txt".into()).unwrap());
assert!(!pattern.matches(&"/data/other.txt".into()).unwrap());
}
#[test]
fn test_regex_matches() {
let regex = RegexConstraint::new(r"^prod-[a-z]+$").unwrap();
assert!(regex.matches(&"prod-web".into()).unwrap());
assert!(regex.matches(&"prod-api".into()).unwrap());
assert!(!regex.matches(&"prod-123".into()).unwrap());
assert!(!regex.matches(&"staging-web".into()).unwrap());
}
#[test]
fn test_exact_matches_various_types() {
let exact = Exact::new("hello");
assert!(exact.matches(&"hello".into()).unwrap());
assert!(!exact.matches(&"world".into()).unwrap());
let exact = Exact::new(42i64);
assert!(exact.matches(&42i64.into()).unwrap());
assert!(!exact.matches(&43i64.into()).unwrap());
let exact = Exact::new(true);
assert!(exact.matches(&true.into()).unwrap());
assert!(!exact.matches(&false.into()).unwrap());
}
#[test]
fn test_range_matches() {
let range = Range::between(10.0, 100.0).unwrap();
assert!(range.matches(&50i64.into()).unwrap());
assert!(range.matches(&10i64.into()).unwrap());
assert!(range.matches(&100i64.into()).unwrap());
assert!(!range.matches(&5i64.into()).unwrap());
assert!(!range.matches(&150i64.into()).unwrap());
}
#[test]
fn test_range_rejects_nan() {
let result = Range::min(f64::NAN);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("NaN"));
let result = Range::max(f64::NAN);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("NaN"));
let result = Range::between(f64::NAN, 100.0);
assert!(result.is_err());
let result = Range::between(0.0, f64::NAN);
assert!(result.is_err());
assert!(Range::max(100.0).is_ok());
assert!(Range::min(0.0).is_ok());
assert!(Range::between(0.0, 100.0).is_ok());
assert!(Range::max(f64::INFINITY).is_ok());
assert!(Range::min(f64::NEG_INFINITY).is_ok());
}
#[test]
fn test_contains_constraint() {
let contains = Contains::new(["admin", "write"]);
let has_both: ConstraintValue = vec!["admin", "write", "read"]
.into_iter()
.map(|s| ConstraintValue::String(s.to_string()))
.collect::<Vec<_>>()
.into();
assert!(contains.matches(&has_both).unwrap());
let missing_one: ConstraintValue = vec!["admin", "read"]
.into_iter()
.map(|s| ConstraintValue::String(s.to_string()))
.collect::<Vec<_>>()
.into();
assert!(!contains.matches(&missing_one).unwrap());
}
#[test]
fn test_subset_constraint() {
let subset = Subset::new(["read", "write", "admin"]);
let valid: ConstraintValue = vec!["read", "write"]
.into_iter()
.map(|s| ConstraintValue::String(s.to_string()))
.collect::<Vec<_>>()
.into();
assert!(subset.matches(&valid).unwrap());
let invalid: ConstraintValue = vec!["read", "delete"]
.into_iter()
.map(|s| ConstraintValue::String(s.to_string()))
.collect::<Vec<_>>()
.into();
assert!(!subset.matches(&invalid).unwrap());
}
#[test]
fn test_all_constraint() {
let all = All::new([
Range::min(0.0).unwrap().into(),
Range::max(100.0).unwrap().into(),
]);
assert!(all.matches(&50i64.into()).unwrap());
assert!(!all.matches(&(-10i64).into()).unwrap());
assert!(!all.matches(&150i64.into()).unwrap());
}
#[test]
fn test_any_constraint() {
let any = Any::new([Exact::new("admin").into(), Exact::new("superuser").into()]);
assert!(any.matches(&"admin".into()).unwrap());
assert!(any.matches(&"superuser".into()).unwrap());
assert!(!any.matches(&"user".into()).unwrap());
}
#[test]
fn test_not_constraint() {
let not = Not::new(Exact::new("blocked").into());
assert!(not.matches(&"allowed".into()).unwrap());
assert!(!not.matches(&"blocked".into()).unwrap());
}
#[test]
fn test_range_attenuation() {
let parent = Range::max(10000.0).unwrap();
let valid_child = Range::max(5000.0).unwrap();
assert!(parent.validate_attenuation(&valid_child).is_ok());
let invalid_child = Range::max(15000.0).unwrap();
assert!(parent.validate_attenuation(&invalid_child).is_err());
}
#[test]
fn test_range_inclusivity_cannot_expand() {
let parent = Range::between(0.0, 10.0)
.unwrap()
.min_exclusive()
.max_exclusive();
let child_inclusive = Range::between(0.0, 10.0).unwrap();
let result = parent.validate_attenuation(&child_inclusive);
assert!(
result.is_err(),
"Should reject: exclusive->inclusive at same bound expands permissions"
);
assert!(result
.unwrap_err()
.to_string()
.contains("inclusivity expanded"));
}
#[test]
fn test_range_inclusivity_can_narrow() {
let parent = Range::between(0.0, 10.0).unwrap();
let child_exclusive = Range::between(0.0, 10.0)
.unwrap()
.min_exclusive()
.max_exclusive();
assert!(
parent.validate_attenuation(&child_exclusive).is_ok(),
"Should allow: inclusive->exclusive is valid narrowing"
);
}
#[test]
fn test_range_inclusivity_stricter_value_ok() {
let parent = Range::between(0.0, 10.0)
.unwrap()
.min_exclusive()
.max_exclusive();
let child = Range::between(1.0, 9.0).unwrap();
assert!(
parent.validate_attenuation(&child).is_ok(),
"Should allow: child bounds strictly inside parent exclusive range"
);
}
#[test]
fn test_range_to_exact_valid() {
let parent = Constraint::Range(Range::between(0.0, 100.0).unwrap());
let child = Constraint::Exact(Exact::new(50));
assert!(
parent.validate_attenuation(&child).is_ok(),
"Should allow: Exact(50) is within Range(0, 100)"
);
let child_at_min = Constraint::Exact(Exact::new(0));
assert!(
parent.validate_attenuation(&child_at_min).is_ok(),
"Should allow: Exact(0) at inclusive min bound"
);
let child_at_max = Constraint::Exact(Exact::new(100));
assert!(
parent.validate_attenuation(&child_at_max).is_ok(),
"Should allow: Exact(100) at inclusive max bound"
);
}
#[test]
fn test_range_to_exact_invalid() {
let parent = Constraint::Range(Range::between(0.0, 100.0).unwrap());
let child_below = Constraint::Exact(Exact::new(-1));
assert!(
parent.validate_attenuation(&child_below).is_err(),
"Should reject: Exact(-1) below Range(0, 100)"
);
let child_above = Constraint::Exact(Exact::new(150));
assert!(
parent.validate_attenuation(&child_above).is_err(),
"Should reject: Exact(150) above Range(0, 100)"
);
}
#[test]
fn test_range_exclusive_to_exact_boundary() {
let parent = Constraint::Range(
Range::between(0.0, 100.0)
.unwrap()
.min_exclusive()
.max_exclusive(),
);
let child_at_min = Constraint::Exact(Exact::new(0));
assert!(
parent.validate_attenuation(&child_at_min).is_err(),
"Should reject: Exact(0) at exclusive min bound"
);
let child_at_max = Constraint::Exact(Exact::new(100));
assert!(
parent.validate_attenuation(&child_at_max).is_err(),
"Should reject: Exact(100) at exclusive max bound"
);
let child_inside = Constraint::Exact(Exact::new(50));
assert!(
parent.validate_attenuation(&child_inside).is_ok(),
"Should allow: Exact(50) inside exclusive range"
);
}
#[test]
fn test_subset_attenuation() {
let parent = Subset::new(["a", "b", "c"]);
let valid_child = Subset::new(["a", "b"]); assert!(parent.validate_attenuation(&valid_child).is_ok());
let invalid_child = Subset::new(["a", "d"]); assert!(parent.validate_attenuation(&invalid_child).is_err());
}
#[test]
fn test_constraint_set_validation() {
let mut parent = ConstraintSet::new();
parent.insert("cluster", Pattern::new("staging-*").unwrap());
parent.insert("version", Pattern::new("1.28.*").unwrap());
let mut child = ConstraintSet::new();
child.insert("cluster", Exact::new("staging-web"));
child.insert("version", Pattern::new("1.28.5").unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cross_type_attenuation_pattern_to_exact() {
let parent = Constraint::Pattern(Pattern::new("staging-*").unwrap());
let child = Constraint::Exact(Exact::new("staging-web"));
assert!(parent.validate_attenuation(&child).is_ok());
let invalid_child = Constraint::Exact(Exact::new("prod-web"));
assert!(parent.validate_attenuation(&invalid_child).is_err());
}
#[test]
fn test_cross_type_attenuation_oneof_to_exact() {
let parent = Constraint::OneOf(OneOf::new(vec!["upgrade", "restart", "scale"]));
let child = Constraint::Exact(Exact::new("upgrade"));
assert!(parent.validate_attenuation(&child).is_ok());
let invalid_child = Constraint::Exact(Exact::new("delete"));
assert!(parent.validate_attenuation(&invalid_child).is_err());
}
#[test]
fn test_cross_type_attenuation_incompatible_types() {
let parent = Constraint::Pattern(Pattern::new("*").unwrap());
let child = Constraint::OneOf(OneOf::new(vec!["upgrade", "restart"]));
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("incompatible constraint types"));
let parent = Constraint::Range(Range::max(1000.0).unwrap());
let child = Constraint::Pattern(Pattern::new("*").unwrap());
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
}
#[test]
fn test_adding_key_to_nonempty_parent_rejected() {
let mut parent = ConstraintSet::new();
parent.insert("cluster", Pattern::new("staging-*").unwrap());
let mut child = ConstraintSet::new();
child.insert("cluster", Exact::new("staging-web"));
child.insert("action", OneOf::new(vec!["upgrade", "restart"]));
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_adding_key_to_empty_parent_allowed() {
let parent = ConstraintSet::new();
let mut child = ConstraintSet::new();
child.insert("cluster", Exact::new("staging-web"));
child.insert("action", OneOf::new(vec!["upgrade", "restart"]));
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_wildcard_matches_everything() {
let wildcard = Wildcard::new();
assert!(wildcard
.matches(&ConstraintValue::String("anything".to_string()))
.unwrap());
assert!(wildcard.matches(&ConstraintValue::Integer(42)).unwrap());
assert!(wildcard.matches(&ConstraintValue::Float(3.5)).unwrap());
assert!(wildcard.matches(&ConstraintValue::Boolean(true)).unwrap());
assert!(wildcard.matches(&ConstraintValue::List(vec![])).unwrap());
}
#[test]
fn test_wildcard_can_attenuate_to_anything() {
let parent = Constraint::Wildcard(Wildcard::new());
let child = Constraint::Pattern(Pattern::new("staging-*").unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
let child = Constraint::OneOf(OneOf::new(vec!["upgrade", "restart"]));
assert!(parent.validate_attenuation(&child).is_ok());
let child = Constraint::Range(Range::max(1000.0).unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
let child = Constraint::Exact(Exact::new("specific"));
assert!(parent.validate_attenuation(&child).is_ok());
let child = Constraint::Contains(Contains::new(vec!["admin"]));
assert!(parent.validate_attenuation(&child).is_ok());
let child = Constraint::Wildcard(Wildcard::new());
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cannot_attenuate_to_wildcard() {
let parent = Constraint::Pattern(Pattern::new("staging-*").unwrap());
let child = Constraint::Wildcard(Wildcard::new());
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("cannot attenuate to Wildcard"));
let parent = Constraint::OneOf(OneOf::new(vec!["a", "b"]));
let child = Constraint::Wildcard(Wildcard::new());
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_wildcard_in_constraint_set() {
let mut parent = ConstraintSet::new();
parent.insert("cluster", Pattern::new("staging-*").unwrap());
parent.insert("action", Wildcard::new());
let mut child = ConstraintSet::new();
child.insert("cluster", Exact::new("staging-web"));
child.insert("action", OneOf::new(vec!["upgrade", "restart"]));
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_notoneof_matches() {
let constraint = NotOneOf::new(vec!["prod", "secure"]);
assert!(constraint
.matches(&ConstraintValue::String("staging".to_string()))
.unwrap());
assert!(constraint
.matches(&ConstraintValue::String("dev".to_string()))
.unwrap());
assert!(!constraint
.matches(&ConstraintValue::String("prod".to_string()))
.unwrap());
assert!(!constraint
.matches(&ConstraintValue::String("secure".to_string()))
.unwrap());
}
#[test]
fn test_notoneof_attenuation_can_add_exclusions() {
let parent = NotOneOf::new(vec!["prod"]);
let child = NotOneOf::new(vec!["prod", "secure"]);
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_notoneof_attenuation_cannot_remove_exclusions() {
let parent = NotOneOf::new(vec!["prod", "secure"]);
let child = NotOneOf::new(vec!["prod"]);
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("must still exclude"));
}
#[test]
fn test_oneof_to_notoneof_forbidden() {
let parent = Constraint::OneOf(OneOf::new(vec!["a", "b", "c", "d"]));
let child = Constraint::NotOneOf(NotOneOf::new(vec!["b"]));
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
match result.unwrap_err() {
Error::IncompatibleConstraintTypes {
parent_type,
child_type,
} => {
assert_eq!(parent_type, "OneOf");
assert_eq!(child_type, "NotOneOf");
}
e => panic!("Expected IncompatibleConstraintTypes, got {:?}", e),
}
}
#[test]
fn test_oneof_to_notoneof_full_exclusion_also_forbidden() {
let parent = Constraint::OneOf(OneOf::new(vec!["a", "b"]));
let child = Constraint::NotOneOf(NotOneOf::new(vec!["a", "b"]));
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
match result.unwrap_err() {
Error::IncompatibleConstraintTypes {
parent_type,
child_type,
} => {
assert_eq!(parent_type, "OneOf");
assert_eq!(child_type, "NotOneOf");
}
e => panic!("Expected IncompatibleConstraintTypes, got {:?}", e),
}
}
#[test]
fn test_wildcard_to_notoneof() {
let parent = Constraint::Wildcard(Wildcard::new());
let child = Constraint::NotOneOf(NotOneOf::new(vec!["prod", "secure"]));
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_notoneof_to_notoneof() {
let parent = Constraint::NotOneOf(NotOneOf::new(vec!["prod"]));
let child = Constraint::NotOneOf(NotOneOf::new(vec!["prod", "secure"]));
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_creation_ipv4() {
let cidr = Cidr::new("10.0.0.0/8").unwrap();
assert_eq!(cidr.cidr_string, "10.0.0.0/8");
}
#[test]
fn test_cidr_creation_ipv6() {
let cidr = Cidr::new("2001:db8::/32").unwrap();
assert_eq!(cidr.cidr_string, "2001:db8::/32");
}
#[test]
fn test_cidr_invalid() {
assert!(Cidr::new("not-a-cidr").is_err());
assert!(Cidr::new("10.0.0.0/33").is_err()); assert!(Cidr::new("256.0.0.0/8").is_err()); }
#[test]
fn test_cidr_contains_ip() {
let cidr = Cidr::new("10.0.0.0/8").unwrap();
assert!(cidr.contains_ip("10.0.0.1").unwrap());
assert!(cidr.contains_ip("10.255.255.255").unwrap());
assert!(cidr.contains_ip("10.1.2.3").unwrap());
assert!(!cidr.contains_ip("192.168.1.1").unwrap());
assert!(!cidr.contains_ip("11.0.0.1").unwrap());
}
#[test]
fn test_cidr_contains_ip_ipv6() {
let cidr = Cidr::new("2001:db8::/32").unwrap();
assert!(cidr.contains_ip("2001:db8::1").unwrap());
assert!(cidr
.contains_ip("2001:db8:ffff:ffff:ffff:ffff:ffff:ffff")
.unwrap());
assert!(!cidr.contains_ip("2001:db9::1").unwrap());
}
#[test]
fn test_cidr_matches_constraint_value() {
let cidr = Cidr::new("192.168.0.0/16").unwrap();
let value = ConstraintValue::String("192.168.1.100".to_string());
assert!(cidr.matches(&value).unwrap());
let value = ConstraintValue::String("10.0.0.1".to_string());
assert!(!cidr.matches(&value).unwrap());
let value = ConstraintValue::Integer(123);
assert!(!cidr.matches(&value).unwrap());
}
#[test]
fn test_cidr_attenuation_valid_subnet() {
let parent = Cidr::new("10.0.0.0/8").unwrap();
let child = Cidr::new("10.1.0.0/16").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_attenuation_same_network() {
let parent = Cidr::new("10.0.0.0/8").unwrap();
let child = Cidr::new("10.0.0.0/8").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_attenuation_narrower_prefix() {
let parent = Cidr::new("192.168.0.0/16").unwrap();
let child = Cidr::new("192.168.1.0/24").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_attenuation_invalid_not_subset() {
let parent = Cidr::new("10.0.0.0/8").unwrap();
let child = Cidr::new("192.168.0.0/16").unwrap();
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
}
#[test]
fn test_cidr_attenuation_invalid_wider() {
let parent = Cidr::new("10.1.0.0/16").unwrap();
let child = Cidr::new("10.0.0.0/8").unwrap();
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
}
#[test]
fn test_cidr_constraint_attenuation() {
let parent = Constraint::Cidr(Cidr::new("10.0.0.0/8").unwrap());
let child = Constraint::Cidr(Cidr::new("10.1.0.0/16").unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_to_exact_attenuation() {
let parent = Constraint::Cidr(Cidr::new("10.0.0.0/8").unwrap());
let child = Constraint::Exact(Exact::new("10.1.2.3"));
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_to_exact_attenuation_invalid() {
let parent = Constraint::Cidr(Cidr::new("10.0.0.0/8").unwrap());
let child = Constraint::Exact(Exact::new("192.168.1.1"));
let result = parent.validate_attenuation(&child);
assert!(result.is_err());
}
#[test]
fn test_wildcard_to_cidr_attenuation() {
let parent = Constraint::Wildcard(Wildcard::new());
let child = Constraint::Cidr(Cidr::new("10.0.0.0/8").unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_single_ip_prefix32() {
let cidr = Cidr::new("192.168.1.100/32").unwrap();
assert!(cidr.contains_ip("192.168.1.100").unwrap());
assert!(!cidr.contains_ip("192.168.1.101").unwrap());
assert!(!cidr.contains_ip("192.168.1.99").unwrap());
}
#[test]
fn test_cidr_all_ips_prefix0() {
let cidr = Cidr::new("0.0.0.0/0").unwrap();
assert!(cidr.contains_ip("192.168.1.1").unwrap());
assert!(cidr.contains_ip("10.0.0.1").unwrap());
assert!(cidr.contains_ip("255.255.255.255").unwrap());
}
#[test]
fn test_cidr_ipv4_ipv6_mismatch() {
let ipv4_cidr = Cidr::new("10.0.0.0/8").unwrap();
assert!(!ipv4_cidr.contains_ip("2001:db8::1").unwrap());
let ipv6_cidr = Cidr::new("2001:db8::/32").unwrap();
assert!(!ipv6_cidr.contains_ip("10.0.0.1").unwrap());
}
#[test]
fn test_cidr_invalid_ip_string() {
let cidr = Cidr::new("10.0.0.0/8").unwrap();
assert!(cidr.contains_ip("not-an-ip").is_err());
assert!(cidr.contains_ip("").is_err());
assert!(cidr.contains_ip("256.0.0.1").is_err());
assert!(cidr.contains_ip("10.0.0").is_err());
}
#[test]
fn test_cidr_serialization_roundtrip() {
let original = Cidr::new("192.168.0.0/16").unwrap();
let json = serde_json::to_string(&original).unwrap();
assert_eq!(json, "\"192.168.0.0/16\"");
let deserialized: Cidr = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.cidr_string, original.cidr_string);
assert!(deserialized.contains_ip("192.168.1.1").unwrap());
assert!(!deserialized.contains_ip("10.0.0.1").unwrap());
}
#[test]
fn test_cidr_constraint_serialization() {
let constraint = Constraint::Cidr(Cidr::new("10.0.0.0/8").unwrap());
let mut cbor_bytes = Vec::new();
ciborium::ser::into_writer(&constraint, &mut cbor_bytes).unwrap();
assert!(cbor_bytes.len() > 2);
assert_eq!(cbor_bytes[1], constraint_type_id::CIDR);
let deserialized: Constraint = ciborium::de::from_reader(&cbor_bytes[..]).unwrap();
if let Constraint::Cidr(c) = deserialized {
assert_eq!(c.cidr_string, "10.0.0.0/8");
} else {
panic!("Expected Cidr constraint, got {:?}", deserialized);
}
}
#[test]
fn test_all_constraint_type_ids_wire_format() {
use constraint_type_id::*;
fn test_constraint(constraint: Constraint, expected_type_id: u8, name: &str) {
let mut bytes = Vec::new();
ciborium::ser::into_writer(&constraint, &mut bytes).unwrap();
assert!(bytes.len() >= 2, "{}: too short", name);
assert_eq!(bytes[0], 0x82, "{}: not a 2-element array", name);
assert_eq!(bytes[1], expected_type_id, "{}: wrong type ID", name);
let decoded: Constraint = ciborium::de::from_reader(&bytes[..]).unwrap();
assert_eq!(
std::mem::discriminant(&constraint),
std::mem::discriminant(&decoded),
"{}: discriminant mismatch after round-trip",
name
);
}
test_constraint(Constraint::Exact(Exact::new("test")), EXACT, "Exact");
test_constraint(
Constraint::Pattern(Pattern::new("test-*").unwrap()),
PATTERN,
"Pattern",
);
test_constraint(
Constraint::Range(Range::new(Some(0.0), Some(100.0)).unwrap()),
RANGE,
"Range",
);
test_constraint(
Constraint::OneOf(OneOf::new(vec!["a".to_string(), "b".to_string()])),
ONE_OF,
"OneOf",
);
test_constraint(
Constraint::Regex(RegexConstraint::new("^test$").unwrap()),
REGEX,
"Regex",
);
test_constraint(
Constraint::NotOneOf(NotOneOf::new(vec!["x".to_string()])),
NOT_ONE_OF,
"NotOneOf",
);
test_constraint(
Constraint::Cidr(Cidr::new("10.0.0.0/8").unwrap()),
CIDR,
"Cidr",
);
test_constraint(
Constraint::UrlPattern(UrlPattern::new("https://example.com/*").unwrap()),
URL_PATTERN,
"UrlPattern",
);
test_constraint(
Constraint::Contains(Contains::new(vec!["admin".to_string()])),
CONTAINS,
"Contains",
);
test_constraint(
Constraint::Subset(Subset::new(vec!["a".to_string(), "b".to_string()])),
SUBSET,
"Subset",
);
test_constraint(
Constraint::All(All {
constraints: vec![Constraint::Exact(Exact::new("x"))],
}),
ALL,
"All",
);
test_constraint(
Constraint::Any(Any {
constraints: vec![Constraint::Exact(Exact::new("y"))],
}),
ANY,
"Any",
);
test_constraint(
Constraint::Not(Not {
constraint: Box::new(Constraint::Exact(Exact::new("z"))),
}),
NOT,
"Not",
);
test_constraint(Constraint::Cel(CelConstraint::new("x > 0")), CEL, "Cel");
test_constraint(Constraint::Wildcard(Wildcard::new()), WILDCARD, "Wildcard");
test_constraint(
Constraint::Subpath(Subpath::new("/data").unwrap()),
SUBPATH,
"Subpath",
);
test_constraint(Constraint::UrlSafe(UrlSafe::new()), URL_SAFE, "UrlSafe");
}
#[test]
fn test_unknown_constraint_type_id() {
let payload_bytes: Vec<u8> = vec![1, 2, 3, 4];
let mut bytes = Vec::new();
ciborium::ser::into_writer(
&(200u8, serde_bytes::Bytes::new(&payload_bytes)),
&mut bytes,
)
.unwrap();
let constraint: Constraint = ciborium::de::from_reader(&bytes[..]).unwrap();
match constraint {
Constraint::Unknown { type_id, payload } => {
assert_eq!(type_id, 200);
assert_eq!(payload, payload_bytes);
}
_ => panic!("Expected Unknown variant, got {:?}", constraint),
}
}
#[test]
fn test_cidr_attenuation_prefix32_to_prefix32() {
let parent = Cidr::new("10.1.2.3/32").unwrap();
let child = Cidr::new("10.1.2.3/32").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_attenuation_to_single_ip() {
let parent = Cidr::new("10.0.0.0/8").unwrap();
let child = Cidr::new("10.1.2.3/32").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_ipv6_attenuation() {
let parent = Cidr::new("2001:db8::/32").unwrap();
let child = Cidr::new("2001:db8:1::/48").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_cidr_boundary_ips() {
let cidr = Cidr::new("192.168.1.0/24").unwrap();
assert!(cidr.contains_ip("192.168.1.0").unwrap());
assert!(cidr.contains_ip("192.168.1.255").unwrap());
assert!(!cidr.contains_ip("192.168.0.255").unwrap());
assert!(!cidr.contains_ip("192.168.2.0").unwrap());
}
#[test]
fn test_url_pattern_creation() {
let pattern = UrlPattern::new("https://api.example.com/*").unwrap();
assert_eq!(pattern.schemes, vec!["https"]);
assert_eq!(pattern.host_pattern, Some("api.example.com".to_string()));
assert_eq!(pattern.path_pattern, Some("/*".to_string()));
}
#[test]
fn test_url_pattern_wildcard_scheme() {
let pattern = UrlPattern::new("*://example.com/api/*").unwrap();
assert!(pattern.schemes.is_empty()); assert_eq!(pattern.host_pattern, Some("example.com".to_string()));
}
#[test]
fn test_url_pattern_with_port() {
let pattern = UrlPattern::new("https://api.example.com:8443/api/*").unwrap();
assert_eq!(pattern.port, Some(8443));
}
#[test]
fn test_url_pattern_wildcard_host() {
let pattern = UrlPattern::new("https://*.example.com/*").unwrap();
assert_eq!(pattern.host_pattern, Some("*.example.com".to_string()));
}
#[test]
#[ignore = "URLP-001: Bare wildcard host not yet supported - see UrlPattern::new() for details"]
fn test_url_pattern_bare_wildcard_host() {
let pattern = UrlPattern::new("https://*/*").unwrap();
assert!(!pattern.matches_url("https://example.com/path").unwrap());
assert!(!pattern.matches_url("https://evil.com/attack").unwrap());
}
#[test]
fn test_url_pattern_invalid() {
assert!(UrlPattern::new("not-a-url").is_err());
assert!(UrlPattern::new("missing-scheme.com").is_err());
}
#[test]
fn test_url_pattern_matches_basic() {
let pattern = UrlPattern::new("https://api.example.com/*").unwrap();
assert!(pattern
.matches_url("https://api.example.com/v1/users")
.unwrap());
assert!(pattern.matches_url("https://api.example.com/").unwrap());
assert!(!pattern.matches_url("http://api.example.com/v1").unwrap());
assert!(!pattern.matches_url("https://other.example.com/v1").unwrap());
}
#[test]
fn test_url_pattern_matches_wildcard_scheme() {
let pattern = UrlPattern::new("*://api.example.com/*").unwrap();
assert!(pattern.matches_url("https://api.example.com/v1").unwrap());
assert!(pattern.matches_url("http://api.example.com/v1").unwrap());
}
#[test]
fn test_url_pattern_matches_wildcard_host() {
let pattern = UrlPattern::new("https://*.example.com/*").unwrap();
assert!(pattern.matches_url("https://api.example.com/v1").unwrap());
assert!(pattern.matches_url("https://www.example.com/v1").unwrap());
assert!(pattern.matches_url("https://example.com/v1").unwrap());
assert!(!pattern.matches_url("https://api.other.com/v1").unwrap());
}
#[test]
fn test_url_pattern_matches_port() {
let pattern = UrlPattern::new("https://api.example.com:8443/*").unwrap();
assert!(pattern
.matches_url("https://api.example.com:8443/v1")
.unwrap());
assert!(!pattern
.matches_url("https://api.example.com:443/v1")
.unwrap());
assert!(!pattern.matches_url("https://api.example.com/v1").unwrap());
}
#[test]
fn test_url_pattern_matches_path() {
let pattern = UrlPattern::new("https://api.example.com/api/v1/*").unwrap();
assert!(pattern
.matches_url("https://api.example.com/api/v1/users")
.unwrap());
assert!(pattern
.matches_url("https://api.example.com/api/v1/")
.unwrap());
assert!(!pattern
.matches_url("https://api.example.com/api/v2/users")
.unwrap());
assert!(!pattern
.matches_url("https://api.example.com/other")
.unwrap());
}
#[test]
fn test_url_pattern_attenuation_same() {
let parent = UrlPattern::new("https://api.example.com/*").unwrap();
let child = UrlPattern::new("https://api.example.com/*").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_url_pattern_attenuation_narrower_path() {
let parent = UrlPattern::new("https://api.example.com/*").unwrap();
let child = UrlPattern::new("https://api.example.com/api/v1/*").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_url_pattern_attenuation_narrower_host() {
let parent = UrlPattern::new("https://*.example.com/*").unwrap();
let child = UrlPattern::new("https://api.example.com/*").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_url_pattern_attenuation_add_scheme() {
let parent = UrlPattern::new("*://api.example.com/*").unwrap();
let child = UrlPattern::new("https://api.example.com/*").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_url_pattern_attenuation_invalid_scheme_expansion() {
let parent = UrlPattern::new("https://api.example.com/*").unwrap();
let child = UrlPattern::new("http://api.example.com/*").unwrap();
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_url_pattern_attenuation_wildcard_scheme_widening_blocked() {
let parent = UrlPattern::new("https://api.example.com/*").unwrap();
let child = UrlPattern::new("*://api.example.com/*").unwrap();
assert!(child.schemes.is_empty());
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_url_pattern_attenuation_wildcard_to_specific_allowed() {
let parent = UrlPattern::new("*://api.example.com/*").unwrap();
let child = UrlPattern::new("https://api.example.com/*").unwrap();
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_url_pattern_attenuation_invalid_host_expansion() {
let parent = UrlPattern::new("https://api.example.com/*").unwrap();
let child = UrlPattern::new("https://*.example.com/*").unwrap();
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_url_pattern_attenuation_invalid_path_expansion() {
let parent = UrlPattern::new("https://api.example.com/api/v1/*").unwrap();
let child = UrlPattern::new("https://api.example.com/*").unwrap();
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_url_constraint_attenuation() {
let parent = Constraint::UrlPattern(UrlPattern::new("https://*.example.com/*").unwrap());
let child =
Constraint::UrlPattern(UrlPattern::new("https://api.example.com/v1/*").unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_url_to_exact_attenuation() {
let parent = Constraint::UrlPattern(UrlPattern::new("https://api.example.com/*").unwrap());
let child = Constraint::Exact(Exact::new("https://api.example.com/v1/users"));
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_url_to_exact_attenuation_invalid() {
let parent = Constraint::UrlPattern(UrlPattern::new("https://api.example.com/*").unwrap());
let child = Constraint::Exact(Exact::new("https://other.example.com/v1"));
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_wildcard_to_url_attenuation() {
let parent = Constraint::Wildcard(Wildcard::new());
let child = Constraint::UrlPattern(UrlPattern::new("https://api.example.com/*").unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_url_pattern_serialization_roundtrip() {
let original = UrlPattern::new("https://api.example.com/v1/*").unwrap();
let json = serde_json::to_string(&original).unwrap();
assert!(json.contains("https://api.example.com/v1/*"));
let deserialized: UrlPattern = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.pattern, original.pattern);
}
#[test]
fn test_url_pattern_placeholder_collision_rejected() {
assert!(UrlPattern::new("https://__tenuo_host_wildcard__.evil.com/*").is_err());
assert!(UrlPattern::new("https://evil.com/__tenuo_path_wildcard__").is_err());
assert!(UrlPattern::new("https://api.example.com/*").is_ok());
}
#[test]
fn test_unknown_constraint_behavior() {
let unknown = Constraint::Unknown {
type_id: 55,
payload: vec![0, 1, 2, 3],
};
assert!(unknown
.matches(&ConstraintValue::String("test".into()))
.is_err());
assert!(unknown.validate_attenuation(&unknown).is_err());
}
#[test]
fn test_zero_trust_empty_constraint_set_allows_unknown_fields() {
let cs = ConstraintSet::new();
let mut args = HashMap::new();
args.insert(
"url".to_string(),
ConstraintValue::String("https://example.com".into()),
);
args.insert("timeout".to_string(), ConstraintValue::Integer(30));
args.insert(
"anything".to_string(),
ConstraintValue::String("whatever".into()),
);
assert!(cs.matches(&args).is_ok());
}
#[test]
fn test_zero_trust_one_constraint_rejects_unknown_fields() {
let mut cs = ConstraintSet::new();
cs.insert("url", Pattern::new("https://*").unwrap());
let mut args = HashMap::new();
args.insert(
"url".to_string(),
ConstraintValue::String("https://example.com".into()),
);
assert!(cs.matches(&args).is_ok());
args.insert("timeout".to_string(), ConstraintValue::Integer(30));
let result = cs.matches(&args);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("timeout"));
assert!(err.to_string().contains("unknown") || err.to_string().contains("not allowed"));
}
#[test]
fn test_zero_trust_wildcard_allows_any_value_for_field() {
let mut cs = ConstraintSet::new();
cs.insert("url", Pattern::new("https://*").unwrap());
cs.insert("timeout", Wildcard::new());
let mut args = HashMap::new();
args.insert(
"url".to_string(),
ConstraintValue::String("https://example.com".into()),
);
args.insert("timeout".to_string(), ConstraintValue::Integer(9999));
assert!(cs.matches(&args).is_ok());
args.insert("retries".to_string(), ConstraintValue::Integer(3));
assert!(cs.matches(&args).is_err());
}
#[test]
fn test_zero_trust_allow_unknown_explicit_opt_out() {
let mut cs = ConstraintSet::new();
cs.insert("url", Pattern::new("https://*").unwrap());
cs.set_allow_unknown(true);
let mut args = HashMap::new();
args.insert(
"url".to_string(),
ConstraintValue::String("https://example.com".into()),
);
args.insert("timeout".to_string(), ConstraintValue::Integer(30));
args.insert(
"anything".to_string(),
ConstraintValue::String("whatever".into()),
);
assert!(cs.matches(&args).is_ok());
}
#[test]
fn test_zero_trust_allow_unknown_still_enforces_defined_constraints() {
let mut cs = ConstraintSet::new();
cs.insert("url", Pattern::new("https://*").unwrap());
cs.set_allow_unknown(true);
let mut args = HashMap::new();
args.insert(
"url".to_string(),
ConstraintValue::String("http://insecure.com".into()),
); args.insert(
"anything".to_string(),
ConstraintValue::String("allowed".into()),
);
assert!(cs.matches(&args).is_err());
}
#[test]
fn test_zero_trust_attenuation_allow_unknown_not_inherited() {
let mut parent = ConstraintSet::new();
parent.insert("url", Pattern::new("https://*").unwrap());
parent.set_allow_unknown(true);
let mut child = ConstraintSet::new();
child.insert("url", Pattern::new("https://api.example.com/*").unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
assert!(!child.allow_unknown());
let mut args = HashMap::new();
args.insert(
"url".to_string(),
ConstraintValue::String("https://api.example.com/v1".into()),
);
args.insert("timeout".to_string(), ConstraintValue::Integer(30));
assert!(child.matches(&args).is_err()); }
#[test]
fn test_zero_trust_attenuation_child_cannot_enable_allow_unknown() {
let mut parent = ConstraintSet::new();
parent.insert("url", Pattern::new("https://*").unwrap());
let mut child = ConstraintSet::new();
child.insert("url", Pattern::new("https://api.example.com/*").unwrap());
child.set_allow_unknown(true);
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_zero_trust_attenuation_parent_open_child_can_close() {
let parent = ConstraintSet::new();
let mut child = ConstraintSet::new();
child.insert("url", Pattern::new("https://*").unwrap());
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_zero_trust_serialization_roundtrip() {
let mut cs = ConstraintSet::new();
cs.insert("url", Pattern::new("https://*").unwrap());
cs.set_allow_unknown(true);
let json = serde_json::to_string(&cs).unwrap();
let deserialized: ConstraintSet = serde_json::from_str(&json).unwrap();
assert!(deserialized.allow_unknown());
assert!(deserialized.get("url").is_some());
}
#[test]
fn test_zero_trust_default_allow_unknown_is_false() {
let cs = ConstraintSet::new();
assert!(!cs.allow_unknown());
let mut cs_with_constraint = ConstraintSet::new();
cs_with_constraint.insert("url", Pattern::new("https://*").unwrap());
assert!(!cs_with_constraint.allow_unknown());
}
#[test]
fn test_subpath_basic_containment() {
let sp = Subpath::new("/data").unwrap();
assert!(sp.contains_path("/data/file.txt").unwrap());
assert!(sp.contains_path("/data/subdir/file.txt").unwrap());
assert!(sp.contains_path("/data").unwrap()); assert!(!sp.contains_path("/etc/passwd").unwrap());
assert!(!sp.contains_path("/data2/file.txt").unwrap());
}
#[test]
fn test_subpath_traversal_blocking() {
let sp = Subpath::new("/data").unwrap();
assert!(!sp.contains_path("/data/../etc/passwd").unwrap());
assert!(!sp.contains_path("/data/subdir/../../etc/passwd").unwrap());
assert!(sp.contains_path("/data/subdir/../file.txt").unwrap()); }
#[test]
fn test_subpath_null_bytes() {
let sp = Subpath::new("/data").unwrap();
assert!(!sp.contains_path("/data/file\x00.txt").unwrap());
}
#[test]
fn test_subpath_relative_path_rejected() {
let sp = Subpath::new("/data").unwrap();
assert!(!sp.contains_path("data/file.txt").unwrap());
assert!(!sp.contains_path("./file.txt").unwrap());
}
#[test]
fn test_subpath_matches() {
let sp = Subpath::new("/data").unwrap();
assert!(sp.matches(&"/data/file.txt".into()).unwrap());
assert!(!sp.matches(&"/etc/passwd".into()).unwrap());
assert!(!sp.matches(&123.into()).unwrap()); }
#[test]
fn test_url_safe_basic() {
let us = UrlSafe::new();
assert!(us.is_safe("https://api.github.com/repos").unwrap());
assert!(us.is_safe("http://example.com/path").unwrap());
}
#[test]
fn test_url_safe_blocks_loopback() {
let us = UrlSafe::new();
assert!(!us.is_safe("http://127.0.0.1/").unwrap());
assert!(!us.is_safe("http://localhost/").unwrap());
assert!(!us.is_safe("http://[::1]/").unwrap());
}
#[test]
fn test_url_safe_blocks_private_ips() {
let us = UrlSafe::new();
assert!(!us.is_safe("http://10.0.0.1/admin").unwrap());
assert!(!us.is_safe("http://172.16.0.1/").unwrap());
assert!(!us.is_safe("http://192.168.1.1/admin").unwrap());
}
#[test]
fn test_url_safe_blocks_metadata() {
let us = UrlSafe::new();
assert!(!us
.is_safe("http://169.254.169.254/latest/meta-data/")
.unwrap());
assert!(!us.is_safe("http://metadata.google.internal/").unwrap());
}
#[test]
fn test_url_safe_blocks_decimal_ip() {
let us = UrlSafe::new();
assert!(!us.is_safe("http://2130706433/").unwrap());
}
#[test]
fn test_url_safe_blocks_hex_ip() {
let us = UrlSafe::new();
assert!(!us.is_safe("http://0x7f000001/").unwrap());
}
#[test]
fn test_url_safe_empty_host() {
let us = UrlSafe::new();
assert!(us.is_safe("https:///path").unwrap());
assert!(!us.is_safe("http://").unwrap());
assert!(!us.is_safe("not-a-url").unwrap());
}
#[test]
fn test_url_safe_null_bytes() {
let us = UrlSafe::new();
assert!(!us.is_safe("https://evil.com\x00.trusted.com/").unwrap());
}
#[test]
fn test_url_safe_scheme_blocking() {
let us = UrlSafe::new();
assert!(!us.is_safe("file:///etc/passwd").unwrap());
assert!(!us.is_safe("gopher://evil.com/").unwrap());
assert!(!us.is_safe("ftp://example.com/").unwrap());
}
#[test]
fn test_url_safe_domain_allowlist() {
let us = UrlSafe::with_domains(vec!["api.github.com", "*.example.com"]);
assert!(us.is_safe("https://api.github.com/repos").unwrap());
assert!(us.is_safe("https://sub.example.com/path").unwrap());
assert!(!us.is_safe("https://other.com/").unwrap());
}
#[test]
fn test_url_safe_port_restriction() {
let us = UrlSafe {
allow_ports: Some(vec![443, 8443]),
..UrlSafe::new()
};
assert!(us.is_safe("https://example.com:443/").unwrap());
assert!(us.is_safe("https://example.com:8443/").unwrap());
assert!(!us.is_safe("http://example.com:80/").unwrap());
assert!(!us.is_safe("https://example.com:8080/").unwrap());
}
#[test]
fn test_url_safe_matches() {
let us = UrlSafe::new();
assert!(us.matches(&"https://api.github.com/".into()).unwrap());
assert!(!us.matches(&"http://127.0.0.1/".into()).unwrap());
assert!(!us.matches(&123.into()).unwrap()); }
#[test]
fn test_url_safe_blocks_ipv4_compatible_ipv6() {
let us = UrlSafe::new();
assert!(!us.is_safe("http://[::127.0.0.1]/").unwrap());
assert!(!us.is_safe("http://[0:0:0:0:0:0:127.0.0.1]/").unwrap());
assert!(!us.is_safe("http://[::10.0.0.1]/").unwrap());
assert!(!us.is_safe("http://[::172.16.0.1]/").unwrap());
assert!(!us.is_safe("http://[::192.168.1.1]/").unwrap());
assert!(!us.is_safe("http://[::169.254.169.254]/").unwrap());
assert!(!us.is_safe("http://[::ffff:127.0.0.1]/").unwrap());
assert!(!us.is_safe("http://[::ffff:10.0.0.1]/").unwrap());
assert!(!us.is_safe("http://[::1]/").unwrap());
}
#[test]
fn test_url_safe_octal_ip_normalization() {
let us = UrlSafe::new();
assert!(us.is_safe("http://010.0.0.1/").unwrap());
assert!(!us.is_safe("http://0177.0.0.1/").unwrap());
assert!(!us.is_safe("http://012.0.0.1/").unwrap());
assert!(!us.is_safe("http://10.0.0.1/").unwrap()); assert!(us.is_safe("http://8.0.0.1/").unwrap());
assert!(us.is_safe("https://example.com/").unwrap());
assert!(us.is_safe("https://10example.com/").unwrap()); }
#[test]
fn test_url_safe_deny_domains_basic() {
let us = UrlSafe {
deny_domains: Some(vec!["evil.com".to_string(), "*.malware.org".to_string()]),
..UrlSafe::new()
};
assert!(!us.is_safe("https://evil.com/payload").unwrap());
assert!(!us.is_safe("https://sub.malware.org/c2").unwrap());
assert!(us.is_safe("https://example.com/").unwrap());
}
#[test]
fn test_url_safe_deny_domains_blocks_ip_addresses() {
let us = UrlSafe {
deny_domains: Some(vec!["169.254.169.254".to_string()]),
block_metadata: false, ..UrlSafe::new()
};
assert!(!us
.is_safe("http://169.254.169.254/latest/meta-data/")
.unwrap());
assert!(us.is_safe("https://example.com/").unwrap());
let us2 = UrlSafe {
deny_domains: Some(vec!["10.0.0.1".to_string()]),
block_private: false,
..UrlSafe::new()
};
assert!(!us2.is_safe("http://10.0.0.1/admin").unwrap());
assert!(us2.is_safe("http://10.0.0.2/admin").unwrap());
}
#[test]
fn test_url_safe_deny_overrides_allow() {
let us = UrlSafe {
allow_domains: Some(vec!["*.example.com".to_string()]),
deny_domains: Some(vec!["evil.example.com".to_string()]),
..UrlSafe::new()
};
assert!(us.is_safe("https://good.example.com/").unwrap());
assert!(!us.is_safe("https://evil.example.com/").unwrap());
}
#[test]
fn test_url_safe_deny_domains_attenuation() {
let parent = UrlSafe {
deny_domains: Some(vec!["evil.com".to_string()]),
..UrlSafe::new()
};
let child_ok = UrlSafe {
deny_domains: Some(vec!["evil.com".to_string(), "also-bad.com".to_string()]),
..UrlSafe::new()
};
assert!(parent.validate_attenuation(&child_ok).is_ok());
let child_bad = UrlSafe::new();
assert!(parent.validate_attenuation(&child_bad).is_err());
let child_bad2 = UrlSafe {
deny_domains: Some(vec!["other.com".to_string()]),
..UrlSafe::new()
};
assert!(parent.validate_attenuation(&child_bad2).is_err());
}
#[test]
fn test_shlex_basic_allow() {
let sh = Shlex::new(vec!["npm", "docker"]);
assert!(sh
.matches(&ConstraintValue::String("npm install express".into()))
.unwrap());
assert!(sh
.matches(&ConstraintValue::String("docker run alpine".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("rm -rf /".into()))
.unwrap());
}
#[test]
fn test_shlex_rejects_metacharacters() {
let sh = Shlex::new(vec!["npm"]);
assert!(!sh
.matches(&ConstraintValue::String("npm install; rm -rf /".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("npm install | cat".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("npm install && echo pwned".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("npm install $(whoami)".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("npm install `whoami`".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("npm > /etc/passwd".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("npm < /etc/shadow".into()))
.unwrap());
}
#[test]
fn test_shlex_rejects_control_chars() {
let sh = Shlex::new(vec!["npm"]);
assert!(!sh
.matches(&ConstraintValue::String("npm\n rm -rf /".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("npm\0evil".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("npm\rinstall".into()))
.unwrap());
}
#[test]
fn test_shlex_empty_and_non_string() {
let sh = Shlex::new(vec!["npm"]);
assert!(!sh.matches(&ConstraintValue::String("".into())).unwrap());
assert!(!sh.matches(&ConstraintValue::Integer(42)).unwrap());
assert!(!sh.matches(&ConstraintValue::Null).unwrap());
}
#[test]
fn test_shlex_path_binary() {
let sh = Shlex::new(vec!["npm", "/usr/bin/git"]);
assert!(sh
.matches(&ConstraintValue::String("/usr/bin/git status".into()))
.unwrap());
assert!(sh
.matches(&ConstraintValue::String("npm install".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("git status".into()))
.unwrap());
}
#[test]
fn test_shlex_roundtrip_cbor() {
let sh = Shlex::new(vec!["npm", "docker"]);
let constraint = Constraint::Shlex(sh.clone());
let mut bytes = Vec::new();
ciborium::ser::into_writer(&constraint, &mut bytes).unwrap();
assert_eq!(bytes[0], 0x82);
assert_eq!(bytes[1], 0x18); assert_eq!(bytes[2], 0x80);
let decoded: Constraint = ciborium::de::from_reader(&bytes[..]).unwrap();
assert_eq!(constraint, decoded);
}
#[test]
fn test_shlex_attenuation_valid() {
let parent = Shlex::new(vec!["npm", "docker", "git"]);
let child = Shlex::new(vec!["npm"]);
assert!(parent.validate_attenuation(&child).is_ok());
let child2 = Shlex::new(vec!["npm", "docker", "git"]);
assert!(parent.validate_attenuation(&child2).is_ok());
}
#[test]
fn test_shlex_attenuation_invalid_expansion() {
let parent = Shlex::new(vec!["npm"]);
let child = Shlex::new(vec!["npm", "rm"]);
assert!(parent.validate_attenuation(&child).is_err());
}
#[test]
fn test_shlex_constraint_attenuation_dispatch() {
let parent = Constraint::Shlex(Shlex::new(vec!["npm", "docker"]));
let child = Constraint::Shlex(Shlex::new(vec!["npm"]));
assert!(parent.validate_attenuation(&child).is_ok());
let bad_child = Constraint::Shlex(Shlex::new(vec!["npm", "rm"]));
assert!(parent.validate_attenuation(&bad_child).is_err());
let exact_child = Constraint::Exact(Exact::new("npm install express"));
assert!(parent.validate_attenuation(&exact_child).is_ok());
let bad_exact = Constraint::Exact(Exact::new("rm -rf /"));
assert!(parent.validate_attenuation(&bad_exact).is_err());
}
#[test]
fn test_shlex_quoted_operators_rejected_by_rust() {
let sh = Shlex::new(vec!["ls"]);
assert!(!sh
.matches(&ConstraintValue::String(r#"ls "foo; bar""#.into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String(r#"ls 'foo|bar'"#.into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String(r#"echo "$(date)""#.into()))
.unwrap());
}
#[test]
fn test_shlex_whitespace_only() {
let sh = Shlex::new(vec!["npm"]);
assert!(!sh.matches(&ConstraintValue::String(" ".into())).unwrap());
assert!(!sh.matches(&ConstraintValue::String("\t\t".into())).unwrap());
}
#[test]
fn test_shlex_unicode_binary_names() {
let sh = Shlex::new(vec!["café"]);
assert!(sh
.matches(&ConstraintValue::String("café --help".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("evil --help".into()))
.unwrap());
}
#[test]
fn test_shlex_path_traversal_in_binary() {
let sh = Shlex::new(vec!["npm"]);
assert!(!sh
.matches(&ConstraintValue::String("../npm install".into()))
.unwrap());
assert!(!sh
.matches(&ConstraintValue::String("/usr/../bin/npm install".into()))
.unwrap());
}
#[test]
fn test_shlex_binary_with_spaces_in_arguments() {
let sh = Shlex::new(vec!["npm"]);
assert!(sh
.matches(&ConstraintValue::String("npm install express".into()))
.unwrap());
}
#[test]
fn test_shlex_tab_separated() {
let sh = Shlex::new(vec!["npm"]);
assert!(sh
.matches(&ConstraintValue::String("npm\tinstall\texpress".into()))
.unwrap());
}
#[test]
fn test_shlex_backslash_not_rejected() {
let sh = Shlex::new(vec!["ls"]);
assert!(sh
.matches(&ConstraintValue::String(r"ls foo\ bar".into()))
.unwrap());
}
#[test]
fn test_shlex_tilde_not_rejected() {
let sh = Shlex::new(vec!["ls"]);
assert!(sh
.matches(&ConstraintValue::String("ls ~/Documents".into()))
.unwrap());
}
#[test]
fn test_shlex_hash_not_rejected() {
let sh = Shlex::new(vec!["echo"]);
assert!(sh
.matches(&ConstraintValue::String("echo hello #world".into()))
.unwrap());
}
#[test]
fn test_shlex_exclamation_not_rejected() {
let sh = Shlex::new(vec!["echo"]);
assert!(sh
.matches(&ConstraintValue::String("echo hello!".into()))
.unwrap());
}
#[test]
fn test_shlex_double_dash_not_dangerous() {
let sh = Shlex::new(vec!["npm"]);
assert!(sh
.matches(&ConstraintValue::String("npm install -- --save".into()))
.unwrap());
}
#[test]
fn test_shlex_empty_allow_list() {
let sh = Shlex { allow: vec![] };
assert!(!sh
.matches(&ConstraintValue::String("npm install".into()))
.unwrap());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuation_equal() {
let parent = CelConstraint::new("amount < 10000");
let child = CelConstraint::new("amount < 10000");
assert!(parent.validate_attenuation(&child).is_ok());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuation_valid_conjunction() {
let parent = CelConstraint::new("amount < 10000");
let child = CelConstraint::new("(amount < 10000) && (amount > 0)");
assert!(parent.validate_attenuation(&child).is_ok());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuation_valid_conjunction_whitespace() {
let parent = CelConstraint::new("amount < 10000");
let child = CelConstraint::new("( amount < 10000 ) && ( amount > 0 )");
assert!(parent.validate_attenuation(&child).is_ok());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuation_valid_nested_conjunction() {
let parent = CelConstraint::new("amount < 10000");
let child = CelConstraint::new("(amount < 10000) && (amount > 0 && currency == 'USD')");
assert!(parent.validate_attenuation(&child).is_ok());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuation_or_bypass_blocked() {
let parent = CelConstraint::new("amount < 10000");
let child = CelConstraint::new("(amount < 10000) && true || amount < 1000000");
assert!(parent.validate_attenuation(&child).is_err());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuation_bare_predicate_blocked() {
let parent = CelConstraint::new("amount < 10000");
let child = CelConstraint::new("(amount < 10000) && amount > 0");
assert!(parent.validate_attenuation(&child).is_err());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuation_or_inside_parens_ok() {
let parent = CelConstraint::new("amount < 10000");
let child =
CelConstraint::new("(amount < 10000) && (currency == 'USD' || currency == 'EUR')");
assert!(parent.validate_attenuation(&child).is_ok());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuation_different_expression_rejected() {
let parent = CelConstraint::new("amount < 10000");
let child = CelConstraint::new("amount < 5000");
assert!(parent.validate_attenuation(&child).is_err());
}
#[cfg(feature = "cel")]
#[test]
fn test_cel_attenuate_helper_uses_parens() {
let parent = CelConstraint::new("amount < 10000");
let child = CelConstraint::attenuate(&parent, "amount > 0");
assert_eq!(child.expression, "(amount < 10000) && (amount > 0)");
assert!(parent.validate_attenuation(&child).is_ok());
}
#[test]
fn test_balanced_outer_parens() {
assert!(has_balanced_outer_parens("(x>0)"));
assert!(has_balanced_outer_parens("(x>0&&y<10)"));
assert!(has_balanced_outer_parens("(x>0||(y<10&&z==1))"));
assert!(has_balanced_outer_parens("(f(x)||g(x))"));
assert!(has_balanced_outer_parens("(f(x,y)||g(a,b)&&h(c))"));
assert!(!has_balanced_outer_parens("(x>0)&&(y<10)"));
assert!(!has_balanced_outer_parens("(x>0)||evil"));
assert!(!has_balanced_outer_parens("(x>0)extra"));
assert!(!has_balanced_outer_parens("x>0"));
assert!(!has_balanced_outer_parens(""));
}
#[test]
fn test_not_attenuation_rejected() {
let parent = Constraint::Not(Not::new(Constraint::Exact(Exact::new("admin"))));
let child = Constraint::Not(Not::new(Constraint::Exact(Exact::new("admin"))));
assert!(
parent.validate_attenuation(&child).is_err(),
"Not -> Not attenuation must be rejected (direction is unsound)"
);
}
#[test]
fn test_any_attenuation_rejected() {
let parent = Constraint::Any(Any::new(vec![
Constraint::Exact(Exact::new("a")),
Constraint::Exact(Exact::new("b")),
]));
let child = Constraint::Any(Any::new(vec![Constraint::Exact(Exact::new("a"))]));
assert!(
parent.validate_attenuation(&child).is_err(),
"Any -> Any attenuation must be rejected (not implemented)"
);
}
}