Skip to main content

deepstrike_core/governance/
rate_limit.rs

1use std::collections::{HashMap, VecDeque};
2
3use compact_str::CompactString;
4
5use crate::types::message::ToolCall;
6use crate::types::policy::GovernanceVerdict;
7
8/// Rate limit configuration for a tool.
9#[derive(Debug, Clone)]
10pub struct RateLimit {
11    pub max_calls: u32,
12    pub window_ms: u64,
13}
14
15impl Default for RateLimit {
16    fn default() -> Self {
17        Self {
18            max_calls: 60,
19            window_ms: 60_000,
20        }
21    }
22}
23
24/// Sliding-window rate limiter per tool.
25pub struct RateLimiter {
26    windows: HashMap<CompactString, VecDeque<u64>>,
27    limits: HashMap<CompactString, RateLimit>,
28    default_limit: RateLimit,
29    /// Current timestamp in ms — injected by SDK layer (no I/O in kernel).
30    current_time_ms: u64,
31}
32
33impl RateLimiter {
34    pub fn new(default_limit: RateLimit) -> Self {
35        Self {
36            windows: HashMap::new(),
37            limits: HashMap::new(),
38            default_limit,
39            current_time_ms: 0,
40        }
41    }
42
43    pub fn set_limit(&mut self, tool_name: impl Into<CompactString>, limit: RateLimit) {
44        self.limits.insert(tool_name.into(), limit);
45    }
46
47    pub fn limit_count(&self) -> usize {
48        self.limits.len()
49    }
50
51    /// Must be called before each check to provide current time.
52    pub fn set_time(&mut self, now_ms: u64) {
53        self.current_time_ms = now_ms;
54    }
55
56    pub fn check(&mut self, call: &ToolCall) -> Option<GovernanceVerdict> {
57        // current_time_ms defaults to 0; SDK is expected to call set_time() before check.
58        // We don't debug_assert here because 0 is a valid monotonic-clock origin.
59        let limit = self.limits.get(&call.name).unwrap_or(&self.default_limit);
60        let window = self.windows.entry(call.name.clone()).or_default();
61
62        // Evict expired entries
63        let cutoff = self.current_time_ms.saturating_sub(limit.window_ms);
64        while window.front().is_some_and(|&t| t < cutoff) {
65            window.pop_front();
66        }
67
68        if window.len() as u32 >= limit.max_calls {
69            let oldest = window.front().copied().unwrap_or(self.current_time_ms);
70            let retry_after = oldest + limit.window_ms - self.current_time_ms;
71            return Some(GovernanceVerdict::RateLimited {
72                retry_after_ms: retry_after,
73            });
74        }
75
76        window.push_back(self.current_time_ms);
77        None
78    }
79}
80
81impl Default for RateLimiter {
82    fn default() -> Self {
83        Self::new(RateLimit::default())
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    fn make_call(name: &str) -> ToolCall {
92        ToolCall {
93            id: CompactString::new("c"),
94            name: CompactString::new(name),
95            arguments: serde_json::Value::Null,
96        }
97    }
98
99    #[test]
100    fn allows_within_limit() {
101        let mut rl = RateLimiter::new(RateLimit {
102            max_calls: 3,
103            window_ms: 1000,
104        });
105        rl.set_time(100);
106        assert!(rl.check(&make_call("foo")).is_none());
107        assert!(rl.check(&make_call("foo")).is_none());
108        assert!(rl.check(&make_call("foo")).is_none());
109        // 4th call should be limited
110        assert!(rl.check(&make_call("foo")).is_some());
111    }
112
113    #[test]
114    fn expires_old_entries() {
115        let mut rl = RateLimiter::new(RateLimit {
116            max_calls: 1,
117            window_ms: 100,
118        });
119        rl.set_time(0);
120        assert!(rl.check(&make_call("bar")).is_none());
121        assert!(rl.check(&make_call("bar")).is_some());
122
123        rl.set_time(200); // window expired
124        assert!(rl.check(&make_call("bar")).is_none());
125    }
126}