deepstrike_core/governance/
rate_limit.rs1use std::collections::{HashMap, VecDeque};
2
3use compact_str::CompactString;
4
5use crate::types::message::ToolCall;
6use crate::types::policy::GovernanceVerdict;
7
8#[derive(Debug, Clone)]
10pub struct RateLimit {
11 pub max_calls: u32,
12 pub window_ms: u64,
13}
14
15impl Default for RateLimit {
16 fn default() -> Self {
17 Self {
18 max_calls: 60,
19 window_ms: 60_000,
20 }
21 }
22}
23
24pub struct RateLimiter {
26 windows: HashMap<CompactString, VecDeque<u64>>,
27 limits: HashMap<CompactString, RateLimit>,
28 default_limit: RateLimit,
29 current_time_ms: u64,
31}
32
33impl RateLimiter {
34 pub fn new(default_limit: RateLimit) -> Self {
35 Self {
36 windows: HashMap::new(),
37 limits: HashMap::new(),
38 default_limit,
39 current_time_ms: 0,
40 }
41 }
42
43 pub fn set_limit(&mut self, tool_name: impl Into<CompactString>, limit: RateLimit) {
44 self.limits.insert(tool_name.into(), limit);
45 }
46
47 pub fn limit_count(&self) -> usize {
48 self.limits.len()
49 }
50
51 pub fn set_time(&mut self, now_ms: u64) {
53 self.current_time_ms = now_ms;
54 }
55
56 pub fn check(&mut self, call: &ToolCall) -> Option<GovernanceVerdict> {
57 let limit = self.limits.get(&call.name).unwrap_or(&self.default_limit);
60 let window = self.windows.entry(call.name.clone()).or_default();
61
62 let cutoff = self.current_time_ms.saturating_sub(limit.window_ms);
64 while window.front().is_some_and(|&t| t < cutoff) {
65 window.pop_front();
66 }
67
68 if window.len() as u32 >= limit.max_calls {
69 let oldest = window.front().copied().unwrap_or(self.current_time_ms);
70 let retry_after = oldest + limit.window_ms - self.current_time_ms;
71 return Some(GovernanceVerdict::RateLimited {
72 retry_after_ms: retry_after,
73 });
74 }
75
76 window.push_back(self.current_time_ms);
77 None
78 }
79}
80
81impl Default for RateLimiter {
82 fn default() -> Self {
83 Self::new(RateLimit::default())
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 fn make_call(name: &str) -> ToolCall {
92 ToolCall {
93 id: CompactString::new("c"),
94 name: CompactString::new(name),
95 arguments: serde_json::Value::Null,
96 }
97 }
98
99 #[test]
100 fn allows_within_limit() {
101 let mut rl = RateLimiter::new(RateLimit {
102 max_calls: 3,
103 window_ms: 1000,
104 });
105 rl.set_time(100);
106 assert!(rl.check(&make_call("foo")).is_none());
107 assert!(rl.check(&make_call("foo")).is_none());
108 assert!(rl.check(&make_call("foo")).is_none());
109 assert!(rl.check(&make_call("foo")).is_some());
111 }
112
113 #[test]
114 fn expires_old_entries() {
115 let mut rl = RateLimiter::new(RateLimit {
116 max_calls: 1,
117 window_ms: 100,
118 });
119 rl.set_time(0);
120 assert!(rl.check(&make_call("bar")).is_none());
121 assert!(rl.check(&make_call("bar")).is_some());
122
123 rl.set_time(200); assert!(rl.check(&make_call("bar")).is_none());
125 }
126}