use regex::Regex;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use crate::ConnectionTrait;
use crate::error::{ClientError, Result};
#[derive(Debug, Clone, PartialEq)]
pub enum PolicyAction {
Allow,
Deny,
Modify(Vec<u8>),
Log,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PolicyTrigger {
BeforeRead,
AfterRead,
BeforeWrite,
AfterWrite,
BeforeDelete,
AfterDelete,
}
#[derive(Debug, Clone)]
pub struct PolicyContext {
pub operation: String,
pub key: Vec<u8>,
pub value: Option<Vec<u8>>,
pub agent_id: Option<String>,
pub session_id: Option<String>,
pub timestamp: Instant,
pub custom: HashMap<String, String>,
}
impl PolicyContext {
pub fn new(operation: &str, key: &[u8]) -> Self {
Self {
operation: operation.to_string(),
key: key.to_vec(),
value: None,
agent_id: None,
session_id: None,
timestamp: Instant::now(),
custom: HashMap::new(),
}
}
pub fn get(&self, key: &str) -> Option<&String> {
self.custom.get(key)
}
pub fn set(&mut self, key: &str, value: &str) {
self.custom.insert(key.to_string(), value.to_string());
}
pub fn with_agent_id(mut self, agent_id: &str) -> Self {
self.agent_id = Some(agent_id.to_string());
self
}
pub fn with_session_id(mut self, session_id: &str) -> Self {
self.session_id = Some(session_id.to_string());
self
}
}
pub type PolicyHandler = Arc<dyn Fn(&PolicyContext) -> PolicyAction + Send + Sync>;
struct PatternPolicy {
pattern: String,
trigger: PolicyTrigger,
handler: PolicyHandler,
regex: Regex,
}
impl PatternPolicy {
fn new(pattern: &str, trigger: PolicyTrigger, handler: PolicyHandler) -> Self {
let regex_str = pattern
.replace(".", "\\.")
.replace("**", ".*")
.replace("*", "[^/]*");
let regex_str = format!("^{}$", regex_str);
Self {
pattern: pattern.to_string(),
trigger,
handler,
regex: Regex::new(®ex_str).unwrap_or_else(|_| Regex::new("^$").unwrap()),
}
}
fn matches(&self, key: &[u8]) -> bool {
if let Ok(key_str) = std::str::from_utf8(key) {
self.regex.is_match(key_str)
} else {
false
}
}
}
struct RateLimiter {
max_per_minute: u32,
tokens: Mutex<u32>,
last_refill: Mutex<Instant>,
}
impl RateLimiter {
fn new(max_per_minute: u32) -> Self {
Self {
max_per_minute,
tokens: Mutex::new(max_per_minute),
last_refill: Mutex::new(Instant::now()),
}
}
fn try_acquire(&self) -> bool {
let mut tokens = self.tokens.lock().unwrap();
let mut last_refill = self.last_refill.lock().unwrap();
let now = Instant::now();
let elapsed = now.duration_since(*last_refill);
let refill = (elapsed.as_secs_f64() / 60.0 * self.max_per_minute as f64) as u32;
if refill > 0 {
*tokens = (*tokens + refill).min(self.max_per_minute);
*last_refill = now;
}
if *tokens > 0 {
*tokens -= 1;
true
} else {
false
}
}
}
struct RateLimitConfig {
operation: String,
max_per_minute: u32,
scope: String,
}
#[derive(Debug, Clone)]
pub struct AuditEntry {
pub timestamp: Instant,
pub operation: String,
pub key: String,
pub agent_id: Option<String>,
pub session_id: Option<String>,
pub result: String,
}
#[derive(Debug)]
pub struct PolicyViolationError {
pub message: String,
}
impl std::fmt::Display for PolicyViolationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PolicyViolation: {}", self.message)
}
}
impl std::error::Error for PolicyViolationError {}
pub struct PolicyEngine<C: ConnectionTrait> {
conn: C,
policies: RwLock<HashMap<PolicyTrigger, Vec<PatternPolicy>>>,
rate_limiters: RwLock<HashMap<String, HashMap<String, Arc<RateLimiter>>>>,
rate_limit_configs: RwLock<Vec<RateLimitConfig>>,
audit_log: RwLock<Vec<AuditEntry>>,
audit_enabled: RwLock<bool>,
max_audit_entries: usize,
}
impl<C: ConnectionTrait> PolicyEngine<C> {
pub fn new(conn: C) -> Self {
let mut policies = HashMap::new();
policies.insert(PolicyTrigger::BeforeRead, Vec::new());
policies.insert(PolicyTrigger::AfterRead, Vec::new());
policies.insert(PolicyTrigger::BeforeWrite, Vec::new());
policies.insert(PolicyTrigger::AfterWrite, Vec::new());
policies.insert(PolicyTrigger::BeforeDelete, Vec::new());
policies.insert(PolicyTrigger::AfterDelete, Vec::new());
Self {
conn,
policies: RwLock::new(policies),
rate_limiters: RwLock::new(HashMap::new()),
rate_limit_configs: RwLock::new(Vec::new()),
audit_log: RwLock::new(Vec::new()),
audit_enabled: RwLock::new(false),
max_audit_entries: 10000,
}
}
pub fn before_write<F>(&self, pattern: &str, handler: F)
where
F: Fn(&PolicyContext) -> PolicyAction + Send + Sync + 'static,
{
let mut policies = self.policies.write().unwrap();
policies
.get_mut(&PolicyTrigger::BeforeWrite)
.unwrap()
.push(PatternPolicy::new(pattern, PolicyTrigger::BeforeWrite, Arc::new(handler)));
}
pub fn after_write<F>(&self, pattern: &str, handler: F)
where
F: Fn(&PolicyContext) -> PolicyAction + Send + Sync + 'static,
{
let mut policies = self.policies.write().unwrap();
policies
.get_mut(&PolicyTrigger::AfterWrite)
.unwrap()
.push(PatternPolicy::new(pattern, PolicyTrigger::AfterWrite, Arc::new(handler)));
}
pub fn before_read<F>(&self, pattern: &str, handler: F)
where
F: Fn(&PolicyContext) -> PolicyAction + Send + Sync + 'static,
{
let mut policies = self.policies.write().unwrap();
policies
.get_mut(&PolicyTrigger::BeforeRead)
.unwrap()
.push(PatternPolicy::new(pattern, PolicyTrigger::BeforeRead, Arc::new(handler)));
}
pub fn after_read<F>(&self, pattern: &str, handler: F)
where
F: Fn(&PolicyContext) -> PolicyAction + Send + Sync + 'static,
{
let mut policies = self.policies.write().unwrap();
policies
.get_mut(&PolicyTrigger::AfterRead)
.unwrap()
.push(PatternPolicy::new(pattern, PolicyTrigger::AfterRead, Arc::new(handler)));
}
pub fn before_delete<F>(&self, pattern: &str, handler: F)
where
F: Fn(&PolicyContext) -> PolicyAction + Send + Sync + 'static,
{
let mut policies = self.policies.write().unwrap();
policies
.get_mut(&PolicyTrigger::BeforeDelete)
.unwrap()
.push(PatternPolicy::new(pattern, PolicyTrigger::BeforeDelete, Arc::new(handler)));
}
pub fn add_rate_limit(&self, operation: &str, max_per_minute: u32, scope: &str) {
let mut configs = self.rate_limit_configs.write().unwrap();
configs.push(RateLimitConfig {
operation: operation.to_string(),
max_per_minute,
scope: scope.to_string(),
});
}
pub fn enable_audit(&self) {
let mut enabled = self.audit_enabled.write().unwrap();
*enabled = true;
}
pub fn disable_audit(&self) {
let mut enabled = self.audit_enabled.write().unwrap();
*enabled = false;
}
pub fn get_audit_log(&self, limit: usize) -> Vec<AuditEntry> {
let log = self.audit_log.read().unwrap();
let start = log.len().saturating_sub(limit);
log[start..].to_vec()
}
fn check_rate_limit(&self, operation: &str, ctx: &PolicyContext) -> bool {
let configs = self.rate_limit_configs.read().unwrap();
let mut limiters = self.rate_limiters.write().unwrap();
for config in configs.iter() {
if config.operation != operation && config.operation != "all" {
continue;
}
let scope_key = match config.scope.as_str() {
"global" => "global".to_string(),
"agent_id" => ctx.agent_id.clone().unwrap_or_else(|| "unknown".to_string()),
"session_id" => ctx.session_id.clone().unwrap_or_else(|| "unknown".to_string()),
_ => ctx.get(&config.scope).cloned().unwrap_or_else(|| "unknown".to_string()),
};
let limiter_key = format!("{}:{}", config.operation, config.scope);
let scope_limiters = limiters.entry(limiter_key).or_insert_with(HashMap::new);
let limiter = scope_limiters
.entry(scope_key)
.or_insert_with(|| Arc::new(RateLimiter::new(config.max_per_minute)));
if !limiter.try_acquire() {
return false;
}
}
true
}
fn evaluate_policies(&self, trigger: PolicyTrigger, ctx: &PolicyContext) -> PolicyAction {
let policies = self.policies.read().unwrap();
if let Some(trigger_policies) = policies.get(&trigger) {
for policy in trigger_policies {
if policy.matches(&ctx.key) {
let action = (policy.handler)(ctx);
match &action {
PolicyAction::Deny | PolicyAction::Modify(_) => return action,
_ => {}
}
}
}
}
PolicyAction::Allow
}
fn audit(&self, operation: &str, key: &[u8], ctx: &PolicyContext, result: &str) {
let enabled = self.audit_enabled.read().unwrap();
if !*enabled {
return;
}
let mut log = self.audit_log.write().unwrap();
log.push(AuditEntry {
timestamp: Instant::now(),
operation: operation.to_string(),
key: String::from_utf8_lossy(key).to_string(),
agent_id: ctx.agent_id.clone(),
session_id: ctx.session_id.clone(),
result: result.to_string(),
});
if log.len() > self.max_audit_entries {
let start = log.len() - self.max_audit_entries;
*log = log[start..].to_vec();
}
}
pub fn put(
&self,
key: &[u8],
value: &[u8],
ctx: Option<&PolicyContext>,
) -> std::result::Result<(), PolicyViolationError> {
let default_ctx = PolicyContext::new("write", key);
let ctx = ctx.unwrap_or(&default_ctx);
if !self.check_rate_limit("write", ctx) {
self.audit("write", key, ctx, "rate_limited");
return Err(PolicyViolationError {
message: "Rate limit exceeded".to_string(),
});
}
match self.evaluate_policies(PolicyTrigger::BeforeWrite, ctx) {
PolicyAction::Deny => {
self.audit("write", key, ctx, "denied");
return Err(PolicyViolationError {
message: "Write blocked by policy".to_string(),
});
}
PolicyAction::Modify(modified) => {
self.conn.put(key, &modified).map_err(|_| PolicyViolationError {
message: "Write failed".to_string(),
})?;
}
_ => {
self.conn.put(key, value).map_err(|_| PolicyViolationError {
message: "Write failed".to_string(),
})?;
}
}
self.evaluate_policies(PolicyTrigger::AfterWrite, ctx);
self.audit("write", key, ctx, "allowed");
Ok(())
}
pub fn get(&self, key: &[u8], ctx: Option<&PolicyContext>) -> std::result::Result<Option<Vec<u8>>, PolicyViolationError> {
let default_ctx = PolicyContext::new("read", key);
let ctx = ctx.unwrap_or(&default_ctx);
if !self.check_rate_limit("read", ctx) {
self.audit("read", key, ctx, "rate_limited");
return Err(PolicyViolationError {
message: "Rate limit exceeded".to_string(),
});
}
if let PolicyAction::Deny = self.evaluate_policies(PolicyTrigger::BeforeRead, ctx) {
self.audit("read", key, ctx, "denied");
return Err(PolicyViolationError {
message: "Read blocked by policy".to_string(),
});
}
let value = self.conn.get(key).map_err(|_| PolicyViolationError {
message: "Read failed".to_string(),
})?;
if let Some(ref val) = value {
let mut read_ctx = ctx.clone();
read_ctx.value = Some(val.clone());
match self.evaluate_policies(PolicyTrigger::AfterRead, &read_ctx) {
PolicyAction::Modify(modified) => {
self.audit("read", key, ctx, "allowed");
return Ok(Some(modified));
}
PolicyAction::Deny => {
self.audit("read", key, ctx, "redacted");
return Ok(None);
}
_ => {}
}
}
self.audit("read", key, ctx, "allowed");
Ok(value)
}
pub fn delete(&self, key: &[u8], ctx: Option<&PolicyContext>) -> std::result::Result<(), PolicyViolationError> {
let default_ctx = PolicyContext::new("delete", key);
let ctx = ctx.unwrap_or(&default_ctx);
if !self.check_rate_limit("delete", ctx) {
self.audit("delete", key, ctx, "rate_limited");
return Err(PolicyViolationError {
message: "Rate limit exceeded".to_string(),
});
}
if let PolicyAction::Deny = self.evaluate_policies(PolicyTrigger::BeforeDelete, ctx) {
self.audit("delete", key, ctx, "denied");
return Err(PolicyViolationError {
message: "Delete blocked by policy".to_string(),
});
}
self.conn.delete(key).map_err(|_| PolicyViolationError {
message: "Delete failed".to_string(),
})?;
self.audit("delete", key, ctx, "allowed");
Ok(())
}
}
pub fn deny_all() -> impl Fn(&PolicyContext) -> PolicyAction {
|_| PolicyAction::Deny
}
pub fn allow_all() -> impl Fn(&PolicyContext) -> PolicyAction {
|_| PolicyAction::Allow
}
pub fn require_agent_id() -> impl Fn(&PolicyContext) -> PolicyAction {
|ctx| {
if ctx.agent_id.is_some() {
PolicyAction::Allow
} else {
PolicyAction::Deny
}
}
}
pub fn redact_value(replacement: Vec<u8>) -> impl Fn(&PolicyContext) -> PolicyAction {
move |_| PolicyAction::Modify(replacement.clone())
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum PolicyOutcome {
Allow = 0,
AllowWithLog = 1,
Modify = 2,
Deny = 3,
}
impl PolicyOutcome {
pub fn join(self, other: PolicyOutcome) -> PolicyOutcome {
if self >= other { self } else { other }
}
}
impl From<&PolicyAction> for PolicyOutcome {
fn from(action: &PolicyAction) -> Self {
match action {
PolicyAction::Allow => PolicyOutcome::Allow,
PolicyAction::Log => PolicyOutcome::AllowWithLog,
PolicyAction::Modify(_) => PolicyOutcome::Modify,
PolicyAction::Deny => PolicyOutcome::Deny,
}
}
}
#[derive(Debug, Clone)]
pub struct PolicyRule {
pub id: String,
pub description: String,
pub trigger: PolicyTrigger,
pub pattern: String,
pub priority: i32,
pub namespace: Option<String>,
pub outcome: PolicyOutcome,
}
pub struct CompiledPolicySet {
rules_by_trigger: HashMap<PolicyTrigger, Vec<CompiledRule>>,
}
struct CompiledRule {
rule: PolicyRule,
regex: Regex,
handler: Option<PolicyHandler>,
}
impl CompiledPolicySet {
pub fn new() -> Self {
let mut rules_by_trigger = HashMap::new();
for trigger in [
PolicyTrigger::BeforeRead,
PolicyTrigger::AfterRead,
PolicyTrigger::BeforeWrite,
PolicyTrigger::AfterWrite,
PolicyTrigger::BeforeDelete,
PolicyTrigger::AfterDelete,
] {
rules_by_trigger.insert(trigger, Vec::new());
}
Self { rules_by_trigger }
}
pub fn add_rule(&mut self, rule: PolicyRule, handler: Option<PolicyHandler>) {
let regex_str = rule.pattern
.replace(".", "\\.")
.replace("**", ".*")
.replace("*", "[^/]*");
let regex_str = format!("^{}$", regex_str);
let regex = Regex::new(®ex_str).unwrap_or_else(|_| Regex::new("^$").unwrap());
let compiled = CompiledRule {
rule: rule.clone(),
regex,
handler,
};
if let Some(rules) = self.rules_by_trigger.get_mut(&rule.trigger) {
rules.push(compiled);
rules.sort_by(|a, b| b.rule.priority.cmp(&a.rule.priority));
}
}
pub fn evaluate(&self, trigger: PolicyTrigger, ctx: &PolicyContext) -> EvaluationResult {
let mut final_outcome = PolicyOutcome::Allow;
let mut applied_rules = Vec::new();
let mut modifications = Vec::new();
if let Some(rules) = self.rules_by_trigger.get(&trigger) {
for compiled in rules {
if let Some(ref ns) = compiled.rule.namespace {
if let Some(ctx_ns) = ctx.custom.get("namespace") {
if ns != ctx_ns.as_str() {
continue;
}
}
}
let key_str = String::from_utf8_lossy(&ctx.key);
if !compiled.regex.is_match(&key_str) {
continue;
}
let outcome = if let Some(ref handler) = compiled.handler {
let action = handler(ctx);
match &action {
PolicyAction::Modify(data) => {
modifications.push(data.clone());
}
_ => {}
}
PolicyOutcome::from(&action)
} else {
compiled.rule.outcome.clone()
};
applied_rules.push(compiled.rule.id.clone());
final_outcome = final_outcome.join(outcome);
if final_outcome == PolicyOutcome::Deny {
break;
}
}
}
EvaluationResult {
outcome: final_outcome,
applied_rules,
modifications,
}
}
}
impl Default for CompiledPolicySet {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct EvaluationResult {
pub outcome: PolicyOutcome,
pub applied_rules: Vec<String>,
pub modifications: Vec<Vec<u8>>,
}
impl EvaluationResult {
pub fn is_allowed(&self) -> bool {
!matches!(self.outcome, PolicyOutcome::Deny)
}
pub fn get_modification(&self) -> Option<Vec<u8>> {
if self.modifications.is_empty() {
None
} else if self.modifications.len() == 1 {
Some(self.modifications[0].clone())
} else {
Some(self.modifications.last().unwrap().clone())
}
}
}
impl<C: ConnectionTrait> PolicyEngine<C> {
pub fn get_denied_ids(
&self,
trigger: PolicyTrigger,
candidate_ids: &[Vec<u8>],
base_ctx: &PolicyContext,
) -> Vec<Vec<u8>> {
let mut denied = Vec::new();
for key in candidate_ids {
let mut ctx = base_ctx.clone();
ctx.key = key.clone();
if let PolicyAction::Deny = self.evaluate_policies(trigger, &ctx) {
denied.push(key.clone());
}
}
denied
}
}