1use super::error::{ExecutionError, ExecutionErrorCategory};
23use super::ids::{ExecutionId, StepId, TenantId};
24use crate::context::ResourceLimits;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use tokio::sync::RwLock;
31
32#[derive(Debug)]
38pub struct ExecutionUsage {
39 pub execution_id: ExecutionId,
41 pub tenant_id: TenantId,
43 pub steps: AtomicU32,
45 pub input_tokens: AtomicU32,
47 pub output_tokens: AtomicU32,
49 pub started_at: Instant,
51 pub last_activity: RwLock<Instant>,
53 pub discovered_steps: AtomicU32,
56 pub discovery_depth: AtomicU32,
58 pub max_discovery_depth_reached: AtomicU32,
60 pub cost_cents: AtomicU64,
62}
63
64impl ExecutionUsage {
65 pub fn new(execution_id: ExecutionId, tenant_id: TenantId) -> Self {
67 let now = Instant::now();
68 Self {
69 execution_id,
70 tenant_id,
71 steps: AtomicU32::new(0),
72 input_tokens: AtomicU32::new(0),
73 output_tokens: AtomicU32::new(0),
74 started_at: now,
75 last_activity: RwLock::new(now),
76 discovered_steps: AtomicU32::new(0),
77 discovery_depth: AtomicU32::new(0),
78 max_discovery_depth_reached: AtomicU32::new(0),
79 cost_cents: AtomicU64::new(0),
80 }
81 }
82
83 pub fn record_step(&self) {
85 self.steps.fetch_add(1, Ordering::SeqCst);
86 }
87
88 pub fn record_discovered_step(&self) {
90 self.discovered_steps.fetch_add(1, Ordering::SeqCst);
91 }
92
93 pub fn record_tokens(&self, input: u32, output: u32) {
95 self.input_tokens.fetch_add(input, Ordering::SeqCst);
96 self.output_tokens.fetch_add(output, Ordering::SeqCst);
97 }
98
99 pub fn record_cost_usd(&self, cost_usd: f64) {
101 let cents = (cost_usd * 100.0) as u64;
102 self.cost_cents.fetch_add(cents, Ordering::SeqCst);
103 }
104
105 pub fn push_discovery_depth(&self) {
107 let new_depth = self.discovery_depth.fetch_add(1, Ordering::SeqCst) + 1;
108 let current_max = self.max_discovery_depth_reached.load(Ordering::SeqCst);
110 if new_depth > current_max {
111 self.max_discovery_depth_reached
112 .store(new_depth, Ordering::SeqCst);
113 }
114 }
115
116 pub fn pop_discovery_depth(&self) {
118 self.discovery_depth.fetch_sub(1, Ordering::SeqCst);
119 }
120
121 pub async fn touch(&self) {
123 let mut last = self.last_activity.write().await;
124 *last = Instant::now();
125 }
126
127 pub fn step_count(&self) -> u32 {
129 self.steps.load(Ordering::SeqCst)
130 }
131
132 pub fn discovered_step_count(&self) -> u32 {
134 self.discovered_steps.load(Ordering::SeqCst)
135 }
136
137 pub fn current_discovery_depth(&self) -> u32 {
139 self.discovery_depth.load(Ordering::SeqCst)
140 }
141
142 pub fn total_tokens(&self) -> u32 {
144 self.input_tokens.load(Ordering::SeqCst) + self.output_tokens.load(Ordering::SeqCst)
145 }
146
147 pub fn cost_usd(&self) -> f64 {
149 self.cost_cents.load(Ordering::SeqCst) as f64 / 100.0
150 }
151
152 pub fn wall_time(&self) -> Duration {
154 self.started_at.elapsed()
155 }
156
157 pub fn wall_time_ms(&self) -> u64 {
159 self.wall_time().as_millis() as u64
160 }
161
162 pub async fn idle_duration(&self) -> Duration {
164 let last = self.last_activity.read().await;
165 last.elapsed()
166 }
167
168 pub async fn idle_seconds(&self) -> u64 {
170 self.idle_duration().await.as_secs()
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct UsageSnapshot {
177 pub execution_id: String,
178 pub tenant_id: String,
179 pub steps: u32,
180 pub input_tokens: u32,
181 pub output_tokens: u32,
182 pub total_tokens: u32,
183 pub wall_time_ms: u64,
184 pub discovered_steps: u32,
186 pub discovery_depth: u32,
187 pub max_discovery_depth: u32,
188 pub cost_usd: f64,
189}
190
191impl From<&ExecutionUsage> for UsageSnapshot {
192 fn from(usage: &ExecutionUsage) -> Self {
193 let input = usage.input_tokens.load(Ordering::SeqCst);
194 let output = usage.output_tokens.load(Ordering::SeqCst);
195 Self {
196 execution_id: usage.execution_id.as_str().to_string(),
197 tenant_id: usage.tenant_id.as_str().to_string(),
198 steps: usage.steps.load(Ordering::SeqCst),
199 input_tokens: input,
200 output_tokens: output,
201 total_tokens: input + output,
202 wall_time_ms: usage.wall_time_ms(),
203 discovered_steps: usage.discovered_steps.load(Ordering::SeqCst),
204 discovery_depth: usage.discovery_depth.load(Ordering::SeqCst),
205 max_discovery_depth: usage.max_discovery_depth_reached.load(Ordering::SeqCst),
206 cost_usd: usage.cost_usd(),
207 }
208 }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq)]
217pub enum EnforcementResult {
218 Allowed,
220 Blocked(EnforcementViolation),
222 Warning(EnforcementWarning),
224}
225
226impl EnforcementResult {
227 pub fn is_allowed(&self) -> bool {
229 matches!(self, Self::Allowed | Self::Warning(_))
230 }
231
232 pub fn is_blocked(&self) -> bool {
234 matches!(self, Self::Blocked(_))
235 }
236
237 pub fn to_error(&self) -> Option<ExecutionError> {
239 match self {
240 Self::Blocked(violation) => Some(violation.to_error()),
241 _ => None,
242 }
243 }
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
248pub enum ViolationType {
249 StepLimit,
251 TokenLimit,
253 WallTimeLimit,
255 MemoryLimit,
257 ConcurrencyLimit,
259 RateLimit,
261 NetworkViolation,
263 DiscoveredStepLimit,
266 DiscoveryDepthLimit,
268 CostThreshold,
270 IdleTimeout,
272 SameStepLoop,
274}
275
276impl std::fmt::Display for ViolationType {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 match self {
279 Self::StepLimit => write!(f, "step_limit"),
280 Self::TokenLimit => write!(f, "token_limit"),
281 Self::WallTimeLimit => write!(f, "wall_time_limit"),
282 Self::MemoryLimit => write!(f, "memory_limit"),
283 Self::ConcurrencyLimit => write!(f, "concurrency_limit"),
284 Self::RateLimit => write!(f, "rate_limit"),
285 Self::NetworkViolation => write!(f, "network_violation"),
286 Self::DiscoveredStepLimit => write!(f, "discovered_step_limit"),
287 Self::DiscoveryDepthLimit => write!(f, "discovery_depth_limit"),
288 Self::CostThreshold => write!(f, "cost_threshold"),
289 Self::IdleTimeout => write!(f, "idle_timeout"),
290 Self::SameStepLoop => write!(f, "same_step_loop"),
291 }
292 }
293}
294
295#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
297pub struct EnforcementViolation {
298 pub violation_type: ViolationType,
300 pub current: u64,
302 pub limit: u64,
304 pub message: String,
306}
307
308impl EnforcementViolation {
309 pub fn new(violation_type: ViolationType, current: u64, limit: u64) -> Self {
311 let message = format!(
312 "{} exceeded: {} / {} ({}%)",
313 violation_type,
314 current,
315 limit,
316 (current as f64 / limit as f64 * 100.0) as u32
317 );
318 Self {
319 violation_type,
320 current,
321 limit,
322 message,
323 }
324 }
325
326 pub fn to_error(&self) -> ExecutionError {
328 let category = match self.violation_type {
329 ViolationType::WallTimeLimit => ExecutionErrorCategory::Timeout,
330 ViolationType::RateLimit => ExecutionErrorCategory::LlmError, ViolationType::NetworkViolation => ExecutionErrorCategory::PolicyViolation, _ => ExecutionErrorCategory::QuotaExceeded,
333 };
334
335 ExecutionError::new(category, self.message.clone())
336 .with_code(self.violation_type.to_string())
337 .with_details(serde_json::json!({
338 "current": self.current,
339 "limit": self.limit,
340 "violation_type": self.violation_type.to_string(),
341 }))
342 }
343}
344
345#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
347pub struct EnforcementWarning {
348 pub warning_type: ViolationType,
350 pub usage_percent: u32,
352 pub message: String,
354}
355
356impl EnforcementWarning {
357 pub fn new(warning_type: ViolationType, current: u64, limit: u64) -> Self {
359 let percent = (current as f64 / limit as f64 * 100.0) as u32;
360 let message = format!("{} at {}%: {} / {}", warning_type, percent, current, limit);
361 Self {
362 warning_type,
363 usage_percent: percent,
364 message,
365 }
366 }
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct EnforcementPolicy {
376 pub warning_threshold: u32,
378 pub emit_warning_events: bool,
380 pub emit_block_events: bool,
382 pub timeout_grace_ms: u64,
384}
385
386impl Default for EnforcementPolicy {
387 fn default() -> Self {
388 Self {
389 warning_threshold: 80, emit_warning_events: true,
391 emit_block_events: true,
392 timeout_grace_ms: 1000, }
394 }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct LongRunningExecutionPolicy {
407 pub max_discovered_steps: Option<u32>,
410 pub max_discovery_depth: Option<u32>,
413 pub cost_alert_threshold_usd: Option<f64>,
416 pub idle_timeout_seconds: Option<u64>,
418 pub max_same_step_repetitions: Option<u32>,
420}
421
422impl Default for LongRunningExecutionPolicy {
423 fn default() -> Self {
424 Self::standard()
425 }
426}
427
428impl LongRunningExecutionPolicy {
429 pub fn standard() -> Self {
435 Self {
436 max_discovered_steps: Some(50),
437 max_discovery_depth: Some(5),
438 cost_alert_threshold_usd: Some(5.0),
439 idle_timeout_seconds: Some(1800), max_same_step_repetitions: Some(3),
441 }
442 }
443
444 pub fn extended() -> Self {
450 Self {
451 max_discovered_steps: Some(300),
452 max_discovery_depth: Some(10),
453 cost_alert_threshold_usd: Some(50.0),
454 idle_timeout_seconds: Some(14400), max_same_step_repetitions: Some(5),
456 }
457 }
458
459 pub fn unlimited() -> Self {
464 Self {
465 max_discovered_steps: None,
466 max_discovery_depth: None,
467 cost_alert_threshold_usd: Some(100.0), idle_timeout_seconds: Some(86400), max_same_step_repetitions: None,
470 }
471 }
472
473 pub fn disabled() -> Self {
475 Self {
476 max_discovered_steps: None,
477 max_discovery_depth: None,
478 cost_alert_threshold_usd: None,
479 idle_timeout_seconds: None,
480 max_same_step_repetitions: None,
481 }
482 }
483}
484
485#[derive(Debug)]
491pub struct EnforcementMiddleware {
492 executions: RwLock<HashMap<ExecutionId, Arc<ExecutionUsage>>>,
494 tenant_executions: RwLock<HashMap<TenantId, AtomicU32>>,
496 #[allow(dead_code)]
498 rate_limiter: RwLock<RateLimiterState>,
499 policy: EnforcementPolicy,
501}
502
503impl EnforcementMiddleware {
504 pub fn new() -> Self {
506 Self::with_policy(EnforcementPolicy::default())
507 }
508
509 pub fn with_policy(policy: EnforcementPolicy) -> Self {
511 Self {
512 executions: RwLock::new(HashMap::new()),
513 tenant_executions: RwLock::new(HashMap::new()),
514 rate_limiter: RwLock::new(RateLimiterState::new()),
515 policy,
516 }
517 }
518
519 pub fn emit_warning_events_enabled(&self) -> bool {
521 self.policy.emit_warning_events
522 }
523
524 pub async fn register_execution(
526 &self,
527 execution_id: ExecutionId,
528 tenant_id: TenantId,
529 ) -> Arc<ExecutionUsage> {
530 let usage = Arc::new(ExecutionUsage::new(execution_id.clone(), tenant_id.clone()));
531
532 {
534 let mut executions = self.executions.write().await;
535 executions.insert(execution_id, Arc::clone(&usage));
536 }
537
538 {
540 let mut tenant_execs = self.tenant_executions.write().await;
541 tenant_execs
542 .entry(tenant_id)
543 .or_insert_with(|| AtomicU32::new(0))
544 .fetch_add(1, Ordering::SeqCst);
545 }
546
547 usage
548 }
549
550 pub async fn unregister_execution(&self, execution_id: &ExecutionId) {
552 let tenant_id = {
553 let mut executions = self.executions.write().await;
554 executions.remove(execution_id).map(|u| u.tenant_id.clone())
555 };
556
557 if let Some(tenant_id) = tenant_id {
559 let tenant_execs = self.tenant_executions.read().await;
560 if let Some(count) = tenant_execs.get(&tenant_id) {
561 count.fetch_sub(1, Ordering::SeqCst);
562 }
563 }
564 }
565
566 pub async fn get_usage(&self, execution_id: &ExecutionId) -> Option<Arc<ExecutionUsage>> {
568 let executions = self.executions.read().await;
569 executions.get(execution_id).cloned()
570 }
571
572 pub async fn get_usage_snapshot(&self, execution_id: &ExecutionId) -> Option<UsageSnapshot> {
574 self.get_usage(execution_id)
575 .await
576 .map(|u| UsageSnapshot::from(u.as_ref()))
577 }
578
579 pub async fn check_step_allowed(
581 &self,
582 execution_id: &ExecutionId,
583 limits: &ResourceLimits,
584 ) -> EnforcementResult {
585 let usage = match self.get_usage(execution_id).await {
586 Some(u) => u,
587 None => return EnforcementResult::Allowed, };
589
590 let current = usage.step_count() as u64 + 1; let limit = limits.max_steps as u64;
592
593 if current > limit {
594 return EnforcementResult::Blocked(EnforcementViolation::new(
595 ViolationType::StepLimit,
596 current,
597 limit,
598 ));
599 }
600
601 let percent = (current as f64 / limit as f64 * 100.0) as u32;
602 if percent >= self.policy.warning_threshold {
603 return EnforcementResult::Warning(EnforcementWarning::new(
604 ViolationType::StepLimit,
605 current,
606 limit,
607 ));
608 }
609
610 EnforcementResult::Allowed
611 }
612
613 pub async fn check_tokens_allowed(
615 &self,
616 execution_id: &ExecutionId,
617 limits: &ResourceLimits,
618 additional_tokens: u32,
619 ) -> EnforcementResult {
620 let usage = match self.get_usage(execution_id).await {
621 Some(u) => u,
622 None => return EnforcementResult::Allowed,
623 };
624
625 let current = usage.total_tokens() as u64 + additional_tokens as u64;
626 let limit = limits.max_tokens as u64;
627
628 if current > limit {
629 return EnforcementResult::Blocked(EnforcementViolation::new(
630 ViolationType::TokenLimit,
631 current,
632 limit,
633 ));
634 }
635
636 let percent = (current as f64 / limit as f64 * 100.0) as u32;
637 if percent >= self.policy.warning_threshold {
638 return EnforcementResult::Warning(EnforcementWarning::new(
639 ViolationType::TokenLimit,
640 current,
641 limit,
642 ));
643 }
644
645 EnforcementResult::Allowed
646 }
647
648 pub async fn check_wall_time_allowed(
650 &self,
651 execution_id: &ExecutionId,
652 limits: &ResourceLimits,
653 ) -> EnforcementResult {
654 let usage = match self.get_usage(execution_id).await {
655 Some(u) => u,
656 None => return EnforcementResult::Allowed,
657 };
658
659 let current = usage.wall_time_ms();
660 let limit = limits.max_wall_time_ms;
661
662 let effective_limit = limit + self.policy.timeout_grace_ms;
664
665 if current > effective_limit {
666 return EnforcementResult::Blocked(EnforcementViolation::new(
667 ViolationType::WallTimeLimit,
668 current,
669 limit,
670 ));
671 }
672
673 let percent = (current as f64 / limit as f64 * 100.0) as u32;
674 if percent >= self.policy.warning_threshold {
675 return EnforcementResult::Warning(EnforcementWarning::new(
676 ViolationType::WallTimeLimit,
677 current,
678 limit,
679 ));
680 }
681
682 EnforcementResult::Allowed
683 }
684
685 pub async fn check_concurrency_allowed(
687 &self,
688 tenant_id: &TenantId,
689 limits: &ResourceLimits,
690 ) -> EnforcementResult {
691 let max_concurrent = match limits.max_concurrent_executions {
692 Some(max) => max,
693 None => return EnforcementResult::Allowed, };
695
696 let current = {
697 let tenant_execs = self.tenant_executions.read().await;
698 tenant_execs
699 .get(tenant_id)
700 .map(|c| c.load(Ordering::SeqCst))
701 .unwrap_or(0) as u64
702 };
703
704 let limit = max_concurrent as u64;
705
706 if current >= limit {
707 return EnforcementResult::Blocked(EnforcementViolation::new(
708 ViolationType::ConcurrencyLimit,
709 current + 1, limit,
711 ));
712 }
713
714 EnforcementResult::Allowed
715 }
716
717 pub async fn check_all_limits(
719 &self,
720 execution_id: &ExecutionId,
721 limits: &ResourceLimits,
722 ) -> EnforcementResult {
723 let wall_check = self.check_wall_time_allowed(execution_id, limits).await;
725 if wall_check.is_blocked() {
726 return wall_check;
727 }
728
729 let step_check = self.check_step_allowed(execution_id, limits).await;
731 if step_check.is_blocked() {
732 return step_check;
733 }
734
735 let token_check = self.check_tokens_allowed(execution_id, limits, 0).await;
737 if token_check.is_blocked() {
738 return token_check;
739 }
740
741 if let EnforcementResult::Warning(w) = wall_check {
743 return EnforcementResult::Warning(w);
744 }
745 if let EnforcementResult::Warning(w) = step_check {
746 return EnforcementResult::Warning(w);
747 }
748 if let EnforcementResult::Warning(w) = token_check {
749 return EnforcementResult::Warning(w);
750 }
751
752 EnforcementResult::Allowed
753 }
754
755 pub async fn record_step(&self, execution_id: &ExecutionId) {
757 if let Some(usage) = self.get_usage(execution_id).await {
758 usage.record_step();
759 usage.touch().await;
760 }
761 }
762
763 pub async fn record_tokens(&self, execution_id: &ExecutionId, input: u32, output: u32) {
765 if let Some(usage) = self.get_usage(execution_id).await {
766 usage.record_tokens(input, output);
767 usage.touch().await;
768 }
769 }
770
771 pub async fn record_discovered_step(&self, execution_id: &ExecutionId) {
773 if let Some(usage) = self.get_usage(execution_id).await {
774 usage.record_discovered_step();
775 usage.touch().await;
776 }
777 }
778
779 pub async fn record_cost(&self, execution_id: &ExecutionId, cost_usd: f64) {
781 if let Some(usage) = self.get_usage(execution_id).await {
782 usage.record_cost_usd(cost_usd);
783 usage.touch().await;
784 }
785 }
786
787 pub async fn push_discovery_depth(&self, execution_id: &ExecutionId) {
789 if let Some(usage) = self.get_usage(execution_id).await {
790 usage.push_discovery_depth();
791 }
792 }
793
794 pub async fn pop_discovery_depth(&self, execution_id: &ExecutionId) {
796 if let Some(usage) = self.get_usage(execution_id).await {
797 usage.pop_discovery_depth();
798 }
799 }
800
801 pub async fn check_discovered_step_limit(
807 &self,
808 execution_id: &ExecutionId,
809 policy: &LongRunningExecutionPolicy,
810 ) -> EnforcementResult {
811 let max_discovered = match policy.max_discovered_steps {
812 Some(max) => max,
813 None => return EnforcementResult::Allowed,
814 };
815
816 let usage = match self.get_usage(execution_id).await {
817 Some(u) => u,
818 None => return EnforcementResult::Allowed,
819 };
820
821 let current = usage.discovered_step_count() as u64 + 1; let limit = max_discovered as u64;
823
824 if current > limit {
825 return EnforcementResult::Blocked(EnforcementViolation::new(
826 ViolationType::DiscoveredStepLimit,
827 current,
828 limit,
829 ));
830 }
831
832 let percent = (current as f64 / limit as f64 * 100.0) as u32;
833 if percent >= self.policy.warning_threshold {
834 return EnforcementResult::Warning(EnforcementWarning::new(
835 ViolationType::DiscoveredStepLimit,
836 current,
837 limit,
838 ));
839 }
840
841 EnforcementResult::Allowed
842 }
843
844 pub async fn check_discovery_depth_limit(
846 &self,
847 execution_id: &ExecutionId,
848 policy: &LongRunningExecutionPolicy,
849 ) -> EnforcementResult {
850 let max_depth = match policy.max_discovery_depth {
851 Some(max) => max,
852 None => return EnforcementResult::Allowed,
853 };
854
855 let usage = match self.get_usage(execution_id).await {
856 Some(u) => u,
857 None => return EnforcementResult::Allowed,
858 };
859
860 let current = usage.current_discovery_depth() as u64 + 1; let limit = max_depth as u64;
862
863 if current > limit {
864 return EnforcementResult::Blocked(EnforcementViolation::new(
865 ViolationType::DiscoveryDepthLimit,
866 current,
867 limit,
868 ));
869 }
870
871 EnforcementResult::Allowed
873 }
874
875 pub async fn check_cost_threshold(
877 &self,
878 execution_id: &ExecutionId,
879 policy: &LongRunningExecutionPolicy,
880 ) -> EnforcementResult {
881 let threshold = match policy.cost_alert_threshold_usd {
882 Some(t) => t,
883 None => return EnforcementResult::Allowed,
884 };
885
886 let usage = match self.get_usage(execution_id).await {
887 Some(u) => u,
888 None => return EnforcementResult::Allowed,
889 };
890
891 let current_cents = usage.cost_cents.load(Ordering::SeqCst);
892 let current_usd = current_cents as f64 / 100.0;
893 let limit_cents = (threshold * 100.0) as u64;
894
895 if current_usd >= threshold {
896 return EnforcementResult::Blocked(EnforcementViolation::new(
897 ViolationType::CostThreshold,
898 current_cents,
899 limit_cents,
900 ));
901 }
902
903 let percent = (current_usd / threshold * 100.0) as u32;
904 if percent >= self.policy.warning_threshold {
905 return EnforcementResult::Warning(EnforcementWarning::new(
906 ViolationType::CostThreshold,
907 current_cents,
908 limit_cents,
909 ));
910 }
911
912 EnforcementResult::Allowed
913 }
914
915 pub async fn check_idle_timeout(
917 &self,
918 execution_id: &ExecutionId,
919 policy: &LongRunningExecutionPolicy,
920 ) -> EnforcementResult {
921 let timeout_secs = match policy.idle_timeout_seconds {
922 Some(t) => t,
923 None => return EnforcementResult::Allowed,
924 };
925
926 let usage = match self.get_usage(execution_id).await {
927 Some(u) => u,
928 None => return EnforcementResult::Allowed,
929 };
930
931 let idle_secs = usage.idle_seconds().await;
932
933 if idle_secs >= timeout_secs {
934 return EnforcementResult::Blocked(EnforcementViolation::new(
935 ViolationType::IdleTimeout,
936 idle_secs,
937 timeout_secs,
938 ));
939 }
940
941 let percent = (idle_secs as f64 / timeout_secs as f64 * 100.0) as u32;
943 if percent >= self.policy.warning_threshold {
944 return EnforcementResult::Warning(EnforcementWarning::new(
945 ViolationType::IdleTimeout,
946 idle_secs,
947 timeout_secs,
948 ));
949 }
950
951 EnforcementResult::Allowed
952 }
953
954 pub async fn check_long_running_limits(
956 &self,
957 execution_id: &ExecutionId,
958 policy: &LongRunningExecutionPolicy,
959 ) -> EnforcementResult {
960 let cost_check = self.check_cost_threshold(execution_id, policy).await;
962 if cost_check.is_blocked() {
963 return cost_check;
964 }
965
966 let depth_check = self.check_discovery_depth_limit(execution_id, policy).await;
968 if depth_check.is_blocked() {
969 return depth_check;
970 }
971
972 let discovered_check = self.check_discovered_step_limit(execution_id, policy).await;
974 if discovered_check.is_blocked() {
975 return discovered_check;
976 }
977
978 let idle_check = self.check_idle_timeout(execution_id, policy).await;
980 if idle_check.is_blocked() {
981 return idle_check;
982 }
983
984 if let EnforcementResult::Warning(w) = cost_check {
986 return EnforcementResult::Warning(w);
987 }
988 if let EnforcementResult::Warning(w) = discovered_check {
989 return EnforcementResult::Warning(w);
990 }
991 if let EnforcementResult::Warning(w) = idle_check {
992 return EnforcementResult::Warning(w);
993 }
994
995 EnforcementResult::Allowed
996 }
997}
998
999impl Default for EnforcementMiddleware {
1000 fn default() -> Self {
1001 Self::new()
1002 }
1003}
1004
1005#[derive(Debug)]
1011struct RateLimiterState {
1012 #[allow(dead_code)]
1014 provider_tokens: HashMap<String, TokenBucket>,
1015}
1016
1017impl RateLimiterState {
1018 fn new() -> Self {
1019 Self {
1020 provider_tokens: HashMap::new(),
1021 }
1022 }
1023}
1024
1025#[derive(Debug)]
1027struct TokenBucket {
1028 tokens: AtomicU64,
1030 max_tokens: u64,
1032 refill_rate: u64,
1034 last_refill: RwLock<Instant>,
1036}
1037
1038impl TokenBucket {
1039 #[allow(dead_code)]
1041 fn new(max_tokens: u64, refill_rate: u64) -> Self {
1042 Self {
1043 tokens: AtomicU64::new(max_tokens),
1044 max_tokens,
1045 refill_rate,
1046 last_refill: RwLock::new(Instant::now()),
1047 }
1048 }
1049
1050 #[allow(dead_code)]
1052 async fn try_acquire(&self, count: u64) -> bool {
1053 {
1055 let mut last = self.last_refill.write().await;
1056 let elapsed = last.elapsed();
1057 let new_tokens = (elapsed.as_secs_f64() * self.refill_rate as f64) as u64;
1058 if new_tokens > 0 {
1059 let current = self.tokens.load(Ordering::SeqCst);
1060 let new_total = std::cmp::min(current + new_tokens, self.max_tokens);
1061 self.tokens.store(new_total, Ordering::SeqCst);
1062 *last = Instant::now();
1063 }
1064 }
1065
1066 let current = self.tokens.load(Ordering::SeqCst);
1068 if current >= count {
1069 self.tokens.fetch_sub(count, Ordering::SeqCst);
1070 true
1071 } else {
1072 false
1073 }
1074 }
1075}
1076
1077pub struct StepTimeoutGuard {
1083 step_id: StepId,
1084 timeout: Duration,
1085 started_at: Instant,
1086}
1087
1088impl StepTimeoutGuard {
1089 pub fn new(step_id: StepId, timeout: Duration) -> Self {
1091 Self {
1092 step_id,
1093 timeout,
1094 started_at: Instant::now(),
1095 }
1096 }
1097
1098 pub fn is_timed_out(&self) -> bool {
1100 self.started_at.elapsed() > self.timeout
1101 }
1102
1103 pub fn remaining(&self) -> Duration {
1105 self.timeout.saturating_sub(self.started_at.elapsed())
1106 }
1107
1108 pub fn elapsed(&self) -> Duration {
1110 self.started_at.elapsed()
1111 }
1112
1113 #[allow(clippy::result_large_err)]
1115 pub fn check(&self) -> Result<(), ExecutionError> {
1116 if self.is_timed_out() {
1117 Err(ExecutionError::timeout(format!(
1118 "Step {} timed out after {:?}",
1119 self.step_id, self.timeout
1120 ))
1121 .with_step_id(self.step_id.clone()))
1122 } else {
1123 Ok(())
1124 }
1125 }
1126}
1127
1128#[cfg(test)]
1133mod tests {
1134 use super::*;
1135
1136 #[tokio::test]
1137 async fn test_usage_tracking() {
1138 let exec_id = ExecutionId::new();
1139 let tenant_id = TenantId::from("tenant_test123456789012345");
1140 let usage = ExecutionUsage::new(exec_id, tenant_id);
1141
1142 usage.record_step();
1143 usage.record_step();
1144 assert_eq!(usage.step_count(), 2);
1145
1146 usage.record_tokens(100, 50);
1147 assert_eq!(usage.total_tokens(), 150);
1148 }
1149
1150 #[tokio::test]
1151 async fn test_step_limit_enforcement() {
1152 let middleware = EnforcementMiddleware::new();
1153 let exec_id = ExecutionId::new();
1154 let tenant_id = TenantId::from("tenant_test123456789012345");
1155
1156 let limits = ResourceLimits {
1157 max_steps: 5,
1158 max_tokens: 1000,
1159 max_wall_time_ms: 60000,
1160 max_memory_mb: None,
1161 max_concurrent_executions: None,
1162 };
1163
1164 let usage = middleware
1165 .register_execution(exec_id.clone(), tenant_id)
1166 .await;
1167
1168 for _ in 0..5 {
1170 let result = middleware.check_step_allowed(&exec_id, &limits).await;
1171 assert!(result.is_allowed(), "Step should be allowed");
1172 usage.record_step();
1173 }
1174
1175 let result = middleware.check_step_allowed(&exec_id, &limits).await;
1177 assert!(result.is_blocked(), "Step should be blocked");
1178 }
1179
1180 #[tokio::test]
1181 async fn test_token_limit_enforcement() {
1182 let middleware = EnforcementMiddleware::new();
1183 let exec_id = ExecutionId::new();
1184 let tenant_id = TenantId::from("tenant_test123456789012345");
1185
1186 let limits = ResourceLimits {
1187 max_steps: 100,
1188 max_tokens: 100,
1189 max_wall_time_ms: 60000,
1190 max_memory_mb: None,
1191 max_concurrent_executions: None,
1192 };
1193
1194 let usage = middleware
1195 .register_execution(exec_id.clone(), tenant_id)
1196 .await;
1197
1198 usage.record_tokens(50, 30);
1200
1201 let result = middleware.check_tokens_allowed(&exec_id, &limits, 25).await;
1203 assert!(
1204 result.is_blocked(),
1205 "Should be blocked when exceeding limit"
1206 );
1207
1208 let result = middleware.check_tokens_allowed(&exec_id, &limits, 10).await;
1210 assert!(result.is_allowed(), "Should be allowed within limit");
1211 }
1212
1213 #[tokio::test]
1214 async fn test_warning_threshold() {
1215 let policy = EnforcementPolicy {
1216 warning_threshold: 80,
1217 ..Default::default()
1218 };
1219 let middleware = EnforcementMiddleware::with_policy(policy);
1220 let exec_id = ExecutionId::new();
1221 let tenant_id = TenantId::from("tenant_test123456789012345");
1222
1223 let limits = ResourceLimits {
1224 max_steps: 10,
1225 max_tokens: 1000,
1226 max_wall_time_ms: 60000,
1227 max_memory_mb: None,
1228 max_concurrent_executions: None,
1229 };
1230
1231 let usage = middleware
1232 .register_execution(exec_id.clone(), tenant_id)
1233 .await;
1234
1235 for _ in 0..7 {
1237 usage.record_step();
1238 }
1239
1240 let result = middleware.check_step_allowed(&exec_id, &limits).await;
1242 assert!(matches!(result, EnforcementResult::Warning(_)));
1243 }
1244
1245 #[test]
1246 fn test_step_timeout_guard() {
1247 let step_id = StepId::new();
1248 let guard = StepTimeoutGuard::new(step_id, Duration::from_millis(100));
1249
1250 assert!(!guard.is_timed_out());
1251 assert!(guard.check().is_ok());
1252
1253 std::thread::sleep(Duration::from_millis(150));
1255
1256 assert!(guard.is_timed_out());
1257 assert!(guard.check().is_err());
1258 }
1259
1260 #[tokio::test]
1261 async fn test_concurrency_limit() {
1262 let middleware = EnforcementMiddleware::new();
1263 let tenant_id = TenantId::from("tenant_test123456789012345");
1264
1265 let limits = ResourceLimits {
1266 max_steps: 100,
1267 max_tokens: 1000,
1268 max_wall_time_ms: 60000,
1269 max_memory_mb: None,
1270 max_concurrent_executions: Some(2),
1271 };
1272
1273 let exec1 = ExecutionId::new();
1275 let exec2 = ExecutionId::new();
1276 middleware
1277 .register_execution(exec1.clone(), tenant_id.clone())
1278 .await;
1279 middleware
1280 .register_execution(exec2.clone(), tenant_id.clone())
1281 .await;
1282
1283 let result = middleware
1285 .check_concurrency_allowed(&tenant_id, &limits)
1286 .await;
1287 assert!(result.is_blocked());
1288
1289 middleware.unregister_execution(&exec1).await;
1291
1292 let result = middleware
1294 .check_concurrency_allowed(&tenant_id, &limits)
1295 .await;
1296 assert!(result.is_allowed());
1297 }
1298
1299 #[test]
1300 fn test_network_violation_type() {
1301 let violation = EnforcementViolation::new(ViolationType::NetworkViolation, 0, 0);
1303
1304 let error = violation.to_error();
1305 assert_eq!(error.category, ExecutionErrorCategory::PolicyViolation);
1306 assert!(!error.is_retryable());
1307 assert!(error.is_fatal());
1308 }
1309
1310 #[test]
1311 fn test_violation_type_display_network() {
1312 assert_eq!(
1313 format!("{}", ViolationType::NetworkViolation),
1314 "network_violation"
1315 );
1316 }
1317}