use super::error::{ExecutionError, ExecutionErrorCategory};
use super::ids::{ExecutionId, StepId, TenantId};
use crate::context::ResourceLimits;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug)]
pub struct ExecutionUsage {
pub execution_id: ExecutionId,
pub tenant_id: TenantId,
pub steps: AtomicU32,
pub input_tokens: AtomicU32,
pub output_tokens: AtomicU32,
pub started_at: Instant,
pub last_activity: RwLock<Instant>,
pub discovered_steps: AtomicU32,
pub discovery_depth: AtomicU32,
pub max_discovery_depth_reached: AtomicU32,
pub cost_cents: AtomicU64,
}
impl ExecutionUsage {
pub fn new(execution_id: ExecutionId, tenant_id: TenantId) -> Self {
let now = Instant::now();
Self {
execution_id,
tenant_id,
steps: AtomicU32::new(0),
input_tokens: AtomicU32::new(0),
output_tokens: AtomicU32::new(0),
started_at: now,
last_activity: RwLock::new(now),
discovered_steps: AtomicU32::new(0),
discovery_depth: AtomicU32::new(0),
max_discovery_depth_reached: AtomicU32::new(0),
cost_cents: AtomicU64::new(0),
}
}
pub fn record_step(&self) {
self.steps.fetch_add(1, Ordering::SeqCst);
}
pub fn record_discovered_step(&self) {
self.discovered_steps.fetch_add(1, Ordering::SeqCst);
}
pub fn record_tokens(&self, input: u32, output: u32) {
self.input_tokens.fetch_add(input, Ordering::SeqCst);
self.output_tokens.fetch_add(output, Ordering::SeqCst);
}
pub fn record_cost_usd(&self, cost_usd: f64) {
let cents = (cost_usd * 100.0) as u64;
self.cost_cents.fetch_add(cents, Ordering::SeqCst);
}
pub fn push_discovery_depth(&self) {
let new_depth = self.discovery_depth.fetch_add(1, Ordering::SeqCst) + 1;
let current_max = self.max_discovery_depth_reached.load(Ordering::SeqCst);
if new_depth > current_max {
self.max_discovery_depth_reached
.store(new_depth, Ordering::SeqCst);
}
}
pub fn pop_discovery_depth(&self) {
self.discovery_depth.fetch_sub(1, Ordering::SeqCst);
}
pub async fn touch(&self) {
let mut last = self.last_activity.write().await;
*last = Instant::now();
}
pub fn step_count(&self) -> u32 {
self.steps.load(Ordering::SeqCst)
}
pub fn discovered_step_count(&self) -> u32 {
self.discovered_steps.load(Ordering::SeqCst)
}
pub fn current_discovery_depth(&self) -> u32 {
self.discovery_depth.load(Ordering::SeqCst)
}
pub fn total_tokens(&self) -> u32 {
self.input_tokens.load(Ordering::SeqCst) + self.output_tokens.load(Ordering::SeqCst)
}
pub fn cost_usd(&self) -> f64 {
self.cost_cents.load(Ordering::SeqCst) as f64 / 100.0
}
pub fn wall_time(&self) -> Duration {
self.started_at.elapsed()
}
pub fn wall_time_ms(&self) -> u64 {
self.wall_time().as_millis() as u64
}
pub async fn idle_duration(&self) -> Duration {
let last = self.last_activity.read().await;
last.elapsed()
}
pub async fn idle_seconds(&self) -> u64 {
self.idle_duration().await.as_secs()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageSnapshot {
pub execution_id: String,
pub tenant_id: String,
pub steps: u32,
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
pub wall_time_ms: u64,
pub discovered_steps: u32,
pub discovery_depth: u32,
pub max_discovery_depth: u32,
pub cost_usd: f64,
}
impl From<&ExecutionUsage> for UsageSnapshot {
fn from(usage: &ExecutionUsage) -> Self {
let input = usage.input_tokens.load(Ordering::SeqCst);
let output = usage.output_tokens.load(Ordering::SeqCst);
Self {
execution_id: usage.execution_id.as_str().to_string(),
tenant_id: usage.tenant_id.as_str().to_string(),
steps: usage.steps.load(Ordering::SeqCst),
input_tokens: input,
output_tokens: output,
total_tokens: input + output,
wall_time_ms: usage.wall_time_ms(),
discovered_steps: usage.discovered_steps.load(Ordering::SeqCst),
discovery_depth: usage.discovery_depth.load(Ordering::SeqCst),
max_discovery_depth: usage.max_discovery_depth_reached.load(Ordering::SeqCst),
cost_usd: usage.cost_usd(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EnforcementResult {
Allowed,
Blocked(EnforcementViolation),
Warning(EnforcementWarning),
}
impl EnforcementResult {
pub fn is_allowed(&self) -> bool {
matches!(self, Self::Allowed | Self::Warning(_))
}
pub fn is_blocked(&self) -> bool {
matches!(self, Self::Blocked(_))
}
pub fn to_error(&self) -> Option<ExecutionError> {
match self {
Self::Blocked(violation) => Some(violation.to_error()),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ViolationType {
StepLimit,
TokenLimit,
WallTimeLimit,
MemoryLimit,
ConcurrencyLimit,
RateLimit,
NetworkViolation,
DiscoveredStepLimit,
DiscoveryDepthLimit,
CostThreshold,
IdleTimeout,
SameStepLoop,
}
impl std::fmt::Display for ViolationType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::StepLimit => write!(f, "step_limit"),
Self::TokenLimit => write!(f, "token_limit"),
Self::WallTimeLimit => write!(f, "wall_time_limit"),
Self::MemoryLimit => write!(f, "memory_limit"),
Self::ConcurrencyLimit => write!(f, "concurrency_limit"),
Self::RateLimit => write!(f, "rate_limit"),
Self::NetworkViolation => write!(f, "network_violation"),
Self::DiscoveredStepLimit => write!(f, "discovered_step_limit"),
Self::DiscoveryDepthLimit => write!(f, "discovery_depth_limit"),
Self::CostThreshold => write!(f, "cost_threshold"),
Self::IdleTimeout => write!(f, "idle_timeout"),
Self::SameStepLoop => write!(f, "same_step_loop"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EnforcementViolation {
pub violation_type: ViolationType,
pub current: u64,
pub limit: u64,
pub message: String,
}
impl EnforcementViolation {
pub fn new(violation_type: ViolationType, current: u64, limit: u64) -> Self {
let message = format!(
"{} exceeded: {} / {} ({}%)",
violation_type,
current,
limit,
(current as f64 / limit as f64 * 100.0) as u32
);
Self {
violation_type,
current,
limit,
message,
}
}
pub fn to_error(&self) -> ExecutionError {
let category = match self.violation_type {
ViolationType::WallTimeLimit => ExecutionErrorCategory::Timeout,
ViolationType::RateLimit => ExecutionErrorCategory::LlmError, ViolationType::NetworkViolation => ExecutionErrorCategory::PolicyViolation, _ => ExecutionErrorCategory::QuotaExceeded,
};
ExecutionError::new(category, self.message.clone())
.with_code(self.violation_type.to_string())
.with_details(serde_json::json!({
"current": self.current,
"limit": self.limit,
"violation_type": self.violation_type.to_string(),
}))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EnforcementWarning {
pub warning_type: ViolationType,
pub usage_percent: u32,
pub message: String,
}
impl EnforcementWarning {
pub fn new(warning_type: ViolationType, current: u64, limit: u64) -> Self {
let percent = (current as f64 / limit as f64 * 100.0) as u32;
let message = format!("{} at {}%: {} / {}", warning_type, percent, current, limit);
Self {
warning_type,
usage_percent: percent,
message,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnforcementPolicy {
pub warning_threshold: u32,
pub emit_warning_events: bool,
pub emit_block_events: bool,
pub timeout_grace_ms: u64,
}
impl Default for EnforcementPolicy {
fn default() -> Self {
Self {
warning_threshold: 80, emit_warning_events: true,
emit_block_events: true,
timeout_grace_ms: 1000, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LongRunningExecutionPolicy {
pub max_discovered_steps: Option<u32>,
pub max_discovery_depth: Option<u32>,
pub cost_alert_threshold_usd: Option<f64>,
pub idle_timeout_seconds: Option<u64>,
pub max_same_step_repetitions: Option<u32>,
}
impl Default for LongRunningExecutionPolicy {
fn default() -> Self {
Self::standard()
}
}
impl LongRunningExecutionPolicy {
pub fn standard() -> Self {
Self {
max_discovered_steps: Some(50),
max_discovery_depth: Some(5),
cost_alert_threshold_usd: Some(5.0),
idle_timeout_seconds: Some(1800), max_same_step_repetitions: Some(3),
}
}
pub fn extended() -> Self {
Self {
max_discovered_steps: Some(300),
max_discovery_depth: Some(10),
cost_alert_threshold_usd: Some(50.0),
idle_timeout_seconds: Some(14400), max_same_step_repetitions: Some(5),
}
}
pub fn unlimited() -> Self {
Self {
max_discovered_steps: None,
max_discovery_depth: None,
cost_alert_threshold_usd: Some(100.0), idle_timeout_seconds: Some(86400), max_same_step_repetitions: None,
}
}
pub fn disabled() -> Self {
Self {
max_discovered_steps: None,
max_discovery_depth: None,
cost_alert_threshold_usd: None,
idle_timeout_seconds: None,
max_same_step_repetitions: None,
}
}
}
#[derive(Debug)]
pub struct EnforcementMiddleware {
executions: RwLock<HashMap<ExecutionId, Arc<ExecutionUsage>>>,
tenant_executions: RwLock<HashMap<TenantId, AtomicU32>>,
#[allow(dead_code)]
rate_limiter: RwLock<RateLimiterState>,
policy: EnforcementPolicy,
}
impl EnforcementMiddleware {
pub fn new() -> Self {
Self::with_policy(EnforcementPolicy::default())
}
pub fn with_policy(policy: EnforcementPolicy) -> Self {
Self {
executions: RwLock::new(HashMap::new()),
tenant_executions: RwLock::new(HashMap::new()),
rate_limiter: RwLock::new(RateLimiterState::new()),
policy,
}
}
pub fn emit_warning_events_enabled(&self) -> bool {
self.policy.emit_warning_events
}
pub async fn register_execution(
&self,
execution_id: ExecutionId,
tenant_id: TenantId,
) -> Arc<ExecutionUsage> {
let usage = Arc::new(ExecutionUsage::new(execution_id.clone(), tenant_id.clone()));
{
let mut executions = self.executions.write().await;
executions.insert(execution_id, Arc::clone(&usage));
}
{
let mut tenant_execs = self.tenant_executions.write().await;
tenant_execs
.entry(tenant_id)
.or_insert_with(|| AtomicU32::new(0))
.fetch_add(1, Ordering::SeqCst);
}
usage
}
pub async fn unregister_execution(&self, execution_id: &ExecutionId) {
let tenant_id = {
let mut executions = self.executions.write().await;
executions.remove(execution_id).map(|u| u.tenant_id.clone())
};
if let Some(tenant_id) = tenant_id {
let tenant_execs = self.tenant_executions.read().await;
if let Some(count) = tenant_execs.get(&tenant_id) {
count.fetch_sub(1, Ordering::SeqCst);
}
}
}
pub async fn get_usage(&self, execution_id: &ExecutionId) -> Option<Arc<ExecutionUsage>> {
let executions = self.executions.read().await;
executions.get(execution_id).cloned()
}
pub async fn get_usage_snapshot(&self, execution_id: &ExecutionId) -> Option<UsageSnapshot> {
self.get_usage(execution_id)
.await
.map(|u| UsageSnapshot::from(u.as_ref()))
}
pub async fn check_step_allowed(
&self,
execution_id: &ExecutionId,
limits: &ResourceLimits,
) -> EnforcementResult {
let usage = match self.get_usage(execution_id).await {
Some(u) => u,
None => return EnforcementResult::Allowed, };
let current = usage.step_count() as u64 + 1; let limit = limits.max_steps as u64;
if current > limit {
return EnforcementResult::Blocked(EnforcementViolation::new(
ViolationType::StepLimit,
current,
limit,
));
}
let percent = (current as f64 / limit as f64 * 100.0) as u32;
if percent >= self.policy.warning_threshold {
return EnforcementResult::Warning(EnforcementWarning::new(
ViolationType::StepLimit,
current,
limit,
));
}
EnforcementResult::Allowed
}
pub async fn check_tokens_allowed(
&self,
execution_id: &ExecutionId,
limits: &ResourceLimits,
additional_tokens: u32,
) -> EnforcementResult {
let usage = match self.get_usage(execution_id).await {
Some(u) => u,
None => return EnforcementResult::Allowed,
};
let current = usage.total_tokens() as u64 + additional_tokens as u64;
let limit = limits.max_tokens as u64;
if current > limit {
return EnforcementResult::Blocked(EnforcementViolation::new(
ViolationType::TokenLimit,
current,
limit,
));
}
let percent = (current as f64 / limit as f64 * 100.0) as u32;
if percent >= self.policy.warning_threshold {
return EnforcementResult::Warning(EnforcementWarning::new(
ViolationType::TokenLimit,
current,
limit,
));
}
EnforcementResult::Allowed
}
pub async fn check_wall_time_allowed(
&self,
execution_id: &ExecutionId,
limits: &ResourceLimits,
) -> EnforcementResult {
let usage = match self.get_usage(execution_id).await {
Some(u) => u,
None => return EnforcementResult::Allowed,
};
let current = usage.wall_time_ms();
let limit = limits.max_wall_time_ms;
let effective_limit = limit + self.policy.timeout_grace_ms;
if current > effective_limit {
return EnforcementResult::Blocked(EnforcementViolation::new(
ViolationType::WallTimeLimit,
current,
limit,
));
}
let percent = (current as f64 / limit as f64 * 100.0) as u32;
if percent >= self.policy.warning_threshold {
return EnforcementResult::Warning(EnforcementWarning::new(
ViolationType::WallTimeLimit,
current,
limit,
));
}
EnforcementResult::Allowed
}
pub async fn check_concurrency_allowed(
&self,
tenant_id: &TenantId,
limits: &ResourceLimits,
) -> EnforcementResult {
let max_concurrent = match limits.max_concurrent_executions {
Some(max) => max,
None => return EnforcementResult::Allowed, };
let current = {
let tenant_execs = self.tenant_executions.read().await;
tenant_execs
.get(tenant_id)
.map(|c| c.load(Ordering::SeqCst))
.unwrap_or(0) as u64
};
let limit = max_concurrent as u64;
if current >= limit {
return EnforcementResult::Blocked(EnforcementViolation::new(
ViolationType::ConcurrencyLimit,
current + 1, limit,
));
}
EnforcementResult::Allowed
}
pub async fn check_all_limits(
&self,
execution_id: &ExecutionId,
limits: &ResourceLimits,
) -> EnforcementResult {
let wall_check = self.check_wall_time_allowed(execution_id, limits).await;
if wall_check.is_blocked() {
return wall_check;
}
let step_check = self.check_step_allowed(execution_id, limits).await;
if step_check.is_blocked() {
return step_check;
}
let token_check = self.check_tokens_allowed(execution_id, limits, 0).await;
if token_check.is_blocked() {
return token_check;
}
if let EnforcementResult::Warning(w) = wall_check {
return EnforcementResult::Warning(w);
}
if let EnforcementResult::Warning(w) = step_check {
return EnforcementResult::Warning(w);
}
if let EnforcementResult::Warning(w) = token_check {
return EnforcementResult::Warning(w);
}
EnforcementResult::Allowed
}
pub async fn record_step(&self, execution_id: &ExecutionId) {
if let Some(usage) = self.get_usage(execution_id).await {
usage.record_step();
usage.touch().await;
}
}
pub async fn record_tokens(&self, execution_id: &ExecutionId, input: u32, output: u32) {
if let Some(usage) = self.get_usage(execution_id).await {
usage.record_tokens(input, output);
usage.touch().await;
}
}
pub async fn record_discovered_step(&self, execution_id: &ExecutionId) {
if let Some(usage) = self.get_usage(execution_id).await {
usage.record_discovered_step();
usage.touch().await;
}
}
pub async fn record_cost(&self, execution_id: &ExecutionId, cost_usd: f64) {
if let Some(usage) = self.get_usage(execution_id).await {
usage.record_cost_usd(cost_usd);
usage.touch().await;
}
}
pub async fn push_discovery_depth(&self, execution_id: &ExecutionId) {
if let Some(usage) = self.get_usage(execution_id).await {
usage.push_discovery_depth();
}
}
pub async fn pop_discovery_depth(&self, execution_id: &ExecutionId) {
if let Some(usage) = self.get_usage(execution_id).await {
usage.pop_discovery_depth();
}
}
pub async fn check_discovered_step_limit(
&self,
execution_id: &ExecutionId,
policy: &LongRunningExecutionPolicy,
) -> EnforcementResult {
let max_discovered = match policy.max_discovered_steps {
Some(max) => max,
None => return EnforcementResult::Allowed,
};
let usage = match self.get_usage(execution_id).await {
Some(u) => u,
None => return EnforcementResult::Allowed,
};
let current = usage.discovered_step_count() as u64 + 1; let limit = max_discovered as u64;
if current > limit {
return EnforcementResult::Blocked(EnforcementViolation::new(
ViolationType::DiscoveredStepLimit,
current,
limit,
));
}
let percent = (current as f64 / limit as f64 * 100.0) as u32;
if percent >= self.policy.warning_threshold {
return EnforcementResult::Warning(EnforcementWarning::new(
ViolationType::DiscoveredStepLimit,
current,
limit,
));
}
EnforcementResult::Allowed
}
pub async fn check_discovery_depth_limit(
&self,
execution_id: &ExecutionId,
policy: &LongRunningExecutionPolicy,
) -> EnforcementResult {
let max_depth = match policy.max_discovery_depth {
Some(max) => max,
None => return EnforcementResult::Allowed,
};
let usage = match self.get_usage(execution_id).await {
Some(u) => u,
None => return EnforcementResult::Allowed,
};
let current = usage.current_discovery_depth() as u64 + 1; let limit = max_depth as u64;
if current > limit {
return EnforcementResult::Blocked(EnforcementViolation::new(
ViolationType::DiscoveryDepthLimit,
current,
limit,
));
}
EnforcementResult::Allowed
}
pub async fn check_cost_threshold(
&self,
execution_id: &ExecutionId,
policy: &LongRunningExecutionPolicy,
) -> EnforcementResult {
let threshold = match policy.cost_alert_threshold_usd {
Some(t) => t,
None => return EnforcementResult::Allowed,
};
let usage = match self.get_usage(execution_id).await {
Some(u) => u,
None => return EnforcementResult::Allowed,
};
let current_cents = usage.cost_cents.load(Ordering::SeqCst);
let current_usd = current_cents as f64 / 100.0;
let limit_cents = (threshold * 100.0) as u64;
if current_usd >= threshold {
return EnforcementResult::Blocked(EnforcementViolation::new(
ViolationType::CostThreshold,
current_cents,
limit_cents,
));
}
let percent = (current_usd / threshold * 100.0) as u32;
if percent >= self.policy.warning_threshold {
return EnforcementResult::Warning(EnforcementWarning::new(
ViolationType::CostThreshold,
current_cents,
limit_cents,
));
}
EnforcementResult::Allowed
}
pub async fn check_idle_timeout(
&self,
execution_id: &ExecutionId,
policy: &LongRunningExecutionPolicy,
) -> EnforcementResult {
let timeout_secs = match policy.idle_timeout_seconds {
Some(t) => t,
None => return EnforcementResult::Allowed,
};
let usage = match self.get_usage(execution_id).await {
Some(u) => u,
None => return EnforcementResult::Allowed,
};
let idle_secs = usage.idle_seconds().await;
if idle_secs >= timeout_secs {
return EnforcementResult::Blocked(EnforcementViolation::new(
ViolationType::IdleTimeout,
idle_secs,
timeout_secs,
));
}
let percent = (idle_secs as f64 / timeout_secs as f64 * 100.0) as u32;
if percent >= self.policy.warning_threshold {
return EnforcementResult::Warning(EnforcementWarning::new(
ViolationType::IdleTimeout,
idle_secs,
timeout_secs,
));
}
EnforcementResult::Allowed
}
pub async fn check_long_running_limits(
&self,
execution_id: &ExecutionId,
policy: &LongRunningExecutionPolicy,
) -> EnforcementResult {
let cost_check = self.check_cost_threshold(execution_id, policy).await;
if cost_check.is_blocked() {
return cost_check;
}
let depth_check = self.check_discovery_depth_limit(execution_id, policy).await;
if depth_check.is_blocked() {
return depth_check;
}
let discovered_check = self.check_discovered_step_limit(execution_id, policy).await;
if discovered_check.is_blocked() {
return discovered_check;
}
let idle_check = self.check_idle_timeout(execution_id, policy).await;
if idle_check.is_blocked() {
return idle_check;
}
if let EnforcementResult::Warning(w) = cost_check {
return EnforcementResult::Warning(w);
}
if let EnforcementResult::Warning(w) = discovered_check {
return EnforcementResult::Warning(w);
}
if let EnforcementResult::Warning(w) = idle_check {
return EnforcementResult::Warning(w);
}
EnforcementResult::Allowed
}
}
impl Default for EnforcementMiddleware {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct RateLimiterState {
#[allow(dead_code)]
provider_tokens: HashMap<String, TokenBucket>,
}
impl RateLimiterState {
fn new() -> Self {
Self {
provider_tokens: HashMap::new(),
}
}
}
#[derive(Debug)]
struct TokenBucket {
tokens: AtomicU64,
max_tokens: u64,
refill_rate: u64,
last_refill: RwLock<Instant>,
}
impl TokenBucket {
#[allow(dead_code)]
fn new(max_tokens: u64, refill_rate: u64) -> Self {
Self {
tokens: AtomicU64::new(max_tokens),
max_tokens,
refill_rate,
last_refill: RwLock::new(Instant::now()),
}
}
#[allow(dead_code)]
async fn try_acquire(&self, count: u64) -> bool {
{
let mut last = self.last_refill.write().await;
let elapsed = last.elapsed();
let new_tokens = (elapsed.as_secs_f64() * self.refill_rate as f64) as u64;
if new_tokens > 0 {
let current = self.tokens.load(Ordering::SeqCst);
let new_total = std::cmp::min(current + new_tokens, self.max_tokens);
self.tokens.store(new_total, Ordering::SeqCst);
*last = Instant::now();
}
}
let current = self.tokens.load(Ordering::SeqCst);
if current >= count {
self.tokens.fetch_sub(count, Ordering::SeqCst);
true
} else {
false
}
}
}
pub struct StepTimeoutGuard {
step_id: StepId,
timeout: Duration,
started_at: Instant,
}
impl StepTimeoutGuard {
pub fn new(step_id: StepId, timeout: Duration) -> Self {
Self {
step_id,
timeout,
started_at: Instant::now(),
}
}
pub fn is_timed_out(&self) -> bool {
self.started_at.elapsed() > self.timeout
}
pub fn remaining(&self) -> Duration {
self.timeout.saturating_sub(self.started_at.elapsed())
}
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
#[allow(clippy::result_large_err)]
pub fn check(&self) -> Result<(), ExecutionError> {
if self.is_timed_out() {
Err(ExecutionError::timeout(format!(
"Step {} timed out after {:?}",
self.step_id, self.timeout
))
.with_step_id(self.step_id.clone()))
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_usage_tracking() {
let exec_id = ExecutionId::new();
let tenant_id = TenantId::from("tenant_test123456789012345");
let usage = ExecutionUsage::new(exec_id, tenant_id);
usage.record_step();
usage.record_step();
assert_eq!(usage.step_count(), 2);
usage.record_tokens(100, 50);
assert_eq!(usage.total_tokens(), 150);
}
#[tokio::test]
async fn test_step_limit_enforcement() {
let middleware = EnforcementMiddleware::new();
let exec_id = ExecutionId::new();
let tenant_id = TenantId::from("tenant_test123456789012345");
let limits = ResourceLimits {
max_steps: 5,
max_tokens: 1000,
max_wall_time_ms: 60000,
max_memory_mb: None,
max_concurrent_executions: None,
};
let usage = middleware
.register_execution(exec_id.clone(), tenant_id)
.await;
for _ in 0..5 {
let result = middleware.check_step_allowed(&exec_id, &limits).await;
assert!(result.is_allowed(), "Step should be allowed");
usage.record_step();
}
let result = middleware.check_step_allowed(&exec_id, &limits).await;
assert!(result.is_blocked(), "Step should be blocked");
}
#[tokio::test]
async fn test_token_limit_enforcement() {
let middleware = EnforcementMiddleware::new();
let exec_id = ExecutionId::new();
let tenant_id = TenantId::from("tenant_test123456789012345");
let limits = ResourceLimits {
max_steps: 100,
max_tokens: 100,
max_wall_time_ms: 60000,
max_memory_mb: None,
max_concurrent_executions: None,
};
let usage = middleware
.register_execution(exec_id.clone(), tenant_id)
.await;
usage.record_tokens(50, 30);
let result = middleware.check_tokens_allowed(&exec_id, &limits, 25).await;
assert!(
result.is_blocked(),
"Should be blocked when exceeding limit"
);
let result = middleware.check_tokens_allowed(&exec_id, &limits, 10).await;
assert!(result.is_allowed(), "Should be allowed within limit");
}
#[tokio::test]
async fn test_warning_threshold() {
let policy = EnforcementPolicy {
warning_threshold: 80,
..Default::default()
};
let middleware = EnforcementMiddleware::with_policy(policy);
let exec_id = ExecutionId::new();
let tenant_id = TenantId::from("tenant_test123456789012345");
let limits = ResourceLimits {
max_steps: 10,
max_tokens: 1000,
max_wall_time_ms: 60000,
max_memory_mb: None,
max_concurrent_executions: None,
};
let usage = middleware
.register_execution(exec_id.clone(), tenant_id)
.await;
for _ in 0..7 {
usage.record_step();
}
let result = middleware.check_step_allowed(&exec_id, &limits).await;
assert!(matches!(result, EnforcementResult::Warning(_)));
}
#[test]
fn test_step_timeout_guard() {
let step_id = StepId::new();
let guard = StepTimeoutGuard::new(step_id, Duration::from_millis(100));
assert!(!guard.is_timed_out());
assert!(guard.check().is_ok());
std::thread::sleep(Duration::from_millis(150));
assert!(guard.is_timed_out());
assert!(guard.check().is_err());
}
#[tokio::test]
async fn test_concurrency_limit() {
let middleware = EnforcementMiddleware::new();
let tenant_id = TenantId::from("tenant_test123456789012345");
let limits = ResourceLimits {
max_steps: 100,
max_tokens: 1000,
max_wall_time_ms: 60000,
max_memory_mb: None,
max_concurrent_executions: Some(2),
};
let exec1 = ExecutionId::new();
let exec2 = ExecutionId::new();
middleware
.register_execution(exec1.clone(), tenant_id.clone())
.await;
middleware
.register_execution(exec2.clone(), tenant_id.clone())
.await;
let result = middleware
.check_concurrency_allowed(&tenant_id, &limits)
.await;
assert!(result.is_blocked());
middleware.unregister_execution(&exec1).await;
let result = middleware
.check_concurrency_allowed(&tenant_id, &limits)
.await;
assert!(result.is_allowed());
}
#[test]
fn test_network_violation_type() {
let violation = EnforcementViolation::new(ViolationType::NetworkViolation, 0, 0);
let error = violation.to_error();
assert_eq!(error.category, ExecutionErrorCategory::PolicyViolation);
assert!(!error.is_retryable());
assert!(error.is_fatal());
}
#[test]
fn test_violation_type_display_network() {
assert_eq!(
format!("{}", ViolationType::NetworkViolation),
"network_violation"
);
}
}