use regex::Regex;
use std::collections::{BTreeMap, HashMap};
use std::sync::RwLock;
const MAX_PATTERN_LENGTH: usize = 512;
const MAX_WILDCARDS: usize = 8;
const WILDCARD_CHARS: &[char] = &['*', '<', ':', '{'];
fn ash_is_wildcard_escaped(chars: &[char], i: usize) -> bool {
if i == 0 {
return false;
}
let mut backslash_count = 0;
let mut pos = i - 1;
loop {
if chars[pos] == '\\' {
backslash_count += 1;
if pos == 0 {
break;
}
pos -= 1;
} else {
break;
}
}
backslash_count % 2 == 1
}
fn ash_has_unescaped_wildcard(pattern: &str) -> bool {
let chars: Vec<char> = pattern.chars().collect();
for (i, &ch) in chars.iter().enumerate() {
if WILDCARD_CHARS.contains(&ch) {
if !ash_is_wildcard_escaped(&chars, i) {
return true;
}
}
}
false
}
fn ash_count_unescaped_wildcards(pattern: &str) -> usize {
let chars: Vec<char> = pattern.chars().collect();
let mut count = 0;
for (i, &ch) in chars.iter().enumerate() {
if WILDCARD_CHARS.contains(&ch) {
if !ash_is_wildcard_escaped(&chars, i) {
count += 1;
}
}
}
count
}
fn ash_unescape_pattern(pattern: &str) -> String {
let mut result = String::with_capacity(pattern.len());
let chars: Vec<char> = pattern.chars().collect();
let mut i = 0;
while i < chars.len() {
if chars[i] == '\\' && i + 1 < chars.len() {
let next = chars[i + 1];
if next == '\\' {
result.push('\\');
i += 2;
} else if WILDCARD_CHARS.contains(&next) {
result.push(next);
i += 2;
} else {
result.push(chars[i]);
i += 1;
}
} else {
result.push(chars[i]);
i += 1;
}
}
result
}
#[derive(Debug, Clone)]
struct CompiledPattern {
pattern: String,
regex: Option<Regex>,
is_exact: bool,
}
impl CompiledPattern {
fn compile(pattern: &str) -> Option<Self> {
if pattern.len() > MAX_PATTERN_LENGTH {
return None;
}
let has_unescaped_wildcards = ash_has_unescaped_wildcard(pattern);
if !has_unescaped_wildcards {
return Some(CompiledPattern {
pattern: pattern.to_string(),
regex: None,
is_exact: true,
});
}
let wildcard_count = ash_count_unescaped_wildcards(pattern);
if wildcard_count > MAX_WILDCARDS {
return None;
}
let regex = ash_build_safe_regex(pattern)?;
Some(CompiledPattern {
pattern: pattern.to_string(),
regex: Some(regex),
is_exact: false,
})
}
fn matches(&self, binding: &str) -> bool {
if self.is_exact {
return binding == ash_unescape_pattern(&self.pattern);
}
if let Some(ref regex) = self.regex {
regex.is_match(binding)
} else {
false
}
}
}
fn ash_build_safe_regex(pattern: &str) -> Option<Regex> {
lazy_static::lazy_static! {
static ref FLASK_RE: Regex = Regex::new(r"<[a-zA-Z_][a-zA-Z0-9_]*>").unwrap();
static ref EXPRESS_RE: Regex = Regex::new(r":[a-zA-Z_][a-zA-Z0-9_]*").unwrap();
static ref LARAVEL_RE: Regex = Regex::new(r"\\\{[a-zA-Z_][a-zA-Z0-9_]*\\\}").unwrap();
}
let placeholder_backslash = "\x00ESCAPED_BACKSLASH\x00";
let placeholder_star = "\x00ESCAPED_STAR\x00";
let placeholder_lt = "\x00ESCAPED_LT\x00";
let placeholder_colon = "\x00ESCAPED_COLON\x00";
let placeholder_lbrace = "\x00ESCAPED_LBRACE\x00";
let mut temp = String::with_capacity(pattern.len() * 2);
let chars: Vec<char> = pattern.chars().collect();
let mut i = 0;
while i < chars.len() {
if chars[i] == '\\' && i + 1 < chars.len() {
let next = chars[i + 1];
if next == '\\' {
temp.push_str(placeholder_backslash);
i += 2;
} else if next == '*' {
temp.push_str(placeholder_star);
i += 2;
} else if next == '<' {
temp.push_str(placeholder_lt);
i += 2;
} else if next == ':' {
temp.push_str(placeholder_colon);
i += 2;
} else if next == '{' {
temp.push_str(placeholder_lbrace);
i += 2;
} else {
temp.push(chars[i]);
i += 1;
}
} else {
temp.push(chars[i]);
i += 1;
}
}
let mut regex_str = regex::escape(&temp);
regex_str = regex_str.replace(r"\*\*", "[^|]*");
regex_str = regex_str.replace(r"\*", "[^|/]*");
regex_str = FLASK_RE.replace_all(®ex_str, "[^|/]+").to_string();
regex_str = EXPRESS_RE.replace_all(®ex_str, "[^|/]+").to_string();
regex_str = LARAVEL_RE.replace_all(®ex_str, "[^|/]+").to_string();
regex_str = regex_str.replace(®ex::escape(placeholder_backslash), r"\\");
regex_str = regex_str.replace(®ex::escape(placeholder_star), r"\*");
regex_str = regex_str.replace(®ex::escape(placeholder_lt), "<");
regex_str = regex_str.replace(®ex::escape(placeholder_colon), ":");
regex_str = regex_str.replace(®ex::escape(placeholder_lbrace), r"\{");
regex::RegexBuilder::new(&format!("^{}$", regex_str))
.size_limit(10 * 1024) .build()
.ok()
}
#[derive(Debug, Default)]
pub struct ScopePolicyRegistry {
policies_ordered: Vec<(String, CompiledPattern, Vec<String>)>,
exact_matches: HashMap<String, usize>, }
impl ScopePolicyRegistry {
pub fn new() -> Self {
Self {
policies_ordered: Vec::new(),
exact_matches: HashMap::new(),
}
}
pub fn register(&mut self, binding: &str, fields: &[&str]) -> bool {
if let Some(compiled) = CompiledPattern::compile(binding) {
let fields_vec: Vec<String> = fields.iter().map(|s| s.to_string()).collect();
let existing_idx = self.policies_ordered.iter().position(|(p, _, _)| p == binding);
if let Some(idx) = existing_idx {
let old_unescaped = ash_unescape_pattern(binding);
self.exact_matches.remove(&old_unescaped);
self.policies_ordered[idx] = (binding.to_string(), compiled.clone(), fields_vec);
if compiled.is_exact {
let unescaped = ash_unescape_pattern(binding);
self.exact_matches.insert(unescaped, idx);
}
} else {
let idx = self.policies_ordered.len();
self.policies_ordered.push((binding.to_string(), compiled.clone(), fields_vec));
if compiled.is_exact {
let unescaped = ash_unescape_pattern(binding);
self.exact_matches.insert(unescaped, idx);
}
}
true
} else {
false
}
}
pub fn register_many(&mut self, policies_map: &BTreeMap<&str, Vec<&str>>) -> usize {
let mut count = 0;
for (binding, fields) in policies_map {
if self.register(binding, fields) {
count += 1;
}
}
count
}
pub fn get(&self, binding: &str) -> Vec<String> {
if let Some(&idx) = self.exact_matches.get(binding) {
if idx < self.policies_ordered.len() {
return self.policies_ordered[idx].2.clone();
}
}
for (_, compiled, fields) in &self.policies_ordered {
if compiled.matches(binding) {
return fields.clone();
}
}
Vec::new()
}
pub fn has(&self, binding: &str) -> bool {
if self.exact_matches.contains_key(binding) {
return true;
}
for (_, compiled, _) in &self.policies_ordered {
if compiled.matches(binding) {
return true;
}
}
false
}
pub fn get_all(&self) -> BTreeMap<String, Vec<String>> {
self.policies_ordered
.iter()
.map(|(pattern, _, fields)| (pattern.clone(), fields.clone()))
.collect()
}
pub fn clear(&mut self) {
self.policies_ordered.clear();
self.exact_matches.clear();
}
}
lazy_static::lazy_static! {
static ref GLOBAL_REGISTRY: RwLock<ScopePolicyRegistry> = RwLock::new(ScopePolicyRegistry::new());
}
fn ash_get_write_lock() -> std::sync::RwLockWriteGuard<'static, ScopePolicyRegistry> {
GLOBAL_REGISTRY
.write()
.unwrap_or_else(|poisoned| {
poisoned.into_inner()
})
}
fn ash_get_read_lock() -> std::sync::RwLockReadGuard<'static, ScopePolicyRegistry> {
GLOBAL_REGISTRY
.read()
.unwrap_or_else(|poisoned| {
poisoned.into_inner()
})
}
pub fn ash_register_scope_policy(binding: &str, fields: &[&str]) -> bool {
let mut registry = ash_get_write_lock();
registry.register(binding, fields)
}
pub fn ash_register_scope_policies(policies_map: &BTreeMap<&str, Vec<&str>>) -> usize {
let mut registry = ash_get_write_lock();
registry.register_many(policies_map)
}
pub fn ash_get_scope_policy(binding: &str) -> Vec<String> {
let registry = ash_get_read_lock();
registry.get(binding)
}
pub fn ash_has_scope_policy(binding: &str) -> bool {
let registry = ash_get_read_lock();
registry.has(binding)
}
pub fn ash_get_all_scope_policies() -> BTreeMap<String, Vec<String>> {
let registry = ash_get_read_lock();
registry.get_all()
}
pub fn ash_clear_scope_policies() {
let mut registry = ash_get_write_lock();
registry.clear();
}
#[deprecated(since = "2.4.0", note = "Use ash_register_scope_policy instead")]
pub fn register_scope_policy(binding: &str, fields: &[&str]) -> bool {
ash_register_scope_policy(binding, fields)
}
#[deprecated(since = "2.4.0", note = "Use ash_register_scope_policies instead")]
pub fn register_scope_policies(policies_map: &BTreeMap<&str, Vec<&str>>) -> usize {
ash_register_scope_policies(policies_map)
}
#[deprecated(since = "2.4.0", note = "Use ash_get_scope_policy instead")]
pub fn get_scope_policy(binding: &str) -> Vec<String> {
ash_get_scope_policy(binding)
}
#[deprecated(since = "2.4.0", note = "Use ash_has_scope_policy instead")]
pub fn has_scope_policy(binding: &str) -> bool {
ash_has_scope_policy(binding)
}
#[deprecated(since = "2.4.0", note = "Use ash_get_all_scope_policies instead")]
pub fn get_all_scope_policies() -> BTreeMap<String, Vec<String>> {
ash_get_all_scope_policies()
}
#[deprecated(since = "2.4.0", note = "Use ash_clear_scope_policies instead")]
pub fn clear_scope_policies() {
ash_clear_scope_policies()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_register_and_get() {
let mut registry = ScopePolicyRegistry::new();
assert!(registry.register("POST|/api/transfer|", &["amount", "recipient"]));
let scope = registry.get("POST|/api/transfer|");
assert_eq!(scope, vec!["amount", "recipient"]);
}
#[test]
fn test_registry_get_no_match() {
let registry = ScopePolicyRegistry::new();
let scope = registry.get("GET|/api/users|");
assert!(scope.is_empty());
}
#[test]
fn test_registry_has() {
let mut registry = ScopePolicyRegistry::new();
registry.register("POST|/api/transfer|", &["amount"]);
assert!(registry.has("POST|/api/transfer|"));
assert!(!registry.has("GET|/api/users|"));
}
#[test]
fn test_registry_pattern_matching_flask_style() {
let mut registry = ScopePolicyRegistry::new();
registry.register("PUT|/api/users/<id>|", &["role", "permissions"]);
let scope = registry.get("PUT|/api/users/123|");
assert_eq!(scope, vec!["role", "permissions"]);
}
#[test]
fn test_registry_pattern_matching_express_style() {
let mut registry = ScopePolicyRegistry::new();
registry.register("PUT|/api/users/:id|", &["role"]);
let scope = registry.get("PUT|/api/users/456|");
assert_eq!(scope, vec!["role"]);
}
#[test]
fn test_registry_pattern_matching_laravel_style() {
let mut registry = ScopePolicyRegistry::new();
registry.register("PUT|/api/users/{id}|", &["email"]);
let scope = registry.get("PUT|/api/users/789|");
assert_eq!(scope, vec!["email"]);
}
#[test]
fn test_registry_pattern_matching_wildcard() {
let mut registry = ScopePolicyRegistry::new();
registry.register("POST|/api/*/transfer|", &["amount"]);
let scope = registry.get("POST|/api/v1/transfer|");
assert_eq!(scope, vec!["amount"]);
}
#[test]
fn test_registry_pattern_matching_double_wildcard() {
let mut registry = ScopePolicyRegistry::new();
registry.register("POST|/api/**/transfer|", &["amount"]);
let scope = registry.get("POST|/api/v1/users/transfer|");
assert_eq!(scope, vec!["amount"]);
}
#[test]
fn test_registry_clear() {
let mut registry = ScopePolicyRegistry::new();
registry.register("POST|/api/transfer|", &["amount"]);
assert!(registry.has("POST|/api/transfer|"));
registry.clear();
assert!(!registry.has("POST|/api/transfer|"));
}
#[test]
fn test_registry_register_many() {
let mut registry = ScopePolicyRegistry::new();
let mut policies = BTreeMap::new();
policies.insert("POST|/api/transfer|", vec!["amount"]);
policies.insert("POST|/api/payment|", vec!["card"]);
let count = registry.register_many(&policies);
assert_eq!(count, 2);
assert!(registry.has("POST|/api/transfer|"));
assert!(registry.has("POST|/api/payment|"));
}
#[test]
fn test_registry_get_all() {
let mut registry = ScopePolicyRegistry::new();
registry.register("POST|/api/transfer|", &["amount"]);
registry.register("POST|/api/payment|", &["card"]);
let all = registry.get_all();
assert_eq!(all.len(), 2);
}
#[test]
fn test_rejects_pattern_too_long() {
let mut registry = ScopePolicyRegistry::new();
let long_pattern = "POST|/api/".to_string() + &"a".repeat(600) + "|";
assert!(!registry.register(&long_pattern, &["field"]));
}
#[test]
fn test_rejects_too_many_wildcards() {
let mut registry = ScopePolicyRegistry::new();
let many_wildcards = "POST|/*/*/*/*/*/*/*/*/*|";
assert!(!registry.register(many_wildcards, &["field"]));
}
#[test]
fn test_accepts_valid_wildcards() {
let mut registry = ScopePolicyRegistry::new();
let valid_pattern = "POST|/*/*/*/*/*/*/*|";
assert!(registry.register(valid_pattern, &["field"]));
}
#[test]
fn test_escaped_backslash_before_wildcard() {
let mut registry = ScopePolicyRegistry::new();
assert!(registry.register(r"POST|/api\\*|", &["field"]));
let scope = registry.get(r"POST|/api\test|");
assert_eq!(scope, vec!["field"]);
let scope2 = registry.get(r"POST|/api\foo|");
assert_eq!(scope2, vec!["field"]);
}
#[test]
fn test_escaped_asterisk_exact_match() {
let mut registry = ScopePolicyRegistry::new();
assert!(registry.register(r"POST|/api/\*|", &["field"]));
let scope = registry.get("POST|/api/*|");
assert_eq!(scope, vec!["field"]);
let scope2 = registry.get("POST|/api/test|");
assert!(scope2.is_empty());
}
#[test]
fn test_double_escaped_backslash() {
let mut registry = ScopePolicyRegistry::new();
assert!(registry.register(r"POST|/api/\\\\test|", &["field"]));
let scope = registry.get(r"POST|/api/\\test|");
assert_eq!(scope, vec!["field"]);
}
#[test]
fn test_is_wildcard_escaped_helper() {
let chars1: Vec<char> = r"\*".chars().collect();
assert!(ash_is_wildcard_escaped(&chars1, 1));
let chars2: Vec<char> = r"\\*".chars().collect();
assert!(!ash_is_wildcard_escaped(&chars2, 2));
let chars3: Vec<char> = r"\\\*".chars().collect();
assert!(ash_is_wildcard_escaped(&chars3, 3));
let chars4: Vec<char> = "*".chars().collect();
assert!(!ash_is_wildcard_escaped(&chars4, 0)); }
}