use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use super::metadata::NodeId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetyEnvelope {
pub id: String,
pub resource_limits: ResourceEnvelope,
pub rate_limits: RateEnvelope,
pub content_policies: Vec<ContentPolicy>,
pub audit: AuditConfig,
pub kill_switch: KillSwitchConfig,
pub mode: EnforcementMode,
}
impl Default for SafetyEnvelope {
fn default() -> Self {
Self {
id: "default".to_string(),
resource_limits: ResourceEnvelope::default(),
rate_limits: RateEnvelope::default(),
content_policies: Vec::new(),
audit: AuditConfig::default(),
kill_switch: KillSwitchConfig::default(),
mode: EnforcementMode::Enforce,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceEnvelope {
pub max_concurrent: u32,
pub max_tokens_per_request: u32,
pub max_memory_gb: u32,
pub max_time_ms: u32,
pub max_cost_per_hour_cents: u32,
}
impl Default for ResourceEnvelope {
fn default() -> Self {
Self {
max_concurrent: 1000,
max_tokens_per_request: 128_000,
max_memory_gb: 16,
max_time_ms: 300_000, max_cost_per_hour_cents: 10_000, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateEnvelope {
pub global_rpm: u32,
pub per_source_rpm: u32,
pub tokens_per_minute: u64,
pub burst_multiplier: f32,
}
impl Default for RateEnvelope {
fn default() -> Self {
Self {
global_rpm: 10_000,
per_source_rpm: 1_000,
tokens_per_minute: 10_000_000,
burst_multiplier: 2.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentPolicy {
pub id: String,
pub check: ContentCheck,
pub action: PolicyAction,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ContentCheck {
BlockPatterns(Vec<String>),
RequirePatterns(Vec<String>),
MaxSize(usize),
Custom {
validator_id: String,
},
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum PolicyAction {
Block,
Warn,
Log,
Redact,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditConfig {
pub enabled: bool,
pub log_success: bool,
pub log_blocked: bool,
pub log_warnings: bool,
pub max_entries: usize,
pub flush_interval_ms: u64,
}
impl Default for AuditConfig {
fn default() -> Self {
Self {
enabled: true,
log_success: false,
log_blocked: true,
log_warnings: true,
max_entries: 10_000,
flush_interval_ms: 5_000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct KillSwitchConfig {
pub enabled: bool,
pub reason: Option<String>,
pub auto_reset_secs: Option<u32>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
pub enum EnforcementMode {
#[default]
Enforce,
AuditOnly,
Disabled,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SafetyViolation {
KillSwitchActive {
reason: String,
},
RateLimitExceeded {
limit_type: RateLimitType,
current: u64,
limit: u64,
},
ResourceLimitExceeded {
resource: ResourceType,
requested: u64,
available: u64,
},
ContentPolicyViolation {
policy_id: String,
details: String,
},
Timeout {
elapsed_ms: u64,
limit_ms: u64,
},
}
impl std::fmt::Display for SafetyViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::KillSwitchActive { reason } => {
write!(f, "kill switch active: {}", reason)
}
Self::RateLimitExceeded {
limit_type,
current,
limit,
} => {
write!(
f,
"rate limit exceeded: {:?} ({}/{})",
limit_type, current, limit
)
}
Self::ResourceLimitExceeded {
resource,
requested,
available,
} => {
write!(
f,
"resource limit exceeded: {:?} (requested {}, available {})",
resource, requested, available
)
}
Self::ContentPolicyViolation { policy_id, details } => {
write!(f, "content policy violation [{}]: {}", policy_id, details)
}
Self::Timeout {
elapsed_ms,
limit_ms,
} => {
write!(f, "timeout: {}ms elapsed, limit {}ms", elapsed_ms, limit_ms)
}
}
}
}
impl std::error::Error for SafetyViolation {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitType {
GlobalRpm,
PerSourceRpm,
TokensPerMinute,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResourceType {
Concurrent,
Tokens,
Memory,
Time,
Cost,
}
#[derive(Debug, Clone, Default)]
pub struct ResourceClaim {
pub concurrent_slots: u32,
pub tokens: u32,
pub memory_gb: u32,
pub time_ms: u32,
pub cost_cents: u32,
}
impl ResourceClaim {
pub fn new() -> Self {
Self::default()
}
pub fn with_concurrent(mut self, slots: u32) -> Self {
self.concurrent_slots = slots;
self
}
pub fn with_tokens(mut self, tokens: u32) -> Self {
self.tokens = tokens;
self
}
pub fn with_memory_gb(mut self, gb: u32) -> Self {
self.memory_gb = gb;
self
}
pub fn with_time_ms(mut self, ms: u32) -> Self {
self.time_ms = ms;
self
}
pub fn with_cost_cents(mut self, cents: u32) -> Self {
self.cost_cents = cents;
self
}
}
pub struct ResourceGuard {
enforcer: Arc<SafetyEnforcer>,
claim: ResourceClaim,
acquired_at: Instant,
}
impl ResourceGuard {
pub fn elapsed(&self) -> Duration {
self.acquired_at.elapsed()
}
pub fn claim(&self) -> &ResourceClaim {
&self.claim
}
pub fn update_tokens(&mut self, actual_tokens: u32) {
let diff = actual_tokens as i64 - self.claim.tokens as i64;
if diff > 0 {
self.enforcer
.usage
.tokens
.fetch_add(diff as u64, Ordering::Relaxed);
} else if diff < 0 {
let sub = (-diff) as u64;
let _ = self.enforcer.usage.tokens.fetch_update(
Ordering::Relaxed,
Ordering::Relaxed,
|current| Some(current.saturating_sub(sub)),
);
}
self.claim.tokens = actual_tokens;
}
}
impl Drop for ResourceGuard {
fn drop(&mut self) {
self.enforcer.release(&self.claim);
}
}
struct RateBucket {
packed: AtomicU64,
}
impl RateBucket {
const FLOOR_SHIFT: u64 = 32;
const COUNT_MASK: u64 = 0xFFFF_FFFF;
fn new(initial_floor: u32) -> Self {
Self {
packed: AtomicU64::new((initial_floor as u64) << Self::FLOOR_SHIFT),
}
}
#[inline]
fn split(packed: u64) -> (u32, u32) {
let floor = (packed >> Self::FLOOR_SHIFT) as u32;
let count = (packed & Self::COUNT_MASK) as u32;
(floor, count)
}
#[inline]
fn pack(floor: u32, count: u32) -> u64 {
((floor as u64) << Self::FLOOR_SHIFT) | (count as u64)
}
fn try_acquire(&self, current_floor: u32, effective_limit: u64) -> Result<u32, u32> {
let mut last_observed = 0u32;
match self
.packed
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
let (cur_floor, cur_count) = Self::split(current);
if cur_floor != current_floor {
Some(Self::pack(current_floor, 1))
} else if (cur_count as u64) >= effective_limit {
last_observed = cur_count;
None
} else {
Some(Self::pack(cur_floor, cur_count.saturating_add(1)))
}
}) {
Ok(prev) => {
let (cur_floor, cur_count) = Self::split(prev);
let new_count = if cur_floor != current_floor {
1
} else {
cur_count.saturating_add(1)
};
Ok(new_count)
}
Err(_) => Err(last_observed),
}
}
fn current_count(&self, current_floor: u32) -> u32 {
let (cur_floor, cur_count) = Self::split(self.packed.load(Ordering::Acquire));
if cur_floor == current_floor {
cur_count
} else {
0
}
}
fn rollback(&self, current_floor: u32) {
let _ = self
.packed
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
let (cur_floor, cur_count) = Self::split(current);
if cur_floor == current_floor && cur_count > 0 {
Some(Self::pack(cur_floor, cur_count - 1))
} else {
None
}
});
}
fn floor(&self) -> u32 {
Self::split(self.packed.load(Ordering::Relaxed)).0
}
}
struct RateLimiter {
global_requests: AtomicU64,
global_tokens: AtomicU64,
per_source: DashMap<NodeId, RateBucket>,
last_reset: RwLock<Instant>,
created_at: Instant,
reset_interval: Duration,
}
impl RateLimiter {
fn new() -> Self {
let now = Instant::now();
Self {
global_requests: AtomicU64::new(0),
global_tokens: AtomicU64::new(0),
per_source: DashMap::new(),
last_reset: RwLock::new(now),
created_at: now,
reset_interval: Duration::from_secs(60),
}
}
#[inline]
fn current_floor(&self) -> u32 {
let secs = self.created_at.elapsed().as_secs();
let interval_secs = self.reset_interval.as_secs().max(1);
u32::try_from(secs / interval_secs).unwrap_or(u32::MAX)
}
fn maybe_reset(&self) {
let should_reset = {
let last = self.last_reset.read();
last.elapsed() >= self.reset_interval
};
if should_reset {
let mut last = self.last_reset.write();
if last.elapsed() >= self.reset_interval {
self.global_requests.store(0, Ordering::Relaxed);
self.global_tokens.store(0, Ordering::Relaxed);
self.gc_per_source_stale();
*last = Instant::now();
}
}
}
fn gc_per_source_stale(&self) {
let cur = self.current_floor();
const GC_AGE_WINDOWS: u32 = 5;
let cutoff = cur.saturating_sub(GC_AGE_WINDOWS);
self.per_source.retain(|_, bucket| bucket.floor() >= cutoff);
}
fn check_global_rpm(&self, limit: u32, burst: f32) -> Result<(), SafetyViolation> {
self.maybe_reset();
let current = self.global_requests.load(Ordering::Relaxed);
let effective_limit = (limit as f32 * burst) as u64;
if current >= effective_limit {
return Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::GlobalRpm,
current,
limit: effective_limit,
});
}
Ok(())
}
fn check_source_rpm(
&self,
source: &NodeId,
limit: u32,
burst: f32,
) -> Result<(), SafetyViolation> {
self.maybe_reset();
let cur_floor = self.current_floor();
let effective_limit = (limit as f32 * burst) as u64;
let current = if let Some(bucket) = self.per_source.get(source) {
bucket.current_count(cur_floor) as u64
} else {
0
};
if current >= effective_limit {
return Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::PerSourceRpm,
current,
limit: effective_limit,
});
}
Ok(())
}
fn check_tokens(&self, tokens: u64, limit: u64, burst: f32) -> Result<(), SafetyViolation> {
self.maybe_reset();
let current = self.global_tokens.load(Ordering::Relaxed);
let effective_limit = (limit as f64 * burst as f64) as u64;
let would_be = match current.checked_add(tokens) {
Some(sum) => sum,
None => {
return Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::TokensPerMinute,
current: u64::MAX,
limit: effective_limit,
});
}
};
if would_be > effective_limit {
return Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::TokensPerMinute,
current: would_be,
limit: effective_limit,
});
}
Ok(())
}
#[allow(dead_code)] fn record_request(&self, source: Option<&NodeId>, tokens: u64) {
self.global_requests.fetch_add(1, Ordering::Relaxed);
self.global_tokens.fetch_add(tokens, Ordering::Relaxed);
if let Some(src) = source {
let cur_floor = self.current_floor();
let bucket = self
.per_source
.entry(*src)
.or_insert_with(|| RateBucket::new(cur_floor));
let _ = bucket.try_acquire(cur_floor, u64::MAX);
}
}
fn try_acquire_global_rpm(&self, limit: u32, burst: f32) -> Result<(), SafetyViolation> {
self.maybe_reset();
let effective_limit = (limit as f32 * burst) as u64;
match self
.global_requests
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
if current >= effective_limit {
None
} else {
Some(current + 1)
}
}) {
Ok(_) => Ok(()),
Err(current) => Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::GlobalRpm,
current,
limit: effective_limit,
}),
}
}
fn try_acquire_source_rpm(
&self,
source: &NodeId,
limit: u32,
burst: f32,
) -> Result<(), SafetyViolation> {
self.maybe_reset();
let cur_floor = self.current_floor();
let bucket = self
.per_source
.entry(*source)
.or_insert_with(|| RateBucket::new(cur_floor));
let effective_limit = (limit as f32 * burst) as u64;
match bucket.try_acquire(cur_floor, effective_limit) {
Ok(_) => Ok(()),
Err(current) => Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::PerSourceRpm,
current: current as u64,
limit: effective_limit,
}),
}
}
fn try_acquire_tokens(
&self,
tokens: u64,
limit: u64,
burst: f32,
) -> Result<(), SafetyViolation> {
self.maybe_reset();
let effective_limit = (limit as f64 * burst as f64) as u64;
match self
.global_tokens
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
let next = current.checked_add(tokens)?;
if next > effective_limit {
None
} else {
Some(next)
}
}) {
Ok(_) => Ok(()),
Err(current) => Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::TokensPerMinute,
current,
limit: effective_limit,
}),
}
}
fn rollback_global_rpm(&self) {
self.global_requests.fetch_sub(1, Ordering::Relaxed);
}
fn rollback_source_rpm(&self, source: &NodeId) {
if let Some(bucket) = self.per_source.get(source) {
bucket.rollback(self.current_floor());
}
}
#[allow(dead_code)] fn rollback_tokens(&self, tokens: u64) {
self.global_tokens.fetch_sub(tokens, Ordering::Relaxed);
}
}
struct AtomicResourceUsage {
concurrent: AtomicU32,
tokens: AtomicU64,
memory_gb: AtomicU32,
cost_cents_per_hour: AtomicU32,
hour_start: RwLock<Instant>,
}
impl AtomicResourceUsage {
fn new() -> Self {
Self {
concurrent: AtomicU32::new(0),
tokens: AtomicU64::new(0),
memory_gb: AtomicU32::new(0),
cost_cents_per_hour: AtomicU32::new(0),
hour_start: RwLock::new(Instant::now()),
}
}
fn maybe_reset_hourly(&self) {
let should_reset = {
let start = self.hour_start.read();
start.elapsed() >= Duration::from_secs(3600)
};
if should_reset {
let mut start = self.hour_start.write();
if start.elapsed() >= Duration::from_secs(3600) {
self.cost_cents_per_hour.store(0, Ordering::Relaxed);
*start = Instant::now();
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct UsageStats {
pub concurrent: u32,
pub tokens: u64,
pub memory_gb: u32,
pub cost_cents_per_hour: u32,
pub requests_per_minute: u64,
pub tokens_per_minute: u64,
}
#[derive(Debug, Clone, Serialize)]
pub struct AuditEntry {
pub timestamp_ns: u64,
pub event_type: AuditEventType,
pub source_node: Option<NodeId>,
pub request_id: Option<u128>,
pub details: HashMap<String, String>,
pub outcome: AuditOutcome,
}
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
pub enum AuditEventType {
RequestReceived,
RequestAllowed,
RequestBlocked,
RateLimitHit,
ResourceLimitHit,
ContentPolicyViolation,
KillSwitchTriggered,
KillSwitchReset,
EnvelopeUpdated,
}
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
pub enum AuditOutcome {
Success,
Blocked,
Warning,
Error,
}
pub trait AuditSink: Send + Sync {
fn write(&self, entry: &AuditEntry);
fn flush(&self);
}
struct AuditLog {
entries: RwLock<VecDeque<AuditEntry>>,
config: AuditConfig,
sink: Option<Box<dyn AuditSink>>,
}
impl AuditLog {
fn new(config: AuditConfig) -> Self {
Self {
entries: RwLock::new(VecDeque::with_capacity(config.max_entries)),
config,
sink: None,
}
}
fn log(&self, entry: AuditEntry) {
if !self.config.enabled {
return;
}
let should_log = match entry.outcome {
AuditOutcome::Success => self.config.log_success,
AuditOutcome::Blocked => self.config.log_blocked,
AuditOutcome::Warning => self.config.log_warnings,
AuditOutcome::Error => true,
};
if !should_log {
return;
}
if let Some(ref sink) = self.sink {
sink.write(&entry);
}
let mut entries = self.entries.write();
if entries.len() >= self.config.max_entries {
entries.pop_front();
}
entries.push_back(entry);
}
fn get_entries(&self, limit: usize) -> Vec<AuditEntry> {
let entries = self.entries.read();
entries.iter().rev().take(limit).cloned().collect()
}
fn clear(&self) {
self.entries.write().clear();
}
}
#[derive(Debug, Clone, Default)]
pub struct SafetyRequest {
pub source_node: Option<NodeId>,
pub request_id: Option<u128>,
pub content: Option<String>,
pub content_size: usize,
pub estimated_tokens: u32,
pub metadata: HashMap<String, String>,
}
impl SafetyRequest {
pub fn new() -> Self {
Self::default()
}
pub fn with_source(mut self, node: NodeId) -> Self {
self.source_node = Some(node);
self
}
pub fn with_request_id(mut self, id: u128) -> Self {
self.request_id = Some(id);
self
}
pub fn with_content(mut self, content: impl Into<String>) -> Self {
let content = content.into();
self.content_size = content.len();
self.content = Some(content);
self
}
pub fn with_content_size(mut self, size: usize) -> Self {
self.content_size = size;
self
}
pub fn with_tokens(mut self, tokens: u32) -> Self {
self.estimated_tokens = tokens;
self
}
}
pub struct SafetyEnforcer {
envelope: RwLock<SafetyEnvelope>,
usage: AtomicResourceUsage,
rate_limiter: RateLimiter,
audit_log: AuditLog,
kill_switch: AtomicBool,
kill_switch_at: RwLock<Option<Instant>>,
kill_switch_reason: RwLock<Option<String>>,
#[allow(dead_code)]
compiled_patterns: RwLock<Vec<(String, regex::Regex)>>,
}
impl SafetyEnforcer {
pub fn new() -> Self {
Self::with_envelope(SafetyEnvelope::default())
}
pub fn with_envelope(envelope: SafetyEnvelope) -> Self {
let audit_log = AuditLog::new(envelope.audit.clone());
let kill_switch = envelope.kill_switch.enabled;
Self {
envelope: RwLock::new(envelope),
usage: AtomicResourceUsage::new(),
rate_limiter: RateLimiter::new(),
audit_log,
kill_switch: AtomicBool::new(kill_switch),
kill_switch_at: RwLock::new(None),
kill_switch_reason: RwLock::new(None),
compiled_patterns: RwLock::new(Vec::new()),
}
}
pub fn update_envelope(&self, envelope: SafetyEnvelope) {
*self.envelope.write() = envelope;
self.log_event(AuditEventType::EnvelopeUpdated, None, AuditOutcome::Success);
}
pub fn check(&self, req: &SafetyRequest) -> Result<(), SafetyViolation> {
let envelope = self.envelope.read();
if envelope.mode == EnforcementMode::Disabled {
return Ok(());
}
self.check_kill_switch()?;
self.check_rate_limits(req, &envelope)?;
self.check_content_policies(req, &envelope)?;
if envelope.mode == EnforcementMode::AuditOnly {
self.log_event(
AuditEventType::RequestAllowed,
req.source_node,
AuditOutcome::Success,
);
}
Ok(())
}
pub fn acquire(
self: &Arc<Self>,
req: &SafetyRequest,
claim: ResourceClaim,
) -> Result<ResourceGuard, SafetyViolation> {
let envelope = self.envelope.read();
if envelope.mode == EnforcementMode::Disabled {
return Ok(ResourceGuard {
enforcer: Arc::clone(self),
claim,
acquired_at: Instant::now(),
});
}
self.check_kill_switch()?;
let limits = &envelope.resource_limits;
let enforce = envelope.mode == EnforcementMode::Enforce;
if enforce && claim.tokens > limits.max_tokens_per_request {
return Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Tokens,
requested: claim.tokens as u64,
available: limits.max_tokens_per_request as u64,
});
}
self.usage.maybe_reset_hourly();
fn try_fetch_add_capped_u32(
counter: &std::sync::atomic::AtomicU32,
add: u32,
max: u32,
) -> Result<(), u32> {
counter
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
let next = current.saturating_add(add);
if next > max {
None
} else {
Some(next)
}
})
.map(|_| ())
}
if enforce {
if let Err(cur) = try_fetch_add_capped_u32(
&self.usage.concurrent,
claim.concurrent_slots,
limits.max_concurrent,
) {
return Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Concurrent,
requested: claim.concurrent_slots as u64,
available: limits.max_concurrent.saturating_sub(cur) as u64,
});
}
} else {
self.usage
.concurrent
.fetch_add(claim.concurrent_slots, Ordering::Relaxed);
}
if enforce {
if let Err(cur) = try_fetch_add_capped_u32(
&self.usage.memory_gb,
claim.memory_gb,
limits.max_memory_gb,
) {
self.usage
.concurrent
.fetch_sub(claim.concurrent_slots, Ordering::Relaxed);
return Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Memory,
requested: claim.memory_gb as u64,
available: limits.max_memory_gb.saturating_sub(cur) as u64,
});
}
} else {
self.usage
.memory_gb
.fetch_add(claim.memory_gb, Ordering::Relaxed);
}
if enforce {
if let Err(cur) = try_fetch_add_capped_u32(
&self.usage.cost_cents_per_hour,
claim.cost_cents,
limits.max_cost_per_hour_cents,
) {
self.usage
.concurrent
.fetch_sub(claim.concurrent_slots, Ordering::Relaxed);
self.usage
.memory_gb
.fetch_sub(claim.memory_gb, Ordering::Relaxed);
return Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Cost,
requested: claim.cost_cents as u64,
available: limits.max_cost_per_hour_cents.saturating_sub(cur) as u64,
});
}
} else {
self.usage
.cost_cents_per_hour
.fetch_add(claim.cost_cents, Ordering::Relaxed);
}
let rate = &envelope.rate_limits;
let rate_burst = rate.burst_multiplier;
if enforce {
if let Err(e) = self
.rate_limiter
.try_acquire_global_rpm(rate.global_rpm, rate_burst)
{
self.usage
.concurrent
.fetch_sub(claim.concurrent_slots, Ordering::Relaxed);
self.usage
.memory_gb
.fetch_sub(claim.memory_gb, Ordering::Relaxed);
self.usage
.cost_cents_per_hour
.fetch_sub(claim.cost_cents, Ordering::Relaxed);
self.log_event(
AuditEventType::RateLimitHit,
req.source_node,
AuditOutcome::Blocked,
);
return Err(e);
}
}
if enforce {
if let Some(ref source) = req.source_node {
if let Err(e) = self.rate_limiter.try_acquire_source_rpm(
source,
rate.per_source_rpm,
rate_burst,
) {
self.rate_limiter.rollback_global_rpm();
self.usage
.concurrent
.fetch_sub(claim.concurrent_slots, Ordering::Relaxed);
self.usage
.memory_gb
.fetch_sub(claim.memory_gb, Ordering::Relaxed);
self.usage
.cost_cents_per_hour
.fetch_sub(claim.cost_cents, Ordering::Relaxed);
self.log_event(
AuditEventType::RateLimitHit,
req.source_node,
AuditOutcome::Blocked,
);
return Err(e);
}
}
}
if enforce {
if let Err(e) = self.rate_limiter.try_acquire_tokens(
claim.tokens as u64,
rate.tokens_per_minute,
rate_burst,
) {
if let Some(ref source) = req.source_node {
self.rate_limiter.rollback_source_rpm(source);
}
self.rate_limiter.rollback_global_rpm();
self.usage
.concurrent
.fetch_sub(claim.concurrent_slots, Ordering::Relaxed);
self.usage
.memory_gb
.fetch_sub(claim.memory_gb, Ordering::Relaxed);
self.usage
.cost_cents_per_hour
.fetch_sub(claim.cost_cents, Ordering::Relaxed);
self.log_event(
AuditEventType::RateLimitHit,
req.source_node,
AuditOutcome::Blocked,
);
return Err(e);
}
} else {
let _ = self.rate_limiter.global_tokens.fetch_update(
Ordering::Relaxed,
Ordering::Relaxed,
|v| Some(v.saturating_add(claim.tokens as u64)),
);
}
self.usage
.tokens
.fetch_add(claim.tokens as u64, Ordering::Relaxed);
if !enforce {
self.rate_limiter
.global_requests
.fetch_add(1, Ordering::Relaxed);
if let Some(ref source) = req.source_node {
let cur_floor = self.rate_limiter.current_floor();
let bucket = self
.rate_limiter
.per_source
.entry(*source)
.or_insert_with(|| RateBucket::new(cur_floor));
let _ = bucket.try_acquire(cur_floor, u64::MAX);
}
}
self.log_event(
AuditEventType::RequestAllowed,
req.source_node,
AuditOutcome::Success,
);
Ok(ResourceGuard {
enforcer: Arc::clone(self),
claim,
acquired_at: Instant::now(),
})
}
fn release(&self, claim: &ResourceClaim) {
let _ =
self.usage
.concurrent
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
Some(current.saturating_sub(claim.concurrent_slots))
});
let _ = self
.usage
.memory_gb
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
Some(current.saturating_sub(claim.memory_gb))
});
let _ = self
.usage
.tokens
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
Some(current.saturating_sub(claim.tokens as u64))
});
let _ = self.usage.cost_cents_per_hour.fetch_update(
Ordering::AcqRel,
Ordering::Acquire,
|current| Some(current.saturating_sub(claim.cost_cents)),
);
}
pub fn kill(&self, reason: impl Into<String>) {
let reason = reason.into();
self.kill_switch.store(true, Ordering::SeqCst);
*self.kill_switch_at.write() = Some(Instant::now());
*self.kill_switch_reason.write() = Some(reason.clone());
self.log_event_with_details(
AuditEventType::KillSwitchTriggered,
None,
AuditOutcome::Success,
[("reason".to_string(), reason)].into_iter().collect(),
);
}
pub fn reset(&self) {
self.kill_switch.store(false, Ordering::SeqCst);
*self.kill_switch_at.write() = None;
*self.kill_switch_reason.write() = None;
self.log_event(AuditEventType::KillSwitchReset, None, AuditOutcome::Success);
}
pub fn is_killed(&self) -> bool {
self.kill_switch.load(Ordering::Relaxed)
}
pub fn usage(&self) -> UsageStats {
UsageStats {
concurrent: self.usage.concurrent.load(Ordering::Relaxed),
tokens: self.usage.tokens.load(Ordering::Relaxed),
memory_gb: self.usage.memory_gb.load(Ordering::Relaxed),
cost_cents_per_hour: self.usage.cost_cents_per_hour.load(Ordering::Relaxed),
requests_per_minute: self.rate_limiter.global_requests.load(Ordering::Relaxed),
tokens_per_minute: self.rate_limiter.global_tokens.load(Ordering::Relaxed),
}
}
pub fn audit_entries(&self, limit: usize) -> Vec<AuditEntry> {
self.audit_log.get_entries(limit)
}
pub fn clear_audit(&self) {
self.audit_log.clear();
}
pub fn envelope(&self) -> SafetyEnvelope {
self.envelope.read().clone()
}
fn check_kill_switch(&self) -> Result<(), SafetyViolation> {
if !self.kill_switch.load(Ordering::Relaxed) {
return Ok(());
}
let envelope = self.envelope.read();
if let Some(auto_reset_secs) = envelope.kill_switch.auto_reset_secs {
if let Some(killed_at) = *self.kill_switch_at.read() {
if killed_at.elapsed() >= Duration::from_secs(auto_reset_secs as u64) {
drop(envelope);
self.reset();
return Ok(());
}
}
}
let reason = self
.kill_switch_reason
.read()
.clone()
.unwrap_or_else(|| "kill switch active".to_string());
Err(SafetyViolation::KillSwitchActive { reason })
}
fn check_rate_limits(
&self,
req: &SafetyRequest,
envelope: &SafetyEnvelope,
) -> Result<(), SafetyViolation> {
let limits = &envelope.rate_limits;
let burst = limits.burst_multiplier;
let outcome = match envelope.mode {
EnforcementMode::Enforce => AuditOutcome::Blocked,
EnforcementMode::AuditOnly => AuditOutcome::Warning,
EnforcementMode::Disabled => AuditOutcome::Warning,
};
if let Err(e) = self.rate_limiter.check_global_rpm(limits.global_rpm, burst) {
self.log_event(AuditEventType::RateLimitHit, req.source_node, outcome);
if envelope.mode == EnforcementMode::Enforce {
return Err(e);
}
}
if let Some(ref source) = req.source_node {
if let Err(e) = self
.rate_limiter
.check_source_rpm(source, limits.per_source_rpm, burst)
{
self.log_event(AuditEventType::RateLimitHit, req.source_node, outcome);
if envelope.mode == EnforcementMode::Enforce {
return Err(e);
}
}
}
if let Err(e) = self.rate_limiter.check_tokens(
req.estimated_tokens as u64,
limits.tokens_per_minute,
burst,
) {
self.log_event(AuditEventType::RateLimitHit, req.source_node, outcome);
if envelope.mode == EnforcementMode::Enforce {
return Err(e);
}
}
Ok(())
}
#[allow(dead_code)]
fn check_resource_limits(
&self,
claim: &ResourceClaim,
envelope: &SafetyEnvelope,
) -> Result<(), SafetyViolation> {
let limits = &envelope.resource_limits;
let current_concurrent = self.usage.concurrent.load(Ordering::Relaxed);
if current_concurrent.saturating_add(claim.concurrent_slots) > limits.max_concurrent
&& envelope.mode == EnforcementMode::Enforce
{
return Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Concurrent,
requested: claim.concurrent_slots as u64,
available: limits.max_concurrent.saturating_sub(current_concurrent) as u64,
});
}
if claim.tokens > limits.max_tokens_per_request && envelope.mode == EnforcementMode::Enforce
{
return Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Tokens,
requested: claim.tokens as u64,
available: limits.max_tokens_per_request as u64,
});
}
let current_memory = self.usage.memory_gb.load(Ordering::Relaxed);
if current_memory.saturating_add(claim.memory_gb) > limits.max_memory_gb
&& envelope.mode == EnforcementMode::Enforce
{
return Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Memory,
requested: claim.memory_gb as u64,
available: limits.max_memory_gb.saturating_sub(current_memory) as u64,
});
}
self.usage.maybe_reset_hourly();
let current_cost = self.usage.cost_cents_per_hour.load(Ordering::Relaxed);
if current_cost.saturating_add(claim.cost_cents) > limits.max_cost_per_hour_cents
&& envelope.mode == EnforcementMode::Enforce
{
return Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Cost,
requested: claim.cost_cents as u64,
available: limits.max_cost_per_hour_cents.saturating_sub(current_cost) as u64,
});
}
Ok(())
}
fn check_content_policies(
&self,
req: &SafetyRequest,
envelope: &SafetyEnvelope,
) -> Result<(), SafetyViolation> {
for policy in &envelope.content_policies {
if !policy.enabled {
continue;
}
if let Err(violation) = self.check_policy(req, policy) {
match policy.action {
PolicyAction::Block => {
if envelope.mode == EnforcementMode::Enforce {
self.log_event(
AuditEventType::ContentPolicyViolation,
req.source_node,
AuditOutcome::Blocked,
);
return Err(violation);
}
}
PolicyAction::Warn => {
self.log_event(
AuditEventType::ContentPolicyViolation,
req.source_node,
AuditOutcome::Warning,
);
}
PolicyAction::Log => {
self.log_event(
AuditEventType::ContentPolicyViolation,
req.source_node,
AuditOutcome::Warning,
);
}
PolicyAction::Redact => {
self.log_event(
AuditEventType::ContentPolicyViolation,
req.source_node,
AuditOutcome::Warning,
);
}
}
}
}
Ok(())
}
fn check_policy(
&self,
req: &SafetyRequest,
policy: &ContentPolicy,
) -> Result<(), SafetyViolation> {
match &policy.check {
ContentCheck::MaxSize(max_size) => {
if req.content_size > *max_size {
return Err(SafetyViolation::ContentPolicyViolation {
policy_id: policy.id.clone(),
details: format!(
"content size {} exceeds max {}",
req.content_size, max_size
),
});
}
}
ContentCheck::BlockPatterns(patterns) => {
if let Some(ref content) = req.content {
for pattern in patterns {
if content.contains(pattern) {
return Err(SafetyViolation::ContentPolicyViolation {
policy_id: policy.id.clone(),
details: format!("blocked pattern found: {}", pattern),
});
}
}
}
}
ContentCheck::RequirePatterns(patterns) => {
if let Some(ref content) = req.content {
for pattern in patterns {
if !content.contains(pattern) {
return Err(SafetyViolation::ContentPolicyViolation {
policy_id: policy.id.clone(),
details: format!("required pattern not found: {}", pattern),
});
}
}
}
}
ContentCheck::Custom { validator_id } => {
let _ = validator_id;
}
}
Ok(())
}
fn log_event(
&self,
event_type: AuditEventType,
source_node: Option<NodeId>,
outcome: AuditOutcome,
) {
self.log_event_with_details(event_type, source_node, outcome, HashMap::new());
}
fn log_event_with_details(
&self,
event_type: AuditEventType,
source_node: Option<NodeId>,
outcome: AuditOutcome,
details: HashMap<String, String>,
) {
let entry = AuditEntry {
timestamp_ns: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0),
event_type,
source_node,
request_id: None,
details,
outcome,
};
self.audit_log.log(entry);
}
}
impl Default for SafetyEnforcer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_node_id(n: u8) -> NodeId {
let mut id = [0u8; 32];
id[0] = n;
id
}
#[test]
fn test_default_envelope() {
let envelope = SafetyEnvelope::default();
assert_eq!(envelope.mode, EnforcementMode::Enforce);
assert_eq!(envelope.resource_limits.max_concurrent, 1000);
assert_eq!(envelope.rate_limits.global_rpm, 10_000);
}
#[test]
fn test_safety_enforcer_check_passes() {
let enforcer = SafetyEnforcer::new();
let req = SafetyRequest::new().with_tokens(100);
let result = enforcer.check(&req);
assert!(result.is_ok());
}
#[test]
fn test_kill_switch() {
let enforcer = SafetyEnforcer::new();
let req = SafetyRequest::new();
assert!(enforcer.check(&req).is_ok());
assert!(!enforcer.is_killed());
enforcer.kill("test kill");
assert!(enforcer.is_killed());
let result = enforcer.check(&req);
assert!(matches!(
result,
Err(SafetyViolation::KillSwitchActive { .. })
));
enforcer.reset();
assert!(!enforcer.is_killed());
assert!(enforcer.check(&req).is_ok());
}
#[test]
fn test_resource_acquisition() {
let enforcer = Arc::new(SafetyEnforcer::new());
let req = SafetyRequest::new();
let claim = ResourceClaim::new().with_concurrent(1).with_tokens(100);
let guard = enforcer.acquire(&req, claim).unwrap();
assert_eq!(enforcer.usage().concurrent, 1);
assert_eq!(enforcer.usage().tokens, 100);
drop(guard);
assert_eq!(enforcer.usage().concurrent, 0);
}
#[test]
fn test_concurrent_limit() {
let envelope = SafetyEnvelope {
resource_limits: ResourceEnvelope {
max_concurrent: 2,
..Default::default()
},
..Default::default()
};
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let req = SafetyRequest::new();
let claim = ResourceClaim::new().with_concurrent(1);
let _guard1 = enforcer.acquire(&req, claim.clone()).unwrap();
let _guard2 = enforcer.acquire(&req, claim.clone()).unwrap();
let result = enforcer.acquire(&req, claim);
assert!(matches!(
result,
Err(SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Concurrent,
..
})
));
}
#[test]
fn test_content_policy_max_size() {
let envelope = SafetyEnvelope {
content_policies: vec![ContentPolicy {
id: "max-size".to_string(),
check: ContentCheck::MaxSize(100),
action: PolicyAction::Block,
enabled: true,
}],
..Default::default()
};
let enforcer = SafetyEnforcer::with_envelope(envelope);
let req = SafetyRequest::new().with_content_size(50);
assert!(enforcer.check(&req).is_ok());
let req = SafetyRequest::new().with_content_size(200);
assert!(matches!(
enforcer.check(&req),
Err(SafetyViolation::ContentPolicyViolation { .. })
));
}
#[test]
fn test_content_policy_block_patterns() {
let envelope = SafetyEnvelope {
content_policies: vec![ContentPolicy {
id: "block-bad".to_string(),
check: ContentCheck::BlockPatterns(vec!["bad_word".to_string()]),
action: PolicyAction::Block,
enabled: true,
}],
..Default::default()
};
let enforcer = SafetyEnforcer::with_envelope(envelope);
let req = SafetyRequest::new().with_content("hello world");
assert!(enforcer.check(&req).is_ok());
let req = SafetyRequest::new().with_content("this has a bad_word in it");
assert!(matches!(
enforcer.check(&req),
Err(SafetyViolation::ContentPolicyViolation { .. })
));
}
#[test]
fn test_audit_only_mode() {
let envelope = SafetyEnvelope {
mode: EnforcementMode::AuditOnly,
content_policies: vec![ContentPolicy {
id: "max-size".to_string(),
check: ContentCheck::MaxSize(100),
action: PolicyAction::Block,
enabled: true,
}],
..Default::default()
};
let enforcer = SafetyEnforcer::with_envelope(envelope);
let req = SafetyRequest::new().with_content_size(200);
assert!(enforcer.check(&req).is_ok());
}
#[test]
fn test_disabled_mode() {
let envelope = SafetyEnvelope {
mode: EnforcementMode::Disabled,
..Default::default()
};
let enforcer = SafetyEnforcer::with_envelope(envelope);
enforcer.kill("test");
let req = SafetyRequest::new();
assert!(enforcer.check(&req).is_ok());
}
#[test]
fn test_usage_stats() {
let enforcer = Arc::new(SafetyEnforcer::new());
let req = SafetyRequest::new();
let claim = ResourceClaim::new()
.with_concurrent(5)
.with_tokens(1000)
.with_memory_gb(8);
let _guard = enforcer.acquire(&req, claim).unwrap();
let stats = enforcer.usage();
assert_eq!(stats.concurrent, 5);
assert_eq!(stats.tokens, 1000);
assert_eq!(stats.memory_gb, 8);
}
#[test]
fn test_audit_entries() {
let envelope = SafetyEnvelope {
audit: AuditConfig {
enabled: true,
log_success: true,
log_blocked: true,
log_warnings: true,
max_entries: 100,
flush_interval_ms: 5000,
},
..Default::default()
};
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let req = SafetyRequest::new();
let claim = ResourceClaim::new().with_concurrent(1);
let _guard = enforcer.acquire(&req, claim).unwrap();
drop(_guard);
let entries = enforcer.audit_entries(10);
assert!(!entries.is_empty());
}
#[test]
fn test_rate_limiting() {
let envelope = SafetyEnvelope {
rate_limits: RateEnvelope {
global_rpm: 2,
per_source_rpm: 1,
tokens_per_minute: 1000,
burst_multiplier: 1.0,
},
..Default::default()
};
let enforcer = SafetyEnforcer::with_envelope(envelope);
let source = make_node_id(1);
let req = SafetyRequest::new().with_source(source).with_tokens(100);
assert!(enforcer.check(&req).is_ok());
enforcer.rate_limiter.record_request(Some(&source), 100);
let result = enforcer.check(&req);
assert!(matches!(
result,
Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::PerSourceRpm,
..
})
));
}
#[test]
fn audit_only_mode_logs_rate_limit_violations_as_warnings() {
let envelope = SafetyEnvelope {
mode: EnforcementMode::AuditOnly,
rate_limits: RateEnvelope {
global_rpm: 1,
per_source_rpm: 1,
tokens_per_minute: 1000,
burst_multiplier: 1.0,
},
audit: AuditConfig {
enabled: true,
log_success: false,
log_blocked: true,
log_warnings: true,
max_entries: 100,
flush_interval_ms: 5000,
},
..Default::default()
};
let enforcer = SafetyEnforcer::with_envelope(envelope);
let source = make_node_id(7);
let req = SafetyRequest::new().with_source(source).with_tokens(100);
assert!(enforcer.check(&req).is_ok());
enforcer.rate_limiter.record_request(Some(&source), 100);
assert!(
enforcer.check(&req).is_ok(),
"AuditOnly must not block the request"
);
let entries = enforcer.audit_entries(100);
let hits: Vec<_> = entries
.iter()
.filter(|e| e.event_type == AuditEventType::RateLimitHit)
.collect();
assert!(
!hits.is_empty(),
"AuditOnly mode must emit a RateLimitHit audit entry on violation; \
pre-fix the entry was suppressed because logging was gated on \
Enforce mode. Entries: {:?}",
entries,
);
assert!(
hits.iter().all(|e| e.outcome == AuditOutcome::Warning),
"AuditOnly violations must be logged with Warning outcome \
(Blocked is reserved for the Enforce path that actually \
returns Err). Outcomes: {:?}",
hits.iter().map(|e| e.outcome).collect::<Vec<_>>(),
);
}
#[test]
fn release_does_not_underflow_concurrent_or_memory_in_disabled_mode() {
let enforcer = Arc::new(SafetyEnforcer::with_envelope(SafetyEnvelope {
mode: EnforcementMode::Disabled,
..Default::default()
}));
let req = SafetyRequest::new();
let claim = ResourceClaim::new().with_concurrent(5).with_memory_gb(100);
let guard = enforcer.acquire(&req, claim).unwrap();
drop(guard);
let stats = enforcer.usage();
assert_eq!(
stats.concurrent, 0,
"concurrent must stay clamped at 0 when releasing in \
Disabled mode (pre-fix this wrapped to u32::MAX-4)"
);
assert_eq!(
stats.memory_gb, 0,
"memory_gb must stay clamped at 0 when releasing in \
Disabled mode (pre-fix this wrapped to u32::MAX-99)"
);
let new_envelope = SafetyEnvelope {
mode: EnforcementMode::Enforce,
..Default::default()
};
enforcer.update_envelope(new_envelope);
let req2 = SafetyRequest::new();
let claim2 = ResourceClaim::new().with_concurrent(1);
let guard2 = enforcer
.acquire(&req2, claim2)
.expect("Enforce-mode acquire after a Disabled-mode release must succeed");
drop(guard2);
}
#[test]
fn test_regression_release_decrements_tokens_and_cost() {
let enforcer = Arc::new(SafetyEnforcer::new());
let source = make_node_id(1);
let req = SafetyRequest::new().with_source(source).with_tokens(500);
let claim = ResourceClaim {
tokens: 500,
concurrent_slots: 1,
memory_gb: 8,
time_ms: 0,
cost_cents: 50,
};
let guard = enforcer.acquire(&req, claim).unwrap();
assert!(enforcer.usage.tokens.load(Ordering::Relaxed) >= 500);
assert!(enforcer.usage.cost_cents_per_hour.load(Ordering::Relaxed) >= 50);
drop(guard);
assert_eq!(
enforcer.usage.tokens.load(Ordering::Relaxed),
0,
"tokens should be released on drop"
);
assert_eq!(
enforcer.usage.cost_cents_per_hour.load(Ordering::Relaxed),
0,
"cost should be released on drop"
);
}
#[test]
fn test_regression_update_tokens_no_underflow() {
let enforcer = Arc::new(SafetyEnforcer::new());
let source = make_node_id(1);
let req = SafetyRequest::new().with_source(source).with_tokens(100);
let claim = ResourceClaim {
tokens: 100,
concurrent_slots: 1,
memory_gb: 10,
time_ms: 0,
cost_cents: 0,
};
let mut guard = enforcer.acquire(&req, claim).unwrap();
guard.update_tokens(30);
let tokens = enforcer.usage.tokens.load(Ordering::Relaxed);
assert!(
tokens < u64::MAX / 2,
"token counter should not have underflowed (got {})",
tokens
);
drop(guard);
let final_tokens = enforcer.usage.tokens.load(Ordering::Relaxed);
assert_eq!(
final_tokens, 0,
"tokens should be 0 after release, not underflowed"
);
}
#[test]
fn test_regression_check_tokens_overflow_is_rejected() {
let limiter = RateLimiter::new();
limiter
.global_tokens
.store(u64::MAX - 10, Ordering::Relaxed);
let result = limiter.check_tokens(100, 1_000_000, 1.0);
assert!(
matches!(
result,
Err(SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::TokensPerMinute,
..
})
),
"overflow must be rejected, got {:?}",
result
);
}
#[test]
fn acquire_concurrent_cap_is_atomic_under_contention() {
use std::sync::Arc;
use std::sync::Barrier;
use std::thread;
const CAP: u32 = 5;
const ATTEMPTS: usize = 100;
let limits = ResourceEnvelope {
max_concurrent: CAP,
max_tokens_per_request: 1_000_000,
max_memory_gb: 1_000_000,
max_time_ms: 1_000_000,
max_cost_per_hour_cents: u32::MAX,
};
let envelope = SafetyEnvelope {
mode: EnforcementMode::Enforce,
resource_limits: limits,
..Default::default()
};
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let barrier = Arc::new(Barrier::new(ATTEMPTS));
let handles: Vec<_> = (0..ATTEMPTS)
.map(|_| {
let enf = Arc::clone(&enforcer);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
let req = SafetyRequest::new();
let claim = ResourceClaim {
concurrent_slots: 1,
tokens: 1,
memory_gb: 0,
time_ms: 0,
cost_cents: 0,
};
enf.acquire(&req, claim)
})
})
.collect();
let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
let successes: Vec<_> = results.into_iter().filter_map(|r| r.ok()).collect();
assert!(
successes.len() as u32 <= CAP,
"TOCTOU regression (#8): {} concurrent acquires committed against \
cap of {}",
successes.len(),
CAP
);
assert!(
enforcer.usage.concurrent.load(Ordering::Relaxed) <= CAP,
"concurrent counter exceeds cap"
);
}
#[test]
fn acquire_global_rpm_cap_is_atomic_under_contention() {
use std::sync::Arc;
use std::sync::Barrier;
use std::thread;
const RPM_CAP: u32 = 5;
const ATTEMPTS: usize = 100;
let envelope = SafetyEnvelope {
mode: EnforcementMode::Enforce,
resource_limits: ResourceEnvelope {
max_concurrent: u32::MAX,
max_tokens_per_request: 1_000_000,
max_memory_gb: u32::MAX,
max_time_ms: u32::MAX,
max_cost_per_hour_cents: u32::MAX,
},
rate_limits: RateEnvelope {
global_rpm: RPM_CAP,
per_source_rpm: u32::MAX,
tokens_per_minute: u64::MAX,
burst_multiplier: 1.0,
},
..Default::default()
};
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let barrier = Arc::new(Barrier::new(ATTEMPTS));
let handles: Vec<_> = (0..ATTEMPTS)
.map(|_| {
let enf = Arc::clone(&enforcer);
let b = Arc::clone(&barrier);
thread::spawn(move || {
b.wait();
let req = SafetyRequest::new();
let claim = ResourceClaim {
concurrent_slots: 1,
tokens: 1,
memory_gb: 0,
time_ms: 0,
cost_cents: 0,
};
enf.acquire(&req, claim)
})
})
.collect();
let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
let successes: Vec<_> = results.into_iter().filter_map(|r| r.ok()).collect();
assert!(
successes.len() as u32 <= RPM_CAP,
"RPM TOCTOU regression (#8): {} acquires committed against cap {}",
successes.len(),
RPM_CAP,
);
assert!(
enforcer
.rate_limiter
.global_requests
.load(Ordering::Relaxed)
<= RPM_CAP as u64,
"global_requests counter exceeds RPM cap",
);
}
#[test]
fn rate_bucket_self_resets_on_window_rollover() {
let bucket = RateBucket::new(0);
assert_eq!(bucket.current_count(0), 0);
assert!(matches!(bucket.try_acquire(0, 5), Ok(1)));
assert_eq!(bucket.current_count(0), 1);
for _ in 0..4 {
assert!(bucket.try_acquire(0, 5).is_ok());
}
assert_eq!(bucket.current_count(0), 5);
assert!(matches!(bucket.try_acquire(0, 5), Err(5)));
assert_eq!(bucket.current_count(1), 0);
assert!(matches!(bucket.try_acquire(1, 5), Ok(1)));
assert_eq!(bucket.current_count(1), 1);
assert_eq!(bucket.current_count(0), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn per_source_cap_respected_under_contention_no_clear_race() {
const PER_SOURCE_CAP: u32 = 50;
const N_THREADS: u32 = 16;
const N_PER_THREAD: u32 = 100;
let mut envelope = SafetyEnvelope::default();
envelope.rate_limits.per_source_rpm = PER_SOURCE_CAP;
envelope.rate_limits.global_rpm = u32::MAX;
envelope.rate_limits.tokens_per_minute = u64::MAX;
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let source: NodeId = [0xAA; 32];
let mut handles = Vec::new();
let success_count = Arc::new(AtomicU64::new(0));
for _ in 0..N_THREADS {
let e = enforcer.clone();
let sc = success_count.clone();
handles.push(tokio::task::spawn_blocking(move || {
for _ in 0..N_PER_THREAD {
if e.rate_limiter
.try_acquire_source_rpm(&source, PER_SOURCE_CAP, 1.0)
.is_ok()
{
sc.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for h in handles {
h.await.unwrap();
}
let total = success_count.load(Ordering::Relaxed);
assert!(
total <= PER_SOURCE_CAP as u64,
"per-source cap regression (#125): {} acquires committed against cap {} \
(no global clear race should let any over-commit happen — bucket \
self-resets via packed atomic CAS)",
total,
PER_SOURCE_CAP,
);
}
#[test]
fn memory_limit_failure_rolls_back_concurrent() {
let envelope = SafetyEnvelope {
resource_limits: ResourceEnvelope {
max_concurrent: 100,
max_memory_gb: 1,
..Default::default()
},
..Default::default()
};
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let req = SafetyRequest::new();
let claim = ResourceClaim::new().with_concurrent(1).with_memory_gb(2);
let err = match enforcer.acquire(&req, claim) {
Err(e) => e,
Ok(_) => panic!("expected memory limit failure"),
};
assert!(matches!(
err,
SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Memory,
..
}
));
assert_eq!(enforcer.usage().concurrent, 0);
assert_eq!(enforcer.usage().memory_gb, 0);
}
#[test]
fn cost_limit_failure_rolls_back_concurrent_and_memory() {
let envelope = SafetyEnvelope {
resource_limits: ResourceEnvelope {
max_concurrent: 100,
max_memory_gb: 100,
max_cost_per_hour_cents: 10,
..Default::default()
},
..Default::default()
};
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let req = SafetyRequest::new();
let claim = ResourceClaim::new()
.with_concurrent(1)
.with_memory_gb(1)
.with_cost_cents(100);
let err = match enforcer.acquire(&req, claim) {
Err(e) => e,
Ok(_) => panic!("expected cost limit failure"),
};
assert!(matches!(
err,
SafetyViolation::ResourceLimitExceeded {
resource: ResourceType::Cost,
..
}
));
assert_eq!(enforcer.usage().concurrent, 0);
assert_eq!(enforcer.usage().memory_gb, 0);
assert_eq!(enforcer.usage().cost_cents_per_hour, 0);
}
#[test]
fn per_source_rpm_failure_rolls_back_global_and_resources() {
let envelope = SafetyEnvelope {
rate_limits: RateEnvelope {
global_rpm: 100,
per_source_rpm: 1,
tokens_per_minute: 1_000_000,
burst_multiplier: 1.0,
},
..Default::default()
};
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let source = make_node_id(11);
let req = SafetyRequest::new().with_source(source).with_tokens(10);
let claim = ResourceClaim::new().with_concurrent(1).with_memory_gb(1);
let _guard = enforcer.acquire(&req, claim.clone()).unwrap();
let err = match enforcer.acquire(&req, claim) {
Err(e) => e,
Ok(_) => panic!("expected per-source RPM failure"),
};
assert!(matches!(
err,
SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::PerSourceRpm,
..
}
));
let usage = enforcer.usage();
assert_eq!(usage.concurrent, 1);
assert_eq!(usage.memory_gb, 1);
assert_eq!(usage.requests_per_minute, 1);
}
#[test]
fn tokens_failure_rolls_back_source_and_global_and_resources() {
let envelope = SafetyEnvelope {
rate_limits: RateEnvelope {
global_rpm: 100,
per_source_rpm: 100,
tokens_per_minute: 10,
burst_multiplier: 1.0,
},
..Default::default()
};
let enforcer = Arc::new(SafetyEnforcer::with_envelope(envelope));
let source = make_node_id(22);
let req = SafetyRequest::new().with_source(source);
let claim = ResourceClaim::new()
.with_concurrent(1)
.with_memory_gb(1)
.with_tokens(100);
let err = match enforcer.acquire(&req, claim) {
Err(e) => e,
Ok(_) => panic!("expected tokens-per-minute failure"),
};
assert!(matches!(
err,
SafetyViolation::RateLimitExceeded {
limit_type: RateLimitType::TokensPerMinute,
..
}
));
let usage = enforcer.usage();
assert_eq!(usage.concurrent, 0);
assert_eq!(usage.memory_gb, 0);
assert_eq!(usage.cost_cents_per_hour, 0);
assert_eq!(usage.requests_per_minute, 0);
}
fn content_policy_envelope(action: PolicyAction) -> SafetyEnvelope {
SafetyEnvelope {
content_policies: vec![ContentPolicy {
id: "warn-on-bad".into(),
check: ContentCheck::BlockPatterns(vec!["bad".into()]),
action,
enabled: true,
}],
audit: AuditConfig {
enabled: true,
log_success: false,
log_blocked: true,
log_warnings: true,
max_entries: 100,
flush_interval_ms: 5000,
},
..Default::default()
}
}
#[test]
fn content_policy_warn_logs_warning_without_blocking() {
let enforcer = SafetyEnforcer::with_envelope(content_policy_envelope(PolicyAction::Warn));
let req = SafetyRequest::new().with_content("this is bad");
assert!(enforcer.check(&req).is_ok(), "Warn must not block");
let warnings: Vec<_> = enforcer
.audit_entries(100)
.into_iter()
.filter(|e| {
e.event_type == AuditEventType::ContentPolicyViolation
&& e.outcome == AuditOutcome::Warning
})
.collect();
assert!(
!warnings.is_empty(),
"Warn action must log a Warning-outcome audit entry",
);
}
#[test]
fn content_policy_log_logs_warning_without_blocking() {
let enforcer = SafetyEnforcer::with_envelope(content_policy_envelope(PolicyAction::Log));
let req = SafetyRequest::new().with_content("this is bad");
assert!(enforcer.check(&req).is_ok(), "Log must not block");
let warnings: Vec<_> = enforcer
.audit_entries(100)
.into_iter()
.filter(|e| {
e.event_type == AuditEventType::ContentPolicyViolation
&& e.outcome == AuditOutcome::Warning
})
.collect();
assert!(!warnings.is_empty());
}
#[test]
fn content_policy_redact_logs_warning_without_blocking() {
let enforcer = SafetyEnforcer::with_envelope(content_policy_envelope(PolicyAction::Redact));
let req = SafetyRequest::new().with_content("this is bad");
assert!(enforcer.check(&req).is_ok(), "Redact must not block");
let warnings: Vec<_> = enforcer
.audit_entries(100)
.into_iter()
.filter(|e| {
e.event_type == AuditEventType::ContentPolicyViolation
&& e.outcome == AuditOutcome::Warning
})
.collect();
assert!(!warnings.is_empty());
}
}