Skip to main content

enact_core/kernel/
enforcement.rs

1//! Enforcement - Kernel-owned limits, quotas, and rate limiting
2//!
3//! This module provides the enforcement layer that ensures executions
4//! respect their resource boundaries. All limit enforcement happens
5//! in the kernel, not in providers.
6//!
7//! ## Design Principles
8//!
9//! 1. **Kernel Owns Enforcement**: Providers are dumb adapters
10//! 2. **Hard Limits**: Quota exceeded = execution halts immediately
11//! 3. **Deterministic**: Same limits → same enforcement behavior
12//! 4. **Observable**: All enforcement decisions are logged/events
13//!
14//! ## Key Components
15//!
16//! - `UsageTracker`: Tracks resource consumption per execution
17//! - `EnforcementPolicy`: Defines limits and enforcement rules
18//! - `EnforcementResult`: Outcome of limit checks
19//!
20//! @see docs/feat-03-limits-quotas.md
21
22use super::error::{ExecutionError, ExecutionErrorCategory};
23use super::ids::{ExecutionId, StepId, TenantId};
24use crate::context::ResourceLimits;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use tokio::sync::RwLock;
31
32// =============================================================================
33// Usage Tracking
34// =============================================================================
35
36/// Tracks resource usage for a single execution
37#[derive(Debug)]
38pub struct ExecutionUsage {
39    /// Execution ID
40    pub execution_id: ExecutionId,
41    /// Tenant ID
42    pub tenant_id: TenantId,
43    /// Number of steps executed
44    pub steps: AtomicU32,
45    /// Total input tokens consumed
46    pub input_tokens: AtomicU32,
47    /// Total output tokens consumed
48    pub output_tokens: AtomicU32,
49    /// Wall clock start time
50    pub started_at: Instant,
51    /// Last activity timestamp
52    pub last_activity: RwLock<Instant>,
53    // === Long-running execution tracking ===
54    /// Number of dynamically discovered steps (StepSource::Discovered)
55    pub discovered_steps: AtomicU32,
56    /// Current discovery chain depth (how deep in the discovery tree)
57    pub discovery_depth: AtomicU32,
58    /// Maximum discovery depth reached during execution
59    pub max_discovery_depth_reached: AtomicU32,
60    /// Cumulative cost in cents (USD * 100 for integer precision)
61    pub cost_cents: AtomicU64,
62}
63
64impl ExecutionUsage {
65    /// Create a new usage tracker for an execution
66    pub fn new(execution_id: ExecutionId, tenant_id: TenantId) -> Self {
67        let now = Instant::now();
68        Self {
69            execution_id,
70            tenant_id,
71            steps: AtomicU32::new(0),
72            input_tokens: AtomicU32::new(0),
73            output_tokens: AtomicU32::new(0),
74            started_at: now,
75            last_activity: RwLock::new(now),
76            discovered_steps: AtomicU32::new(0),
77            discovery_depth: AtomicU32::new(0),
78            max_discovery_depth_reached: AtomicU32::new(0),
79            cost_cents: AtomicU64::new(0),
80        }
81    }
82
83    /// Record step execution
84    pub fn record_step(&self) {
85        self.steps.fetch_add(1, Ordering::SeqCst);
86    }
87
88    /// Record a discovered step (dynamically added to DAG)
89    pub fn record_discovered_step(&self) {
90        self.discovered_steps.fetch_add(1, Ordering::SeqCst);
91    }
92
93    /// Record token usage
94    pub fn record_tokens(&self, input: u32, output: u32) {
95        self.input_tokens.fetch_add(input, Ordering::SeqCst);
96        self.output_tokens.fetch_add(output, Ordering::SeqCst);
97    }
98
99    /// Record cost in USD (converted to cents for storage)
100    pub fn record_cost_usd(&self, cost_usd: f64) {
101        let cents = (cost_usd * 100.0) as u64;
102        self.cost_cents.fetch_add(cents, Ordering::SeqCst);
103    }
104
105    /// Push discovery depth (entering a discovered step)
106    pub fn push_discovery_depth(&self) {
107        let new_depth = self.discovery_depth.fetch_add(1, Ordering::SeqCst) + 1;
108        // Update max if this is deeper than before
109        let current_max = self.max_discovery_depth_reached.load(Ordering::SeqCst);
110        if new_depth > current_max {
111            self.max_discovery_depth_reached
112                .store(new_depth, Ordering::SeqCst);
113        }
114    }
115
116    /// Pop discovery depth (exiting a discovered step)
117    pub fn pop_discovery_depth(&self) {
118        self.discovery_depth.fetch_sub(1, Ordering::SeqCst);
119    }
120
121    /// Update last activity timestamp
122    pub async fn touch(&self) {
123        let mut last = self.last_activity.write().await;
124        *last = Instant::now();
125    }
126
127    /// Get current step count
128    pub fn step_count(&self) -> u32 {
129        self.steps.load(Ordering::SeqCst)
130    }
131
132    /// Get discovered step count
133    pub fn discovered_step_count(&self) -> u32 {
134        self.discovered_steps.load(Ordering::SeqCst)
135    }
136
137    /// Get current discovery depth
138    pub fn current_discovery_depth(&self) -> u32 {
139        self.discovery_depth.load(Ordering::SeqCst)
140    }
141
142    /// Get total token count
143    pub fn total_tokens(&self) -> u32 {
144        self.input_tokens.load(Ordering::SeqCst) + self.output_tokens.load(Ordering::SeqCst)
145    }
146
147    /// Get cumulative cost in USD
148    pub fn cost_usd(&self) -> f64 {
149        self.cost_cents.load(Ordering::SeqCst) as f64 / 100.0
150    }
151
152    /// Get wall clock duration
153    pub fn wall_time(&self) -> Duration {
154        self.started_at.elapsed()
155    }
156
157    /// Get wall time in milliseconds
158    pub fn wall_time_ms(&self) -> u64 {
159        self.wall_time().as_millis() as u64
160    }
161
162    /// Get idle duration (time since last activity)
163    pub async fn idle_duration(&self) -> Duration {
164        let last = self.last_activity.read().await;
165        last.elapsed()
166    }
167
168    /// Get idle duration in seconds
169    pub async fn idle_seconds(&self) -> u64 {
170        self.idle_duration().await.as_secs()
171    }
172}
173
174/// Serializable snapshot of execution usage
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct UsageSnapshot {
177    pub execution_id: String,
178    pub tenant_id: String,
179    pub steps: u32,
180    pub input_tokens: u32,
181    pub output_tokens: u32,
182    pub total_tokens: u32,
183    pub wall_time_ms: u64,
184    // Long-running execution metrics
185    pub discovered_steps: u32,
186    pub discovery_depth: u32,
187    pub max_discovery_depth: u32,
188    pub cost_usd: f64,
189}
190
191impl From<&ExecutionUsage> for UsageSnapshot {
192    fn from(usage: &ExecutionUsage) -> Self {
193        let input = usage.input_tokens.load(Ordering::SeqCst);
194        let output = usage.output_tokens.load(Ordering::SeqCst);
195        Self {
196            execution_id: usage.execution_id.as_str().to_string(),
197            tenant_id: usage.tenant_id.as_str().to_string(),
198            steps: usage.steps.load(Ordering::SeqCst),
199            input_tokens: input,
200            output_tokens: output,
201            total_tokens: input + output,
202            wall_time_ms: usage.wall_time_ms(),
203            discovered_steps: usage.discovered_steps.load(Ordering::SeqCst),
204            discovery_depth: usage.discovery_depth.load(Ordering::SeqCst),
205            max_discovery_depth: usage.max_discovery_depth_reached.load(Ordering::SeqCst),
206            cost_usd: usage.cost_usd(),
207        }
208    }
209}
210
211// =============================================================================
212// Enforcement Results
213// =============================================================================
214
215/// Result of an enforcement check
216#[derive(Debug, Clone, PartialEq, Eq)]
217pub enum EnforcementResult {
218    /// Operation is allowed to proceed
219    Allowed,
220    /// Operation is blocked due to limit exceeded
221    Blocked(EnforcementViolation),
222    /// Operation is allowed but near limit (warning)
223    Warning(EnforcementWarning),
224}
225
226impl EnforcementResult {
227    /// Check if the result allows the operation
228    pub fn is_allowed(&self) -> bool {
229        matches!(self, Self::Allowed | Self::Warning(_))
230    }
231
232    /// Check if the result blocks the operation
233    pub fn is_blocked(&self) -> bool {
234        matches!(self, Self::Blocked(_))
235    }
236
237    /// Convert to an ExecutionError if blocked
238    pub fn to_error(&self) -> Option<ExecutionError> {
239        match self {
240            Self::Blocked(violation) => Some(violation.to_error()),
241            _ => None,
242        }
243    }
244}
245
246/// Type of enforcement violation
247#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
248pub enum ViolationType {
249    /// Maximum steps exceeded
250    StepLimit,
251    /// Maximum tokens exceeded
252    TokenLimit,
253    /// Wall clock timeout exceeded
254    WallTimeLimit,
255    /// Memory limit exceeded
256    MemoryLimit,
257    /// Concurrent execution limit exceeded
258    ConcurrencyLimit,
259    /// Rate limit exceeded
260    RateLimit,
261    /// Network access denied in air-gapped mode
262    NetworkViolation,
263    // === Long-running execution controls ===
264    /// Maximum discovered steps exceeded (agentic DAG)
265    DiscoveredStepLimit,
266    /// Discovery chain depth exceeded (prevents infinite discovery)
267    DiscoveryDepthLimit,
268    /// Cost threshold exceeded (USD-based alerting)
269    CostThreshold,
270    /// No activity for too long (idle timeout)
271    IdleTimeout,
272    /// Agent repeating same methodology (semantic loop)
273    SameStepLoop,
274}
275
276impl std::fmt::Display for ViolationType {
277    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278        match self {
279            Self::StepLimit => write!(f, "step_limit"),
280            Self::TokenLimit => write!(f, "token_limit"),
281            Self::WallTimeLimit => write!(f, "wall_time_limit"),
282            Self::MemoryLimit => write!(f, "memory_limit"),
283            Self::ConcurrencyLimit => write!(f, "concurrency_limit"),
284            Self::RateLimit => write!(f, "rate_limit"),
285            Self::NetworkViolation => write!(f, "network_violation"),
286            Self::DiscoveredStepLimit => write!(f, "discovered_step_limit"),
287            Self::DiscoveryDepthLimit => write!(f, "discovery_depth_limit"),
288            Self::CostThreshold => write!(f, "cost_threshold"),
289            Self::IdleTimeout => write!(f, "idle_timeout"),
290            Self::SameStepLoop => write!(f, "same_step_loop"),
291        }
292    }
293}
294
295/// Details of an enforcement violation
296#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
297pub struct EnforcementViolation {
298    /// Type of violation
299    pub violation_type: ViolationType,
300    /// Current value
301    pub current: u64,
302    /// Limit value
303    pub limit: u64,
304    /// Human-readable message
305    pub message: String,
306}
307
308impl EnforcementViolation {
309    /// Create a new violation
310    pub fn new(violation_type: ViolationType, current: u64, limit: u64) -> Self {
311        let message = format!(
312            "{} exceeded: {} / {} ({}%)",
313            violation_type,
314            current,
315            limit,
316            (current as f64 / limit as f64 * 100.0) as u32
317        );
318        Self {
319            violation_type,
320            current,
321            limit,
322            message,
323        }
324    }
325
326    /// Convert to an ExecutionError
327    pub fn to_error(&self) -> ExecutionError {
328        let category = match self.violation_type {
329            ViolationType::WallTimeLimit => ExecutionErrorCategory::Timeout,
330            ViolationType::RateLimit => ExecutionErrorCategory::LlmError, // Rate limits are retryable
331            ViolationType::NetworkViolation => ExecutionErrorCategory::PolicyViolation, // Non-retryable policy
332            _ => ExecutionErrorCategory::QuotaExceeded,
333        };
334
335        ExecutionError::new(category, self.message.clone())
336            .with_code(self.violation_type.to_string())
337            .with_details(serde_json::json!({
338                "current": self.current,
339                "limit": self.limit,
340                "violation_type": self.violation_type.to_string(),
341            }))
342    }
343}
344
345/// Warning about approaching limits
346#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
347pub struct EnforcementWarning {
348    /// Type of limit being approached
349    pub warning_type: ViolationType,
350    /// Current usage percentage (0-100)
351    pub usage_percent: u32,
352    /// Human-readable message
353    pub message: String,
354}
355
356impl EnforcementWarning {
357    /// Create a new warning
358    pub fn new(warning_type: ViolationType, current: u64, limit: u64) -> Self {
359        let percent = (current as f64 / limit as f64 * 100.0) as u32;
360        let message = format!("{} at {}%: {} / {}", warning_type, percent, current, limit);
361        Self {
362            warning_type,
363            usage_percent: percent,
364            message,
365        }
366    }
367}
368
369// =============================================================================
370// Enforcement Policy
371// =============================================================================
372
373/// Configuration for enforcement behavior
374#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct EnforcementPolicy {
376    /// Warning threshold (percentage of limit, 0-100)
377    pub warning_threshold: u32,
378    /// Whether to emit events on warnings
379    pub emit_warning_events: bool,
380    /// Whether to emit events on blocks
381    pub emit_block_events: bool,
382    /// Grace period for timeouts (milliseconds)
383    pub timeout_grace_ms: u64,
384}
385
386impl Default for EnforcementPolicy {
387    fn default() -> Self {
388        Self {
389            warning_threshold: 80, // Warn at 80% usage
390            emit_warning_events: true,
391            emit_block_events: true,
392            timeout_grace_ms: 1000, // 1 second grace period
393        }
394    }
395}
396
397// =============================================================================
398// Long-Running Execution Policy
399// =============================================================================
400
401/// Policy configuration for long-running agentic executions
402///
403/// These controls prevent runaway costs, infinite discovery loops, and idle
404/// executions from consuming resources in the Agentic DAG model.
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct LongRunningExecutionPolicy {
407    /// Maximum number of dynamically discovered steps before intervention
408    /// (Steps with StepSource::Discovered)
409    pub max_discovered_steps: Option<u32>,
410    /// Maximum depth of discovery chains (prevents infinite discovery)
411    /// e.g., agent discovers step A, which discovers step B, which discovers C...
412    pub max_discovery_depth: Option<u32>,
413    /// Alert threshold for cumulative cost in USD
414    /// When exceeded, execution pauses for approval
415    pub cost_alert_threshold_usd: Option<f64>,
416    /// Maximum time without activity before idle timeout (seconds)
417    pub idle_timeout_seconds: Option<u64>,
418    /// Maximum repetitions of same methodology before loop detection (default: 3)
419    pub max_same_step_repetitions: Option<u32>,
420}
421
422impl Default for LongRunningExecutionPolicy {
423    fn default() -> Self {
424        Self::standard()
425    }
426}
427
428impl LongRunningExecutionPolicy {
429    /// Standard preset - balanced limits for typical long-running executions
430    /// - Max duration: ~30 minutes (via idle timeout)
431    /// - Discovered steps: 50
432    /// - Discovery depth: 5
433    /// - Cost alert: $5.00 USD
434    pub fn standard() -> Self {
435        Self {
436            max_discovered_steps: Some(50),
437            max_discovery_depth: Some(5),
438            cost_alert_threshold_usd: Some(5.0),
439            idle_timeout_seconds: Some(1800), // 30 minutes
440            max_same_step_repetitions: Some(3),
441        }
442    }
443
444    /// Extended preset - higher limits for complex, supervised workflows
445    /// - Max duration: ~4 hours
446    /// - Discovered steps: 300
447    /// - Discovery depth: 10
448    /// - Cost alert: $50.00 USD
449    pub fn extended() -> Self {
450        Self {
451            max_discovered_steps: Some(300),
452            max_discovery_depth: Some(10),
453            cost_alert_threshold_usd: Some(50.0),
454            idle_timeout_seconds: Some(14400), // 4 hours
455            max_same_step_repetitions: Some(5),
456        }
457    }
458
459    /// Unlimited preset - no discovery limits, but requires cost monitoring
460    /// - No step/depth limits
461    /// - Cost alert: $100.00 USD (mandatory safety net)
462    /// - Idle timeout: 24 hours
463    pub fn unlimited() -> Self {
464        Self {
465            max_discovered_steps: None,
466            max_discovery_depth: None,
467            cost_alert_threshold_usd: Some(100.0), // Required safety net
468            idle_timeout_seconds: Some(86400),     // 24 hours
469            max_same_step_repetitions: None,
470        }
471    }
472
473    /// Disabled - no long-running controls (use with caution)
474    pub fn disabled() -> Self {
475        Self {
476            max_discovered_steps: None,
477            max_discovery_depth: None,
478            cost_alert_threshold_usd: None,
479            idle_timeout_seconds: None,
480            max_same_step_repetitions: None,
481        }
482    }
483}
484
485// =============================================================================
486// Enforcement Middleware
487// =============================================================================
488
489/// Enforcement middleware for checking limits before operations
490#[derive(Debug)]
491pub struct EnforcementMiddleware {
492    /// Active executions and their usage
493    executions: RwLock<HashMap<ExecutionId, Arc<ExecutionUsage>>>,
494    /// Active execution count per tenant
495    tenant_executions: RwLock<HashMap<TenantId, AtomicU32>>,
496    /// Global rate limiter state
497    #[allow(dead_code)]
498    rate_limiter: RwLock<RateLimiterState>,
499    /// Enforcement policy
500    policy: EnforcementPolicy,
501}
502
503impl EnforcementMiddleware {
504    /// Create a new enforcement middleware
505    pub fn new() -> Self {
506        Self::with_policy(EnforcementPolicy::default())
507    }
508
509    /// Create with custom policy
510    pub fn with_policy(policy: EnforcementPolicy) -> Self {
511        Self {
512            executions: RwLock::new(HashMap::new()),
513            tenant_executions: RwLock::new(HashMap::new()),
514            rate_limiter: RwLock::new(RateLimiterState::new()),
515            policy,
516        }
517    }
518
519    /// Whether warning events should be emitted when limits are near
520    pub fn emit_warning_events_enabled(&self) -> bool {
521        self.policy.emit_warning_events
522    }
523
524    /// Register a new execution
525    pub async fn register_execution(
526        &self,
527        execution_id: ExecutionId,
528        tenant_id: TenantId,
529    ) -> Arc<ExecutionUsage> {
530        let usage = Arc::new(ExecutionUsage::new(execution_id.clone(), tenant_id.clone()));
531
532        // Register in executions map
533        {
534            let mut executions = self.executions.write().await;
535            executions.insert(execution_id, Arc::clone(&usage));
536        }
537
538        // Increment tenant execution count
539        {
540            let mut tenant_execs = self.tenant_executions.write().await;
541            tenant_execs
542                .entry(tenant_id)
543                .or_insert_with(|| AtomicU32::new(0))
544                .fetch_add(1, Ordering::SeqCst);
545        }
546
547        usage
548    }
549
550    /// Unregister an execution
551    pub async fn unregister_execution(&self, execution_id: &ExecutionId) {
552        let tenant_id = {
553            let mut executions = self.executions.write().await;
554            executions.remove(execution_id).map(|u| u.tenant_id.clone())
555        };
556
557        // Decrement tenant execution count
558        if let Some(tenant_id) = tenant_id {
559            let tenant_execs = self.tenant_executions.read().await;
560            if let Some(count) = tenant_execs.get(&tenant_id) {
561                count.fetch_sub(1, Ordering::SeqCst);
562            }
563        }
564    }
565
566    /// Get usage for an execution
567    pub async fn get_usage(&self, execution_id: &ExecutionId) -> Option<Arc<ExecutionUsage>> {
568        let executions = self.executions.read().await;
569        executions.get(execution_id).cloned()
570    }
571
572    /// Get usage snapshot for an execution
573    pub async fn get_usage_snapshot(&self, execution_id: &ExecutionId) -> Option<UsageSnapshot> {
574        self.get_usage(execution_id)
575            .await
576            .map(|u| UsageSnapshot::from(u.as_ref()))
577    }
578
579    /// Check if a new step can be started
580    pub async fn check_step_allowed(
581        &self,
582        execution_id: &ExecutionId,
583        limits: &ResourceLimits,
584    ) -> EnforcementResult {
585        let usage = match self.get_usage(execution_id).await {
586            Some(u) => u,
587            None => return EnforcementResult::Allowed, // No tracking = allowed
588        };
589
590        let current = usage.step_count() as u64 + 1; // +1 for the step we're about to start
591        let limit = limits.max_steps as u64;
592
593        if current > limit {
594            return EnforcementResult::Blocked(EnforcementViolation::new(
595                ViolationType::StepLimit,
596                current,
597                limit,
598            ));
599        }
600
601        let percent = (current as f64 / limit as f64 * 100.0) as u32;
602        if percent >= self.policy.warning_threshold {
603            return EnforcementResult::Warning(EnforcementWarning::new(
604                ViolationType::StepLimit,
605                current,
606                limit,
607            ));
608        }
609
610        EnforcementResult::Allowed
611    }
612
613    /// Check if token usage is within limits
614    pub async fn check_tokens_allowed(
615        &self,
616        execution_id: &ExecutionId,
617        limits: &ResourceLimits,
618        additional_tokens: u32,
619    ) -> EnforcementResult {
620        let usage = match self.get_usage(execution_id).await {
621            Some(u) => u,
622            None => return EnforcementResult::Allowed,
623        };
624
625        let current = usage.total_tokens() as u64 + additional_tokens as u64;
626        let limit = limits.max_tokens as u64;
627
628        if current > limit {
629            return EnforcementResult::Blocked(EnforcementViolation::new(
630                ViolationType::TokenLimit,
631                current,
632                limit,
633            ));
634        }
635
636        let percent = (current as f64 / limit as f64 * 100.0) as u32;
637        if percent >= self.policy.warning_threshold {
638            return EnforcementResult::Warning(EnforcementWarning::new(
639                ViolationType::TokenLimit,
640                current,
641                limit,
642            ));
643        }
644
645        EnforcementResult::Allowed
646    }
647
648    /// Check if wall time is within limits
649    pub async fn check_wall_time_allowed(
650        &self,
651        execution_id: &ExecutionId,
652        limits: &ResourceLimits,
653    ) -> EnforcementResult {
654        let usage = match self.get_usage(execution_id).await {
655            Some(u) => u,
656            None => return EnforcementResult::Allowed,
657        };
658
659        let current = usage.wall_time_ms();
660        let limit = limits.max_wall_time_ms;
661
662        // Add grace period
663        let effective_limit = limit + self.policy.timeout_grace_ms;
664
665        if current > effective_limit {
666            return EnforcementResult::Blocked(EnforcementViolation::new(
667                ViolationType::WallTimeLimit,
668                current,
669                limit,
670            ));
671        }
672
673        let percent = (current as f64 / limit as f64 * 100.0) as u32;
674        if percent >= self.policy.warning_threshold {
675            return EnforcementResult::Warning(EnforcementWarning::new(
676                ViolationType::WallTimeLimit,
677                current,
678                limit,
679            ));
680        }
681
682        EnforcementResult::Allowed
683    }
684
685    /// Check if concurrent execution limit is respected
686    pub async fn check_concurrency_allowed(
687        &self,
688        tenant_id: &TenantId,
689        limits: &ResourceLimits,
690    ) -> EnforcementResult {
691        let max_concurrent = match limits.max_concurrent_executions {
692            Some(max) => max,
693            None => return EnforcementResult::Allowed, // No limit set
694        };
695
696        let current = {
697            let tenant_execs = self.tenant_executions.read().await;
698            tenant_execs
699                .get(tenant_id)
700                .map(|c| c.load(Ordering::SeqCst))
701                .unwrap_or(0) as u64
702        };
703
704        let limit = max_concurrent as u64;
705
706        if current >= limit {
707            return EnforcementResult::Blocked(EnforcementViolation::new(
708                ViolationType::ConcurrencyLimit,
709                current + 1, // +1 for the execution we're about to start
710                limit,
711            ));
712        }
713
714        EnforcementResult::Allowed
715    }
716
717    /// Perform all limit checks before starting a step
718    pub async fn check_all_limits(
719        &self,
720        execution_id: &ExecutionId,
721        limits: &ResourceLimits,
722    ) -> EnforcementResult {
723        // Check wall time first (most likely to timeout)
724        let wall_check = self.check_wall_time_allowed(execution_id, limits).await;
725        if wall_check.is_blocked() {
726            return wall_check;
727        }
728
729        // Check step limit
730        let step_check = self.check_step_allowed(execution_id, limits).await;
731        if step_check.is_blocked() {
732            return step_check;
733        }
734
735        // Check token limit
736        let token_check = self.check_tokens_allowed(execution_id, limits, 0).await;
737        if token_check.is_blocked() {
738            return token_check;
739        }
740
741        // Return warnings if any
742        if let EnforcementResult::Warning(w) = wall_check {
743            return EnforcementResult::Warning(w);
744        }
745        if let EnforcementResult::Warning(w) = step_check {
746            return EnforcementResult::Warning(w);
747        }
748        if let EnforcementResult::Warning(w) = token_check {
749            return EnforcementResult::Warning(w);
750        }
751
752        EnforcementResult::Allowed
753    }
754
755    /// Record step completion and update usage
756    pub async fn record_step(&self, execution_id: &ExecutionId) {
757        if let Some(usage) = self.get_usage(execution_id).await {
758            usage.record_step();
759            usage.touch().await;
760        }
761    }
762
763    /// Record token usage
764    pub async fn record_tokens(&self, execution_id: &ExecutionId, input: u32, output: u32) {
765        if let Some(usage) = self.get_usage(execution_id).await {
766            usage.record_tokens(input, output);
767            usage.touch().await;
768        }
769    }
770
771    /// Record a discovered step and update usage
772    pub async fn record_discovered_step(&self, execution_id: &ExecutionId) {
773        if let Some(usage) = self.get_usage(execution_id).await {
774            usage.record_discovered_step();
775            usage.touch().await;
776        }
777    }
778
779    /// Record cost in USD
780    pub async fn record_cost(&self, execution_id: &ExecutionId, cost_usd: f64) {
781        if let Some(usage) = self.get_usage(execution_id).await {
782            usage.record_cost_usd(cost_usd);
783            usage.touch().await;
784        }
785    }
786
787    /// Push discovery depth (entering a discovered step's sub-execution)
788    pub async fn push_discovery_depth(&self, execution_id: &ExecutionId) {
789        if let Some(usage) = self.get_usage(execution_id).await {
790            usage.push_discovery_depth();
791        }
792    }
793
794    /// Pop discovery depth (exiting a discovered step's sub-execution)
795    pub async fn pop_discovery_depth(&self, execution_id: &ExecutionId) {
796        if let Some(usage) = self.get_usage(execution_id).await {
797            usage.pop_discovery_depth();
798        }
799    }
800
801    // =========================================================================
802    // Long-Running Execution Checks
803    // =========================================================================
804
805    /// Check if discovered step limit is within bounds
806    pub async fn check_discovered_step_limit(
807        &self,
808        execution_id: &ExecutionId,
809        policy: &LongRunningExecutionPolicy,
810    ) -> EnforcementResult {
811        let max_discovered = match policy.max_discovered_steps {
812            Some(max) => max,
813            None => return EnforcementResult::Allowed,
814        };
815
816        let usage = match self.get_usage(execution_id).await {
817            Some(u) => u,
818            None => return EnforcementResult::Allowed,
819        };
820
821        let current = usage.discovered_step_count() as u64 + 1; // +1 for step we're about to discover
822        let limit = max_discovered as u64;
823
824        if current > limit {
825            return EnforcementResult::Blocked(EnforcementViolation::new(
826                ViolationType::DiscoveredStepLimit,
827                current,
828                limit,
829            ));
830        }
831
832        let percent = (current as f64 / limit as f64 * 100.0) as u32;
833        if percent >= self.policy.warning_threshold {
834            return EnforcementResult::Warning(EnforcementWarning::new(
835                ViolationType::DiscoveredStepLimit,
836                current,
837                limit,
838            ));
839        }
840
841        EnforcementResult::Allowed
842    }
843
844    /// Check if discovery depth is within bounds
845    pub async fn check_discovery_depth_limit(
846        &self,
847        execution_id: &ExecutionId,
848        policy: &LongRunningExecutionPolicy,
849    ) -> EnforcementResult {
850        let max_depth = match policy.max_discovery_depth {
851            Some(max) => max,
852            None => return EnforcementResult::Allowed,
853        };
854
855        let usage = match self.get_usage(execution_id).await {
856            Some(u) => u,
857            None => return EnforcementResult::Allowed,
858        };
859
860        let current = usage.current_discovery_depth() as u64 + 1; // +1 for depth we're about to enter
861        let limit = max_depth as u64;
862
863        if current > limit {
864            return EnforcementResult::Blocked(EnforcementViolation::new(
865                ViolationType::DiscoveryDepthLimit,
866                current,
867                limit,
868            ));
869        }
870
871        // No warning for depth - it's either allowed or not
872        EnforcementResult::Allowed
873    }
874
875    /// Check if cost threshold has been exceeded
876    pub async fn check_cost_threshold(
877        &self,
878        execution_id: &ExecutionId,
879        policy: &LongRunningExecutionPolicy,
880    ) -> EnforcementResult {
881        let threshold = match policy.cost_alert_threshold_usd {
882            Some(t) => t,
883            None => return EnforcementResult::Allowed,
884        };
885
886        let usage = match self.get_usage(execution_id).await {
887            Some(u) => u,
888            None => return EnforcementResult::Allowed,
889        };
890
891        let current_cents = usage.cost_cents.load(Ordering::SeqCst);
892        let current_usd = current_cents as f64 / 100.0;
893        let limit_cents = (threshold * 100.0) as u64;
894
895        if current_usd >= threshold {
896            return EnforcementResult::Blocked(EnforcementViolation::new(
897                ViolationType::CostThreshold,
898                current_cents,
899                limit_cents,
900            ));
901        }
902
903        let percent = (current_usd / threshold * 100.0) as u32;
904        if percent >= self.policy.warning_threshold {
905            return EnforcementResult::Warning(EnforcementWarning::new(
906                ViolationType::CostThreshold,
907                current_cents,
908                limit_cents,
909            ));
910        }
911
912        EnforcementResult::Allowed
913    }
914
915    /// Check if idle timeout has been exceeded
916    pub async fn check_idle_timeout(
917        &self,
918        execution_id: &ExecutionId,
919        policy: &LongRunningExecutionPolicy,
920    ) -> EnforcementResult {
921        let timeout_secs = match policy.idle_timeout_seconds {
922            Some(t) => t,
923            None => return EnforcementResult::Allowed,
924        };
925
926        let usage = match self.get_usage(execution_id).await {
927            Some(u) => u,
928            None => return EnforcementResult::Allowed,
929        };
930
931        let idle_secs = usage.idle_seconds().await;
932
933        if idle_secs >= timeout_secs {
934            return EnforcementResult::Blocked(EnforcementViolation::new(
935                ViolationType::IdleTimeout,
936                idle_secs,
937                timeout_secs,
938            ));
939        }
940
941        // Warn at 80% of idle timeout
942        let percent = (idle_secs as f64 / timeout_secs as f64 * 100.0) as u32;
943        if percent >= self.policy.warning_threshold {
944            return EnforcementResult::Warning(EnforcementWarning::new(
945                ViolationType::IdleTimeout,
946                idle_secs,
947                timeout_secs,
948            ));
949        }
950
951        EnforcementResult::Allowed
952    }
953
954    /// Perform all long-running execution checks
955    pub async fn check_long_running_limits(
956        &self,
957        execution_id: &ExecutionId,
958        policy: &LongRunningExecutionPolicy,
959    ) -> EnforcementResult {
960        // Check cost threshold first (most critical for runaway costs)
961        let cost_check = self.check_cost_threshold(execution_id, policy).await;
962        if cost_check.is_blocked() {
963            return cost_check;
964        }
965
966        // Check discovery depth (prevents infinite discovery)
967        let depth_check = self.check_discovery_depth_limit(execution_id, policy).await;
968        if depth_check.is_blocked() {
969            return depth_check;
970        }
971
972        // Check discovered step count
973        let discovered_check = self.check_discovered_step_limit(execution_id, policy).await;
974        if discovered_check.is_blocked() {
975            return discovered_check;
976        }
977
978        // Check idle timeout
979        let idle_check = self.check_idle_timeout(execution_id, policy).await;
980        if idle_check.is_blocked() {
981            return idle_check;
982        }
983
984        // Return warnings if any
985        if let EnforcementResult::Warning(w) = cost_check {
986            return EnforcementResult::Warning(w);
987        }
988        if let EnforcementResult::Warning(w) = discovered_check {
989            return EnforcementResult::Warning(w);
990        }
991        if let EnforcementResult::Warning(w) = idle_check {
992            return EnforcementResult::Warning(w);
993        }
994
995        EnforcementResult::Allowed
996    }
997}
998
999impl Default for EnforcementMiddleware {
1000    fn default() -> Self {
1001        Self::new()
1002    }
1003}
1004
1005// =============================================================================
1006// Rate Limiter
1007// =============================================================================
1008
1009/// Rate limiter state using token bucket algorithm
1010#[derive(Debug)]
1011struct RateLimiterState {
1012    /// Tokens per provider
1013    #[allow(dead_code)]
1014    provider_tokens: HashMap<String, TokenBucket>,
1015}
1016
1017impl RateLimiterState {
1018    fn new() -> Self {
1019        Self {
1020            provider_tokens: HashMap::new(),
1021        }
1022    }
1023}
1024
1025/// Token bucket for rate limiting
1026#[derive(Debug)]
1027struct TokenBucket {
1028    /// Current token count
1029    tokens: AtomicU64,
1030    /// Maximum tokens (bucket size)
1031    max_tokens: u64,
1032    /// Tokens added per second
1033    refill_rate: u64,
1034    /// Last refill timestamp
1035    last_refill: RwLock<Instant>,
1036}
1037
1038impl TokenBucket {
1039    /// Create a new token bucket
1040    #[allow(dead_code)]
1041    fn new(max_tokens: u64, refill_rate: u64) -> Self {
1042        Self {
1043            tokens: AtomicU64::new(max_tokens),
1044            max_tokens,
1045            refill_rate,
1046            last_refill: RwLock::new(Instant::now()),
1047        }
1048    }
1049
1050    /// Try to acquire tokens
1051    #[allow(dead_code)]
1052    async fn try_acquire(&self, count: u64) -> bool {
1053        // Refill tokens based on elapsed time
1054        {
1055            let mut last = self.last_refill.write().await;
1056            let elapsed = last.elapsed();
1057            let new_tokens = (elapsed.as_secs_f64() * self.refill_rate as f64) as u64;
1058            if new_tokens > 0 {
1059                let current = self.tokens.load(Ordering::SeqCst);
1060                let new_total = std::cmp::min(current + new_tokens, self.max_tokens);
1061                self.tokens.store(new_total, Ordering::SeqCst);
1062                *last = Instant::now();
1063            }
1064        }
1065
1066        // Try to acquire
1067        let current = self.tokens.load(Ordering::SeqCst);
1068        if current >= count {
1069            self.tokens.fetch_sub(count, Ordering::SeqCst);
1070            true
1071        } else {
1072            false
1073        }
1074    }
1075}
1076
1077// =============================================================================
1078// Step Timeout Guard
1079// =============================================================================
1080
1081/// Guard for enforcing step timeouts
1082pub struct StepTimeoutGuard {
1083    step_id: StepId,
1084    timeout: Duration,
1085    started_at: Instant,
1086}
1087
1088impl StepTimeoutGuard {
1089    /// Create a new timeout guard
1090    pub fn new(step_id: StepId, timeout: Duration) -> Self {
1091        Self {
1092            step_id,
1093            timeout,
1094            started_at: Instant::now(),
1095        }
1096    }
1097
1098    /// Check if the timeout has been exceeded
1099    pub fn is_timed_out(&self) -> bool {
1100        self.started_at.elapsed() > self.timeout
1101    }
1102
1103    /// Get remaining time
1104    pub fn remaining(&self) -> Duration {
1105        self.timeout.saturating_sub(self.started_at.elapsed())
1106    }
1107
1108    /// Get elapsed time
1109    pub fn elapsed(&self) -> Duration {
1110        self.started_at.elapsed()
1111    }
1112
1113    /// Check and return an error if timed out
1114    #[allow(clippy::result_large_err)]
1115    pub fn check(&self) -> Result<(), ExecutionError> {
1116        if self.is_timed_out() {
1117            Err(ExecutionError::timeout(format!(
1118                "Step {} timed out after {:?}",
1119                self.step_id, self.timeout
1120            ))
1121            .with_step_id(self.step_id.clone()))
1122        } else {
1123            Ok(())
1124        }
1125    }
1126}
1127
1128// =============================================================================
1129// Tests
1130// =============================================================================
1131
1132#[cfg(test)]
1133mod tests {
1134    use super::*;
1135
1136    #[tokio::test]
1137    async fn test_usage_tracking() {
1138        let exec_id = ExecutionId::new();
1139        let tenant_id = TenantId::from("tenant_test123456789012345");
1140        let usage = ExecutionUsage::new(exec_id, tenant_id);
1141
1142        usage.record_step();
1143        usage.record_step();
1144        assert_eq!(usage.step_count(), 2);
1145
1146        usage.record_tokens(100, 50);
1147        assert_eq!(usage.total_tokens(), 150);
1148    }
1149
1150    #[tokio::test]
1151    async fn test_step_limit_enforcement() {
1152        let middleware = EnforcementMiddleware::new();
1153        let exec_id = ExecutionId::new();
1154        let tenant_id = TenantId::from("tenant_test123456789012345");
1155
1156        let limits = ResourceLimits {
1157            max_steps: 5,
1158            max_tokens: 1000,
1159            max_wall_time_ms: 60000,
1160            max_memory_mb: None,
1161            max_concurrent_executions: None,
1162        };
1163
1164        let usage = middleware
1165            .register_execution(exec_id.clone(), tenant_id)
1166            .await;
1167
1168        // First 5 steps should be allowed
1169        for _ in 0..5 {
1170            let result = middleware.check_step_allowed(&exec_id, &limits).await;
1171            assert!(result.is_allowed(), "Step should be allowed");
1172            usage.record_step();
1173        }
1174
1175        // 6th step should be blocked
1176        let result = middleware.check_step_allowed(&exec_id, &limits).await;
1177        assert!(result.is_blocked(), "Step should be blocked");
1178    }
1179
1180    #[tokio::test]
1181    async fn test_token_limit_enforcement() {
1182        let middleware = EnforcementMiddleware::new();
1183        let exec_id = ExecutionId::new();
1184        let tenant_id = TenantId::from("tenant_test123456789012345");
1185
1186        let limits = ResourceLimits {
1187            max_steps: 100,
1188            max_tokens: 100,
1189            max_wall_time_ms: 60000,
1190            max_memory_mb: None,
1191            max_concurrent_executions: None,
1192        };
1193
1194        let usage = middleware
1195            .register_execution(exec_id.clone(), tenant_id)
1196            .await;
1197
1198        // Record some tokens
1199        usage.record_tokens(50, 30);
1200
1201        // Check with additional tokens that would exceed
1202        let result = middleware.check_tokens_allowed(&exec_id, &limits, 25).await;
1203        assert!(
1204            result.is_blocked(),
1205            "Should be blocked when exceeding limit"
1206        );
1207
1208        // Check with tokens that would stay within limit
1209        let result = middleware.check_tokens_allowed(&exec_id, &limits, 10).await;
1210        assert!(result.is_allowed(), "Should be allowed within limit");
1211    }
1212
1213    #[tokio::test]
1214    async fn test_warning_threshold() {
1215        let policy = EnforcementPolicy {
1216            warning_threshold: 80,
1217            ..Default::default()
1218        };
1219        let middleware = EnforcementMiddleware::with_policy(policy);
1220        let exec_id = ExecutionId::new();
1221        let tenant_id = TenantId::from("tenant_test123456789012345");
1222
1223        let limits = ResourceLimits {
1224            max_steps: 10,
1225            max_tokens: 1000,
1226            max_wall_time_ms: 60000,
1227            max_memory_mb: None,
1228            max_concurrent_executions: None,
1229        };
1230
1231        let usage = middleware
1232            .register_execution(exec_id.clone(), tenant_id)
1233            .await;
1234
1235        // Record 8 steps (80% = warning threshold)
1236        for _ in 0..7 {
1237            usage.record_step();
1238        }
1239
1240        // 8th step should trigger warning
1241        let result = middleware.check_step_allowed(&exec_id, &limits).await;
1242        assert!(matches!(result, EnforcementResult::Warning(_)));
1243    }
1244
1245    #[test]
1246    fn test_step_timeout_guard() {
1247        let step_id = StepId::new();
1248        let guard = StepTimeoutGuard::new(step_id, Duration::from_millis(100));
1249
1250        assert!(!guard.is_timed_out());
1251        assert!(guard.check().is_ok());
1252
1253        // Sleep past timeout
1254        std::thread::sleep(Duration::from_millis(150));
1255
1256        assert!(guard.is_timed_out());
1257        assert!(guard.check().is_err());
1258    }
1259
1260    #[tokio::test]
1261    async fn test_concurrency_limit() {
1262        let middleware = EnforcementMiddleware::new();
1263        let tenant_id = TenantId::from("tenant_test123456789012345");
1264
1265        let limits = ResourceLimits {
1266            max_steps: 100,
1267            max_tokens: 1000,
1268            max_wall_time_ms: 60000,
1269            max_memory_mb: None,
1270            max_concurrent_executions: Some(2),
1271        };
1272
1273        // Register 2 executions
1274        let exec1 = ExecutionId::new();
1275        let exec2 = ExecutionId::new();
1276        middleware
1277            .register_execution(exec1.clone(), tenant_id.clone())
1278            .await;
1279        middleware
1280            .register_execution(exec2.clone(), tenant_id.clone())
1281            .await;
1282
1283        // Third should be blocked
1284        let result = middleware
1285            .check_concurrency_allowed(&tenant_id, &limits)
1286            .await;
1287        assert!(result.is_blocked());
1288
1289        // Unregister one
1290        middleware.unregister_execution(&exec1).await;
1291
1292        // Now should be allowed
1293        let result = middleware
1294            .check_concurrency_allowed(&tenant_id, &limits)
1295            .await;
1296        assert!(result.is_allowed());
1297    }
1298
1299    #[test]
1300    fn test_network_violation_type() {
1301        // Verify NetworkViolation exists and is non-retryable
1302        let violation = EnforcementViolation::new(ViolationType::NetworkViolation, 0, 0);
1303
1304        let error = violation.to_error();
1305        assert_eq!(error.category, ExecutionErrorCategory::PolicyViolation);
1306        assert!(!error.is_retryable());
1307        assert!(error.is_fatal());
1308    }
1309
1310    #[test]
1311    fn test_violation_type_display_network() {
1312        assert_eq!(
1313            format!("{}", ViolationType::NetworkViolation),
1314            "network_violation"
1315        );
1316    }
1317}