riglr_core/util/
rate_limit_strategy.rs1use crate::ToolError;
7use std::time::{Duration, Instant};
8
9pub trait RateLimitStrategy: Send + Sync {
14 fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError>;
23
24 fn reset_client(&self, client_id: &str);
26
27 fn clear_all(&self);
29
30 fn get_request_count(&self, client_id: &str) -> usize;
32
33 fn strategy_name(&self) -> &str;
35}
36
37#[derive(Debug, Clone)]
39pub struct ClientRateInfo {
40 pub request_times: Vec<Instant>,
42 pub burst_tokens: f64,
44 pub last_refill: Instant,
46}
47
48impl ClientRateInfo {
49 pub fn new(initial_tokens: f64) -> Self {
51 Self {
52 request_times: Vec::new(),
53 burst_tokens: initial_tokens,
54 last_refill: Instant::now(),
55 }
56 }
57}
58
59#[derive(Debug)]
64pub struct FixedWindowStrategy {
65 pub max_requests: usize,
67 pub window_duration: Duration,
69 pub clients: dashmap::DashMap<String, FixedWindowClientInfo>,
71}
72
73#[derive(Debug, Clone)]
75pub struct FixedWindowClientInfo {
76 pub window_start: Instant,
78 pub request_count: usize,
80}
81
82impl RateLimitStrategy for FixedWindowStrategy {
83 fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
84 let now = Instant::now();
85 let mut entry = self
86 .clients
87 .entry(client_id.to_string())
88 .or_insert_with(|| FixedWindowClientInfo {
89 window_start: now,
90 request_count: 0,
91 });
92
93 if now.duration_since(entry.window_start) >= self.window_duration {
95 entry.window_start = now;
97 entry.request_count = 0;
98 }
99
100 if entry.request_count >= self.max_requests {
102 let time_until_reset = self
103 .window_duration
104 .saturating_sub(now.duration_since(entry.window_start));
105
106 return Err(ToolError::RateLimited {
107 source: None,
108 source_message: format!(
109 "Fixed window rate limit: {} requests per {:?}",
110 self.max_requests, self.window_duration
111 ),
112 context: format!("Exceeded {} requests in current window", self.max_requests),
113 retry_after: Some(time_until_reset),
114 });
115 }
116
117 entry.request_count += 1;
118 Ok(())
119 }
120
121 fn reset_client(&self, client_id: &str) {
122 self.clients.remove(client_id);
123 }
124
125 fn clear_all(&self) {
126 self.clients.clear();
127 }
128
129 fn get_request_count(&self, client_id: &str) -> usize {
130 self.clients
131 .get(client_id)
132 .map(|entry| entry.request_count)
133 .unwrap_or(0)
134 }
135
136 fn strategy_name(&self) -> &str {
137 "FixedWindow"
138 }
139}
140
141impl FixedWindowStrategy {
142 pub fn new(max_requests: usize, window_duration: Duration) -> Self {
144 Self {
145 max_requests,
146 window_duration,
147 clients: dashmap::DashMap::new(),
148 }
149 }
150}